From 1b3518d13f96b97ad5d129e0229bf0290d362dd7 Mon Sep 17 00:00:00 2001 From: bvolpato Date: Sun, 22 Mar 2026 00:10:05 -0400 Subject: [PATCH 1/2] [Java IO] Add ArrowFlight IO connector Add a new IO connector for Apache Arrow Flight, enabling high-performance data transfer over gRPC using the Arrow columnar format. Includes read (BoundedSource) and write (DoFn with doPut) support with endpoint-level split parallelism and bearer token authentication. Fixes #20116 --- CHANGES.md | 1 + .../beam/gradle/BeamModulePlugin.groovy | 2 + sdks/java/io/arrow-flight/build.gradle | 39 + .../sdk/io/arrowflight/ArrowFlightIO.java | 806 ++++++++++++++++++ .../beam/sdk/io/arrowflight/package-info.java | 30 + .../sdk/io/arrowflight/ArrowFlightIOTest.java | 211 +++++ settings.gradle.kts | 1 + .../content/en/documentation/io/connectors.md | 16 + 8 files changed, 1106 insertions(+) create mode 100644 sdks/java/io/arrow-flight/build.gradle create mode 100644 sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIO.java create mode 100644 sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/package-info.java create mode 100644 sdks/java/io/arrow-flight/src/test/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIOTest.java diff --git a/CHANGES.md b/CHANGES.md index e91da103c30e..6d676ac59e08 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -65,6 +65,7 @@ ## I/Os * Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Add ArrowFlight IO (Java) ([#20116](https://github.com/apache/beam/issues/20116)). * DebeziumIO (Java): added `OffsetRetainer` interface and `FileSystemOffsetRetainer` implementation to persist and restore CDC offsets across pipeline restarts, and exposed `withStartOffset` / `withOffsetRetainer` on `DebeziumIO.Read` and the cross-language `ReadBuilder` ([#28248](https://github.com/apache/beam/issues/28248)). ## New Features / Improvements diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 10aa127ed806..96cc2b3fc152 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -925,6 +925,8 @@ class BeamModulePlugin implements Plugin { arrow_vector : "org.apache.arrow:arrow-vector:$arrow_version", arrow_memory_core : "org.apache.arrow:arrow-memory-core:$arrow_version", arrow_memory_netty : "org.apache.arrow:arrow-memory-netty:$arrow_version", + arrow_flight_core : "org.apache.arrow:flight-core:$arrow_version", + arrow_flight_sql : "org.apache.arrow:flight-sql:$arrow_version", ], groovy: [ groovy_all: "org.codehaus.groovy:groovy-all:2.4.13", diff --git a/sdks/java/io/arrow-flight/build.gradle b/sdks/java/io/arrow-flight/build.gradle new file mode 100644 index 000000000000..f587e59c19d9 --- /dev/null +++ b/sdks/java/io/arrow-flight/build.gradle @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { id 'org.apache.beam.module' } +applyJavaNature(automaticModuleName: 'org.apache.beam.sdk.io.arrowflight') + +description = "Apache Beam :: SDKs :: Java :: IO :: Arrow Flight" +ext.summary = "IO to read and write data using Apache Arrow Flight RPC." + +dependencies { + implementation project(path: ":sdks:java:core", configuration: "shadow") + implementation project(path: ":sdks:java:extensions:arrow") + implementation library.java.joda_time + implementation library.java.slf4j_api + implementation library.java.vendored_guava_32_1_2_jre + implementation(library.java.arrow_flight_core) + implementation(library.java.arrow_memory_core) + implementation(library.java.arrow_vector) + testImplementation library.java.hamcrest + testImplementation library.java.junit + testImplementation(library.java.arrow_memory_netty) + testRuntimeOnly library.java.slf4j_simple + testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") +} diff --git a/sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIO.java b/sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIO.java new file mode 100644 index 000000000000..cf1607641778 --- /dev/null +++ b/sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIO.java @@ -0,0 +1,806 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.arrowflight; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + +import com.google.auto.value.AutoValue; +import java.io.IOException; +import java.io.Serializable; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import org.apache.arrow.flight.FlightCallHeaders; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.HeaderCallOption; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.extensions.arrow.ArrowConversion; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.Row; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * IO to read and write data using Apache + * Arrow Flight. + * + *

Arrow Flight is a high-performance RPC framework for transferring Arrow-formatted data over + * gRPC. It enables 10-50x faster data transfer compared to JDBC/ODBC by avoiding + * serialization/deserialization overhead. + * + *

Reading from an Arrow Flight server

+ * + *

{@link ArrowFlightIO#read()} returns a bounded {@link PCollection} of {@link Row} elements. + * Each row is converted from Arrow record batches using the existing {@link ArrowConversion} + * extension. + * + *

{@code
+ * PCollection rows = pipeline.apply(
+ *     ArrowFlightIO.read()
+ *         .withHost("localhost")
+ *         .withPort(47470)
+ *         .withCommand("SELECT * FROM my_table"));
+ * }
+ * + *

Writing to an Arrow Flight server

+ * + *

{@link ArrowFlightIO#write()} accepts a {@link PCollection} of {@link Row} elements and + * streams them to a Flight server using {@code doPut}. + * + *

{@code
+ * rows.apply(
+ *     ArrowFlightIO.write()
+ *         .withHost("localhost")
+ *         .withPort(47470)
+ *         .withDescriptor("my_table")
+ *         .withBatchSize(1024));
+ * }
+ */ +public class ArrowFlightIO { + + private static final Logger LOG = LoggerFactory.getLogger(ArrowFlightIO.class); + + private ArrowFlightIO() {} + + public static Read read() { + return new AutoValue_ArrowFlightIO_Read.Builder().setPort(47470).setUseTls(false).build(); + } + + public static Write write() { + return new AutoValue_ArrowFlightIO_Write.Builder() + .setPort(47470) + .setUseTls(false) + .setBatchSize(1024) + .build(); + } + + /** + * Creates a {@link FlightClient} from the given connection parameters. + * + *

The client uses a {@link RootAllocator} for Arrow memory management and connects to the + * specified host and port using either plaintext or TLS. + */ + static FlightClient createClient( + BufferAllocator allocator, String host, int port, boolean useTls) { + Location location; + if (useTls) { + location = Location.forGrpcTls(host, port); + } else { + location = Location.forGrpcInsecure(host, port); + } + return FlightClient.builder(allocator, location).build(); + } + + /** A serializable wrapper around Flight endpoint information for use in BoundedSource splits. */ + static class SerializableEndpoint implements Serializable { + private static final long serialVersionUID = 1L; + + private final byte[] ticketBytes; + private final @Nullable String host; + private final int port; + + SerializableEndpoint(byte[] ticketBytes, @Nullable String host, int port) { + this.ticketBytes = ticketBytes; + this.host = host; + this.port = port; + } + + static SerializableEndpoint fromFlightEndpoint( + FlightEndpoint endpoint, String defaultHost, int defaultPort) { + byte[] ticket = endpoint.getTicket().getBytes(); + List locations = endpoint.getLocations(); + if (locations != null && !locations.isEmpty()) { + URI uri = locations.get(0).getUri(); + return new SerializableEndpoint(ticket, uri.getHost(), uri.getPort()); + } + return new SerializableEndpoint(ticket, defaultHost, defaultPort); + } + + Ticket getTicket() { + return new Ticket(ticketBytes); + } + + String getHost(String defaultHost) { + return host != null ? host : defaultHost; + } + + int getPort(int defaultPort) { + return port > 0 ? port : defaultPort; + } + } + + // ======================== READ ======================== + + @AutoValue + public abstract static class Read extends PTransform> { + + abstract String host(); + + abstract int port(); + + abstract boolean useTls(); + + abstract @Nullable String command(); + + @SuppressWarnings("mutable") + abstract byte @Nullable [] token(); + + abstract Builder builder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setHost(String host); + + abstract Builder setPort(int port); + + abstract Builder setUseTls(boolean useTls); + + abstract Builder setCommand(String command); + + abstract Builder setToken(byte[] token); + + abstract Read build(); + } + + /** Sets the Flight server host. */ + public Read withHost(String host) { + return builder().setHost(host).build(); + } + + /** Sets the Flight server port. */ + public Read withPort(int port) { + return builder().setPort(port).build(); + } + + /** Enables TLS for the connection. */ + public Read withUseTls(boolean useTls) { + return builder().setUseTls(useTls).build(); + } + + /** Sets the command (e.g., a SQL query or table name) to request from the Flight server. */ + public Read withCommand(String command) { + return builder().setCommand(command).build(); + } + + /** Sets a bearer token for authentication. */ + public Read withToken(byte[] token) { + return builder().setToken(token).build(); + } + + @Override + public PCollection expand(PBegin input) { + checkArgument(host() != null, "withHost() is required"); + checkArgument(command() != null, "withCommand() is required"); + + Schema beamSchema; + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + FlightClient client = createClient(allocator, host(), port(), useTls())) { + FlightInfo info = + client.getInfo( + FlightDescriptor.command( + checkNotNull(command(), "command").getBytes(StandardCharsets.UTF_8)), + callOptions()); + beamSchema = ArrowConversion.ArrowSchemaTranslator.toBeamSchema(info.getSchema()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while fetching Flight schema", e); + } + + return input + .apply(org.apache.beam.sdk.io.Read.from(new FlightBoundedSource(this, beamSchema))) + .setRowSchema(beamSchema); + } + + HeaderCallOption[] callOptions() { + if (token() != null) { + FlightCallHeaders headers = new FlightCallHeaders(); + headers.insert( + "authorization", + "Bearer " + new String(checkNotNull(token(), "token"), StandardCharsets.UTF_8)); + return new HeaderCallOption[] {new HeaderCallOption(headers)}; + } + return new HeaderCallOption[0]; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder.add(DisplayData.item("host", host())); + builder.add(DisplayData.item("port", port())); + builder.add(DisplayData.item("useTls", useTls())); + builder.addIfNotNull(DisplayData.item("command", command())); + } + } + + /** A {@link BoundedSource} that reads rows from Arrow Flight endpoints. */ + static class FlightBoundedSource extends BoundedSource { + private final Read spec; + private final Schema beamSchema; + private final @Nullable SerializableEndpoint endpoint; + + FlightBoundedSource(Read spec, Schema beamSchema) { + this(spec, beamSchema, null); + } + + FlightBoundedSource(Read spec, Schema beamSchema, @Nullable SerializableEndpoint endpoint) { + this.spec = spec; + this.beamSchema = beamSchema; + this.endpoint = endpoint; + } + + @Override + public List> split( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + if (endpoint != null) { + return Collections.singletonList(this); + } + + List> sources = new ArrayList<>(); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + FlightClient client = createClient(allocator, spec.host(), spec.port(), spec.useTls())) { + FlightInfo info = + client.getInfo( + FlightDescriptor.command( + checkNotNull(spec.command(), "command").getBytes(StandardCharsets.UTF_8)), + spec.callOptions()); + for (FlightEndpoint fe : info.getEndpoints()) { + SerializableEndpoint se = + SerializableEndpoint.fromFlightEndpoint(fe, spec.host(), spec.port()); + sources.add(new FlightBoundedSource(spec, beamSchema, se)); + } + } + + if (sources.isEmpty()) { + sources.add(this); + } + return sources; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + if (endpoint != null) { + return -1; + } + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + FlightClient client = createClient(allocator, spec.host(), spec.port(), spec.useTls())) { + FlightInfo info = + client.getInfo( + FlightDescriptor.command( + checkNotNull(spec.command(), "command").getBytes(StandardCharsets.UTF_8)), + spec.callOptions()); + return info.getBytes(); + } + } + + @Override + public BoundedReader createReader(PipelineOptions options) { + return new FlightBoundedReader(this); + } + + @Override + public void validate() { + checkArgument(spec.host() != null, "host is required"); + checkArgument(spec.command() != null, "command is required"); + } + + @Override + public Coder getOutputCoder() { + return RowCoder.of(beamSchema); + } + } + + /** Reader that streams Arrow record batches from a Flight endpoint and emits Beam Rows. */ + @SuppressWarnings("initialization.fields.uninitialized") + static class FlightBoundedReader extends BoundedSource.BoundedReader { + private static final Counter RECORDS_READ = Metrics.counter(ArrowFlightIO.class, "recordsRead"); + + private final FlightBoundedSource source; + private transient BufferAllocator allocator; + private transient FlightClient client; + private transient FlightStream stream; + private transient Iterator currentBatchIterator; + private transient Row current; + + FlightBoundedReader(FlightBoundedSource source) { + this.source = source; + } + + @Override + public boolean start() throws IOException { + allocator = new RootAllocator(Long.MAX_VALUE); + Read spec = source.spec; + + if (source.endpoint != null) { + String host = source.endpoint.getHost(spec.host()); + int port = source.endpoint.getPort(spec.port()); + client = createClient(allocator, host, port, spec.useTls()); + stream = client.getStream(source.endpoint.getTicket(), spec.callOptions()); + } else { + client = createClient(allocator, spec.host(), spec.port(), spec.useTls()); + FlightInfo info = + client.getInfo( + FlightDescriptor.command( + checkNotNull(spec.command(), "command").getBytes(StandardCharsets.UTF_8)), + spec.callOptions()); + List endpoints = info.getEndpoints(); + if (endpoints.isEmpty()) { + return false; + } + stream = client.getStream(endpoints.get(0).getTicket(), spec.callOptions()); + } + + currentBatchIterator = Collections.emptyIterator(); + return advance(); + } + + @Override + public boolean advance() throws IOException { + while (true) { + if (currentBatchIterator.hasNext()) { + current = currentBatchIterator.next(); + RECORDS_READ.inc(); + return true; + } + if (stream.next()) { + VectorSchemaRoot root = stream.getRoot(); + if (root.getRowCount() > 0) { + currentBatchIterator = ArrowConversion.rowsFromRecordBatch(source.beamSchema, root); + } + } else { + return false; + } + } + } + + @Override + public Row getCurrent() { + return current; + } + + @Override + public void close() throws IOException { + try { + if (stream != null) { + stream.close(); + } + } catch (Exception e) { + LOG.warn("Error closing FlightStream", e); + } + try { + if (client != null) { + client.close(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted closing FlightClient", e); + } + try { + if (allocator != null) { + allocator.close(); + } + } catch (Exception e) { + LOG.warn("Error closing BufferAllocator", e); + } + } + + @Override + public BoundedSource getCurrentSource() { + return source; + } + } + + // ======================== WRITE ======================== + + @AutoValue + public abstract static class Write extends PTransform, PDone> { + + abstract String host(); + + abstract int port(); + + abstract boolean useTls(); + + abstract @Nullable String descriptor(); + + abstract int batchSize(); + + @SuppressWarnings("mutable") + abstract byte @Nullable [] token(); + + abstract Builder builder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setHost(String host); + + abstract Builder setPort(int port); + + abstract Builder setUseTls(boolean useTls); + + abstract Builder setDescriptor(String descriptor); + + abstract Builder setBatchSize(int batchSize); + + abstract Builder setToken(byte[] token); + + abstract Write build(); + } + + /** Sets the Flight server host. */ + public Write withHost(String host) { + return builder().setHost(host).build(); + } + + /** Sets the Flight server port. */ + public Write withPort(int port) { + return builder().setPort(port).build(); + } + + /** Enables TLS for the connection. */ + public Write withUseTls(boolean useTls) { + return builder().setUseTls(useTls).build(); + } + + /** Sets the Flight descriptor (table name or path) for the write target. */ + public Write withDescriptor(String descriptor) { + return builder().setDescriptor(descriptor).build(); + } + + /** Sets the batch size for writing. Rows are buffered and flushed in batches. */ + public Write withBatchSize(int batchSize) { + checkArgument(batchSize > 0, "batchSize must be positive"); + return builder().setBatchSize(batchSize).build(); + } + + /** Sets a bearer token for authentication. */ + public Write withToken(byte[] token) { + return builder().setToken(token).build(); + } + + @Override + public PDone expand(PCollection input) { + checkArgument(host() != null, "withHost() is required"); + checkArgument(descriptor() != null, "withDescriptor() is required"); + + input.apply(ParDo.of(new FlightWriteFn(this))); + return PDone.in(input.getPipeline()); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder.add(DisplayData.item("host", host())); + builder.add(DisplayData.item("port", port())); + builder.add(DisplayData.item("useTls", useTls())); + builder.addIfNotNull(DisplayData.item("descriptor", descriptor())); + builder.add(DisplayData.item("batchSize", batchSize())); + } + } + + /** DoFn that buffers Beam Rows and streams them as Arrow record batches to a Flight server. */ + @SuppressWarnings("initialization.fields.uninitialized") + static class FlightWriteFn extends DoFn { + private static final Counter RECORDS_WRITTEN = + Metrics.counter(ArrowFlightIO.class, "recordsWritten"); + private static final Counter BATCHES_WRITTEN = + Metrics.counter(ArrowFlightIO.class, "batchesWritten"); + + private final Write spec; + private transient BufferAllocator allocator; + private transient FlightClient client; + private transient FlightClient.ClientStreamListener listener; + private transient VectorSchemaRoot root; + private transient org.apache.arrow.vector.types.pojo.Schema arrowSchema; + private transient List batch; + private transient @Nullable Schema beamSchema; + + FlightWriteFn(Write spec) { + this.spec = spec; + } + + @StartBundle + public void startBundle() { + batch = new ArrayList<>(); + } + + @ProcessElement + public void processElement(@Element Row row) { + if (beamSchema == null) { + beamSchema = row.getSchema(); + } + batch.add(row); + if (batch.size() >= spec.batchSize()) { + flush(); + } + } + + @FinishBundle + public void finishBundle() { + flush(); + closeConnection(); + } + + @Teardown + public void teardown() { + closeConnection(); + } + + @SuppressWarnings("nullness") + private void ensureConnection() { + if (client == null) { + allocator = new RootAllocator(Long.MAX_VALUE); + client = createClient(allocator, spec.host(), spec.port(), spec.useTls()); + + List arrowFields = new ArrayList<>(); + for (Schema.Field beamField : beamSchema.getFields()) { + arrowFields.add(toArrowField(beamField)); + } + arrowSchema = new org.apache.arrow.vector.types.pojo.Schema(arrowFields); + root = VectorSchemaRoot.create(arrowSchema, allocator); + + FlightDescriptor descriptor = FlightDescriptor.path(spec.descriptor()); + listener = client.startPut(descriptor, root, new AsyncPutListener()); + } + } + + private Field toArrowField(Schema.Field beamField) { + ArrowType arrowType = beamTypeToArrowType(beamField.getType()); + FieldType fieldType = + beamField.getType().getNullable() + ? FieldType.nullable(arrowType) + : FieldType.notNullable(arrowType); + return new Field(beamField.getName(), fieldType, Collections.emptyList()); + } + + private ArrowType beamTypeToArrowType(Schema.FieldType beamType) { + switch (beamType.getTypeName()) { + case BYTE: + return new ArrowType.Int(8, true); + case INT16: + return new ArrowType.Int(16, true); + case INT32: + return new ArrowType.Int(32, true); + case INT64: + return new ArrowType.Int(64, true); + case FLOAT: + return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + case DOUBLE: + return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case STRING: + return ArrowType.Utf8.INSTANCE; + case BOOLEAN: + return ArrowType.Bool.INSTANCE; + case BYTES: + return ArrowType.Binary.INSTANCE; + case DATETIME: + return new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC"); + default: + LOG.warn("Unsupported Beam type {}, falling back to Utf8", beamType.getTypeName()); + return ArrowType.Utf8.INSTANCE; + } + } + + @SuppressWarnings("nullness") + private void flush() { + if (batch == null || batch.isEmpty()) { + return; + } + ensureConnection(); + + root.setRowCount(batch.size()); + for (int colIdx = 0; colIdx < beamSchema.getFieldCount(); colIdx++) { + FieldVector vector = root.getVector(colIdx); + vector.allocateNew(); + Schema.Field field = beamSchema.getField(colIdx); + for (int rowIdx = 0; rowIdx < batch.size(); rowIdx++) { + Object value = batch.get(rowIdx).getValue(colIdx); + if (value == null) { + vector.setNull(rowIdx); + } else { + setVectorValue(vector, rowIdx, value, field.getType()); + } + } + vector.setValueCount(batch.size()); + } + + listener.putNext(); + RECORDS_WRITTEN.inc(batch.size()); + BATCHES_WRITTEN.inc(); + root.clear(); + batch.clear(); + } + + @SuppressWarnings("nullness") + private void setVectorValue( + FieldVector vector, int index, Object value, Schema.FieldType type) { + switch (type.getTypeName()) { + case BYTE: + ((TinyIntVector) vector).setSafe(index, ((Number) value).byteValue()); + break; + case INT16: + ((SmallIntVector) vector).setSafe(index, ((Number) value).shortValue()); + break; + case INT32: + ((IntVector) vector).setSafe(index, ((Number) value).intValue()); + break; + case INT64: + ((BigIntVector) vector).setSafe(index, ((Number) value).longValue()); + break; + case FLOAT: + ((Float4Vector) vector).setSafe(index, ((Number) value).floatValue()); + break; + case DOUBLE: + ((Float8Vector) vector).setSafe(index, ((Number) value).doubleValue()); + break; + case BOOLEAN: + ((BitVector) vector).setSafe(index, ((Boolean) value) ? 1 : 0); + break; + case STRING: + ((VarCharVector) vector) + .setSafe(index, value.toString().getBytes(StandardCharsets.UTF_8)); + break; + case BYTES: + ((VarBinaryVector) vector).setSafe(index, (byte[]) value); + break; + case DATETIME: + long millis; + if (value instanceof org.joda.time.ReadableInstant) { + millis = ((org.joda.time.ReadableInstant) value).getMillis(); + } else { + millis = ((Number) value).longValue(); + } + ((TimeStampMilliTZVector) vector).setSafe(index, millis); + break; + default: + ((VarCharVector) vector) + .setSafe(index, value.toString().getBytes(StandardCharsets.UTF_8)); + break; + } + } + + @SuppressWarnings("nullness") + private void closeConnection() { + try { + if (listener != null) { + listener.completed(); + listener.getResult(); + listener = null; + } + } catch (Exception e) { + LOG.warn("Error completing Flight put", e); + } + try { + if (root != null) { + root.close(); + root = null; + } + } catch (Exception e) { + LOG.warn("Error closing VectorSchemaRoot", e); + } + try { + if (client != null) { + client.close(); + client = null; + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted closing FlightClient", e); + } + try { + if (allocator != null) { + allocator.close(); + allocator = null; + } + } catch (Exception e) { + LOG.warn("Error closing BufferAllocator", e); + } + } + } + + /** A no-op listener for async put operations. */ + static class AsyncPutListener implements FlightClient.PutListener { + private volatile @Nullable Throwable error; + + @Override + public void onNext(PutResult val) {} + + @Override + public void onError(Throwable t) { + this.error = t; + } + + @Override + public void onCompleted() {} + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public void getResult() { + Throwable t = error; + if (t != null) { + throw new RuntimeException("Error during Flight put", t); + } + } + } +} diff --git a/sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/package-info.java b/sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/package-info.java new file mode 100644 index 000000000000..34f9b95670d9 --- /dev/null +++ b/sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/package-info.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * I/O connector for Apache Arrow + * Flight. + * + *

Arrow Flight is a high-performance RPC framework for fast data transport using the Apache + * Arrow columnar format over gRPC. This connector enables Beam pipelines to read from and write to + * any Arrow Flight-compatible data system, including Dremio, ClickHouse, Apache Doris, InfluxDB 3, + * DataFusion, and custom Flight servers. + * + * @see org.apache.beam.sdk.io.arrowflight.ArrowFlightIO + */ +package org.apache.beam.sdk.io.arrowflight; diff --git a/sdks/java/io/arrow-flight/src/test/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIOTest.java b/sdks/java/io/arrow-flight/src/test/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIOTest.java new file mode 100644 index 000000000000..e45a0be3db0e --- /dev/null +++ b/sdks/java/io/arrow-flight/src/test/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIOTest.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.arrowflight; + +import static org.junit.Assert.assertEquals; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.ActionType; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link ArrowFlightIO}. */ +@RunWith(JUnit4.class) +public class ArrowFlightIOTest { + + @Rule public final TestPipeline pipeline = TestPipeline.create(); + + private BufferAllocator allocator; + private FlightServer server; + private TestFlightProducer producer; + private int port; + + @Before + public void setUp() throws Exception { + allocator = new RootAllocator(Long.MAX_VALUE); + producer = new TestFlightProducer(allocator); + + // Bind to any available port + Location location = Location.forGrpcInsecure("localhost", 0); + server = FlightServer.builder(allocator, location, producer).build(); + server.start(); + + port = server.getPort(); + } + + @After + public void tearDown() throws Exception { + if (server != null) { + server.close(); + } + if (allocator != null) { + allocator.close(); + } + } + + @Test + public void testRead() { + PCollection output = + pipeline.apply( + "Read from Flight", + ArrowFlightIO.read().withHost("localhost").withPort(port).withCommand("test_query")); + + Schema expectedSchema = Schema.builder().addStringField("name").build(); + Row expectedRow1 = Row.withSchema(expectedSchema).addValue("Alice").build(); + Row expectedRow2 = Row.withSchema(expectedSchema).addValue("Bob").build(); + + PAssert.that(output).containsInAnyOrder(expectedRow1, expectedRow2); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testWrite() throws Exception { + Schema expectedSchema = Schema.builder().addStringField("name").build(); + Row row1 = Row.withSchema(expectedSchema).addValue("Charlie").build(); + Row row2 = Row.withSchema(expectedSchema).addValue("Dave").build(); + + pipeline + .apply(Create.of(row1, row2).withRowSchema(expectedSchema)) + .apply( + "Write to Flight", + ArrowFlightIO.write() + .withHost("localhost") + .withPort(port) + .withDescriptor("test_table")); + + pipeline.run().waitUntilFinish(); + + assertEquals(2, producer.writtenRecords); + } + + /** A simple FlightProducer that returns predefined data for reads and counts writes. */ + private static class TestFlightProducer implements FlightProducer { + + private final BufferAllocator allocator; + int writtenRecords = 0; + + TestFlightProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + org.apache.arrow.vector.types.pojo.Schema schema = + new org.apache.arrow.vector.types.pojo.Schema( + Collections.singletonList( + new Field( + "name", FieldType.nullable(new ArrowType.Utf8()), Collections.emptyList()))); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + listener.start(root); + + VarCharVector vector = (VarCharVector) root.getVector("name"); + vector.allocateNew(); + vector.setSafe(0, "Alice".getBytes(StandardCharsets.UTF_8)); + vector.setSafe(1, "Bob".getBytes(StandardCharsets.UTF_8)); + vector.setValueCount(2); + root.setRowCount(2); + + listener.putNext(); + listener.completed(); + } catch (Exception e) { + listener.error(e); + } + } + + @Override + public void listFlights( + CallContext context, Criteria criteria, StreamListener listener) { + listener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + org.apache.arrow.vector.types.pojo.Schema schema = + new org.apache.arrow.vector.types.pojo.Schema( + Collections.singletonList( + new Field( + "name", FieldType.nullable(new ArrowType.Utf8()), Collections.emptyList()))); + return new FlightInfo( + schema, + descriptor, + Collections.singletonList( + new FlightEndpoint( + new Ticket(descriptor.getCommand()), Location.forGrpcInsecure("localhost", 0))), + -1, + -1); + } + + @Override + public Runnable acceptPut( + CallContext context, FlightStream flightStream, StreamListener ackStream) { + return () -> { + try { + while (flightStream.next()) { + VectorSchemaRoot root = flightStream.getRoot(); + writtenRecords += root.getRowCount(); + } + ackStream.onCompleted(); + } catch (Exception e) { + ackStream.onError(e); + } + }; + } + + @Override + public void doAction(CallContext context, Action action, StreamListener listener) { + listener.onCompleted(); + } + + @Override + public void listActions(CallContext context, StreamListener listener) { + listener.onCompleted(); + } + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index a37c57d043c3..6a5829d24d1b 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -212,6 +212,7 @@ include(":sdks:java:io:azure") include(":sdks:java:io:azure-cosmos") include(":sdks:java:io:cassandra") include(":sdks:java:io:clickhouse") +include(":sdks:java:io:arrow-flight") include(":sdks:java:io:common") include(":sdks:java:io:components") include(":sdks:java:io:contextualtextio") diff --git a/website/www/site/content/en/documentation/io/connectors.md b/website/www/site/content/en/documentation/io/connectors.md index 856675f0e0f3..26fc01100198 100644 --- a/website/www/site/content/en/documentation/io/connectors.md +++ b/website/www/site/content/en/documentation/io/connectors.md @@ -878,6 +878,22 @@ This table provides a consolidated, at-a-glance overview of the available built- ✘ ✘ + + ArrowFlightIO + ✔ + ✔ + + ✔ + native + + Not available + Not available + Not available + Not available + ✔ + ✔ + ✘ + DatabaseIO ✔ From e1e26fdbf76f03265227a5bf164b74b36ab378ae Mon Sep 17 00:00:00 2001 From: Bruno Volpato Date: Fri, 27 Mar 2026 20:33:05 -0400 Subject: [PATCH 2/2] Add arrow-flight to javaioPreCommit and fix test issues - Register :sdks:java:io:arrow-flight in javaioPreCommit task (build.gradle.kts) - Add --add-opens JVM arg for Arrow native memory on JDK 17+ - Make host() nullable at AutoValue level to fix factory method NPE - Eagerly materialize rows from Arrow buffers to prevent stale access - Move root.setRowCount() after vector population for correct ordering - Use AtomicInteger for thread-safe write record counting in tests --- build.gradle.kts | 1 + sdks/java/io/arrow-flight/build.gradle | 4 ++ .../sdk/io/arrowflight/ArrowFlightIO.java | 39 +++++++++++++------ .../sdk/io/arrowflight/ArrowFlightIOTest.java | 7 ++-- 4 files changed, 36 insertions(+), 15 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 2465d581228a..13a61b6faf87 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -343,6 +343,7 @@ tasks.register("javaPreCommit") { // a precommit task build multiple IOs (except those splitting into single jobs) tasks.register("javaioPreCommit") { dependsOn(":sdks:java:io:amqp:build") + dependsOn(":sdks:java:io:arrow-flight:build") dependsOn(":sdks:java:io:cassandra:build") dependsOn(":sdks:java:io:csv:build") dependsOn(":sdks:java:io:cdap:build") diff --git a/sdks/java/io/arrow-flight/build.gradle b/sdks/java/io/arrow-flight/build.gradle index f587e59c19d9..020dc61e4096 100644 --- a/sdks/java/io/arrow-flight/build.gradle +++ b/sdks/java/io/arrow-flight/build.gradle @@ -37,3 +37,7 @@ dependencies { testRuntimeOnly library.java.slf4j_simple testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") } + +test { + jvmArgs '--add-opens=java.base/java.nio=ALL-UNNAMED' +} diff --git a/sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIO.java b/sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIO.java index cf1607641778..a57c3bf59140 100644 --- a/sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIO.java +++ b/sdks/java/io/arrow-flight/src/main/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIO.java @@ -192,7 +192,7 @@ int getPort(int defaultPort) { @AutoValue public abstract static class Read extends PTransform> { - abstract String host(); + abstract @Nullable String host(); abstract int port(); @@ -252,7 +252,8 @@ public PCollection expand(PBegin input) { Schema beamSchema; try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - FlightClient client = createClient(allocator, host(), port(), useTls())) { + FlightClient client = + createClient(allocator, checkNotNull(host(), "host"), port(), useTls())) { FlightInfo info = client.getInfo( FlightDescriptor.command( @@ -283,7 +284,7 @@ HeaderCallOption[] callOptions() { @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - builder.add(DisplayData.item("host", host())); + builder.addIfNotNull(DisplayData.item("host", host())); builder.add(DisplayData.item("port", port())); builder.add(DisplayData.item("useTls", useTls())); builder.addIfNotNull(DisplayData.item("command", command())); @@ -315,7 +316,9 @@ public List> split( List> sources = new ArrayList<>(); try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - FlightClient client = createClient(allocator, spec.host(), spec.port(), spec.useTls())) { + FlightClient client = + createClient( + allocator, checkNotNull(spec.host(), "host"), spec.port(), spec.useTls())) { FlightInfo info = client.getInfo( FlightDescriptor.command( @@ -323,7 +326,8 @@ public List> split( spec.callOptions()); for (FlightEndpoint fe : info.getEndpoints()) { SerializableEndpoint se = - SerializableEndpoint.fromFlightEndpoint(fe, spec.host(), spec.port()); + SerializableEndpoint.fromFlightEndpoint( + fe, checkNotNull(spec.host(), "host"), spec.port()); sources.add(new FlightBoundedSource(spec, beamSchema, se)); } } @@ -340,7 +344,9 @@ public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { return -1; } try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - FlightClient client = createClient(allocator, spec.host(), spec.port(), spec.useTls())) { + FlightClient client = + createClient( + allocator, checkNotNull(spec.host(), "host"), spec.port(), spec.useTls())) { FlightInfo info = client.getInfo( FlightDescriptor.command( @@ -388,13 +394,14 @@ public boolean start() throws IOException { allocator = new RootAllocator(Long.MAX_VALUE); Read spec = source.spec; + String hostName = checkNotNull(spec.host(), "host"); if (source.endpoint != null) { - String host = source.endpoint.getHost(spec.host()); + String host = source.endpoint.getHost(hostName); int port = source.endpoint.getPort(spec.port()); client = createClient(allocator, host, port, spec.useTls()); stream = client.getStream(source.endpoint.getTicket(), spec.callOptions()); } else { - client = createClient(allocator, spec.host(), spec.port(), spec.useTls()); + client = createClient(allocator, hostName, spec.port(), spec.useTls()); FlightInfo info = client.getInfo( FlightDescriptor.command( @@ -422,7 +429,15 @@ public boolean advance() throws IOException { if (stream.next()) { VectorSchemaRoot root = stream.getRoot(); if (root.getRowCount() > 0) { - currentBatchIterator = ArrowConversion.rowsFromRecordBatch(source.beamSchema, root); + Iterator lazyIterator = + ArrowConversion.rowsFromRecordBatch(source.beamSchema, root); + List materializedRows = new ArrayList<>(); + while (lazyIterator.hasNext()) { + Row lazyRow = lazyIterator.next(); + materializedRows.add( + Row.withSchema(source.beamSchema).addValues(lazyRow.getValues()).build()); + } + currentBatchIterator = materializedRows.iterator(); } } else { return false; @@ -472,7 +487,7 @@ public BoundedSource getCurrentSource() { @AutoValue public abstract static class Write extends PTransform, PDone> { - abstract String host(); + abstract @Nullable String host(); abstract int port(); @@ -547,7 +562,7 @@ public PDone expand(PCollection input) { @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - builder.add(DisplayData.item("host", host())); + builder.addIfNotNull(DisplayData.item("host", host())); builder.add(DisplayData.item("port", port())); builder.add(DisplayData.item("useTls", useTls())); builder.addIfNotNull(DisplayData.item("descriptor", descriptor())); @@ -665,7 +680,6 @@ private void flush() { } ensureConnection(); - root.setRowCount(batch.size()); for (int colIdx = 0; colIdx < beamSchema.getFieldCount(); colIdx++) { FieldVector vector = root.getVector(colIdx); vector.allocateNew(); @@ -680,6 +694,7 @@ private void flush() { } vector.setValueCount(batch.size()); } + root.setRowCount(batch.size()); listener.putNext(); RECORDS_WRITTEN.inc(batch.size()); diff --git a/sdks/java/io/arrow-flight/src/test/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIOTest.java b/sdks/java/io/arrow-flight/src/test/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIOTest.java index e45a0be3db0e..b35a8f60827c 100644 --- a/sdks/java/io/arrow-flight/src/test/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIOTest.java +++ b/sdks/java/io/arrow-flight/src/test/java/org/apache/beam/sdk/io/arrowflight/ArrowFlightIOTest.java @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets; import java.util.Collections; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.arrow.flight.Action; import org.apache.arrow.flight.ActionType; import org.apache.arrow.flight.Criteria; @@ -121,14 +122,14 @@ public void testWrite() throws Exception { pipeline.run().waitUntilFinish(); - assertEquals(2, producer.writtenRecords); + assertEquals(2, producer.writtenRecords.get()); } /** A simple FlightProducer that returns predefined data for reads and counts writes. */ private static class TestFlightProducer implements FlightProducer { private final BufferAllocator allocator; - int writtenRecords = 0; + final AtomicInteger writtenRecords = new AtomicInteger(); TestFlightProducer(BufferAllocator allocator) { this.allocator = allocator; @@ -189,7 +190,7 @@ public Runnable acceptPut( try { while (flightStream.next()) { VectorSchemaRoot root = flightStream.getRoot(); - writtenRecords += root.getRowCount(); + writtenRecords.addAndGet(root.getRowCount()); } ackStream.onCompleted(); } catch (Exception e) {