diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 69f438055da9..94ad154ff7b0 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -46,6 +46,7 @@ import java.util.Arrays; import java.util.Comparator; import java.util.List; +import java.util.OptionalInt; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; @@ -58,12 +59,14 @@ import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.Wait; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.windowing.DefaultTrigger; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Window; @@ -73,11 +76,13 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Stopwatch; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; @@ -352,7 +357,6 @@ public static Write write() { .setBatchSizeBytes(DEFAULT_BATCH_SIZE_BYTES) .setMaxNumMutations(DEFAULT_MAX_NUM_MUTATIONS) .setMaxNumRows(DEFAULT_MAX_NUM_ROWS) - .setGroupingFactor(DEFAULT_GROUPING_FACTOR) .setFailureMode(FailureMode.FAIL_FAST) .build(); } @@ -783,7 +787,7 @@ public abstract static class Write extends PTransform, Spa @Nullable abstract PCollection getSchemaReadySignal(); - abstract int getGroupingFactor(); + abstract OptionalInt getGroupingFactor(); abstract Builder toBuilder(); @@ -967,8 +971,14 @@ private void populateDisplayDataWithParamaters(DisplayData.Builder builder) { builder.add( DisplayData.item("maxNumRows", getMaxNumRows()) .withLabel("Max number of rows in each batch")); + // Grouping factor default value depends on whether it is a batch or streaming pipeline. + // This function is not aware of that state, so use 'DEFAULT' if unset. builder.add( - DisplayData.item("groupingFactor", getGroupingFactor()) + DisplayData.item( + "groupingFactor", + (getGroupingFactor().isPresent() + ? Integer.toString(getGroupingFactor().getAsInt()) + : "DEFAULT")) .withLabel("Number of batches to sort over")); } } @@ -1014,75 +1024,98 @@ public void populateDisplayData(DisplayData.Builder builder) { @Override public SpannerWriteResult expand(PCollection input) { + PCollection> batches; + + if (spec.getBatchSizeBytes() <= 1 + || spec.getMaxNumMutations() <= 1 + || spec.getMaxNumRows() <= 1) { + LOG.info("Batching of mutationGroups is disabled"); + TypeDescriptor> descriptor = + new TypeDescriptor>() {}; + batches = + input.apply(MapElements.into(descriptor).via(element -> ImmutableList.of(element))); + } else { - // First, read the Cloud Spanner schema. - PCollection schemaSeed = - input.getPipeline().apply("Create Seed", Create.of((Void) null)); - if (spec.getSchemaReadySignal() != null) { - // Wait for external signal before reading schema. - schemaSeed = schemaSeed.apply("Wait for schema", Wait.on(spec.getSchemaReadySignal())); + // First, read the Cloud Spanner schema. + PCollection schemaSeed = + input.getPipeline().apply("Create Seed", Create.of((Void) null)); + if (spec.getSchemaReadySignal() != null) { + // Wait for external signal before reading schema. + schemaSeed = schemaSeed.apply("Wait for schema", Wait.on(spec.getSchemaReadySignal())); + } + final PCollectionView schemaView = + schemaSeed + .apply( + "Read information schema", + ParDo.of(new ReadSpannerSchema(spec.getSpannerConfig()))) + .apply("Schema View", View.asSingleton()); + + // Split the mutations into batchable and unbatchable mutations. + // Filter out mutation groups too big to be batched. + PCollectionTuple filteredMutations = + input + .apply( + "RewindowIntoGlobal", + Window.into(new GlobalWindows()) + .triggering(DefaultTrigger.of()) + .discardingFiredPanes()) + .apply( + "Filter Unbatchable Mutations", + ParDo.of( + new BatchableMutationFilterFn( + schemaView, + UNBATCHABLE_MUTATIONS_TAG, + spec.getBatchSizeBytes(), + spec.getMaxNumMutations(), + spec.getMaxNumRows())) + .withSideInputs(schemaView) + .withOutputTags( + BATCHABLE_MUTATIONS_TAG, TupleTagList.of(UNBATCHABLE_MUTATIONS_TAG))); + + // Build a set of Mutation groups from the current bundle, + // sort them by table/key then split into batches. + PCollection> batchedMutations = + filteredMutations + .get(BATCHABLE_MUTATIONS_TAG) + .apply( + "Gather And Sort", + ParDo.of( + new GatherBundleAndSortFn( + spec.getBatchSizeBytes(), + spec.getMaxNumMutations(), + spec.getMaxNumRows(), + // Do not group on streaming unless explicitly set. + spec.getGroupingFactor() + .orElse( + input.isBounded() == IsBounded.BOUNDED + ? DEFAULT_GROUPING_FACTOR + : 1), + schemaView)) + .withSideInputs(schemaView)) + .apply( + "Create Batches", + ParDo.of( + new BatchFn( + spec.getBatchSizeBytes(), + spec.getMaxNumMutations(), + spec.getMaxNumRows(), + schemaView)) + .withSideInputs(schemaView)); + + // Merge the batched and unbatchable mutation PCollections and write to Spanner. + batches = + PCollectionList.of(filteredMutations.get(UNBATCHABLE_MUTATIONS_TAG)) + .and(batchedMutations) + .apply("Merge", Flatten.pCollections()); } - final PCollectionView schemaView = - schemaSeed - .apply( - "Read information schema", - ParDo.of(new ReadSpannerSchema(spec.getSpannerConfig()))) - .apply("Schema View", View.asSingleton()); - - // Split the mutations into batchable and unbatchable mutations. - // Filter out mutation groups too big to be batched. - PCollectionTuple filteredMutations = - input - .apply("To Global Window", Window.into(new GlobalWindows())) - .apply( - "Filter Unbatchable Mutations", - ParDo.of( - new BatchableMutationFilterFn( - schemaView, - UNBATCHABLE_MUTATIONS_TAG, - spec.getBatchSizeBytes(), - spec.getMaxNumMutations(), - spec.getMaxNumRows())) - .withSideInputs(schemaView) - .withOutputTags( - BATCHABLE_MUTATIONS_TAG, TupleTagList.of(UNBATCHABLE_MUTATIONS_TAG))); - - // Build a set of Mutation groups from the current bundle, - // sort them by table/key then split into batches. - PCollection> batchedMutations = - filteredMutations - .get(BATCHABLE_MUTATIONS_TAG) - .apply( - "Gather And Sort", - ParDo.of( - new GatherBundleAndSortFn( - spec.getBatchSizeBytes(), - spec.getMaxNumMutations(), - spec.getMaxNumRows(), - spec.getGroupingFactor(), - schemaView)) - .withSideInputs(schemaView)) - .apply( - "Create Batches", - ParDo.of( - new BatchFn( - spec.getBatchSizeBytes(), - spec.getMaxNumMutations(), - spec.getMaxNumRows(), - schemaView)) - .withSideInputs(schemaView)); - - // Merge the batchable and unbatchable mutation PCollections and write to Spanner. + PCollectionTuple result = - PCollectionList.of(filteredMutations.get(UNBATCHABLE_MUTATIONS_TAG)) - .and(batchedMutations) - .apply("Merge", Flatten.pCollections()) - .apply( - "Write mutations to Spanner", - ParDo.of( - new WriteToSpannerFn( - spec.getSpannerConfig(), spec.getFailureMode(), FAILED_MUTATIONS_TAG)) - .withOutputTags(MAIN_OUT_TAG, TupleTagList.of(FAILED_MUTATIONS_TAG))); + batches.apply( + "Write batches to Spanner", + ParDo.of( + new WriteToSpannerFn( + spec.getSpannerConfig(), spec.getFailureMode(), FAILED_MUTATIONS_TAG)) + .withOutputTags(MAIN_OUT_TAG, TupleTagList.of(FAILED_MUTATIONS_TAG))); return new SpannerWriteResult( input.getPipeline(), diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java index aaca469105c4..816556cf1bc3 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java @@ -263,18 +263,34 @@ private void verifyBatches(Iterable... batches) { @Test public void noBatching() throws Exception { + + // This test uses a different mock/fake because it explicitly does not want to populate the + // Spanner schema. + FakeServiceFactory fakeServiceFactory = new FakeServiceFactory(); + ReadOnlyTransaction tx = mock(ReadOnlyTransaction.class); + when(fakeServiceFactory.mockDatabaseClient().readOnlyTransaction()).thenReturn(tx); + + // Capture batches sent to writeAtLeastOnce. + when(fakeServiceFactory.mockDatabaseClient().writeAtLeastOnce(mutationBatchesCaptor.capture())) + .thenReturn(null); + PCollection mutations = pipeline.apply(Create.of(g(m(1L)), g(m(2L)))); mutations.apply( SpannerIO.write() .withProjectId("test-project") .withInstanceId("test-instance") .withDatabaseId("test-database") - .withServiceFactory(serviceFactory) + .withServiceFactory(fakeServiceFactory) .withBatchSizeBytes(1) .grouped()); pipeline.run(); - verifyBatches(batch(m(1L)), batch(m(2L))); + verify(fakeServiceFactory.mockDatabaseClient(), times(1)) + .writeAtLeastOnce(mutationsInNoOrder(batch(m(1L)))); + verify(fakeServiceFactory.mockDatabaseClient(), times(1)) + .writeAtLeastOnce(mutationsInNoOrder(batch(m(2L)))); + // If no batching then the DB schema is never read. + verify(tx, never()).executeQuery(any()); } @Test @@ -300,6 +316,54 @@ public void streamingWrites() throws Exception { verifyBatches(batch(m(1L), m(2L)), batch(m(3L), m(4L)), batch(m(5L), m(6L))); } + @Test + public void streamingWritesWithGrouping() throws Exception { + + // verify that grouping/sorting occurs when set. + TestStream testStream = + TestStream.create(SerializableCoder.of(Mutation.class)) + .addElements(m(1L), m(5L), m(2L), m(4L), m(3L), m(6L)) + .advanceWatermarkToInfinity(); + pipeline + .apply(testStream) + .apply( + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withServiceFactory(serviceFactory) + .withGroupingFactor(40) + .withMaxNumRows(2)); + pipeline.run(); + + // Output should be batches of sorted mutations. + verifyBatches(batch(m(1L), m(2L)), batch(m(3L), m(4L)), batch(m(5L), m(6L))); + } + + @Test + public void streamingWritesNoGrouping() throws Exception { + + // verify that grouping/sorting does not occur - batches should be created in received order. + TestStream testStream = + TestStream.create(SerializableCoder.of(Mutation.class)) + .addElements(m(1L), m(5L), m(2L), m(4L), m(3L), m(6L)) + .advanceWatermarkToInfinity(); + + // verify that grouping/sorting does not occur when notset. + pipeline + .apply(testStream) + .apply( + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .withServiceFactory(serviceFactory) + .withMaxNumRows(2)); + pipeline.run(); + + verifyBatches(batch(m(1L), m(5L)), batch(m(2L), m(4L)), batch(m(3L), m(6L))); + } + @Test public void reportFailures() throws Exception { @@ -608,7 +672,18 @@ public void displayDataWrite() throws Exception { assertThat(data, hasDisplayItem("batchSizeBytes", 123)); assertThat(data, hasDisplayItem("maxNumMutations", 456)); assertThat(data, hasDisplayItem("maxNumRows", 789)); - assertThat(data, hasDisplayItem("groupingFactor", 100)); + assertThat(data, hasDisplayItem("groupingFactor", "100")); + + // check for default grouping value + write = + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database"); + + data = DisplayData.from(write); + assertThat(data.items(), hasSize(7)); + assertThat(data, hasDisplayItem("groupingFactor", "DEFAULT")); } @Test @@ -632,7 +707,19 @@ public void displayDataWriteGrouped() throws Exception { assertThat(data, hasDisplayItem("batchSizeBytes", 123)); assertThat(data, hasDisplayItem("maxNumMutations", 456)); assertThat(data, hasDisplayItem("maxNumRows", 789)); - assertThat(data, hasDisplayItem("groupingFactor", 100)); + assertThat(data, hasDisplayItem("groupingFactor", "100")); + + // check for default grouping value + writeGrouped = + SpannerIO.write() + .withProjectId("test-project") + .withInstanceId("test-instance") + .withDatabaseId("test-database") + .grouped(); + + data = DisplayData.from(writeGrouped); + assertThat(data.items(), hasSize(7)); + assertThat(data, hasDisplayItem("groupingFactor", "DEFAULT")); } @Test