Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/Apache.Arrow.Flight/Client/FlightClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@ public class FlightClient
internal static readonly Empty EmptyInstance = new Empty();

private readonly FlightService.FlightServiceClient _client;
private readonly ArrowContext _context;

public FlightClient(ChannelBase grpcChannel)
public FlightClient(ChannelBase grpcChannel, ArrowContext context = null)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For binary backwards compatibility, can you instead add a second public constructor with the additional parameter? i.e.

public FlightClient(ChannelBase grpcChannel) : this(grpcChannel. null)
{
}

public FlightClient(ChannelBase grpcChannel, ArrowContext context)
{
    _client = new FlightService.FlightServiceClient(grpcChannel);
    _context = context;
}

{
_client = new FlightService.FlightServiceClient(grpcChannel);
_context = context;
}

public FlightClient(CallInvoker callInvoker)
public FlightClient(CallInvoker callInvoker, ArrowContext context = null)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For binary backwards compatibility, can you instead add a second constructor with the additional parameter? This one is trickier because the DI mechanism doesn't like having two constructors that both match what it's looking for, so we need to use a factory method instead:

public FlightClient(CallInvoker callInvoker) : this(callInvoker. null)
{
}

private FlightClient(CallInvoker callInvoker, ArrowContext context)
{
    _client = new FlightService.FlightServiceClient(callInvoker);
    _context = context;
}

public static FlightClient Create(CallInvoker callInvoker, ArrowContext context)
{
    return new FlightClient(callInvoker, context);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would then also be good to add a test which exercises this code path. You could clone the existing TestIntegrationWithGrpcNetClientFactory test and then replace the client setup with

services.AddGrpcClient<FlightClient>(grpc => grpc.Address = new Uri(_testWebFactory.GetAddress()))
    .ConfigureGrpcClientCreator(invoker =>
{
    return FlightClient.Create(invoker, new ArrowContext());
});

This will also demonstrate to users how to use an ArrowContext with the factory.

{
_client = new FlightService.FlightServiceClient(callInvoker);
_context = context;
}

public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria criteria = null, Metadata headers = null)
Expand Down Expand Up @@ -77,7 +80,7 @@ public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata he
public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var stream = _client.DoGet(ticket.ToProtocol(), headers, deadline, cancellationToken);
var responseStream = new FlightClientRecordBatchStreamReader(stream.ResponseStream);
var responseStream = new FlightClientRecordBatchStreamReader(stream.ResponseStream, _context);
return new FlightRecordBatchStreamingCall(responseStream, stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, stream.Dispose);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace Apache.Arrow.Flight.Client
{
public class FlightClientRecordBatchStreamReader : FlightRecordBatchStreamReader
{
internal FlightClientRecordBatchStreamReader(IAsyncStreamReader<Protocol.FlightData> flightDataStream) : base(flightDataStream)
internal FlightClientRecordBatchStreamReader(IAsyncStreamReader<Protocol.FlightData> flightDataStream, ArrowContext context = null) : base(flightDataStream, context)
{
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ public abstract class FlightRecordBatchStreamReader : IAsyncStreamReader<RecordB

private readonly RecordBatchReaderImplementation _arrowReaderImplementation;

private protected FlightRecordBatchStreamReader(IAsyncStreamReader<Protocol.FlightData> flightDataStream)
private protected FlightRecordBatchStreamReader(IAsyncStreamReader<Protocol.FlightData> flightDataStream, ArrowContext context = null)
{
_arrowReaderImplementation = new RecordBatchReaderImplementation(flightDataStream);
_arrowReaderImplementation = new RecordBatchReaderImplementation(flightDataStream, context);
}

public ValueTask<Schema> Schema => _arrowReaderImplementation.GetSchemaAsync();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ internal class RecordBatchReaderImplementation : ArrowReaderImplementation
private FlightDescriptor _flightDescriptor;
private readonly List<ByteString> _applicationMetadatas;

public RecordBatchReaderImplementation(IAsyncStreamReader<Protocol.FlightData> streamReader)
public RecordBatchReaderImplementation(IAsyncStreamReader<Protocol.FlightData> streamReader, ArrowContext context = null)
: base(context?.Allocator, context?.CompressionCodecFactory, context?.ExtensionRegistry)
{
_flightDataStream = streamReader;
_applicationMetadatas = new List<ByteString>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ public static IEnumerable<Cookie> ParseHeader(this string setCookieHeader)

var cookies = new List<Cookie>();

var segments = setCookieHeader.Split([';'], StringSplitOptions.RemoveEmptyEntries);
var segments = setCookieHeader.Split(new[] { ';' }, StringSplitOptions.RemoveEmptyEntries);
if (segments.Length == 0)
return cookies;

var nameValue = segments[0].Split(['='], 2);
var nameValue = segments[0].Split(new[] { '=' }, 2);
if (nameValue.Length != 2 || string.IsNullOrWhiteSpace(nameValue[0]))
return cookies;

Expand All @@ -44,7 +44,7 @@ public static IEnumerable<Cookie> ParseHeader(this string setCookieHeader)

foreach (var segment in segments.Skip(1))
{
var kv = segment.Split(['='], 2, StringSplitOptions.RemoveEmptyEntries);
var kv = segment.Split(new[] { '=' }, 2, StringSplitOptions.RemoveEmptyEntries);
var key = kv[0].Trim().ToLowerInvariant();
var val = kv.Length > 1 ? kv[1] : null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ namespace Apache.Arrow.Flight.Server
{
public class FlightServerRecordBatchStreamReader : FlightRecordBatchStreamReader
{
public FlightServerRecordBatchStreamReader(IAsyncStreamReader<FlightData> flightDataStream) : base(new StreamReader<FlightData, Protocol.FlightData>(flightDataStream, data => data.ToProtocol()))
public FlightServerRecordBatchStreamReader(IAsyncStreamReader<FlightData> flightDataStream, ArrowContext context = null) : base(new StreamReader<FlightData, Protocol.FlightData>(flightDataStream, data => data.ToProtocol()), context)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For binary backwards compatibility, can you instead add a second public constructor with the additional parameter? i.e.

public FlightServerRecordBatchStreamReader(IAsyncStreamReader<FlightData> flightDataStream)
    : this(flightDataStream, null)
{
}

public FlightServerRecordBatchStreamReader(IAsyncStreamReader<FlightData> flightDataStream, ArrowContext context)
    : base(new StreamReader<FlightData, Protocol.FlightData>(flightDataStream, data => data.ToProtocol()), context)
{
}

{
}

internal FlightServerRecordBatchStreamReader(IAsyncStreamReader<Protocol.FlightData> flightDataStream) : base(flightDataStream)
internal FlightServerRecordBatchStreamReader(IAsyncStreamReader<Protocol.FlightData> flightDataStream, ArrowContext context = null) : base(flightDataStream, context)
{
}

Expand Down
76 changes: 76 additions & 0 deletions test/Apache.Arrow.Flight.Tests/FlightTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -577,5 +577,81 @@ public async Task TestIntegrationWithGrpcNetClientFactory()

SchemaComparer.Compare(expectedSchema, actualSchema);
}

[Fact]
public async Task TestGetWithArrowContext()
{
// Verify that FlightClient works when constructed with an ArrowContext
var context = new ArrowContext();
var flightClient = new FlightClient(_testWebFactory.GetChannel(), context);

var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test-context");
var expectedBatch = CreateTestBatch(0, 100);

GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch));

var flightInfo = await flightClient.GetInfo(flightDescriptor);
var endpoint = flightInfo.Endpoints.FirstOrDefault();

var getStream = flightClient.GetStream(endpoint.Ticket);
var resultList = await getStream.ResponseStream.ToListAsync();

Assert.Single(resultList);
ArrowReaderVerifier.CompareBatches(expectedBatch, resultList[0]);
}

[Fact]
public async Task TestGetWithNullArrowContext()
{
// Verify that FlightClient works with null ArrowContext (backward compat)
var flightClient = new FlightClient(_testWebFactory.GetChannel(), null);

var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test-null-context");
var expectedBatch = CreateTestBatch(0, 100);

GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch));

var flightInfo = await flightClient.GetInfo(flightDescriptor);
var endpoint = flightInfo.Endpoints.FirstOrDefault();

var getStream = flightClient.GetStream(endpoint.Ticket);
var resultList = await getStream.ResponseStream.ToListAsync();

Assert.Single(resultList);
ArrowReaderVerifier.CompareBatches(expectedBatch, resultList[0]);
}

[Fact]
public async Task TestPutAndGetWithArrowContext()
{
// Verify put + get round-trip works with ArrowContext
var context = new ArrowContext();
var flightClient = new FlightClient(_testWebFactory.GetChannel(), context);

var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test-context-roundtrip");
var expectedBatch = CreateTestBatch(0, 100);

var putStream = await flightClient.StartPut(flightDescriptor, expectedBatch.Schema);
await putStream.RequestStream.WriteAsync(expectedBatch);
await putStream.RequestStream.CompleteAsync();
await putStream.ResponseStream.ToListAsync();

var flightInfo = await flightClient.GetInfo(flightDescriptor);
var endpoint = flightInfo.Endpoints.FirstOrDefault();

var getStream = flightClient.GetStream(endpoint.Ticket);
var resultList = await getStream.ResponseStream.ToListAsync();

Assert.Single(resultList);
ArrowReaderVerifier.CompareBatches(expectedBatch, resultList[0]);
}

[Fact]
public void TestFlightClientDefaultConstructorStillWorks()
{
// Verify the original constructor without ArrowContext still works
var client = new FlightClient(_testWebFactory.GetChannel());
Assert.NotNull(client);
}
}
}