From f2cc92663ad8ae685183e076cdb652d8fc3ba4e0 Mon Sep 17 00:00:00 2001 From: Boyuan Zhang Date: Fri, 2 Apr 2021 15:20:53 -0700 Subject: [PATCH] Eliminate beam_fn_api from KafkaIO expansion --- .../google-cloud-dataflow-java/build.gradle | 1 + .../beam/runners/dataflow/DataflowRunner.java | 4 + .../SparkStructuredStreamingRunner.java | 3 + runners/spark/spark_runner.gradle | 1 + .../beam/runners/spark/SparkRunner.java | 3 + .../runners/spark/SparkRunnerDebugger.java | 3 + sdks/java/io/kafka/build.gradle | 1 + .../org/apache/beam/sdk/io/kafka/KafkaIO.java | 176 +++++++++++++----- .../sdk/io/kafka/KafkaIOExternalTest.java | 4 +- 9 files changed, 147 insertions(+), 49 deletions(-) diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index 290ea94daddd..476e8c561c3c 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -72,6 +72,7 @@ dependencies { compile project(path: ":model:pipeline", configuration: "shadow") compile project(path: ":sdks:java:core", configuration: "shadow") compile project(":sdks:java:extensions:google-cloud-platform-core") + compile project(":sdks:java:io:kafka") compile project(":sdks:java:io:google-cloud-platform") compile project(":runners:core-construction-java") compile library.java.avro diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index c81631e6d22e..0a79cd9c60de 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -115,6 +115,7 @@ import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessageWithAttributesCoder; import org.apache.beam.sdk.io.gcp.pubsub.PubsubUnboundedSink; import org.apache.beam.sdk.io.gcp.pubsub.PubsubUnboundedSource; +import org.apache.beam.sdk.io.kafka.KafkaIO; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsValidator; import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider; @@ -491,6 +492,9 @@ private List getOverrides(boolean streaming) { new StreamingPubsubIOWriteOverrideFactory(this))); } } + if (useUnifiedWorker(options)) { + overridesBuilder.add(KafkaIO.Read.KAFKA_READ_OVERRIDE); + } overridesBuilder.add( PTransformOverride.of( PTransformMatchers.writeWithRunnerDeterminedSharding(), diff --git a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java index f08c36be97d2..5d8230eccfa2 100644 --- a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java +++ b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java @@ -37,12 +37,14 @@ import org.apache.beam.runners.spark.structuredstreaming.translation.streaming.PipelineTranslatorStreaming; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineRunner; +import org.apache.beam.sdk.io.kafka.KafkaIO; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.metrics.MetricsOptions; import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.PipelineOptionsValidator; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.spark.SparkEnv$; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.metrics.MetricsSystem; @@ -193,6 +195,7 @@ private TranslationContext translatePipeline(Pipeline pipeline) { || ExperimentalOptions.hasExperiment( pipeline.getOptions(), "beam_fn_api_use_deprecated_read") || ExperimentalOptions.hasExperiment(pipeline.getOptions(), "use_deprecated_read")) { + pipeline.replaceAll(ImmutableList.of(KafkaIO.Read.KAFKA_READ_OVERRIDE)); SplittableParDo.convertReadBasedSplittableDoFnsToPrimitiveReads(pipeline); } diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle index fe2393ef6770..38519d06ad60 100644 --- a/runners/spark/spark_runner.gradle +++ b/runners/spark/spark_runner.gradle @@ -152,6 +152,7 @@ dependencies { compile project(":runners:core-java") compile project(":runners:java-fn-execution") compile project(":runners:java-job-service") + compile project(":sdks:java:io:kafka") compile project(":sdks:java:extensions:google-cloud-platform-core") compile library.java.jackson_annotations compile library.java.slf4j_api diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java index 364e54978a8c..60a113ec6f1b 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java @@ -48,6 +48,7 @@ import org.apache.beam.runners.spark.util.SparkCompat; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineRunner; +import org.apache.beam.sdk.io.kafka.KafkaIO; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.metrics.MetricsOptions; import org.apache.beam.sdk.options.ExperimentalOptions; @@ -66,6 +67,7 @@ import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.spark.SparkEnv$; import org.apache.spark.api.java.JavaSparkContext; @@ -181,6 +183,7 @@ public SparkPipelineResult run(final Pipeline pipeline) { || ExperimentalOptions.hasExperiment( pipeline.getOptions(), "beam_fn_api_use_deprecated_read") || ExperimentalOptions.hasExperiment(pipeline.getOptions(), "use_deprecated_read")) { + pipeline.replaceAll(ImmutableList.of(KafkaIO.Read.KAFKA_READ_OVERRIDE)); SplittableParDo.convertReadBasedSplittableDoFnsToPrimitiveReads(pipeline); } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java index 33b8408fa5be..37d9d54c4802 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunnerDebugger.java @@ -26,9 +26,11 @@ import org.apache.beam.runners.spark.translation.streaming.StreamingTransformTranslator; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineRunner; +import org.apache.beam.sdk.io.kafka.KafkaIO; import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsValidator; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.joda.time.Duration; @@ -85,6 +87,7 @@ public SparkPipelineResult run(Pipeline pipeline) { || ExperimentalOptions.hasExperiment( pipeline.getOptions(), "beam_fn_api_use_deprecated_read") || ExperimentalOptions.hasExperiment(pipeline.getOptions(), "use_deprecated_read")) { + pipeline.replaceAll(ImmutableList.of(KafkaIO.Read.KAFKA_READ_OVERRIDE)); SplittableParDo.convertReadBasedSplittableDoFnsToPrimitiveReads(pipeline); } JavaSparkContext jsc = new JavaSparkContext("local[1]", "Debug_Pipeline"); diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle index df76a82ec8f9..7a4ca21f1b06 100644 --- a/sdks/java/io/kafka/build.gradle +++ b/sdks/java/io/kafka/build.gradle @@ -48,6 +48,7 @@ kafkaVersions.each{k,v -> configurations.create("kafkaVersion$k")} dependencies { compile library.java.vendored_guava_26_0_jre compile project(path: ":sdks:java:core", configuration: "shadow") + compile project(":runners:core-construction-java") compile project(":sdks:java:expansion-service") permitUnusedDeclared project(":sdks:java:expansion-service") // BEAM-11761 compile library.java.avro diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 06d4a5dbe9f0..8b6058c6c5fd 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -33,8 +33,11 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import org.apache.beam.runners.core.construction.PTransformMatchers; +import org.apache.beam.runners.core.construction.ReplacementOutputs; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.AvroCoder; import org.apache.beam.sdk.coders.ByteArrayCoder; @@ -51,6 +54,9 @@ import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.PTransformOverride; +import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.transforms.Convert; import org.apache.beam.sdk.transforms.DoFn; @@ -72,6 +78,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; @@ -1209,67 +1216,144 @@ public PCollection> expand(PBegin input) { Coder keyCoder = getKeyCoder(coderRegistry); Coder valueCoder = getValueCoder(coderRegistry); - // The Read will be expanded into SDF transform when "beam_fn_api" is enabled. - if (!ExperimentalOptions.hasExperiment(input.getPipeline().getOptions(), "beam_fn_api") - || ExperimentalOptions.hasExperiment( + // For read from unbounded in a bounded manner, we actually are not going through Read or SDF. + if (ExperimentalOptions.hasExperiment( input.getPipeline().getOptions(), "beam_fn_api_use_deprecated_read") + || ExperimentalOptions.hasExperiment( + input.getPipeline().getOptions(), "use_deprecated_read") || getMaxNumRecords() < Long.MAX_VALUE || getMaxReadTime() != null) { + return input.apply(new ReadFromKafkaViaUnbounded<>(this, keyCoder, valueCoder)); + } + return input.apply(new ReadFromKafkaViaSDF<>(this, keyCoder, valueCoder)); + } + + /** + * A {@link PTransformOverride} for runners to swap {@link ReadFromKafkaViaSDF} to legacy Kafka + * read if runners doesn't have a good support on executing unbounded Splittable DoFn. + */ + @Internal + public static final PTransformOverride KAFKA_READ_OVERRIDE = + PTransformOverride.of( + PTransformMatchers.classEqualTo(ReadFromKafkaViaSDF.class), + new KafkaReadOverrideFactory<>()); + + private static class KafkaReadOverrideFactory + implements PTransformOverrideFactory< + PBegin, PCollection>, ReadFromKafkaViaSDF> { + + @Override + public PTransformReplacement>> getReplacementTransform( + AppliedPTransform>, ReadFromKafkaViaSDF> + transform) { + return PTransformReplacement.of( + transform.getPipeline().begin(), + new ReadFromKafkaViaUnbounded<>( + transform.getTransform().kafkaRead, + transform.getTransform().keyCoder, + transform.getTransform().valueCoder)); + } + + @Override + public Map, ReplacementOutput> mapOutputs( + Map, PCollection> outputs, PCollection> newOutput) { + return ReplacementOutputs.singleton(outputs, newOutput); + } + } + + private static class ReadFromKafkaViaUnbounded + extends PTransform>> { + Read kafkaRead; + Coder keyCoder; + Coder valueCoder; + + ReadFromKafkaViaUnbounded(Read kafkaRead, Coder keyCoder, Coder valueCoder) { + this.kafkaRead = kafkaRead; + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; + } + + @Override + public PCollection> expand(PBegin input) { // Handles unbounded source to bounded conversion if maxNumRecords or maxReadTime is set. Unbounded> unbounded = org.apache.beam.sdk.io.Read.from( - toBuilder().setKeyCoder(keyCoder).setValueCoder(valueCoder).build().makeSource()); + kafkaRead + .toBuilder() + .setKeyCoder(keyCoder) + .setValueCoder(valueCoder) + .build() + .makeSource()); PTransform>> transform = unbounded; - if (getMaxNumRecords() < Long.MAX_VALUE || getMaxReadTime() != null) { + if (kafkaRead.getMaxNumRecords() < Long.MAX_VALUE || kafkaRead.getMaxReadTime() != null) { transform = - unbounded.withMaxReadTime(getMaxReadTime()).withMaxNumRecords(getMaxNumRecords()); + unbounded + .withMaxReadTime(kafkaRead.getMaxReadTime()) + .withMaxNumRecords(kafkaRead.getMaxNumRecords()); } return input.getPipeline().apply(transform); } - ReadSourceDescriptors readTransform = - ReadSourceDescriptors.read() - .withConsumerConfigOverrides(getConsumerConfig()) - .withOffsetConsumerConfigOverrides(getOffsetConsumerConfig()) - .withConsumerFactoryFn(getConsumerFactoryFn()) - .withKeyDeserializerProvider(getKeyDeserializerProvider()) - .withValueDeserializerProvider(getValueDeserializerProvider()) - .withManualWatermarkEstimator() - .withTimestampPolicyFactory(getTimestampPolicyFactory()) - .withCheckStopReadingFn(getCheckStopReadingFn()); - if (isCommitOffsetsInFinalizeEnabled()) { - readTransform = readTransform.commitOffsets(); + } + + static class ReadFromKafkaViaSDF + extends PTransform>> { + Read kafkaRead; + Coder keyCoder; + Coder valueCoder; + + ReadFromKafkaViaSDF(Read kafkaRead, Coder keyCoder, Coder valueCoder) { + this.kafkaRead = kafkaRead; + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; } - PCollection output; - if (isDynamicRead()) { - output = - input - .getPipeline() - .apply(Impulse.create()) - .apply( - MapElements.into( - TypeDescriptors.kvs( - new TypeDescriptor() {}, new TypeDescriptor() {})) - .via(element -> KV.of(element, element))) - .apply( - ParDo.of( - new WatchKafkaTopicPartitionDoFn( - getWatchTopicPartitionDuration(), - getConsumerFactoryFn(), - getCheckStopReadingFn(), - getConsumerConfig(), - getStartReadTime()))); - } else { - output = - input - .getPipeline() - .apply(Impulse.create()) - .apply(ParDo.of(new GenerateKafkaSourceDescriptor(this))); + @Override + public PCollection> expand(PBegin input) { + ReadSourceDescriptors readTransform = + ReadSourceDescriptors.read() + .withConsumerConfigOverrides(kafkaRead.getConsumerConfig()) + .withOffsetConsumerConfigOverrides(kafkaRead.getOffsetConsumerConfig()) + .withConsumerFactoryFn(kafkaRead.getConsumerFactoryFn()) + .withKeyDeserializerProvider(kafkaRead.getKeyDeserializerProvider()) + .withValueDeserializerProvider(kafkaRead.getValueDeserializerProvider()) + .withManualWatermarkEstimator() + .withTimestampPolicyFactory(kafkaRead.getTimestampPolicyFactory()) + .withCheckStopReadingFn(kafkaRead.getCheckStopReadingFn()); + if (kafkaRead.isCommitOffsetsInFinalizeEnabled()) { + readTransform = readTransform.commitOffsets(); + } + PCollection output; + if (kafkaRead.isDynamicRead()) { + output = + input + .getPipeline() + .apply(Impulse.create()) + .apply( + MapElements.into( + TypeDescriptors.kvs( + new TypeDescriptor() {}, new TypeDescriptor() {})) + .via(element -> KV.of(element, element))) + .apply( + ParDo.of( + new WatchKafkaTopicPartitionDoFn( + kafkaRead.getWatchTopicPartitionDuration(), + kafkaRead.getConsumerFactoryFn(), + kafkaRead.getCheckStopReadingFn(), + kafkaRead.getConsumerConfig(), + kafkaRead.getStartReadTime()))); + + } else { + output = + input + .getPipeline() + .apply(Impulse.create()) + .apply(ParDo.of(new GenerateKafkaSourceDescriptor(kafkaRead))); + } + return output.apply(readTransform).setCoder(KafkaRecordCoder.of(keyCoder, valueCoder)); } - return output.apply(readTransform).setCoder(KafkaRecordCoder.of(keyCoder, valueCoder)); } /** @@ -1798,10 +1882,6 @@ ReadSourceDescriptors withTimestampPolicyFactory( @Override public PCollection> expand(PCollection input) { - checkArgument( - ExperimentalOptions.hasExperiment(input.getPipeline().getOptions(), "beam_fn_api"), - "The ReadSourceDescriptors can only used when beam_fn_api is enabled."); - checkArgument(getKeyDeserializerProvider() != null, "withKeyDeserializer() is required"); checkArgument(getValueDeserializerProvider() != null, "withValueDeserializer() is required"); diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java index 108d85832509..bb15e42461d9 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java @@ -128,8 +128,10 @@ public void testConstructKafkaRead() throws Exception { assertThat(transform.getInputsCount(), Matchers.is(0)); assertThat(transform.getOutputsCount(), Matchers.is(1)); - RunnerApi.PTransform kafkaComposite = + RunnerApi.PTransform kafkaReadComposite = result.getComponents().getTransformsOrThrow(transform.getSubtransforms(0)); + RunnerApi.PTransform kafkaComposite = + result.getComponents().getTransformsOrThrow(kafkaReadComposite.getSubtransforms(0)); assertThat( kafkaComposite.getSubtransformsList(), Matchers.hasItem(MatchesPattern.matchesPattern(".*Impulse.*")));