From e6911317f4ee4040c334320ecdb777545f20d1e7 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Tue, 14 Oct 2025 11:05:45 -0400 Subject: [PATCH 1/4] Handle null keys in gbek --- .../org/apache/beam/sdk/transforms/GroupByEncryptedKey.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java index 1f4b7535d89e..6a9f0cb78a93 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java @@ -251,7 +251,8 @@ public void processElement(ProcessContext c) throws Exception { byte[] decryptedKeyBytes = this.cipher.doFinal(encryptedKey); K key = decode(this.keyCoder, decryptedKeyBytes); - if (key != null) { + // If somehow the key was decoded to null, but the byte string is non-empty, throw. + if (key != null || decryptedKeyBytes == null || decryptedKeyBytes.length == 0) { if (!decryptedKvs.containsKey(key)) { decryptedKvs.put(key, new java.util.ArrayList<>()); } From 19bd68d9f1b76f9b94124a91cc09e31e3fb542c6 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Tue, 14 Oct 2025 11:40:11 -0400 Subject: [PATCH 2/4] Allow null values with hashmap --- .../org/apache/beam/sdk/transforms/GroupByEncryptedKey.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java index 6a9f0cb78a93..f7023a4ca3d2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java @@ -239,8 +239,9 @@ public void setup() { } @ProcessElement + @SuppressWarnings("nullness") public void processElement(ProcessContext c) throws Exception { - java.util.Map> decryptedKvs = new java.util.HashMap<>(); + java.util.HashMap> decryptedKvs = new java.util.HashMap<>(); for (KV encryptedKv : c.element().getValue()) { byte[] iv = Arrays.copyOfRange(encryptedKv.getKey(), 0, 12); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(128, iv); From 79e32171c706d5b702a79513f77ea2a5edb0918d Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Tue, 14 Oct 2025 13:18:02 -0400 Subject: [PATCH 3/4] add a test --- .../apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java index 3a2fc2f08c04..3d27cca62c45 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java @@ -143,11 +143,13 @@ public static void tearDown() throws IOException { public void testGroupByKeyGcpSecret() { List> ungroupedPairs = Arrays.asList( + KV.of(null, 3), KV.of("k1", 3), KV.of("k5", Integer.MAX_VALUE), KV.of("k5", Integer.MIN_VALUE), KV.of("k2", 66), KV.of("k1", 4), + KV.of(null, 5), KV.of("k2", -33), KV.of("k3", 0)); @@ -162,6 +164,7 @@ public void testGroupByKeyGcpSecret() { PAssert.that(output.apply("Sort", MapElements.via(new SortValues()))) .containsInAnyOrder( KV.of("k1", Arrays.asList(3, 4)), + KV.of(null, Arrays.asList(3, 5)), KV.of("k5", Arrays.asList(Integer.MIN_VALUE, Integer.MAX_VALUE)), KV.of("k2", Arrays.asList(-33, 66)), KV.of("k3", Arrays.asList(0))); From d968b35b4e39c4d379096eded9d855d308221cc0 Mon Sep 17 00:00:00 2001 From: Danny Mccormick Date: Tue, 14 Oct 2025 15:58:08 -0400 Subject: [PATCH 4/4] Test + remove check entirely --- .../sdk/transforms/GroupByEncryptedKey.java | 28 ++++++++----------- .../transforms/GroupByEncryptedKeyTest.java | 6 ++-- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java index f7023a4ca3d2..85483fd517a9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java @@ -252,25 +252,19 @@ public void processElement(ProcessContext c) throws Exception { byte[] decryptedKeyBytes = this.cipher.doFinal(encryptedKey); K key = decode(this.keyCoder, decryptedKeyBytes); - // If somehow the key was decoded to null, but the byte string is non-empty, throw. - if (key != null || decryptedKeyBytes == null || decryptedKeyBytes.length == 0) { - if (!decryptedKvs.containsKey(key)) { - decryptedKvs.put(key, new java.util.ArrayList<>()); - } + if (!decryptedKvs.containsKey(key)) { + decryptedKvs.put(key, new java.util.ArrayList<>()); + } - iv = Arrays.copyOfRange(encryptedKv.getValue(), 0, 12); - gcmParameterSpec = new GCMParameterSpec(128, iv); - this.cipher.init(Cipher.DECRYPT_MODE, this.secretKeySpec, gcmParameterSpec); + iv = Arrays.copyOfRange(encryptedKv.getValue(), 0, 12); + gcmParameterSpec = new GCMParameterSpec(128, iv); + this.cipher.init(Cipher.DECRYPT_MODE, this.secretKeySpec, gcmParameterSpec); - byte[] encryptedValue = - Arrays.copyOfRange(encryptedKv.getValue(), 12, encryptedKv.getValue().length); - byte[] decryptedValueBytes = this.cipher.doFinal(encryptedValue); - V value = decode(this.valueCoder, decryptedValueBytes); - decryptedKvs.get(key).add(value); - } else { - throw new RuntimeException( - "Found null key when decoding " + Arrays.toString(decryptedKeyBytes)); - } + byte[] encryptedValue = + Arrays.copyOfRange(encryptedKv.getValue(), 12, encryptedKv.getValue().length); + byte[] decryptedValueBytes = this.cipher.doFinal(encryptedValue); + V value = decode(this.valueCoder, decryptedValueBytes); + decryptedKvs.get(key).add(value); } for (java.util.Map.Entry> entry : decryptedKvs.entrySet()) { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java index 3d27cca62c45..31064470bd38 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java @@ -33,6 +33,7 @@ import java.util.stream.Collectors; import java.util.stream.StreamSupport; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.testing.NeedsRunner; @@ -42,6 +43,7 @@ import org.apache.beam.sdk.util.Secret; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.checkerframework.checker.nullness.qual.Nullable; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; @@ -141,7 +143,7 @@ public static void tearDown() throws IOException { @Test @Category(NeedsRunner.class) public void testGroupByKeyGcpSecret() { - List> ungroupedPairs = + List> ungroupedPairs = Arrays.asList( KV.of(null, 3), KV.of("k1", 3), @@ -156,7 +158,7 @@ public void testGroupByKeyGcpSecret() { PCollection> input = p.apply( Create.of(ungroupedPairs) - .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))); + .withCoder(KvCoder.of(NullableCoder.of(StringUtf8Coder.of()), VarIntCoder.of()))); PCollection>> output = input.apply(GroupByEncryptedKey.create(gcpSecret));