Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
import org.apache.beam.model.pipeline.v1.RunnerApi.ReadPayload;
import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.Read;
import org.apache.beam.sdk.io.Read.Unbounded;
import org.apache.beam.sdk.io.Source;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.runners.AppliedPTransform;
Expand All @@ -43,21 +41,24 @@
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;

/**
* Methods for translating {@link Read.Bounded} and {@link Read.Unbounded} {@link PTransform
* PTransformTranslation} into {@link ReadPayload} protos.
* Methods for translating {@link SplittableParDo.PrimitiveBoundedRead} and {@link
* SplittableParDo.PrimitiveUnboundedRead} {@link PTransform PTransformTranslation} into {@link
* ReadPayload} protos.
*/
public class ReadTranslation {
private static final String JAVA_SERIALIZED_BOUNDED_SOURCE = "beam:java:boundedsource:v1";
private static final String JAVA_SERIALIZED_UNBOUNDED_SOURCE = "beam:java:unboundedsource:v1";

public static ReadPayload toProto(Read.Bounded<?> read, SdkComponents components) {
public static ReadPayload toProto(
SplittableParDo.PrimitiveBoundedRead<?> read, SdkComponents components) {
return ReadPayload.newBuilder()
.setIsBounded(IsBounded.Enum.BOUNDED)
.setSource(toProto(read.getSource(), components))
.build();
}

public static ReadPayload toProto(Unbounded<?> read, SdkComponents components) {
public static ReadPayload toProto(
SplittableParDo.PrimitiveUnboundedRead<?> read, SdkComponents components) {
return ReadPayload.newBuilder()
.setIsBounded(IsBounded.Enum.UNBOUNDED)
.setSource(toProto(read.getSource(), components))
Expand Down Expand Up @@ -141,23 +142,25 @@ public static PCollection.IsBounded sourceIsBounded(AppliedPTransform<?, ?, ?> t
}
}

/** A {@link TransformPayloadTranslator} for {@link Read.Unbounded}. */
/** A {@link TransformPayloadTranslator} for {@link SplittableParDo.PrimitiveUnboundedRead}. */
public static class UnboundedReadPayloadTranslator
implements PTransformTranslation.TransformPayloadTranslator<Read.Unbounded<?>> {
implements PTransformTranslation.TransformPayloadTranslator<
SplittableParDo.PrimitiveUnboundedRead<?>> {
public static TransformPayloadTranslator create() {
return new UnboundedReadPayloadTranslator();
}

private UnboundedReadPayloadTranslator() {}

@Override
public String getUrn(Read.Unbounded<?> transform) {
public String getUrn(SplittableParDo.PrimitiveUnboundedRead<?> transform) {
return PTransformTranslation.READ_TRANSFORM_URN;
}

@Override
public FunctionSpec translate(
AppliedPTransform<?, ?, Read.Unbounded<?>> transform, SdkComponents components) {
AppliedPTransform<?, ?, SplittableParDo.PrimitiveUnboundedRead<?>> transform,
SdkComponents components) {
ReadPayload payload = toProto(transform.getTransform(), components);
return RunnerApi.FunctionSpec.newBuilder()
.setUrn(getUrn(transform.getTransform()))
Expand All @@ -166,23 +169,25 @@ public FunctionSpec translate(
}
}

/** A {@link TransformPayloadTranslator} for {@link Read.Bounded}. */
/** A {@link TransformPayloadTranslator} for {@link SplittableParDo.PrimitiveBoundedRead}. */
public static class BoundedReadPayloadTranslator
implements PTransformTranslation.TransformPayloadTranslator<Read.Bounded<?>> {
implements PTransformTranslation.TransformPayloadTranslator<
SplittableParDo.PrimitiveBoundedRead<?>> {
public static TransformPayloadTranslator create() {
return new BoundedReadPayloadTranslator();
}

private BoundedReadPayloadTranslator() {}

@Override
public String getUrn(Read.Bounded<?> transform) {
public String getUrn(SplittableParDo.PrimitiveBoundedRead<?> transform) {
return PTransformTranslation.READ_TRANSFORM_URN;
}

@Override
public FunctionSpec translate(
AppliedPTransform<?, ?, Read.Bounded<?>> transform, SdkComponents components) {
AppliedPTransform<?, ?, SplittableParDo.PrimitiveBoundedRead<?>> transform,
SdkComponents components) {
ReadPayload payload = toProto(transform.getTransform(), components);
return RunnerApi.FunctionSpec.newBuilder()
.setUrn(getUrn(transform.getTransform()))
Expand All @@ -198,8 +203,8 @@ public static class Registrar implements TransformPayloadTranslatorRegistrar {
public Map<? extends Class<? extends PTransform>, ? extends TransformPayloadTranslator>
getTransformPayloadTranslators() {
return ImmutableMap.<Class<? extends PTransform>, TransformPayloadTranslator>builder()
.put(Read.Unbounded.class, new UnboundedReadPayloadTranslator())
.put(Read.Bounded.class, new BoundedReadPayloadTranslator())
.put(SplittableParDo.PrimitiveUnboundedRead.class, new UnboundedReadPayloadTranslator())
.put(SplittableParDo.PrimitiveBoundedRead.class, new BoundedReadPayloadTranslator())
.build();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import com.google.auto.service.AutoService;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -37,22 +36,26 @@
import org.apache.beam.runners.core.construction.ReadTranslation.BoundedReadPayloadTranslator;
import org.apache.beam.runners.core.construction.ReadTranslation.UnboundedReadPayloadTranslator;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.Pipeline.PipelineVisitor;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.Read;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark;
import org.apache.beam.sdk.options.ExperimentalOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverride;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
import org.apache.beam.sdk.runners.TransformHierarchy.Node;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.ArgumentProvider;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.BaseArgumentProvider;
Expand All @@ -62,7 +65,9 @@
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.NameUtils;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
Expand All @@ -71,6 +76,7 @@
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.checkerframework.checker.nullness.qual.Nullable;
Expand Down Expand Up @@ -638,43 +644,161 @@ public void tearDown() {
}

/**
* Throws an {@link IllegalArgumentException} if the pipeline contains any primitive read
* transforms that have not been expanded to be executed as {@link DoFn splittable DoFns} as long
* as the experiment {@code use_deprecated_read} is not specified.
* Converts {@link Read} based Splittable DoFn expansions to primitive reads implemented by {@link
* PrimitiveBoundedRead} and {@link PrimitiveUnboundedRead} if either the experiment {@code
* use_deprecated_read} or {@code beam_fn_api_use_deprecated_read} are specified.
*
* <p>TODO(BEAM-10670): Remove the primitive Read and make the splittable DoFn the only option.
*/
public static void convertReadBasedSplittableDoFnsToPrimitiveReadsIfNecessary(Pipeline pipeline) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should be controlled by the runner choosing to invoke the method, not by a global flag. It can have the same status as other runner-internal overrides, like GBK via GBKO.

If you really believe it should be a flag, it should be the runner that reads the flag and decides what to do. This utility library should not change its behavior based on pipeline options. Only runners should opt in to particular behaviors.

if (ExperimentalOptions.hasExperiment(pipeline.getOptions(), "beam_fn_api_use_deprecated_read")
|| ExperimentalOptions.hasExperiment(pipeline.getOptions(), "use_deprecated_read")) {
convertReadBasedSplittableDoFnsToPrimitiveReads(pipeline);
}
}

/**
* Converts {@link Read} based Splittable DoFn expansions to primitive reads implemented by {@link
* PrimitiveBoundedRead} and {@link PrimitiveUnboundedRead}.
*
* <p>TODO(BEAM-10670): Remove the primitive Read and make the splittable DoFn the only option.
*/
public static void convertReadBasedSplittableDoFnsToPrimitiveReads(Pipeline pipeline) {
pipeline.replaceAll(
ImmutableList.of(PRIMITIVE_BOUNDED_READ_OVERRIDE, PRIMITIVE_UNBOUNDED_READ_OVERRIDE));
}

/**
* A transform override for {@link Read.Bounded} that converts it to a {@link
* PrimitiveBoundedRead}.
*/
public static final PTransformOverride PRIMITIVE_BOUNDED_READ_OVERRIDE =
PTransformOverride.of(
PTransformMatchers.classEqualTo(Read.Bounded.class), new BoundedReadOverrideFactory<>());
/**
* A transform override for {@link Read.Unbounded} that converts it to a {@link
* PrimitiveUnboundedRead}.
*/
public static void validateNoPrimitiveReads(Pipeline pipeline) {
// TODO(BEAM-10670): Remove the deprecated Read and make the splittable DoFn the only option.
if (!(ExperimentalOptions.hasExperiment(
pipeline.getOptions(), "beam_fn_api_use_deprecated_read")
|| ExperimentalOptions.hasExperiment(pipeline.getOptions(), "use_deprecated_read"))) {
public static final PTransformOverride PRIMITIVE_UNBOUNDED_READ_OVERRIDE =
PTransformOverride.of(
PTransformMatchers.classEqualTo(Read.Unbounded.class),
new UnboundedReadOverrideFactory<>());

pipeline.traverseTopologically(new ValidateNoPrimitiveReads());
private static class BoundedReadOverrideFactory<T>
implements PTransformOverrideFactory<PBegin, PCollection<T>, Read.Bounded<T>> {
@Override
public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
AppliedPTransform<PBegin, PCollection<T>, Read.Bounded<T>> transform) {
return PTransformReplacement.of(
transform.getPipeline().begin(), new PrimitiveBoundedRead<>(transform.getTransform()));
}

@Override
public Map<PValue, ReplacementOutput> mapOutputs(
Map<TupleTag<?>, PValue> outputs, PCollection<T> newOutput) {
return ReplacementOutputs.singleton(outputs, newOutput);
}
}

private static class UnboundedReadOverrideFactory<T>
implements PTransformOverrideFactory<PBegin, PCollection<T>, Read.Unbounded<T>> {
@Override
public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
AppliedPTransform<PBegin, PCollection<T>, Read.Unbounded<T>> transform) {
return PTransformReplacement.of(
transform.getPipeline().begin(), new PrimitiveUnboundedRead<>(transform.getTransform()));
}

@Override
public Map<PValue, ReplacementOutput> mapOutputs(
Map<TupleTag<?>, PValue> outputs, PCollection<T> newOutput) {
return ReplacementOutputs.singleton(outputs, newOutput);
}
}

/**
* A {@link org.apache.beam.sdk.Pipeline.PipelineVisitor} that ensures that the pipeline does not
* contain any primitive reads.
* Base class that ensures the overridden transform has the same contract as if interacting with
* the original {@link Read.Bounded Read.Bounded}/{@link Read.Unbounded Read.Unbounded}
* implementations.
*/
private static class ValidateNoPrimitiveReads extends PipelineVisitor.Defaults {
public final List<PTransform<?, ?>> foundPrimitiveReads = new ArrayList<>();
private abstract static class PrimitiveRead<T> extends PTransform<PBegin, PCollection<T>> {
private final PTransform<PBegin, PCollection<T>> originalTransform;
protected final Object source;

public PrimitiveRead(PTransform<PBegin, PCollection<T>> originalTransform, Object source) {
this.originalTransform = originalTransform;
this.source = source;
}

@Override
public void visitPrimitiveTransform(Node node) {
if (node.getTransform() instanceof Read.Bounded
|| node.getTransform() instanceof Read.Unbounded) {
foundPrimitiveReads.add(node.getTransform());
}
public void validate(@Nullable PipelineOptions options) {
originalTransform.validate(options);
}

@Override
public void leavePipeline(Pipeline pipeline) {
if (!foundPrimitiveReads.isEmpty()) {
throw new IllegalArgumentException(
String.format(
"Found primitive read transforms %s within the pipeline when only Splittable DoFns were expected. If you would like to use the deprecated behavior, please specify the experiment 'use_deprecated_read'. For example '--experiements=use_deprecated_read' on the command line.",
foundPrimitiveReads));
}
public Map<TupleTag<?>, PValue> getAdditionalInputs() {
return originalTransform.getAdditionalInputs();
}

@Override
public <CoderT> Coder<CoderT> getDefaultOutputCoder(PBegin input, PCollection<CoderT> output)
throws CannotProvideCoderException {
return originalTransform.getDefaultOutputCoder(input, output);
}

@Override
public String getName() {
return originalTransform.getName();
}

@Override
public void populateDisplayData(DisplayData.Builder builder) {
originalTransform.populateDisplayData(builder);
}

@Override
protected String getKindString() {
return String.format("Read(%s)", NameUtils.approximateSimpleName(source));
}
}

/** The original primitive based {@link Read.Bounded Read.Bounded} expansion. */
public static class PrimitiveBoundedRead<T> extends PrimitiveRead<T> {
public PrimitiveBoundedRead(Read.Bounded<T> originalTransform) {
super(originalTransform, originalTransform.getSource());
}

@Override
public PCollection<T> expand(PBegin input) {
return PCollection.createPrimitiveOutputInternal(
input.getPipeline(),
WindowingStrategy.globalDefault(),
PCollection.IsBounded.BOUNDED,
getSource().getOutputCoder());
}

public BoundedSource<T> getSource() {
return (BoundedSource<T>) source;
}
}

/** The original primitive based {@link Read.Unbounded Read.Unbounded} expansion. */
public static class PrimitiveUnboundedRead<T> extends PrimitiveRead<T> {
public PrimitiveUnboundedRead(Read.Unbounded<T> originalTransform) {
super(originalTransform, originalTransform.getSource());
}

@Override
public PCollection<T> expand(PBegin input) {
return PCollection.createPrimitiveOutputInternal(
input.getPipeline(),
WindowingStrategy.globalDefault(),
PCollection.IsBounded.UNBOUNDED,
getSource().getOutputCoder());
}

public UnboundedSource<T, ? extends CheckpointMark> getSource() {
return (UnboundedSource<T, ? extends CheckpointMark>) source;
}
}
}
Loading