diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java index 1f4b7535d89e..85483fd517a9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java @@ -239,8 +239,9 @@ public void setup() { } @ProcessElement + @SuppressWarnings("nullness") public void processElement(ProcessContext c) throws Exception { - java.util.Map> decryptedKvs = new java.util.HashMap<>(); + java.util.HashMap> decryptedKvs = new java.util.HashMap<>(); for (KV encryptedKv : c.element().getValue()) { byte[] iv = Arrays.copyOfRange(encryptedKv.getKey(), 0, 12); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(128, iv); @@ -251,24 +252,19 @@ public void processElement(ProcessContext c) throws Exception { byte[] decryptedKeyBytes = this.cipher.doFinal(encryptedKey); K key = decode(this.keyCoder, decryptedKeyBytes); - if (key != null) { - if (!decryptedKvs.containsKey(key)) { - decryptedKvs.put(key, new java.util.ArrayList<>()); - } + if (!decryptedKvs.containsKey(key)) { + decryptedKvs.put(key, new java.util.ArrayList<>()); + } - iv = Arrays.copyOfRange(encryptedKv.getValue(), 0, 12); - gcmParameterSpec = new GCMParameterSpec(128, iv); - this.cipher.init(Cipher.DECRYPT_MODE, this.secretKeySpec, gcmParameterSpec); + iv = Arrays.copyOfRange(encryptedKv.getValue(), 0, 12); + gcmParameterSpec = new GCMParameterSpec(128, iv); + this.cipher.init(Cipher.DECRYPT_MODE, this.secretKeySpec, gcmParameterSpec); - byte[] encryptedValue = - Arrays.copyOfRange(encryptedKv.getValue(), 12, encryptedKv.getValue().length); - byte[] decryptedValueBytes = this.cipher.doFinal(encryptedValue); - V value = decode(this.valueCoder, decryptedValueBytes); - decryptedKvs.get(key).add(value); - } else { - throw new RuntimeException( - "Found null key when decoding " + Arrays.toString(decryptedKeyBytes)); - } + byte[] encryptedValue = + Arrays.copyOfRange(encryptedKv.getValue(), 12, encryptedKv.getValue().length); + byte[] decryptedValueBytes = this.cipher.doFinal(encryptedValue); + V value = decode(this.valueCoder, decryptedValueBytes); + decryptedKvs.get(key).add(value); } for (java.util.Map.Entry> entry : decryptedKvs.entrySet()) { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java index 3a2fc2f08c04..31064470bd38 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java @@ -33,6 +33,7 @@ import java.util.stream.Collectors; import java.util.stream.StreamSupport; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.testing.NeedsRunner; @@ -42,6 +43,7 @@ import org.apache.beam.sdk.util.Secret; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.checkerframework.checker.nullness.qual.Nullable; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Rule; @@ -141,20 +143,22 @@ public static void tearDown() throws IOException { @Test @Category(NeedsRunner.class) public void testGroupByKeyGcpSecret() { - List> ungroupedPairs = + List> ungroupedPairs = Arrays.asList( + KV.of(null, 3), KV.of("k1", 3), KV.of("k5", Integer.MAX_VALUE), KV.of("k5", Integer.MIN_VALUE), KV.of("k2", 66), KV.of("k1", 4), + KV.of(null, 5), KV.of("k2", -33), KV.of("k3", 0)); PCollection> input = p.apply( Create.of(ungroupedPairs) - .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))); + .withCoder(KvCoder.of(NullableCoder.of(StringUtf8Coder.of()), VarIntCoder.of()))); PCollection>> output = input.apply(GroupByEncryptedKey.create(gcpSecret)); @@ -162,6 +166,7 @@ public void testGroupByKeyGcpSecret() { PAssert.that(output.apply("Sort", MapElements.via(new SortValues()))) .containsInAnyOrder( KV.of("k1", Arrays.asList(3, 4)), + KV.of(null, Arrays.asList(3, 5)), KV.of("k5", Arrays.asList(Integer.MIN_VALUE, Integer.MAX_VALUE)), KV.of("k2", Arrays.asList(-33, 66)), KV.of("k3", Arrays.asList(0)));