Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.common.primitives.Longs;
import com.google.crypto.tink.subtle.*;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.util.ReferenceCounted;
Expand Down Expand Up @@ -82,25 +83,16 @@ public void addToChannel(Channel ch) throws GeneralSecurityException {

@VisibleForTesting
class EncryptionHandler extends ChannelOutboundHandlerAdapter {
private final ByteBuffer plaintextBuffer;
private final ByteBuffer ciphertextBuffer;
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;

EncryptionHandler() throws InvalidAlgorithmParameterException {
aesGcmHkdfStreaming = getAesGcmHkdfStreaming();
plaintextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
}

@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage(
aesGcmHkdfStreaming,
msg,
plaintextBuffer,
ciphertextBuffer);
ctx.write(encryptedMessage, promise);
ctx.write(new GcmEncryptedMessage(aesGcmHkdfStreaming, msg), promise);
}
}

Expand All @@ -116,23 +108,32 @@ static class GcmEncryptedMessage extends AbstractFileRegion {
private final long encryptedCount;

GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming,
Object plaintextMessage,
ByteBuffer plaintextBuffer,
ByteBuffer ciphertextBuffer) throws GeneralSecurityException {
Object plaintextMessage) throws GeneralSecurityException {
JavaUtils.checkArgument(
plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion,
"Unrecognized message type: %s", plaintextMessage.getClass().getName());
this.plaintextMessage = plaintextMessage;
this.plaintextBuffer = plaintextBuffer;
this.ciphertextBuffer = ciphertextBuffer;
this.plaintextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
this.ciphertextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
// If the ciphertext buffer cannot be fully written the target, transferTo may
// return with it containing some unwritten data. The initial call we'll explicitly
// set its limit to 0 to indicate the first call to transferTo.
this.ciphertextBuffer.limit(0);

this.bytesToRead = getReadableBytes();
this.encryptedCount =
LENGTH_HEADER_BYTES + aesGcmHkdfStreaming.expectedCiphertextSize(bytesToRead);
// expectedCiphertextSize(P) adds getCiphertextOffset() when counting segments
// (fullSegments = (P + getCiphertextOffset()) / plaintextSegmentSize), assuming
// the Tink header fills part of segment 0. Because transferTo() writes the header
// separately via headerByteBuffer, all P bytes occupy full-capacity segments;
// subtracting getCiphertextOffset() cancels that internal addition and gives the
// correct count. The function excludes the header from its result, so
// getHeaderLength() adds it back for the bytes headerByteBuffer actually writes.
this.encryptedCount = LENGTH_HEADER_BYTES
+ aesGcmHkdfStreaming.getHeaderLength()
+ aesGcmHkdfStreaming.expectedCiphertextSize(
bytesToRead - aesGcmHkdfStreaming.getCiphertextOffset());
byte[] lengthAad = Longs.toByteArray(encryptedCount);
this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(lengthAad);
this.headerByteBuffer = createHeaderByteBuffer();
Expand Down Expand Up @@ -289,27 +290,67 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter {
private final ByteBuffer headerBuffer;
private final ByteBuffer ciphertextBuffer;
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
private final StreamSegmentDecrypter decrypter;
private StreamSegmentDecrypter decrypter;
private final int headerLength;
private final int plaintextSegmentSize;
private boolean decrypterInit = false;
private boolean completed = false;
private int segmentNumber = 0;
private long expectedLength = -1;
private long ciphertextRead = 0;
// Accumulates all decrypted segments for the current GCM message. Each segment is
// appended as a zero-copy component via addComponent(true, segment). A single
// ctx.fireChannelRead() fires when the message is complete, reducing N EventLoop
// callbacks (one per 32 KB segment) to 1. This prevents the EventLoop from being
// monopolised by large messages, which would starve other channels sharing the
// thread (including the executor–driver heartbeat channel) under concurrent shuffle
// load. Null between messages; ownership is transferred to downstream on
// fireChannelRead().
private CompositeByteBuf plaintextAccumulator = null;

DecryptionHandler() throws GeneralSecurityException {
aesGcmHkdfStreaming = getAesGcmHkdfStreaming();
headerLength = aesGcmHkdfStreaming.getHeaderLength();
expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES);
headerBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength());
headerBuffer = ByteBuffer.allocate(headerLength);
ciphertextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter();
plaintextSegmentSize = aesGcmHkdfStreaming.getPlaintextSegmentSize();
}

private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) {
/**
* Resets all per-message state so that the next incoming GCM message can be
* decoded through the same channel handler instance. This must be called after
* every successfully completed message because AesGcmHkdfStreaming is a one-shot
* streaming primitive: each encrypted message carries its own random IV and must
* be decrypted with a fresh StreamSegmentDecrypter.
*/
private void resetForNextMessage() throws GeneralSecurityException {
expectedLength = -1;
expectedLengthBuffer.clear();
headerBuffer.clear();
ciphertextBuffer.clear();
decrypterInit = false;
completed = false;
segmentNumber = 0;
ciphertextRead = 0;
decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter();
plaintextAccumulator = null; // defensive; should already be null after fireChannelRead
}

private boolean initializeExpectedLength(ByteBuf ciphertextNettyBuf) {
if (expectedLength < 0) {
ciphertextNettyBuf.readBytes(expectedLengthBuffer);
// ByteBuf.readBytes(ByteBuffer) throws if fewer than dst.remaining() bytes
// are available, so temporarily narrow the limit to what is actually present.
int toRead = Math.min(ciphertextNettyBuf.readableBytes(),
expectedLengthBuffer.remaining());
if (toRead > 0) {
int savedLimit = expectedLengthBuffer.limit();
expectedLengthBuffer.limit(expectedLengthBuffer.position() + toRead);
ciphertextNettyBuf.readBytes(expectedLengthBuffer);
expectedLengthBuffer.limit(savedLimit);
}
if (expectedLengthBuffer.hasRemaining()) {
// We did not read enough bytes to initialize the expected length.
return false;
Expand All @@ -324,12 +365,22 @@ private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) {
return true;
}

private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf)
private boolean initializeDecrypter(ByteBuf ciphertextNettyBuf)
throws GeneralSecurityException {
// Check if the ciphertext header has been read. This contains
// the IV and other internal metadata.
if (!decrypterInit) {
ciphertextNettyBuf.readBytes(headerBuffer);
// ByteBuf.readBytes(ByteBuffer) throws if fewer than dst.remaining() bytes
// are available. Under TCP fragmentation the header can arrive in multiple
// chunks, so temporarily narrow the limit to what is actually present.
int toRead = Math.min(ciphertextNettyBuf.readableBytes(),
headerBuffer.remaining());
if (toRead > 0) {
int savedLimit = headerBuffer.limit();
headerBuffer.limit(headerBuffer.position() + toRead);
ciphertextNettyBuf.readBytes(headerBuffer);
headerBuffer.limit(savedLimit);
}
if (headerBuffer.hasRemaining()) {
// We did not read enough bytes to initialize the header.
return false;
Expand All @@ -338,7 +389,7 @@ private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf)
byte[] lengthAad = Longs.toByteArray(expectedLength);
decrypter.init(headerBuffer, lengthAad);
decrypterInit = true;
ciphertextRead += aesGcmHkdfStreaming.getHeaderLength();
ciphertextRead += headerLength;
if (expectedLength == ciphertextRead) {
// If the expected length is just the header, the ciphertext is 0 length.
completed = true;
Expand All @@ -354,60 +405,116 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
"Unrecognized message type: %s",
ciphertextMessage.getClass().getName());
ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage;
// The format of the output is:
// The format of each message is:
// [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
//
// A single channelRead() call may deliver bytes from multiple back-to-back
// GCM messages (common under shuffle load when TCP coalesces writes). The
// outer loop processes as many complete messages as possible from the buffer
// before releasing it, so that bytes belonging to the next message are never
// discarded mid-stream.
try {
if (!initalizeExpectedLength(ciphertextNettyBuf)) {
// We have not read enough bytes to initialize the expected length.
return;
}
if (!initalizeDecrypter(ciphertextNettyBuf)) {
// We have not read enough bytes to initialize a header, needed to
// initialize a decrypter.
return;
}
int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
while (nettyBufReadableBytes > 0 && !completed) {
// Read the ciphertext into the local buffer
int readableBytes = Integer.min(
nettyBufReadableBytes,
ciphertextBuffer.remaining());
int expectedRemaining = (int) (expectedLength - ciphertextRead);
int bytesToRead = Integer.min(readableBytes, expectedRemaining);
// The smallest ciphertext size is 16 bytes for the auth tag
ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead);
ciphertextNettyBuf.readBytes(ciphertextBuffer);
ciphertextRead += bytesToRead;
// Check if this is the last segment
if (ciphertextRead == expectedLength) {
completed = true;
} else if (ciphertextRead > expectedLength) {
throw new IllegalStateException("Read more ciphertext than expected.");
while (true) {
if (!initializeExpectedLength(ciphertextNettyBuf)) {
// We have not read enough bytes to initialize the expected length.
break;
}
if (!initializeDecrypter(ciphertextNettyBuf)) {
// We have not read enough bytes to initialize a header, needed to
// initialize a decrypter.
break;
}
int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
while (nettyBufReadableBytes > 0 && !completed) {
// Read the ciphertext into the local buffer
int readableBytes = Math.min(
nettyBufReadableBytes,
ciphertextBuffer.remaining());
int expectedRemaining = (int) (expectedLength - ciphertextRead);
int bytesToRead = Math.min(readableBytes, expectedRemaining);
// The smallest ciphertext size is 16 bytes for the auth tag
ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead);
ciphertextNettyBuf.readBytes(ciphertextBuffer);
ciphertextRead += bytesToRead;
// Check if this is the last segment
if (ciphertextRead == expectedLength) {
completed = true;
} else if (ciphertextRead > expectedLength) {
throw new IllegalStateException("Read more ciphertext than expected.");
}
// If the ciphertext buffer is full, or this is the last segment,
// then decrypt it and fire a read.
if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) {
ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize);
ciphertextBuffer.flip();
decrypter.decryptSegment(
ciphertextBuffer,
segmentNumber,
completed,
plaintextBuffer);
segmentNumber++;
// Clear the ciphertext buffer because it's been read
ciphertextBuffer.clear();
plaintextBuffer.flip();
if (plaintextAccumulator == null) {
// Integer.MAX_VALUE disables consolidation entirely.
// CompositeByteBuf.newCompArray() always initialises the
// backing array to min(16, maxNumComponents) regardless of
// this value, so there is no upfront memory cost.
plaintextAccumulator =
Unpooled.compositeBuffer(Integer.MAX_VALUE);
}
// Zero-copy append: addComponent(true, ...) increases writerIndex
// so the component is immediately readable from the composite.
plaintextAccumulator.addComponent(
true, Unpooled.wrappedBuffer(plaintextBuffer));
} else {
// Set the ciphertext buffer up to read the next chunk
ciphertextBuffer.limit(ciphertextBuffer.capacity());
}
nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
}
// If the ciphertext buffer is full, or this is the last segment,
// then decrypt it and fire a read.
if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) {
ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize);
ciphertextBuffer.flip();
decrypter.decryptSegment(
ciphertextBuffer,
segmentNumber,
completed,
plaintextBuffer);
segmentNumber++;
// Clear the ciphertext buffer because it's been read
ciphertextBuffer.clear();
plaintextBuffer.flip();
ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer));
} else {
// Set the ciphertext buffer up to read the next chunk
ciphertextBuffer.limit(ciphertextBuffer.capacity());
if (!completed) {
// Partial message: more bytes needed from the next channelRead() call.
break;
}
nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
// Fire the entire plaintext as a single event so that downstream
// handlers receive one callback per Spark message instead of one per
// 32 KB segment.
if (plaintextAccumulator != null) {
ctx.fireChannelRead(plaintextAccumulator);
plaintextAccumulator = null; // ownership transferred to downstream
}
// Current message is fully decoded. Reset state so the handler can
// decode the next independent GCM message on the same channel.
resetForNextMessage();
if (ciphertextNettyBuf.readableBytes() == 0) {
break;
}
// Remaining bytes may belong to another message; loop to process them.
}
} finally {
ciphertextNettyBuf.release();
}
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
releaseAccumulator();
ctx.fireChannelInactive();
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
releaseAccumulator();
ctx.close();
}

private void releaseAccumulator() {
if (plaintextAccumulator != null) {
plaintextAccumulator.release();
plaintextAccumulator = null;
}
}
}
}
Loading