From 0763588b23f66166857b5dedf43d8e00fbec603a Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Sat, 12 Jun 2021 06:00:09 +0900 Subject: [PATCH 01/15] Support Dictionary - Implement Dictionary serialization for ArrowStreamReader/Writer --- .../Internal/FlightMessageSerializer.cs | 6 +- csharp/src/Apache.Arrow/Arrays/ArrayData.cs | 11 +- .../Apache.Arrow/Arrays/ArrowArrayFactory.cs | 1 + .../Apache.Arrow/Arrays/DictionaryArray.cs | 67 +++++ .../Ipc/ArrowFileReaderImplementation.cs | 2 +- csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs | 4 +- .../Ipc/ArrowMemoryReaderImplementation.cs | 4 +- .../Ipc/ArrowReaderImplementation.cs | 59 +++- .../Ipc/ArrowStreamReaderImplementation.cs | 146 ++++++--- .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 284 +++++++++++++++--- .../Ipc/ArrowTypeFlatbufferBuilder.cs | 11 +- csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs | 113 +++++++ .../src/Apache.Arrow/Ipc/MessageSerializer.cs | 33 +- csharp/src/Apache.Arrow/RecordBatch.cs | 2 +- .../src/Apache.Arrow/Types/DictionaryType.cs | 46 +++ csharp/src/Apache.Arrow/Types/Int16Type.cs | 4 +- csharp/src/Apache.Arrow/Types/Int32Type.cs | 4 +- csharp/src/Apache.Arrow/Types/Int64Type.cs | 4 +- csharp/src/Apache.Arrow/Types/Int8Type.cs | 2 +- csharp/src/Apache.Arrow/Types/IntegerType.cs | 22 ++ csharp/src/Apache.Arrow/Types/UInt16Type.cs | 4 +- csharp/src/Apache.Arrow/Types/UInt32Type.cs | 4 +- csharp/src/Apache.Arrow/Types/UInt64Type.cs | 4 +- csharp/src/Apache.Arrow/Types/UInt8Type.cs | 4 +- .../ArrowWriterBenchmark.cs | 2 +- .../Apache.Arrow.Tests/ArrowReaderVerifier.cs | 13 +- .../ArrowStreamWriterTests.cs | 116 +++++-- .../DictionaryArrayTests.cs | 67 +++++ csharp/test/Apache.Arrow.Tests/TestData.cs | 32 +- 29 files changed, 918 insertions(+), 153 deletions(-) create mode 100644 csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs create mode 100644 csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs create mode 100644 csharp/src/Apache.Arrow/Types/DictionaryType.cs create mode 100644 csharp/src/Apache.Arrow/Types/IntegerType.cs create mode 100644 csharp/test/Apache.Arrow.Tests/DictionaryArrayTests.cs diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs index 0ac2d19b2971..91919440f900 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs @@ -44,13 +44,15 @@ public static Schema DecodeSchema(ReadOnlyMemory buffer) } ByteBuffer schemaBuffer = ArrowReaderImplementation.CreateByteBuffer(buffer.Slice(bufferPosition)); - var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer)); + //DictionaryBatch not supported for now + var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer), default); return schema; } internal static Schema DecodeSchema(ByteBuffer schemaBuffer) { - var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer)); + //DictionaryBatch not supported for now + var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer), default); return schema; } } diff --git a/csharp/src/Apache.Arrow/Arrays/ArrayData.cs b/csharp/src/Apache.Arrow/Arrays/ArrayData.cs index 93bb5ccf6d8e..595bb53a0aa9 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrayData.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrayData.cs @@ -30,11 +30,12 @@ public sealed class ArrayData : IDisposable public readonly int Offset; public readonly ArrowBuffer[] Buffers; public readonly ArrayData[] Children; + public readonly ArrayData Dictionary; //Only used for dictionary type public ArrayData( IArrowType dataType, int length, int nullCount = 0, int offset = 0, - IEnumerable buffers = null, IEnumerable children = null) + IEnumerable buffers = null, IEnumerable children = null, ArrayData dictionary = null) { DataType = dataType ?? NullType.Default; Length = length; @@ -42,12 +43,13 @@ public ArrayData( Offset = offset; Buffers = buffers?.ToArray(); Children = children?.ToArray(); + Dictionary = dictionary; } public ArrayData( IArrowType dataType, int length, int nullCount = 0, int offset = 0, - ArrowBuffer[] buffers = null, ArrayData[] children = null) + ArrowBuffer[] buffers = null, ArrayData[] children = null, ArrayData dictionary = null) { DataType = dataType ?? NullType.Default; Length = length; @@ -55,6 +57,7 @@ public ArrayData( Offset = offset; Buffers = buffers; Children = children; + Dictionary = dictionary; } public void Dispose() @@ -74,6 +77,8 @@ public void Dispose() child?.Dispose(); } } + + Dictionary?.Dispose(); } public ArrayData Slice(int offset, int length) @@ -86,7 +91,7 @@ public ArrayData Slice(int offset, int length) length = Math.Min(Length - offset, length); offset += Offset; - return new ArrayData(DataType, length, RecalculateNullCount, offset, Buffers, Children); + return new ArrayData(DataType, length, RecalculateNullCount, offset, Buffers, Children, Dictionary); } } } diff --git a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs index c3429230cc65..c8c0b2487f6a 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs @@ -67,6 +67,7 @@ public static IArrowArray BuildArray(ArrayData data) case ArrowTypeId.Decimal256: return new Decimal256Array(data); case ArrowTypeId.Dictionary: + return new DictionaryArray(data); case ArrowTypeId.FixedSizedBinary: case ArrowTypeId.HalfFloat: case ArrowTypeId.Interval: diff --git a/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs b/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs new file mode 100644 index 000000000000..9e90c5279136 --- /dev/null +++ b/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.IO; +using Apache.Arrow.Types; + +namespace Apache.Arrow +{ + public class DictionaryArray : Array + { + public IArrowArray Dictionary { get; } + public IArrowArray Indices { get; } + public ArrowBuffer IndicesBuffer => Data.Buffers[1]; + + public DictionaryArray(IArrowType dataType, int length, + ArrowBuffer valueOffsetsBuffer, IArrowArray value, + ArrowBuffer nullBitmapBuffer, int nullCount = 0, int offset = 0) + : this(new ArrayData(dataType, length, nullCount, offset, + new[] { nullBitmapBuffer, valueOffsetsBuffer }, new[] { value.Data }, value.Data.Dictionary)) + { + } + + public DictionaryArray(ArrayData data) : base(data) + { + data.EnsureBufferCount(2); + data.EnsureDataType(ArrowTypeId.Dictionary); + + var dicType = data.DataType as DictionaryType; + data.Dictionary.EnsureDataType(dicType.ValueType.TypeId); + + ArrayData indicesData = new ArrayData(dicType.IndexType, data.Length, data.NullCount, data.Offset, data.Buffers, data.Children); + + Indices = ArrowArrayFactory.BuildArray(indicesData); + Dictionary = ArrowArrayFactory.BuildArray(data.Dictionary); + } + + public DictionaryArray(IArrowType dataType, IArrowArray indicesArray, IArrowArray dictionary, bool ordered = false) : + base(new ArrayData(dataType, indicesArray.Length, indicesArray.Data.NullCount, indicesArray.Data.Offset, indicesArray.Data.Buffers, indicesArray.Data.Children, dictionary.Data)) + { + Data.EnsureBufferCount(2); + Data.EnsureDataType(ArrowTypeId.Dictionary); + + var dicType = dataType as DictionaryType; + + indicesArray.Data.EnsureDataType(dicType.IndexType.TypeId); + dictionary.Data.EnsureDataType(dicType.ValueType.TypeId); + + Indices = indicesArray; + Dictionary = dictionary; + } + + public override void Accept(IArrowArrayVisitor visitor) => Accept(this, visitor); + } +} diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs index f2bf21a0c4f3..3b27eec7b2d2 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs @@ -135,7 +135,7 @@ private static int ReadFooterLength(Memory buffer) private void ReadSchema(Memory buffer) { // Deserialize the footer from the footer flatbuffer - _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer))); + _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)), _dictionaryMemo); Schema = _footer.Schema; } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs index aa7d7952d3f0..06f58244051f 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs @@ -61,8 +61,8 @@ public ArrowFooter(Schema schema, IEnumerable dictionaries, IEnumerable _buffer; private int _bufferPosition; - public ArrowMemoryReaderImplementation(ReadOnlyMemory buffer) + public ArrowMemoryReaderImplementation(ReadOnlyMemory buffer) : base() { _buffer = buffer; } @@ -111,7 +111,7 @@ private void ReadSchema() } ByteBuffer schemaBuffer = CreateByteBuffer(_buffer.Slice(_bufferPosition)); - Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer)); + Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), _dictionaryMemo); _bufferPosition += schemaMessageLength; } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index 3279f7030557..010226f3efb5 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -29,6 +29,13 @@ internal abstract class ArrowReaderImplementation : IDisposable { public Schema Schema { get; protected set; } protected bool HasReadSchema => Schema != null; + protected bool HasReadInitialDictionary { get; set; } + protected readonly DictionaryMemo _dictionaryMemo; + + public ArrowReaderImplementation() + { + _dictionaryMemo = new DictionaryMemo(); + } public void Dispose() { @@ -88,8 +95,8 @@ protected RecordBatch CreateArrowObjectFromMessage( // TODO: Read schema and verify equality? break; case Flatbuf.MessageHeader.DictionaryBatch: - // TODO: not supported currently - Debug.WriteLine("Dictionaries are not yet supported."); + Flatbuf.DictionaryBatch dictionaryBatch = message.Header().Value; + ReadDictionaryBatch(dictionaryBatch, bodyByteBuffer, memoryOwner); break; case Flatbuf.MessageHeader.RecordBatch: Flatbuf.RecordBatch rb = message.Header().Value; @@ -109,6 +116,36 @@ internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory buffer) return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer), 0); } + private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBuffer bodyByteBuffer, IMemoryOwner memoryOwner) + { + long id = dictionaryBatch.Id; + IArrowType valueType = _dictionaryMemo.GetDictionaryType(id); + Flatbuf.RecordBatch? recordBatch = dictionaryBatch.Data; + + if (!recordBatch.HasValue) + { + throw new InvalidDataException("Dictionary must contain RecordBatch"); + } + + Field valueField = new Field("dummy", valueType, true); + var schema = new Schema(new[] { valueField }, default); + IList arrays = BuildArrays(schema, bodyByteBuffer, recordBatch.Value); + + if (arrays.Count != 1) + { + throw new InvalidDataException("Dictionary record batch must contain only one field"); + } + + if (dictionaryBatch.IsDelta) + { + throw new NotImplementedException("Dictionary delta is not supported yet"); + } + else + { + _dictionaryMemo.AddOrReplaceDictionary(id, arrays[0]); + } + } + private List BuildArrays( Schema schema, ByteBuffer messageBuffer, @@ -179,7 +216,14 @@ private ArrayData LoadPrimitiveField( ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData); - return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children); + IArrowArray dictionary = null; + if (field.DataType.TypeId == ArrowTypeId.Dictionary) + { + long id = _dictionaryMemo.GetId(field); + dictionary = _dictionaryMemo?.GetDictionary(id); + } + + return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); } private ArrayData LoadVariableField( @@ -218,7 +262,14 @@ private ArrayData LoadVariableField( ArrowBuffer[] arrowBuff = new[] { nullArrowBuffer, offsetArrowBuffer, valueArrowBuffer }; ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData); - return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children); + IArrowArray dictionary = null; + if (field.DataType.TypeId == ArrowTypeId.Dictionary) + { + long id = _dictionaryMemo.GetId(field); + dictionary = _dictionaryMemo?.GetDictionary(id); + } + + return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); } private ArrayData[] GetChildren( diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 1fd320903bd0..01fe41ee9e95 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -28,7 +28,7 @@ internal class ArrowStreamReaderImplementation : ArrowReaderImplementation private readonly bool _leaveOpen; private readonly MemoryAllocator _allocator; - public ArrowStreamReaderImplementation(Stream stream, MemoryAllocator allocator, bool leaveOpen) + public ArrowStreamReaderImplementation(Stream stream, MemoryAllocator allocator, bool leaveOpen) : base() { BaseStream = stream; _allocator = allocator ?? MemoryAllocator.Default.Value; @@ -42,6 +42,37 @@ protected override void Dispose(bool disposing) BaseStream.Dispose(); } } + protected void ReadInitialDictionaries() + { + if (HasReadInitialDictionary) + { + return; + } + + int fieldCount = _dictionaryMemo.GetFieldCount(); + for (int i = 0; i < fieldCount; ++i) + { + ReadArrowObject(); + } + + HasReadInitialDictionary = true; + } + + protected async ValueTask ReadInitialDictionariesAsync(CancellationToken cancellationToken = default) + { + if (HasReadInitialDictionary) + { + return; + } + + int fieldCount = _dictionaryMemo.GetFieldCount(); + for (int i = 0; i < fieldCount; ++i) + { + await ReadArrowObjectAsync(cancellationToken).ConfigureAwait(false); + } + + HasReadInitialDictionary = true; + } public override async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) { @@ -50,6 +81,8 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati return await ReadRecordBatchAsync(cancellationToken).ConfigureAwait(false); } + + public override RecordBatch ReadNextRecordBatch() { return ReadRecordBatch(); @@ -59,6 +92,66 @@ protected async ValueTask ReadRecordBatchAsync(CancellationToken ca { await ReadSchemaAsync().ConfigureAwait(false); + await ReadInitialDictionariesAsync().ConfigureAwait(false); + + return await ReadArrowObjectAsync().ConfigureAwait(false); + } + + + protected RecordBatch ReadRecordBatch() + { + ReadSchema(); + + ReadInitialDictionaries(); + + return ReadArrowObject(); + } + + protected virtual async ValueTask ReadSchemaAsync() + { + if (HasReadSchema) + { + return; + } + + // Figure out length of schema + int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true) + .ConfigureAwait(false); + + await ArrayPool.Shared.RentReturnAsync(schemaMessageLength, async (buff) => + { + // Read in schema + int bytesRead = await BaseStream.ReadFullBufferAsync(buff).ConfigureAwait(false); + EnsureFullRead(buff, bytesRead); + + FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); + Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), _dictionaryMemo); + }).ConfigureAwait(false); + } + + protected virtual void ReadSchema() + { + if (HasReadSchema) + { + return; + } + + // Figure out length of schema + int schemaMessageLength = ReadMessageLength(throwOnFullRead: true); + + ArrayPool.Shared.RentReturn(schemaMessageLength, buff => + { + int bytesRead = BaseStream.ReadFullBuffer(buff); + EnsureFullRead(buff, bytesRead); + + FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); + Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), _dictionaryMemo); + }); + } + + // Note: When the message type is DictionaryBatch, this function adds data to _dictionaryMemo and returns null. + private async ValueTask ReadArrowObjectAsync(CancellationToken cancellationToken = default) + { int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) .ConfigureAwait(false); @@ -68,7 +161,7 @@ protected async ValueTask ReadRecordBatchAsync(CancellationToken ca return null; } - RecordBatch result = null; + RecordBatch result = default; await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) => { int bytesRead = await BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken) @@ -92,10 +185,9 @@ await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) return result; } - protected RecordBatch ReadRecordBatch() + // Note: When the message type is DictionaryBatch, this function adds data to _dictionaryMemo and returns null. + private RecordBatch ReadArrowObject() { - ReadSchema(); - int messageLength = ReadMessageLength(throwOnFullRead: false); if (messageLength == 0) @@ -104,7 +196,7 @@ protected RecordBatch ReadRecordBatch() return null; } - RecordBatch result = null; + RecordBatch result = default; ArrayPool.Shared.RentReturn(messageLength, messageBuff => { int bytesRead = BaseStream.ReadFullBuffer(messageBuff); @@ -126,48 +218,6 @@ protected RecordBatch ReadRecordBatch() return result; } - protected virtual async ValueTask ReadSchemaAsync() - { - if (HasReadSchema) - { - return; - } - - // Figure out length of schema - int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true) - .ConfigureAwait(false); - - await ArrayPool.Shared.RentReturnAsync(schemaMessageLength, async (buff) => - { - // Read in schema - int bytesRead = await BaseStream.ReadFullBufferAsync(buff).ConfigureAwait(false); - EnsureFullRead(buff, bytesRead); - - FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb)); - }).ConfigureAwait(false); - } - - protected virtual void ReadSchema() - { - if (HasReadSchema) - { - return; - } - - // Figure out length of schema - int schemaMessageLength = ReadMessageLength(throwOnFullRead: true); - - ArrayPool.Shared.RentReturn(schemaMessageLength, buff => - { - int bytesRead = BaseStream.ReadFullBuffer(buff); - EnsureFullRead(buff, bytesRead); - - FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb)); - }); - } - private async ValueTask ReadMessageLengthAsync(bool throwOnFullRead, CancellationToken cancellationToken = default) { int messageLength = 0; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 5f0d16f83068..e38895da4d31 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -48,7 +48,8 @@ internal class ArrowRecordBatchFlatBufferBuilder : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, - IArrowArrayVisitor + IArrowArrayVisitor, + IArrowArrayVisitor { public readonly struct Buffer { @@ -128,6 +129,15 @@ public void Visit(StructArray array) } } + public void Visit(DictionaryArray array) + { + //Dictionary is serialized separately in Dictionary serialization. + //We are only interested in indexes at this context. + + _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); + _buffers.Add(CreateBuffer(array.IndicesBuffer)); + } + private void CreateBuffers(BooleanArray array) { _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); @@ -165,6 +175,8 @@ public void Visit(IArrowArray array) protected bool HasWrittenSchema { get; set; } + private bool HasWrittenDictionaryBatch { get; set; } + private bool HasWrittenEnd { get; set; } protected Schema Schema { get; } @@ -178,6 +190,8 @@ public void Visit(IArrowArray array) private readonly ArrowTypeFlatbufferBuilder _fieldTypeBuilder; + private protected readonly DictionaryMemo _dictionaryMemo; + public ArrowStreamWriter(Stream baseStream, Schema schema) : this(baseStream, schema, leaveOpen: false) { @@ -200,6 +214,7 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOp _fieldTypeBuilder = new ArrowTypeFlatbufferBuilder(Builder); _options = options ?? IpcOptions.Default; + _dictionaryMemo = new DictionaryMemo(); } @@ -216,10 +231,10 @@ private void CreateSelfAndChildrenFieldNodes(ArrayData data) Flatbuf.FieldNode.CreateFieldNode(Builder, data.Length, data.NullCount); } - private int CountAllNodes() + private int CountAllNodes(IReadOnlyDictionary fields) { int count = 0; - foreach (Field arrowArray in Schema.Fields.Values) + foreach (Field arrowArray in fields.Values) { CountSelfAndChildrenNodes(arrowArray.DataType, ref count); } @@ -248,6 +263,13 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) HasWrittenSchema = true; } + if (!HasWrittenDictionaryBatch) + { + DictionaryCollector.Collect(recordBatch, _dictionaryMemo); + WriteDictionaries(recordBatch); + HasWrittenDictionaryBatch = true; + } + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) = PreparingWritingRecordBatch(recordBatch); @@ -264,37 +286,9 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) long metadataLength = WriteMessage(Flatbuf.MessageHeader.RecordBatch, recordBatchOffset, recordBatchBuilder.TotalLength); - // Write buffer data - - IReadOnlyList buffers = recordBatchBuilder.Buffers; - - long bodyLength = 0; - - for (int i = 0; i < buffers.Count; i++) - { - ArrowBuffer buffer = buffers[i].DataBuffer; - if (buffer.IsEmpty) - continue; - - WriteBuffer(buffer); - - int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length)); - int padding = paddedLength - buffer.Length; - if (padding > 0) - { - WritePadding(padding); - } - - bodyLength += paddedLength; - } - - // Write padding so the record batch message body length is a multiple of 8 bytes - - int bodyPaddingLength = CalculatePadding(bodyLength); - - WritePadding(bodyPaddingLength); + long bufferLength = WriteBufferData(recordBatchBuilder.Buffers); - FinishedWritingRecordBatch(bodyLength + bodyPaddingLength, metadataLength); + FinishedWritingRecordBatch(bufferLength, metadataLength); } private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch, @@ -308,6 +302,13 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat HasWrittenSchema = true; } + if (!HasWrittenDictionaryBatch) + { + DictionaryCollector.Collect(recordBatch, _dictionaryMemo); + await WriteDictionariesAsync(recordBatch, cancellationToken).ConfigureAwait(false); + HasWrittenDictionaryBatch = true; + } + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) = PreparingWritingRecordBatch(recordBatch); @@ -325,10 +326,44 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat recordBatchOffset, recordBatchBuilder.TotalLength, cancellationToken).ConfigureAwait(false); - // Write buffer data + long bufferLength = await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); - IReadOnlyList buffers = recordBatchBuilder.Buffers; + FinishedWritingRecordBatch(bufferLength, metadataLength); + } + + private long WriteBufferData(IReadOnlyList buffers) + { + long bodyLength = 0; + + for (int i = 0; i < buffers.Count; i++) + { + ArrowBuffer buffer = buffers[i].DataBuffer; + if (buffer.IsEmpty) + continue; + + WriteBuffer(buffer); + + int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length)); + int padding = paddedLength - buffer.Length; + if (padding > 0) + { + WritePadding(padding); + } + + bodyLength += paddedLength; + } + + // Write padding so the record batch message body length is a multiple of 8 bytes + + int bodyPaddingLength = CalculatePadding(bodyLength); + + WritePadding(bodyPaddingLength); + return bodyLength + bodyPaddingLength; + } + + private async ValueTask WriteBufferDataAsync(IReadOnlyList buffers, CancellationToken cancellationToken = default) + { long bodyLength = 0; for (int i = 0; i < buffers.Count; i++) @@ -355,23 +390,28 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat await WritePaddingAsync(bodyPaddingLength).ConfigureAwait(false); - FinishedWritingRecordBatch(bodyLength + bodyPaddingLength, metadataLength); + return bodyLength + bodyPaddingLength; } private Tuple PreparingWritingRecordBatch(RecordBatch recordBatch) + { + return PreparingWritingRecordBatch(recordBatch.Schema.Fields, recordBatch._arrays); + } + + private Tuple PreparingWritingRecordBatch(IReadOnlyDictionary fields, IReadOnlyList arrays) { Builder.Clear(); // Serialize field nodes - int fieldCount = Schema.Fields.Count; + int fieldCount = fields.Count; - Flatbuf.RecordBatch.StartNodesVector(Builder, CountAllNodes()); + Flatbuf.RecordBatch.StartNodesVector(Builder, CountAllNodes(fields)); // flatbuffer struct vectors have to be created in reverse order for (int i = fieldCount - 1; i >= 0; i--) { - CreateSelfAndChildrenFieldNodes(recordBatch.Column(i).Data); + CreateSelfAndChildrenFieldNodes(arrays[i].Data); } VectorOffset fieldNodesVectorOffset = Builder.EndVector(); @@ -381,7 +421,7 @@ private Tuple PreparingWritingR var recordBatchBuilder = new ArrowRecordBatchFlatBufferBuilder(); for (int i = 0; i < fieldCount; i++) { - IArrowArray fieldArray = recordBatch.Column(i); + IArrowArray fieldArray = arrays[i]; fieldArray.Accept(recordBatchBuilder); } @@ -399,6 +439,95 @@ private Tuple PreparingWritingR return Tuple.Create(recordBatchBuilder, fieldNodesVectorOffset); } + + private protected void WriteDictionaries(RecordBatch recordBatch) + { + foreach (Field field in recordBatch.Schema.Fields.Values) + { + WriteDictionary(field); + } + } + + private protected void WriteDictionary(Field field) + { + if (field.DataType.TypeId != ArrowTypeId.Dictionary) + { + if (field.DataType is NestedType nestedType) + { + foreach (Field child in nestedType.Fields) + { + WriteDictionary(child); + } + } + return; + } + + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, Offset dictionaryBatchOffset) = + CreateDictionaryBatchOffset(field); + + WriteMessage(Flatbuf.MessageHeader.DictionaryBatch, + dictionaryBatchOffset, recordBatchBuilder.TotalLength); + + WriteBufferData(recordBatchBuilder.Buffers); + } + + private protected async Task WriteDictionariesAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default) + { + foreach (Field field in recordBatch.Schema.Fields.Values) + { + await WriteDictionaryAsync(field, cancellationToken).ConfigureAwait(false); + } + } + + private protected async Task WriteDictionaryAsync(Field field, CancellationToken cancellationToken = default) + { + if (field.DataType.TypeId != ArrowTypeId.Dictionary) + { + if (field.DataType is NestedType nestedType) + { + foreach (Field child in nestedType.Fields) + { + await WriteDictionaryAsync(child, cancellationToken).ConfigureAwait(false); + } + } + return; + } + + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, Offset dictionaryBatchOffset) = + CreateDictionaryBatchOffset(field); + + await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch, + dictionaryBatchOffset, recordBatchBuilder.TotalLength, cancellationToken).ConfigureAwait(false); + + await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); + } + + private protected Tuple> CreateDictionaryBatchOffset(Field field) + { + Field dictionaryField = new Field("dummy", ((DictionaryType)field.DataType).ValueType, false); + long id = _dictionaryMemo.GetId(field); + IArrowArray dictionary = _dictionaryMemo.GetDictionary(id); + + var fieldsDictionary = new Dictionary { + { dictionaryField.Name, dictionaryField } }; + + var arrays = new List { dictionary }; + + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) = + PreparingWritingRecordBatch(fieldsDictionary, arrays); + + VectorOffset buffersVectorOffset = Builder.EndVector(); + + // Serialize record batch + Offset recordBatchOffset = Flatbuf.RecordBatch.CreateRecordBatch(Builder, dictionary.Length, + fieldNodesVectorOffset, + buffersVectorOffset); + + //TODO: Support delta. + Offset dictionaryBatchOffset = Flatbuf.DictionaryBatch.CreateDictionaryBatch(Builder, id, recordBatchOffset, false); + return Tuple.Create(recordBatchBuilder, dictionaryBatchOffset); + } + private protected virtual void WriteEndInternal() { WriteIpcMessageLength(length: 0); @@ -475,10 +604,11 @@ private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken ca VectorOffset fieldChildrenVectorOffset = GetChildrenFieldOffset(field); VectorOffset fieldMetadataVectorOffset = GetFieldMetadataOffset(field); + Offset dictionaryOffset = GetDictionaryOffset(field); fieldOffsets[i] = Flatbuf.Field.CreateField(Builder, fieldNameOffset, field.IsNullable, fieldType.Type, fieldType.Offset, - default, fieldChildrenVectorOffset, fieldMetadataVectorOffset); + dictionaryOffset, fieldChildrenVectorOffset, fieldMetadataVectorOffset); } VectorOffset fieldsVectorOffset = Flatbuf.Schema.CreateFieldsVector(Builder, fieldOffsets); @@ -529,6 +659,21 @@ private VectorOffset GetFieldMetadataOffset(Field field) return Flatbuf.Field.CreateCustomMetadataVector(Builder, metadataOffsets); } + private Offset GetDictionaryOffset(Field field) + { + if (field.DataType.TypeId != ArrowTypeId.Dictionary) + { + return default; + } + + long id = _dictionaryMemo.GetOrAssignId(field); + var dicType = field.DataType as DictionaryType; + var indexType = dicType.IndexType as NumberType; + + Offset indexOffset = Flatbuf.Int.CreateInt(Builder, indexType.BitWidth, indexType.IsSigned); + return Flatbuf.DictionaryEncoding.CreateDictionaryEncoding(Builder, id, indexOffset, dicType.Ordered); + } + private Offset[] GetMetadataOffsets(IReadOnlyDictionary metadata) { Debug.Assert(metadata != null); @@ -723,4 +868,61 @@ public virtual void Dispose() } } } + + internal static class DictionaryCollector + { + internal static void Collect(RecordBatch recordBatch, DictionaryMemo dictionaryMemo) + { + Schema schema = recordBatch.Schema; + for (int i = 0; i < schema.Fields.Count; i++) + { + { + Field field = schema.GetFieldByIndex(i); + IArrowArray array = recordBatch.Column(i); + + CollectDictionary(field, array, dictionaryMemo); + } + } + } + + private static void CollectDictionary(Field field, IArrowArray array, DictionaryMemo dictionaryMemo) + { + if (field.DataType.TypeId == ArrowTypeId.Dictionary) + { + IArrowArray dictionary = (array as DictionaryArray).Dictionary; + long id = dictionaryMemo.GetOrAssignId(field); + + dictionaryMemo.AddOrReplaceDictionary(id, dictionary); + WalkChildren(dictionary, dictionaryMemo); + } + else + { + WalkChildren(array, dictionaryMemo); + } + } + + private static void WalkChildren(IArrowArray array, DictionaryMemo dictionaryMemo) + { + ArrayData[] children = array.Data.Children; + + if (children == null) + { + return; + } + + if (!(array.Data.DataType is NestedType nestedType)) + { + return; + } + + for (int i = 0; i < nestedType.Fields.Count; i++) + { + Field childField = nestedType.Fields[i]; + ArrayData child = children[i]; + IArrowArray childArray = ArrowArrayFactory.BuildArray(child); + + CollectDictionary(childField, childArray, dictionaryMemo); + } + } + } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs index d0d2b74e701b..23f5b3f3d9c8 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs @@ -63,7 +63,8 @@ class TypeVisitor : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private FlatBufferBuilder Builder { get; } @@ -201,6 +202,14 @@ private void CreateIntType(NumberType type) Flatbuf.Int.CreateInt(Builder, type.BitWidth, type.IsSigned)); } + public void Visit(DictionaryType type) + { + // In this library, the dictionary "type" is a logical construct. Here we + // pass through to the value type, as we've already captured the index + // type in the DictionaryEncoding metadata in the parent field + type.ValueType.Accept(this); + } + public void Visit(IArrowType type) { throw new NotImplementedException(); diff --git a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs new file mode 100644 index 000000000000..82c837aa2aef --- /dev/null +++ b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using Apache.Arrow.Types; + +namespace Apache.Arrow.Ipc +{ + class DictionaryMemo + { + private Dictionary _idToDictionary; + private Dictionary _idToValueType; + private Dictionary _fieldToId; + + public DictionaryMemo() + { + _idToDictionary = new Dictionary(); + _idToValueType = new Dictionary(); + _fieldToId = new Dictionary(); + } + + public IArrowType GetDictionaryType(long id) + { + if (!_idToValueType.TryGetValue(id, out IArrowType type)) + { + throw new ArgumentException($"Dictionary with id {id} not found"); + } + return type; + } + + public IArrowArray GetDictionary(long id) + { + if (!_idToDictionary.TryGetValue(id, out IArrowArray dictionary)) + { + throw new ArgumentException($"Dictionary with id {id} not found"); + } + return dictionary; + } + + public void AddField(long id, Field field) + { + if (_fieldToId.ContainsKey(field)) + { + throw new ArgumentException($"Field {field.Name} is already in Memo"); + } + + if (field.DataType.TypeId != ArrowTypeId.Dictionary) + { + throw new ArgumentException($"Field type is not DictionaryType: Name={field.Name}, {field.DataType.Name}"); + } + + IArrowType valueType = ((DictionaryType)field.DataType).ValueType; + + if (_idToValueType.TryGetValue(id, out IArrowType valueTypeInDic)) + { + if (valueType != valueTypeInDic) + { + throw new ArgumentException($"Field type {field.DataType.Name} does not match the existing type {valueTypeInDic})"); + } + } + + _fieldToId.Add(field, id); + _idToValueType.Add(id, valueType); + } + + public long GetId(Field field) + { + if (!_fieldToId.TryGetValue(field, out long id)) + { + throw new ArgumentException($"Field with name {field.Name} not found"); + } + return id; + } + + public long GetOrAssignId(Field field) + { + if (!_fieldToId.TryGetValue(field, out long id)) + { + id = _fieldToId.Count + 1; + AddField(id, field); + } + return id; + } + + public void AddOrReplaceDictionary(long id, IArrowArray dictionary) + { + _idToDictionary[id] = dictionary; + } + + public void AddDictionaryDelta(long id, IArrowArray dictionary) + { + throw new NotImplementedException("Dictionary delta is not supported yet."); + } + + public int GetFieldCount() + { + return _fieldToId.Count; + } + } +} diff --git a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs index a4e766089245..9b644a5004c5 100644 --- a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs +++ b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs @@ -17,6 +17,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; +using Apache.Arrow.Types; namespace Apache.Arrow.Ipc { @@ -52,14 +53,13 @@ public static Types.NumberType GetNumberType(int bitWidth, bool signed) $"{(signed ? "signed " : "unsigned")} integer."); } - internal static Schema GetSchema(Flatbuf.Schema schema) + internal static Schema GetSchema(Flatbuf.Schema schema, DictionaryMemo dictionaryMemo) { List fields = new List(); for (int i = 0; i < schema.FieldsLength; i++) { Flatbuf.Field field = schema.Fields(i).GetValueOrDefault(); - - fields.Add(FieldFromFlatbuffer(field)); + fields.Add(FieldFromFlatbuffer(field, dictionaryMemo)); } Dictionary metadata = schema.CustomMetadataLength > 0 ? new Dictionary() : null; @@ -73,13 +73,27 @@ internal static Schema GetSchema(Flatbuf.Schema schema) return new Schema(fields, metadata, copyCollections: false); } - private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField) + private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField, DictionaryMemo dictionaryMemo) { Field[] childFields = flatbufField.ChildrenLength > 0 ? new Field[flatbufField.ChildrenLength] : null; for (int i = 0; i < flatbufField.ChildrenLength; i++) { Flatbuf.Field? childFlatbufField = flatbufField.Children(i); - childFields[i] = FieldFromFlatbuffer(childFlatbufField.Value); + childFields[i] = FieldFromFlatbuffer(childFlatbufField.Value, dictionaryMemo); + } + + Flatbuf.DictionaryEncoding? de = flatbufField.Dictionary; + IArrowType type = GetFieldArrowType(flatbufField, childFields); + + if (de.HasValue) + { + Flatbuf.Int? indexTypeAsInt = de.Value.IndexType; + if (!indexTypeAsInt.HasValue) + { + throw new InvalidDataException("Dictionary type not defined"); + } + IArrowType indexType = GetNumberType(indexTypeAsInt.Value.BitWidth, indexTypeAsInt.Value.IsSigned); + type = new DictionaryType(indexType, type, de.Value.IsOrdered); } Dictionary metadata = flatbufField.CustomMetadataLength > 0 ? new Dictionary() : null; @@ -90,7 +104,14 @@ private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField) metadata[keyValue.Key] = keyValue.Value; } - return new Field(flatbufField.Name, GetFieldArrowType(flatbufField, childFields), flatbufField.Nullable, metadata, copyCollections: false); + var arrowField = new Field(flatbufField.Name, type, flatbufField.Nullable, metadata, copyCollections: false); + + if (de.HasValue) + { + dictionaryMemo.AddField(de.Value.Id, arrowField); + } + + return arrowField; } private static Types.IArrowType GetFieldArrowType(Flatbuf.Field field, Field[] childFields = null) diff --git a/csharp/src/Apache.Arrow/RecordBatch.cs b/csharp/src/Apache.Arrow/RecordBatch.cs index 3cffabee32be..82fef5b1dbf6 100644 --- a/csharp/src/Apache.Arrow/RecordBatch.cs +++ b/csharp/src/Apache.Arrow/RecordBatch.cs @@ -29,7 +29,7 @@ public partial class RecordBatch : IDisposable public int Length { get; } private readonly IMemoryOwner _memoryOwner; - private readonly IList _arrays; + internal readonly IReadOnlyList _arrays; public IArrowArray Column(int i) { diff --git a/csharp/src/Apache.Arrow/Types/DictionaryType.cs b/csharp/src/Apache.Arrow/Types/DictionaryType.cs new file mode 100644 index 000000000000..5c1dd4095eb1 --- /dev/null +++ b/csharp/src/Apache.Arrow/Types/DictionaryType.cs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +using System; + +namespace Apache.Arrow.Types +{ + public sealed class DictionaryType : FixedWidthType + { + public static readonly DictionaryType Default = new DictionaryType(Int64Type.Default, Int64Type.Default, false); + + public DictionaryType(IArrowType indexType, IArrowType valueType, bool ordered) + { + if (!(indexType is IntegerType)) + { + throw new ArgumentException($"{nameof(indexType)} must be integer"); + } + + IndexType = indexType; + ValueType = valueType; + Ordered = ordered; + } + + public override ArrowTypeId TypeId => ArrowTypeId.Dictionary; + public override string Name => "dictionary"; + public override int BitWidth => 64; + public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); + + public IArrowType IndexType { get; private set; } + public IArrowType ValueType { get; private set; } + public bool Ordered { get; private set; } + } +} diff --git a/csharp/src/Apache.Arrow/Types/Int16Type.cs b/csharp/src/Apache.Arrow/Types/Int16Type.cs index f1d6868ba8ae..564ae069206f 100644 --- a/csharp/src/Apache.Arrow/Types/Int16Type.cs +++ b/csharp/src/Apache.Arrow/Types/Int16Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class Int16Type : NumberType + public sealed class Int16Type : IntegerType { public static readonly Int16Type Default = new Int16Type(); @@ -26,4 +26,4 @@ public sealed class Int16Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/Int32Type.cs b/csharp/src/Apache.Arrow/Types/Int32Type.cs index a32c88462983..bc2ad32e4a10 100644 --- a/csharp/src/Apache.Arrow/Types/Int32Type.cs +++ b/csharp/src/Apache.Arrow/Types/Int32Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class Int32Type : NumberType + public sealed class Int32Type : IntegerType { public static readonly Int32Type Default = new Int32Type(); @@ -26,4 +26,4 @@ public sealed class Int32Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/Int64Type.cs b/csharp/src/Apache.Arrow/Types/Int64Type.cs index f45523cfb330..9be7f2161ee8 100644 --- a/csharp/src/Apache.Arrow/Types/Int64Type.cs +++ b/csharp/src/Apache.Arrow/Types/Int64Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class Int64Type : NumberType + public sealed class Int64Type : IntegerType { public static readonly Int64Type Default = new Int64Type(); @@ -26,4 +26,4 @@ public sealed class Int64Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/Int8Type.cs b/csharp/src/Apache.Arrow/Types/Int8Type.cs index 9b3f5b5b4fc9..fd6e471155f5 100644 --- a/csharp/src/Apache.Arrow/Types/Int8Type.cs +++ b/csharp/src/Apache.Arrow/Types/Int8Type.cs @@ -16,7 +16,7 @@ namespace Apache.Arrow.Types { - public sealed class Int8Type : NumberType + public sealed class Int8Type : IntegerType { public static readonly Int8Type Default = new Int8Type(); diff --git a/csharp/src/Apache.Arrow/Types/IntegerType.cs b/csharp/src/Apache.Arrow/Types/IntegerType.cs new file mode 100644 index 000000000000..7a5057c46686 --- /dev/null +++ b/csharp/src/Apache.Arrow/Types/IntegerType.cs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +namespace Apache.Arrow.Types +{ + public abstract class IntegerType: NumberType + { + } +} diff --git a/csharp/src/Apache.Arrow/Types/UInt16Type.cs b/csharp/src/Apache.Arrow/Types/UInt16Type.cs index 1925ffb86b79..7e020d37e756 100644 --- a/csharp/src/Apache.Arrow/Types/UInt16Type.cs +++ b/csharp/src/Apache.Arrow/Types/UInt16Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class UInt16Type : NumberType + public sealed class UInt16Type : IntegerType { public static readonly UInt16Type Default = new UInt16Type(); @@ -26,4 +26,4 @@ public sealed class UInt16Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/UInt32Type.cs b/csharp/src/Apache.Arrow/Types/UInt32Type.cs index 8007025f3061..9015f118b3a5 100644 --- a/csharp/src/Apache.Arrow/Types/UInt32Type.cs +++ b/csharp/src/Apache.Arrow/Types/UInt32Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class UInt32Type : NumberType + public sealed class UInt32Type : IntegerType { public static readonly UInt32Type Default = new UInt32Type(); @@ -26,4 +26,4 @@ public sealed class UInt32Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/UInt64Type.cs b/csharp/src/Apache.Arrow/Types/UInt64Type.cs index 20b51ad44f54..a414e701687b 100644 --- a/csharp/src/Apache.Arrow/Types/UInt64Type.cs +++ b/csharp/src/Apache.Arrow/Types/UInt64Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class UInt64Type : NumberType + public sealed class UInt64Type : IntegerType { public static readonly UInt64Type Default = new UInt64Type(); @@ -26,4 +26,4 @@ public sealed class UInt64Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/UInt8Type.cs b/csharp/src/Apache.Arrow/Types/UInt8Type.cs index e2e53657200e..31121b4e059f 100644 --- a/csharp/src/Apache.Arrow/Types/UInt8Type.cs +++ b/csharp/src/Apache.Arrow/Types/UInt8Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class UInt8Type : NumberType + public sealed class UInt8Type : IntegerType { public static readonly UInt8Type Default = new UInt8Type(); @@ -26,4 +26,4 @@ public sealed class UInt8Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/test/Apache.Arrow.Benchmarks/ArrowWriterBenchmark.cs b/csharp/test/Apache.Arrow.Benchmarks/ArrowWriterBenchmark.cs index f35c2a5d78d7..c791c9969356 100644 --- a/csharp/test/Apache.Arrow.Benchmarks/ArrowWriterBenchmark.cs +++ b/csharp/test/Apache.Arrow.Benchmarks/ArrowWriterBenchmark.cs @@ -38,7 +38,7 @@ public class ArrowWriterBenchmark [GlobalSetup] public void GlobalSetup() { - _batch = TestData.CreateSampleRecordBatch(BatchLength, ColumnSetCount); + _batch = TestData.CreateSampleRecordBatch(BatchLength, ColumnSetCount, false); _memoryStream = new MemoryStream(); } diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs index 78f51a7459c0..071560fe6ad8 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs @@ -80,7 +80,8 @@ private class ArrayComparer : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, - IArrowArrayVisitor + IArrowArrayVisitor, + IArrowArrayVisitor { private readonly IArrowArray _expectedArray; private readonly ArrayTypeComparer _arrayTypeComparer; @@ -129,6 +130,16 @@ public void Visit(StructArray array) } } + public void Visit(DictionaryArray array) + { + Assert.IsAssignableFrom(_expectedArray); + DictionaryArray expectedArray = (DictionaryArray)_expectedArray; + var indicesComparer = new ArrayComparer(expectedArray.Indices); + var dictionaryComparer = new ArrayComparer(expectedArray.Dictionary); + array.Indices.Accept(indicesComparer); + array.Dictionary.Accept(dictionaryComparer); + } + public void Visit(FixedSizeBinaryType array) => throw new NotImplementedException(); public void Visit(IArrowArray array) => throw new NotImplementedException(); diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs index 44546f11d7ab..f3f79b496238 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs @@ -32,7 +32,7 @@ public class ArrowStreamWriterTests [Fact] public void Ctor_LeaveOpenDefault_StreamClosedOnDispose() { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); var stream = new MemoryStream(); new ArrowStreamWriter(stream, originalBatch.Schema).Dispose(); Assert.Throws(() => stream.Position); @@ -41,7 +41,7 @@ public void Ctor_LeaveOpenDefault_StreamClosedOnDispose() [Fact] public void Ctor_LeaveOpenFalse_StreamClosedOnDispose() { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); var stream = new MemoryStream(); new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: false).Dispose(); Assert.Throws(() => stream.Position); @@ -50,7 +50,7 @@ public void Ctor_LeaveOpenFalse_StreamClosedOnDispose() [Fact] public void Ctor_LeaveOpenTrue_StreamValidOnDispose() { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); var stream = new MemoryStream(); new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true).Dispose(); Assert.Equal(0, stream.Position); @@ -59,7 +59,7 @@ public void Ctor_LeaveOpenTrue_StreamValidOnDispose() [Fact] public void CanWriteToNetworkStream() { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); const int port = 32153; TcpListener listener = new TcpListener(IPAddress.Loopback, port); @@ -93,7 +93,7 @@ public void CanWriteToNetworkStream() [Fact] public async Task CanWriteToNetworkStreamAsync() { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); const int port = 32154; TcpListener listener = new TcpListener(IPAddress.Loopback, port); @@ -196,13 +196,17 @@ public async Task WriteBatchWithNullsAsync() await TestRoundTripRecordBatchAsync(originalBatch); } - private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null) + + private static void TestRoundTripRecordBatches(List originalBatches, IpcOptions options = null) { using (MemoryStream stream = new MemoryStream()) { - using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) + using (var writer = new ArrowStreamWriter(stream, originalBatches[0].Schema, leaveOpen: true, options)) { - writer.WriteRecordBatch(originalBatch); + foreach (RecordBatch originalBatch in originalBatches) + { + writer.WriteRecordBatch(originalBatch); + } writer.WriteEnd(); } @@ -210,20 +214,25 @@ private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptio using (var reader = new ArrowStreamReader(stream)) { - RecordBatch newBatch = reader.ReadNextRecordBatch(); - ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + foreach (RecordBatch originalBatch in originalBatches) + { + RecordBatch newBatch = reader.ReadNextRecordBatch(); + ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + } } } } - - private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null) + private static async Task TestRoundTripRecordBatchesAsync(List originalBatches, IpcOptions options = null) { using (MemoryStream stream = new MemoryStream()) { - using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) + using (var writer = new ArrowStreamWriter(stream, originalBatches[0].Schema, leaveOpen: true, options)) { - await writer.WriteRecordBatchAsync(originalBatch); + foreach (RecordBatch originalBatch in originalBatches) + { + await writer.WriteRecordBatchAsync(originalBatch); + } await writer.WriteEndAsync(); } @@ -231,12 +240,26 @@ private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatc using (var reader = new ArrowStreamReader(stream)) { - RecordBatch newBatch = reader.ReadNextRecordBatch(); - ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + foreach (RecordBatch originalBatch in originalBatches) + { + RecordBatch newBatch = reader.ReadNextRecordBatch(); + ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + } } } } + private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null) + { + TestRoundTripRecordBatches(new List { originalBatch }, options); + } + + + private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null) + { + await TestRoundTripRecordBatchesAsync(new List { originalBatch }, options); + } + [Fact] public void WriteBatchWithCorrectPadding() { @@ -372,7 +395,7 @@ public async Task WriteBatchWithCorrectPaddingAsync() [Fact] public void LegacyIpcFormatRoundTrips() { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); TestRoundTripRecordBatch(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); } @@ -380,7 +403,7 @@ public void LegacyIpcFormatRoundTrips() [Fact] public async Task LegacyIpcFormatRoundTripsAsync() { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); await TestRoundTripRecordBatchAsync(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); } @@ -389,7 +412,7 @@ public async Task LegacyIpcFormatRoundTripsAsync() [InlineData(false)] public void WriteLegacyIpcFormat(bool writeLegacyIpcFormat) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat }; using (MemoryStream stream = new MemoryStream()) @@ -429,7 +452,7 @@ public void WriteLegacyIpcFormat(bool writeLegacyIpcFormat) [InlineData(false)] public async Task WriteLegacyIpcFormatAsync(bool writeLegacyIpcFormat) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat }; using (MemoryStream stream = new MemoryStream()) @@ -494,5 +517,58 @@ public void WritesMetadataCorrectly() TestRoundTripRecordBatch(originalBatch); } + + [Fact] + public async Task WriteMultipleDictionaryArraysAsync() + { + List originalRecordBatches = CreateMultipleDictionaryArraysTestData(); + await TestRoundTripRecordBatchesAsync(originalRecordBatches); + } + + [Fact] + public void WriteMultipleDictionaryArrays() + { + List originalRecordBatches = CreateMultipleDictionaryArraysTestData(); + TestRoundTripRecordBatches(originalRecordBatches); + } + + private List CreateMultipleDictionaryArraysTestData() + { + var dictionaryData = new List { "a", "b", "c" }; + int length = dictionaryData.Count; + + var indicesSchema = new Schema(new List { + new Field("int8", Int8Type.Default, true), + new Field("uint8", UInt8Type.Default, true), + new Field("int16", Int16Type.Default, true), + new Field("uint16", UInt16Type.Default, true), + new Field("int32", Int32Type.Default, true), + new Field("uint32", UInt32Type.Default, true), + new Field("int64", Int64Type.Default, true), + new Field("uint64", UInt64Type.Default, true) + }, null); + + StringArray dictionary = new StringArray.Builder().AppendRange(new[] { "a", "b", "c" }).Build(); + IEnumerable indicesArrays = TestData.CreateArrays(indicesSchema, length); + + var fields = new List(capacity: length); + var dictionaryArrays = new List(capacity: length); + + foreach (IArrowArray indices in indicesArrays) + { + var dictionaryArray = new DictionaryArray( + new DictionaryType(indices.Data.DataType, StringType.Default, false), + indices, dictionary); + dictionaryArrays.Add(dictionaryArray); + fields.Add(new Field($"dictionaryField_{indices.Data.DataType.Name}", dictionaryArray.Data.DataType, false)); + } + + var schema = new Schema(fields, null); + + return new List { + new RecordBatch(schema, dictionaryArrays, length), + new RecordBatch(schema, dictionaryArrays, length), + }; + } } } diff --git a/csharp/test/Apache.Arrow.Tests/DictionaryArrayTests.cs b/csharp/test/Apache.Arrow.Tests/DictionaryArrayTests.cs new file mode 100644 index 000000000000..da678563c3e6 --- /dev/null +++ b/csharp/test/Apache.Arrow.Tests/DictionaryArrayTests.cs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using Apache.Arrow.Types; +using Xunit; + +namespace Apache.Arrow.Tests +{ + public class DictionaryArrayTests + { + [Fact] + public void CreateTest() + { + (StringArray originalDictionary, Int32Array originalIndicesArray, DictionaryArray dictionaryArray) = + CreateSimpleTestData(); + + Assert.Equal(dictionaryArray.Dictionary, originalDictionary); + Assert.Equal(dictionaryArray.Indices, originalIndicesArray); + } + + [Fact] + public void SliceTest() + { + (StringArray originalDictionary, Int32Array originalIndicesArray, DictionaryArray dictionaryArray) = + CreateSimpleTestData(); + + int batchLength = originalIndicesArray.Length; + for (int offset = 0; offset < batchLength; offset++) + { + for (int length = 1; offset + length <= batchLength; length++) + { + var sliced = dictionaryArray.Slice(offset, length) as DictionaryArray; + var actualSlicedDictionary = sliced.Dictionary as StringArray; + var actualSlicedIndicesArray = sliced.Indices as Int32Array; + + var expectedSlicedIndicesArray = originalIndicesArray.Slice(offset, length) as Int32Array; + + //Dictionary is not sliced. + Assert.Equal(originalDictionary.Data, actualSlicedDictionary.Data); + Assert.Equal(expectedSlicedIndicesArray.ToList(), actualSlicedIndicesArray.ToList()); + } + } + } + + private Tuple CreateSimpleTestData() + { + StringArray originalDictionary = new StringArray.Builder().AppendRange(new[] { "a", "b", "c" }).Build(); + Int32Array originalIndicesArray = new Int32Array.Builder().AppendRange(new[] { 0, 0, 1, 1, 2, 2 }).Build(); + var dictionaryArray = new DictionaryArray(new DictionaryType(Int32Type.Default, StringType.Default, false), originalIndicesArray, originalDictionary); + + return Tuple.Create(originalDictionary, originalIndicesArray, dictionaryArray); + } + } +} diff --git a/csharp/test/Apache.Arrow.Tests/TestData.cs b/csharp/test/Apache.Arrow.Tests/TestData.cs index 8db2241ff891..a66569ce4ad2 100644 --- a/csharp/test/Apache.Arrow.Tests/TestData.cs +++ b/csharp/test/Apache.Arrow.Tests/TestData.cs @@ -21,12 +21,13 @@ namespace Apache.Arrow.Tests { public static class TestData { - public static RecordBatch CreateSampleRecordBatch(int length) + //TODO: Remove the createDictionaryArray argument after all writer/reader supports DictionaryType serialization + public static RecordBatch CreateSampleRecordBatch(int length, bool createDictionaryArray = false) { - return CreateSampleRecordBatch(length, columnSetCount: 1); + return CreateSampleRecordBatch(length, columnSetCount: 1, createDictionaryArray); } - public static RecordBatch CreateSampleRecordBatch(int length, int columnSetCount) + public static RecordBatch CreateSampleRecordBatch(int length, int columnSetCount, bool createDictionaryArray) { Schema.Builder builder = new Schema.Builder(); for (int i = 0; i < columnSetCount; i++) @@ -50,6 +51,12 @@ public static RecordBatch CreateSampleRecordBatch(int length, int columnSetCount builder.Field(CreateField(new StructType(new List { CreateField(StringType.Default, i), CreateField(Int32Type.Default, i) }), i)); builder.Field(CreateField(new Decimal128Type(10, 6), i)); builder.Field(CreateField(new Decimal256Type(16, 8), i)); + + if (createDictionaryArray) + { + builder.Field(CreateField(new DictionaryType(Int32Type.Default, StringType.Default, false), i)); + } + //builder.Field(CreateField(new FixedSizeBinaryType(16), i)); //builder.Field(CreateField(HalfFloatType.Default)); //builder.Field(CreateField(StringType.Default)); @@ -74,7 +81,7 @@ private static Field CreateField(ArrowType type, int iteration) return new Field(type.Name + iteration, type, nullable: false); } - private static IEnumerable CreateArrays(Schema schema, int length) + public static IEnumerable CreateArrays(Schema schema, int length) { int fieldCount = schema.Fields.Count; List arrays = new List(fieldCount); @@ -114,7 +121,8 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private int Length { get; } public IArrowArray Array { get; private set; } @@ -250,6 +258,20 @@ public void Visit(StructType type) Array = new StructArray(type, Length, childArrays, nullBitmap.Build()); } + public void Visit(DictionaryType type) + { + Int32Array.Builder indicesBuilder = new Int32Array.Builder().Reserve(Length); + StringArray.Builder valueBuilder = new StringArray.Builder().Reserve(Length); + + for (int i = 0; i < Length; i++) + { + indicesBuilder.Append(i); + valueBuilder.Append($"{i}"); + } + + Array = new DictionaryArray(type, indicesBuilder.Build(), valueBuilder.Build()); + } + private void GenerateArray(IArrowArrayBuilder builder, Func generator) where TArrayBuilder : IArrowArrayBuilder where TArray : IArrowArray From 8544064b944745aa757efe23e3c3b3aaac65e89a Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Thu, 17 Jun 2021 09:33:17 +0900 Subject: [PATCH 02/15] ARROW-6870: Support Dictionary - Separate ArrayData constructors - Remove unused methods / constructors - Load DictionaryMemo lazily - Hide _arrays of RecordBatch - Fix tests --- csharp/src/Apache.Arrow/Arrays/ArrayData.cs | 22 ++++++-- .../Apache.Arrow/Arrays/DictionaryArray.cs | 28 ++++------ .../Ipc/ArrowFileReaderImplementation.cs | 2 +- csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs | 4 +- .../Ipc/ArrowMemoryReaderImplementation.cs | 2 +- .../Ipc/ArrowReaderImplementation.cs | 22 ++++---- .../Ipc/ArrowStreamReaderImplementation.cs | 20 ++++--- .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 40 +++++++------- .../src/Apache.Arrow/Ipc/MessageSerializer.cs | 10 ++-- csharp/src/Apache.Arrow/LazyCreator.cs | 54 +++++++++++++++++++ csharp/src/Apache.Arrow/RecordBatch.cs | 4 +- .../ArrowStreamReaderTests.cs | 8 +-- .../ArrowStreamWriterTests.cs | 2 +- 13 files changed, 143 insertions(+), 75 deletions(-) create mode 100644 csharp/src/Apache.Arrow/LazyCreator.cs diff --git a/csharp/src/Apache.Arrow/Arrays/ArrayData.cs b/csharp/src/Apache.Arrow/Arrays/ArrayData.cs index 595bb53a0aa9..ceb34d9908dc 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrayData.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrayData.cs @@ -35,7 +35,21 @@ public sealed class ArrayData : IDisposable public ArrayData( IArrowType dataType, int length, int nullCount = 0, int offset = 0, - IEnumerable buffers = null, IEnumerable children = null, ArrayData dictionary = null) + IEnumerable buffers = null, IEnumerable children = null) : + this(dataType, null, length, nullCount, offset, buffers, children) + { } + + public ArrayData( + IArrowType dataType, + int length, int nullCount = 0, int offset = 0, + ArrowBuffer[] buffers = null, ArrayData[] children = null) : + this(dataType, null, length, nullCount, offset, buffers, children) + { } + + public ArrayData( + IArrowType dataType, ArrayData dictionary, + int length, int nullCount = 0, int offset = 0, + IEnumerable buffers = null, IEnumerable children = null) { DataType = dataType ?? NullType.Default; Length = length; @@ -47,9 +61,9 @@ public ArrayData( } public ArrayData( - IArrowType dataType, + IArrowType dataType, ArrayData dictionary, int length, int nullCount = 0, int offset = 0, - ArrowBuffer[] buffers = null, ArrayData[] children = null, ArrayData dictionary = null) + ArrowBuffer[] buffers = null, ArrayData[] children = null) { DataType = dataType ?? NullType.Default; Length = length; @@ -91,7 +105,7 @@ public ArrayData Slice(int offset, int length) length = Math.Min(Length - offset, length); offset += Offset; - return new ArrayData(DataType, length, RecalculateNullCount, offset, Buffers, Children, Dictionary); + return new ArrayData(DataType, Dictionary, length, RecalculateNullCount, offset, Buffers, Children); } } } diff --git a/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs b/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs index 9e90c5279136..49505129a036 100644 --- a/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs @@ -25,38 +25,32 @@ public class DictionaryArray : Array public IArrowArray Indices { get; } public ArrowBuffer IndicesBuffer => Data.Buffers[1]; - public DictionaryArray(IArrowType dataType, int length, - ArrowBuffer valueOffsetsBuffer, IArrowArray value, - ArrowBuffer nullBitmapBuffer, int nullCount = 0, int offset = 0) - : this(new ArrayData(dataType, length, nullCount, offset, - new[] { nullBitmapBuffer, valueOffsetsBuffer }, new[] { value.Data }, value.Data.Dictionary)) - { - } - public DictionaryArray(ArrayData data) : base(data) { data.EnsureBufferCount(2); data.EnsureDataType(ArrowTypeId.Dictionary); - var dicType = data.DataType as DictionaryType; + if (data.Dictionary == null) + { + throw new ArgumentException($"{nameof(data.Dictionary)} must be not null"); + } + + var dicType = (DictionaryType)data.DataType; data.Dictionary.EnsureDataType(dicType.ValueType.TypeId); - ArrayData indicesData = new ArrayData(dicType.IndexType, data.Length, data.NullCount, data.Offset, data.Buffers, data.Children); + var indicesData = new ArrayData(dicType.IndexType, data.Length, data.NullCount, data.Offset, data.Buffers, data.Children); Indices = ArrowArrayFactory.BuildArray(indicesData); Dictionary = ArrowArrayFactory.BuildArray(data.Dictionary); } - public DictionaryArray(IArrowType dataType, IArrowArray indicesArray, IArrowArray dictionary, bool ordered = false) : - base(new ArrayData(dataType, indicesArray.Length, indicesArray.Data.NullCount, indicesArray.Data.Offset, indicesArray.Data.Buffers, indicesArray.Data.Children, dictionary.Data)) + public DictionaryArray(DictionaryType dataType, IArrowArray indicesArray, IArrowArray dictionary) : + base(new ArrayData(dataType, dictionary.Data, indicesArray.Length, indicesArray.Data.NullCount, indicesArray.Data.Offset, indicesArray.Data.Buffers, indicesArray.Data.Children)) { Data.EnsureBufferCount(2); - Data.EnsureDataType(ArrowTypeId.Dictionary); - - var dicType = dataType as DictionaryType; - indicesArray.Data.EnsureDataType(dicType.IndexType.TypeId); - dictionary.Data.EnsureDataType(dicType.ValueType.TypeId); + indicesArray.Data.EnsureDataType(dataType.IndexType.TypeId); + dictionary.Data.EnsureDataType(dataType.ValueType.TypeId); Indices = indicesArray; Dictionary = dictionary; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs index 3b27eec7b2d2..c6fb0b1bd68b 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs @@ -135,7 +135,7 @@ private static int ReadFooterLength(Memory buffer) private void ReadSchema(Memory buffer) { // Deserialize the footer from the footer flatbuffer - _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)), _dictionaryMemo); + _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)), _lazyDictionaryMemo); Schema = _footer.Schema; } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs index 06f58244051f..5b2b47f171fb 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs @@ -61,8 +61,8 @@ public ArrowFooter(Schema schema, IEnumerable dictionaries, IEnumerable lazyDictionaryMemo) + : this(Ipc.MessageSerializer.GetSchema(footer.Schema.GetValueOrDefault(), lazyDictionaryMemo), GetDictionaries(footer), GetRecordBatches(footer)) { } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs index c6cb75b53be6..a5725b5b0403 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs @@ -111,7 +111,7 @@ private void ReadSchema() } ByteBuffer schemaBuffer = CreateByteBuffer(_buffer.Slice(_bufferPosition)); - Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), _dictionaryMemo); + Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), _lazyDictionaryMemo); _bufferPosition += schemaMessageLength; } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index 010226f3efb5..00f09a36094f 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -29,12 +29,12 @@ internal abstract class ArrowReaderImplementation : IDisposable { public Schema Schema { get; protected set; } protected bool HasReadSchema => Schema != null; - protected bool HasReadInitialDictionary { get; set; } - protected readonly DictionaryMemo _dictionaryMemo; + + private protected readonly LazyCreator _lazyDictionaryMemo; public ArrowReaderImplementation() { - _dictionaryMemo = new DictionaryMemo(); + _lazyDictionaryMemo = new LazyCreator(); } public void Dispose() @@ -119,7 +119,7 @@ internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory buffer) private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBuffer bodyByteBuffer, IMemoryOwner memoryOwner) { long id = dictionaryBatch.Id; - IArrowType valueType = _dictionaryMemo.GetDictionaryType(id); + IArrowType valueType = _lazyDictionaryMemo.Instance.GetDictionaryType(id); Flatbuf.RecordBatch? recordBatch = dictionaryBatch.Data; if (!recordBatch.HasValue) @@ -142,7 +142,7 @@ private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBu } else { - _dictionaryMemo.AddOrReplaceDictionary(id, arrays[0]); + _lazyDictionaryMemo.Instance.AddOrReplaceDictionary(id, arrays[0]); } } @@ -219,11 +219,11 @@ private ArrayData LoadPrimitiveField( IArrowArray dictionary = null; if (field.DataType.TypeId == ArrowTypeId.Dictionary) { - long id = _dictionaryMemo.GetId(field); - dictionary = _dictionaryMemo?.GetDictionary(id); + long id = _lazyDictionaryMemo.Instance.GetId(field); + dictionary = _lazyDictionaryMemo.Instance.GetDictionary(id); } - return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); + return new ArrayData(field.DataType, dictionary?.Data, fieldLength, fieldNullCount, 0, arrowBuff, children); } private ArrayData LoadVariableField( @@ -265,11 +265,11 @@ private ArrayData LoadVariableField( IArrowArray dictionary = null; if (field.DataType.TypeId == ArrowTypeId.Dictionary) { - long id = _dictionaryMemo.GetId(field); - dictionary = _dictionaryMemo?.GetDictionary(id); + long id = _lazyDictionaryMemo.Instance.GetId(field); + dictionary = _lazyDictionaryMemo.Instance.GetDictionary(id); } - return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); + return new ArrayData(field.DataType, dictionary?.Data, fieldLength, fieldNullCount, 0, arrowBuff, children); } private ArrayData[] GetChildren( diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 01fe41ee9e95..6caa5a4883b6 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -27,6 +27,8 @@ internal class ArrowStreamReaderImplementation : ArrowReaderImplementation public Stream BaseStream { get; } private readonly bool _leaveOpen; private readonly MemoryAllocator _allocator; + private protected bool HasReadInitialDictionary { get; set; } + public ArrowStreamReaderImplementation(Stream stream, MemoryAllocator allocator, bool leaveOpen) : base() { @@ -42,6 +44,7 @@ protected override void Dispose(bool disposing) BaseStream.Dispose(); } } + protected void ReadInitialDictionaries() { if (HasReadInitialDictionary) @@ -49,12 +52,13 @@ protected void ReadInitialDictionaries() return; } - int fieldCount = _dictionaryMemo.GetFieldCount(); - for (int i = 0; i < fieldCount; ++i) - { - ReadArrowObject(); + if (_lazyDictionaryMemo.IsCreated) { + int fieldCount = _lazyDictionaryMemo.Instance.GetFieldCount(); + for (int i = 0; i < fieldCount; ++i) + { + ReadArrowObject(); + } } - HasReadInitialDictionary = true; } @@ -65,7 +69,7 @@ protected async ValueTask ReadInitialDictionariesAsync(CancellationToken cancell return; } - int fieldCount = _dictionaryMemo.GetFieldCount(); + int fieldCount = _lazyDictionaryMemo.Instance.GetFieldCount(); for (int i = 0; i < fieldCount; ++i) { await ReadArrowObjectAsync(cancellationToken).ConfigureAwait(false); @@ -125,7 +129,7 @@ await ArrayPool.Shared.RentReturnAsync(schemaMessageLength, async (buff) = EnsureFullRead(buff, bytesRead); FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), _dictionaryMemo); + Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), _lazyDictionaryMemo); }).ConfigureAwait(false); } @@ -145,7 +149,7 @@ protected virtual void ReadSchema() EnsureFullRead(buff, bytesRead); FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), _dictionaryMemo); + Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), _lazyDictionaryMemo); }); } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index e38895da4d31..8ecc99696752 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -190,7 +190,7 @@ public void Visit(IArrowArray array) private readonly ArrowTypeFlatbufferBuilder _fieldTypeBuilder; - private protected readonly DictionaryMemo _dictionaryMemo; + private protected readonly LazyCreator _lazyDictionaryMemo; public ArrowStreamWriter(Stream baseStream, Schema schema) : this(baseStream, schema, leaveOpen: false) @@ -214,7 +214,7 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOp _fieldTypeBuilder = new ArrowTypeFlatbufferBuilder(Builder); _options = options ?? IpcOptions.Default; - _dictionaryMemo = new DictionaryMemo(); + _lazyDictionaryMemo = new LazyCreator(); } @@ -231,7 +231,7 @@ private void CreateSelfAndChildrenFieldNodes(ArrayData data) Flatbuf.FieldNode.CreateFieldNode(Builder, data.Length, data.NullCount); } - private int CountAllNodes(IReadOnlyDictionary fields) + private static int CountAllNodes(IReadOnlyDictionary fields) { int count = 0; foreach (Field arrowArray in fields.Values) @@ -241,7 +241,7 @@ private int CountAllNodes(IReadOnlyDictionary fields) return count; } - private void CountSelfAndChildrenNodes(IArrowType type, ref int count) + private static void CountSelfAndChildrenNodes(IArrowType type, ref int count) { if (type is NestedType nestedType) { @@ -265,7 +265,7 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) if (!HasWrittenDictionaryBatch) { - DictionaryCollector.Collect(recordBatch, _dictionaryMemo); + DictionaryCollector.Collect(recordBatch, _lazyDictionaryMemo); WriteDictionaries(recordBatch); HasWrittenDictionaryBatch = true; } @@ -304,7 +304,7 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat if (!HasWrittenDictionaryBatch) { - DictionaryCollector.Collect(recordBatch, _dictionaryMemo); + DictionaryCollector.Collect(recordBatch, _lazyDictionaryMemo); await WriteDictionariesAsync(recordBatch, cancellationToken).ConfigureAwait(false); HasWrittenDictionaryBatch = true; } @@ -395,7 +395,7 @@ private async ValueTask WriteBufferDataAsync(IReadOnlyList PreparingWritingRecordBatch(RecordBatch recordBatch) { - return PreparingWritingRecordBatch(recordBatch.Schema.Fields, recordBatch._arrays); + return PreparingWritingRecordBatch(recordBatch.Schema.Fields, recordBatch._Arrays); } private Tuple PreparingWritingRecordBatch(IReadOnlyDictionary fields, IReadOnlyList arrays) @@ -502,11 +502,11 @@ await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch, await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); } - private protected Tuple> CreateDictionaryBatchOffset(Field field) + private Tuple> CreateDictionaryBatchOffset(Field field) { Field dictionaryField = new Field("dummy", ((DictionaryType)field.DataType).ValueType, false); - long id = _dictionaryMemo.GetId(field); - IArrowArray dictionary = _dictionaryMemo.GetDictionary(id); + long id = _lazyDictionaryMemo.Instance.GetId(field); + IArrowArray dictionary = _lazyDictionaryMemo.Instance.GetDictionary(id); var fieldsDictionary = new Dictionary { { dictionaryField.Name, dictionaryField } }; @@ -666,7 +666,7 @@ private VectorOffset GetFieldMetadataOffset(Field field) return default; } - long id = _dictionaryMemo.GetOrAssignId(field); + long id = _lazyDictionaryMemo.Instance.GetOrAssignId(field); var dicType = field.DataType as DictionaryType; var indexType = dicType.IndexType as NumberType; @@ -871,7 +871,7 @@ public virtual void Dispose() internal static class DictionaryCollector { - internal static void Collect(RecordBatch recordBatch, DictionaryMemo dictionaryMemo) + internal static void Collect(RecordBatch recordBatch, LazyCreator lazyDictionaryMemo) { Schema schema = recordBatch.Schema; for (int i = 0; i < schema.Fields.Count; i++) @@ -880,28 +880,28 @@ internal static void Collect(RecordBatch recordBatch, DictionaryMemo dictionaryM Field field = schema.GetFieldByIndex(i); IArrowArray array = recordBatch.Column(i); - CollectDictionary(field, array, dictionaryMemo); + CollectDictionary(field, array, lazyDictionaryMemo); } } } - private static void CollectDictionary(Field field, IArrowArray array, DictionaryMemo dictionaryMemo) + private static void CollectDictionary(Field field, IArrowArray array, LazyCreator lazyDictionaryMemo) { if (field.DataType.TypeId == ArrowTypeId.Dictionary) { IArrowArray dictionary = (array as DictionaryArray).Dictionary; - long id = dictionaryMemo.GetOrAssignId(field); + long id = lazyDictionaryMemo.Instance.GetOrAssignId(field); - dictionaryMemo.AddOrReplaceDictionary(id, dictionary); - WalkChildren(dictionary, dictionaryMemo); + lazyDictionaryMemo.Instance.AddOrReplaceDictionary(id, dictionary); + WalkChildren(dictionary, lazyDictionaryMemo); } else { - WalkChildren(array, dictionaryMemo); + WalkChildren(array, lazyDictionaryMemo); } } - private static void WalkChildren(IArrowArray array, DictionaryMemo dictionaryMemo) + private static void WalkChildren(IArrowArray array, LazyCreator lazyDictionaryMemo) { ArrayData[] children = array.Data.Children; @@ -921,7 +921,7 @@ private static void WalkChildren(IArrowArray array, DictionaryMemo dictionaryMem ArrayData child = children[i]; IArrowArray childArray = ArrowArrayFactory.BuildArray(child); - CollectDictionary(childField, childArray, dictionaryMemo); + CollectDictionary(childField, childArray, lazyDictionaryMemo); } } } diff --git a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs index 9b644a5004c5..db8589161033 100644 --- a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs +++ b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs @@ -53,13 +53,13 @@ public static Types.NumberType GetNumberType(int bitWidth, bool signed) $"{(signed ? "signed " : "unsigned")} integer."); } - internal static Schema GetSchema(Flatbuf.Schema schema, DictionaryMemo dictionaryMemo) + internal static Schema GetSchema(Flatbuf.Schema schema, LazyCreator lazyDictionaryMemo) { List fields = new List(); for (int i = 0; i < schema.FieldsLength; i++) { Flatbuf.Field field = schema.Fields(i).GetValueOrDefault(); - fields.Add(FieldFromFlatbuffer(field, dictionaryMemo)); + fields.Add(FieldFromFlatbuffer(field, lazyDictionaryMemo)); } Dictionary metadata = schema.CustomMetadataLength > 0 ? new Dictionary() : null; @@ -73,13 +73,13 @@ internal static Schema GetSchema(Flatbuf.Schema schema, DictionaryMemo dictionar return new Schema(fields, metadata, copyCollections: false); } - private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField, DictionaryMemo dictionaryMemo) + private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField, LazyCreator lazyDictionaryMemo) { Field[] childFields = flatbufField.ChildrenLength > 0 ? new Field[flatbufField.ChildrenLength] : null; for (int i = 0; i < flatbufField.ChildrenLength; i++) { Flatbuf.Field? childFlatbufField = flatbufField.Children(i); - childFields[i] = FieldFromFlatbuffer(childFlatbufField.Value, dictionaryMemo); + childFields[i] = FieldFromFlatbuffer(childFlatbufField.Value, lazyDictionaryMemo); } Flatbuf.DictionaryEncoding? de = flatbufField.Dictionary; @@ -108,7 +108,7 @@ private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField, DictionaryM if (de.HasValue) { - dictionaryMemo.AddField(de.Value.Id, arrowField); + lazyDictionaryMemo.Instance.AddField(de.Value.Id, arrowField); } return arrowField; diff --git a/csharp/src/Apache.Arrow/LazyCreator.cs b/csharp/src/Apache.Arrow/LazyCreator.cs new file mode 100644 index 000000000000..17dc6fa7c467 --- /dev/null +++ b/csharp/src/Apache.Arrow/LazyCreator.cs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; + +namespace Apache.Arrow +{ + public class LazyCreator where T : new() + { + private T _instance; + + private readonly Func _instanceCreator; + + public LazyCreator (Func instanceCreator = null){ + _instanceCreator = instanceCreator; + } + + public bool IsCreated => _instance != null; + + public T Instance { + get { + lock (this) { + if(IsCreated) + { + return _instance; + } + + if(_instanceCreator != null) + { + _instance = _instanceCreator(); + } + else + { + _instance = new T(); + } + + return _instance; + } + } + } + } +} diff --git a/csharp/src/Apache.Arrow/RecordBatch.cs b/csharp/src/Apache.Arrow/RecordBatch.cs index 82fef5b1dbf6..971a1faeb1f8 100644 --- a/csharp/src/Apache.Arrow/RecordBatch.cs +++ b/csharp/src/Apache.Arrow/RecordBatch.cs @@ -28,8 +28,10 @@ public partial class RecordBatch : IDisposable public IEnumerable Arrays => _arrays; public int Length { get; } + internal IReadOnlyList _Arrays => (IReadOnlyList)_arrays; + private readonly IMemoryOwner _memoryOwner; - internal readonly IReadOnlyList _arrays; + private readonly IList _arrays; public IArrowArray Column(int i) { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs index a74a27941881..4aad86296990 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs @@ -54,7 +54,7 @@ public void Ctor_LeaveOpenTrue_StreamValidOnDispose() [InlineData(false)] public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); using (MemoryStream stream = new MemoryStream()) { @@ -68,7 +68,7 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) ArrowStreamReader reader = new ArrowStreamReader(stream, memoryPool, shouldLeaveOpen); reader.ReadNextRecordBatch(); - Assert.Equal(1, memoryPool.Statistics.Allocations); + Assert.Equal(2, memoryPool.Statistics.Allocations); Assert.True(memoryPool.Statistics.BytesAllocated > 0); reader.Dispose(); @@ -150,7 +150,7 @@ private static async Task TestReaderFromStream( Func verificationFunc, bool writeEnd) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); using (MemoryStream stream = new MemoryStream()) { @@ -190,7 +190,7 @@ public async Task ReadRecordBatchAsync_PartialReadStream() /// private static async Task TestReaderFromPartialReadStream(Func verificationFunc) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); using (PartialReadStream stream = new PartialReadStream()) { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs index f3f79b496238..6060c4e452f0 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs @@ -548,7 +548,7 @@ private List CreateMultipleDictionaryArraysTestData() new Field("uint64", UInt64Type.Default, true) }, null); - StringArray dictionary = new StringArray.Builder().AppendRange(new[] { "a", "b", "c" }).Build(); + StringArray dictionary = new StringArray.Builder().AppendRange(dictionaryData).Build(); IEnumerable indicesArrays = TestData.CreateArrays(indicesSchema, length); var fields = new List(capacity: length); From 5e2c17d039259cd4ff2c0a14e054bfe1368003ca Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Fri, 18 Jun 2021 03:36:05 +0900 Subject: [PATCH 03/15] ARROW-6870: Support Dictionary - Change a LazyCreator constructor position --- csharp/src/Apache.Arrow/LazyCreator.cs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/csharp/src/Apache.Arrow/LazyCreator.cs b/csharp/src/Apache.Arrow/LazyCreator.cs index 17dc6fa7c467..dc735beeaab9 100644 --- a/csharp/src/Apache.Arrow/LazyCreator.cs +++ b/csharp/src/Apache.Arrow/LazyCreator.cs @@ -23,10 +23,6 @@ namespace Apache.Arrow private readonly Func _instanceCreator; - public LazyCreator (Func instanceCreator = null){ - _instanceCreator = instanceCreator; - } - public bool IsCreated => _instance != null; public T Instance { @@ -50,5 +46,10 @@ public T Instance { } } } + + public LazyCreator(Func instanceCreator = null) + { + _instanceCreator = instanceCreator; + } } } From 79b0015472fd53a57e40763378ba81a1683ec2bc Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Sat, 19 Jun 2021 05:27:01 +0900 Subject: [PATCH 04/15] ARROW-6870: Support Dictionary - Fix a typo --- csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs b/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs index 49505129a036..aaf3e0cf51fc 100644 --- a/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs @@ -32,7 +32,7 @@ public DictionaryArray(ArrayData data) : base(data) if (data.Dictionary == null) { - throw new ArgumentException($"{nameof(data.Dictionary)} must be not null"); + throw new ArgumentException($"{nameof(data.Dictionary)} must not be null"); } var dicType = (DictionaryType)data.DataType; From 4d5745b36d55391101515422889c92e9318db484 Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Sat, 19 Jun 2021 13:07:05 +0900 Subject: [PATCH 05/15] ARROW-6870: Support Dictionary - Slight refactoring for LazyCreator.Instance --- csharp/src/Apache.Arrow/LazyCreator.cs | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/csharp/src/Apache.Arrow/LazyCreator.cs b/csharp/src/Apache.Arrow/LazyCreator.cs index dc735beeaab9..127e2d8e40de 100644 --- a/csharp/src/Apache.Arrow/LazyCreator.cs +++ b/csharp/src/Apache.Arrow/LazyCreator.cs @@ -28,19 +28,9 @@ namespace Apache.Arrow public T Instance { get { lock (this) { - if(IsCreated) - { - return _instance; - } - - if(_instanceCreator != null) - { - _instance = _instanceCreator(); - } - else - { - _instance = new T(); - } + _instance ??= _instanceCreator != null ? + _instanceCreator(): + new T(); return _instance; } From 4ffc321bc027b405f1f240a47877a3059a078cb4 Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Sat, 19 Jun 2021 15:02:27 +0900 Subject: [PATCH 06/15] ARROW-6870: Support Dictionary - refactor ArrayData constructors --- csharp/src/Apache.Arrow/Arrays/ArrayData.cs | 26 +++++++++++-------- .../Apache.Arrow/Arrays/DictionaryArray.cs | 2 +- .../Ipc/ArrowReaderImplementation.cs | 4 +-- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/csharp/src/Apache.Arrow/Arrays/ArrayData.cs b/csharp/src/Apache.Arrow/Arrays/ArrayData.cs index ceb34d9908dc..5bae443e1120 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrayData.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrayData.cs @@ -32,24 +32,28 @@ public sealed class ArrayData : IDisposable public readonly ArrayData[] Children; public readonly ArrayData Dictionary; //Only used for dictionary type + //This is left for compatibility with lower version binaries + //before the dictionary type was supported. public ArrayData( IArrowType dataType, - int length, int nullCount = 0, int offset = 0, - IEnumerable buffers = null, IEnumerable children = null) : - this(dataType, null, length, nullCount, offset, buffers, children) + int length, int nullCount, int offset, + IEnumerable buffers, IEnumerable children) : + this(dataType, length, nullCount, offset, buffers, children, null) { } + //This is left for compatibility with lower version binaries + //before the dictionary type was supported. public ArrayData( IArrowType dataType, - int length, int nullCount = 0, int offset = 0, - ArrowBuffer[] buffers = null, ArrayData[] children = null) : - this(dataType, null, length, nullCount, offset, buffers, children) + int length, int nullCount, int offset, + ArrowBuffer[] buffers, ArrayData[] children) : + this(dataType, length, nullCount, offset, buffers, children, null) { } public ArrayData( - IArrowType dataType, ArrayData dictionary, + IArrowType dataType, int length, int nullCount = 0, int offset = 0, - IEnumerable buffers = null, IEnumerable children = null) + IEnumerable buffers = null, IEnumerable children = null, ArrayData dictionary = null) { DataType = dataType ?? NullType.Default; Length = length; @@ -61,9 +65,9 @@ public ArrayData( } public ArrayData( - IArrowType dataType, ArrayData dictionary, + IArrowType dataType, int length, int nullCount = 0, int offset = 0, - ArrowBuffer[] buffers = null, ArrayData[] children = null) + ArrowBuffer[] buffers = null, ArrayData[] children = null, ArrayData dictionary = null) { DataType = dataType ?? NullType.Default; Length = length; @@ -105,7 +109,7 @@ public ArrayData Slice(int offset, int length) length = Math.Min(Length - offset, length); offset += Offset; - return new ArrayData(DataType, Dictionary, length, RecalculateNullCount, offset, Buffers, Children); + return new ArrayData(DataType, length, RecalculateNullCount, offset, Buffers, Children, Dictionary); } } } diff --git a/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs b/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs index aaf3e0cf51fc..29c0f5c84c75 100644 --- a/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs @@ -45,7 +45,7 @@ public DictionaryArray(ArrayData data) : base(data) } public DictionaryArray(DictionaryType dataType, IArrowArray indicesArray, IArrowArray dictionary) : - base(new ArrayData(dataType, dictionary.Data, indicesArray.Length, indicesArray.Data.NullCount, indicesArray.Data.Offset, indicesArray.Data.Buffers, indicesArray.Data.Children)) + base(new ArrayData(dataType, indicesArray.Length, indicesArray.Data.NullCount, indicesArray.Data.Offset, indicesArray.Data.Buffers, indicesArray.Data.Children, dictionary.Data)) { Data.EnsureBufferCount(2); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index 00f09a36094f..e5f0268963da 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -223,7 +223,7 @@ private ArrayData LoadPrimitiveField( dictionary = _lazyDictionaryMemo.Instance.GetDictionary(id); } - return new ArrayData(field.DataType, dictionary?.Data, fieldLength, fieldNullCount, 0, arrowBuff, children); + return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); } private ArrayData LoadVariableField( @@ -269,7 +269,7 @@ private ArrayData LoadVariableField( dictionary = _lazyDictionaryMemo.Instance.GetDictionary(id); } - return new ArrayData(field.DataType, dictionary?.Data, fieldLength, fieldNullCount, 0, arrowBuff, children); + return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); } private ArrayData[] GetChildren( From 0e195005c4802124179dbe6b79f8efe6668951da Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Sun, 4 Jul 2021 01:40:45 +0900 Subject: [PATCH 07/15] ARROW-6870: Support Dictionary - Fix access modifiers - Fix tests - Change the way of the lazy creation for dictionaryMemo --- .../Internal/FlightMessageSerializer.cs | 6 +- .../Ipc/ArrowFileReaderImplementation.cs | 2 +- csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs | 4 +- .../Ipc/ArrowMemoryReaderImplementation.cs | 2 +- .../Ipc/ArrowReaderImplementation.cs | 17 +++-- .../Ipc/ArrowStreamReaderImplementation.cs | 58 +++++++++----- .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 39 +++++----- csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs | 5 -- .../src/Apache.Arrow/Ipc/MessageSerializer.cs | 23 +++--- csharp/src/Apache.Arrow/LazyCreator.cs | 45 ----------- csharp/src/Apache.Arrow/RecordBatch.cs | 4 +- .../ArrowStreamReaderTests.cs | 56 ++++++++------ .../ArrowStreamWriterTests.cs | 76 +++++++++++-------- csharp/test/Apache.Arrow.Tests/TestData.cs | 3 +- 14 files changed, 168 insertions(+), 172 deletions(-) delete mode 100644 csharp/src/Apache.Arrow/LazyCreator.cs diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs index 91919440f900..36b13a63d3c0 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs @@ -45,14 +45,16 @@ public static Schema DecodeSchema(ReadOnlyMemory buffer) ByteBuffer schemaBuffer = ArrowReaderImplementation.CreateByteBuffer(buffer.Slice(bufferPosition)); //DictionaryBatch not supported for now - var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer), default); + DictionaryMemo dictionaryMemo = null; + var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer), ref dictionaryMemo); return schema; } internal static Schema DecodeSchema(ByteBuffer schemaBuffer) { //DictionaryBatch not supported for now - var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer), default); + DictionaryMemo dictionaryMemo = null; + var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer), ref dictionaryMemo); return schema; } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs index c6fb0b1bd68b..36cd4ddf930c 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs @@ -135,7 +135,7 @@ private static int ReadFooterLength(Memory buffer) private void ReadSchema(Memory buffer) { // Deserialize the footer from the footer flatbuffer - _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)), _lazyDictionaryMemo); + _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)), ref _dictionaryMemo); Schema = _footer.Schema; } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs index 5b2b47f171fb..db269ae019b5 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs @@ -61,8 +61,8 @@ public ArrowFooter(Schema schema, IEnumerable dictionaries, IEnumerable lazyDictionaryMemo) - : this(Ipc.MessageSerializer.GetSchema(footer.Schema.GetValueOrDefault(), lazyDictionaryMemo), GetDictionaries(footer), + public ArrowFooter(Flatbuf.Footer footer, ref DictionaryMemo dictionaryMemo) + : this(Ipc.MessageSerializer.GetSchema(footer.Schema.GetValueOrDefault(), ref dictionaryMemo), GetDictionaries(footer), GetRecordBatches(footer)) { } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs index a5725b5b0403..9e3db0ec3ace 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs @@ -111,7 +111,7 @@ private void ReadSchema() } ByteBuffer schemaBuffer = CreateByteBuffer(_buffer.Slice(_bufferPosition)); - Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), _lazyDictionaryMemo); + Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), ref _dictionaryMemo); _bufferPosition += schemaMessageLength; } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index e5f0268963da..7c8ce34b7bed 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -30,11 +30,12 @@ internal abstract class ArrowReaderImplementation : IDisposable public Schema Schema { get; protected set; } protected bool HasReadSchema => Schema != null; - private protected readonly LazyCreator _lazyDictionaryMemo; + private protected DictionaryMemo _dictionaryMemo; + private protected DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo(); + private protected bool HasCreatedDictionaryMemo => _dictionaryMemo != null; public ArrowReaderImplementation() { - _lazyDictionaryMemo = new LazyCreator(); } public void Dispose() @@ -119,7 +120,7 @@ internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory buffer) private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBuffer bodyByteBuffer, IMemoryOwner memoryOwner) { long id = dictionaryBatch.Id; - IArrowType valueType = _lazyDictionaryMemo.Instance.GetDictionaryType(id); + IArrowType valueType = DictionaryMemo.GetDictionaryType(id); Flatbuf.RecordBatch? recordBatch = dictionaryBatch.Data; if (!recordBatch.HasValue) @@ -142,7 +143,7 @@ private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBu } else { - _lazyDictionaryMemo.Instance.AddOrReplaceDictionary(id, arrays[0]); + DictionaryMemo.AddOrReplaceDictionary(id, arrays[0]); } } @@ -219,8 +220,8 @@ private ArrayData LoadPrimitiveField( IArrowArray dictionary = null; if (field.DataType.TypeId == ArrowTypeId.Dictionary) { - long id = _lazyDictionaryMemo.Instance.GetId(field); - dictionary = _lazyDictionaryMemo.Instance.GetDictionary(id); + long id = DictionaryMemo.GetId(field); + dictionary = DictionaryMemo.GetDictionary(id); } return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); @@ -265,8 +266,8 @@ private ArrayData LoadVariableField( IArrowArray dictionary = null; if (field.DataType.TypeId == ArrowTypeId.Dictionary) { - long id = _lazyDictionaryMemo.Instance.GetId(field); - dictionary = _lazyDictionaryMemo.Instance.GetDictionary(id); + long id = DictionaryMemo.GetId(field); + dictionary = DictionaryMemo.GetDictionary(id); } return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 6caa5a4883b6..2f6149854e01 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -27,10 +27,9 @@ internal class ArrowStreamReaderImplementation : ArrowReaderImplementation public Stream BaseStream { get; } private readonly bool _leaveOpen; private readonly MemoryAllocator _allocator; - private protected bool HasReadInitialDictionary { get; set; } + private bool HasReadInitialDictionary { get; set; } - - public ArrowStreamReaderImplementation(Stream stream, MemoryAllocator allocator, bool leaveOpen) : base() + public ArrowStreamReaderImplementation(Stream stream, MemoryAllocator allocator, bool leaveOpen) { BaseStream = stream; _allocator = allocator ?? MemoryAllocator.Default.Value; @@ -45,15 +44,16 @@ protected override void Dispose(bool disposing) } } - protected void ReadInitialDictionaries() + private void ReadInitialDictionaries() { if (HasReadInitialDictionary) { return; } - if (_lazyDictionaryMemo.IsCreated) { - int fieldCount = _lazyDictionaryMemo.Instance.GetFieldCount(); + if (HasCreatedDictionaryMemo) + { + int fieldCount = DictionaryMemo.GetFieldCount(); for (int i = 0; i < fieldCount; ++i) { ReadArrowObject(); @@ -62,19 +62,21 @@ protected void ReadInitialDictionaries() HasReadInitialDictionary = true; } - protected async ValueTask ReadInitialDictionariesAsync(CancellationToken cancellationToken = default) + private async ValueTask ReadInitialDictionariesAsync(CancellationToken cancellationToken) { if (HasReadInitialDictionary) { return; } - int fieldCount = _lazyDictionaryMemo.Instance.GetFieldCount(); - for (int i = 0; i < fieldCount; ++i) + if (HasCreatedDictionaryMemo) { - await ReadArrowObjectAsync(cancellationToken).ConfigureAwait(false); + int fieldCount = DictionaryMemo.GetFieldCount(); + for (int i = 0; i < fieldCount; ++i) + { + await ReadArrowObjectAsync(cancellationToken).ConfigureAwait(false); + } } - HasReadInitialDictionary = true; } @@ -85,8 +87,6 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati return await ReadRecordBatchAsync(cancellationToken).ConfigureAwait(false); } - - public override RecordBatch ReadNextRecordBatch() { return ReadRecordBatch(); @@ -96,9 +96,9 @@ protected async ValueTask ReadRecordBatchAsync(CancellationToken ca { await ReadSchemaAsync().ConfigureAwait(false); - await ReadInitialDictionariesAsync().ConfigureAwait(false); + await ReadInitialDictionariesAsync(cancellationToken).ConfigureAwait(false); - return await ReadArrowObjectAsync().ConfigureAwait(false); + return await ReadArrowObjectAsync(cancellationToken).ConfigureAwait(false); } @@ -129,7 +129,7 @@ await ArrayPool.Shared.RentReturnAsync(schemaMessageLength, async (buff) = EnsureFullRead(buff, bytesRead); FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), _lazyDictionaryMemo); + Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); }).ConfigureAwait(false); } @@ -149,12 +149,21 @@ protected virtual void ReadSchema() EnsureFullRead(buff, bytesRead); FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), _lazyDictionaryMemo); + Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); }); } - // Note: When the message type is DictionaryBatch, this function adds data to _dictionaryMemo and returns null. - private async ValueTask ReadArrowObjectAsync(CancellationToken cancellationToken = default) + /// + /// Read a record batch or dictionary batch from Flatbuf.Message. + /// + /// + /// This method adds data to _dictionaryMemo and returns null when the message type is DictionaryBatch. + /// > + /// + /// The record batch when the message type is RecordBatch. + /// Null when the message type is DictionaryBatch. + /// + private async ValueTask ReadArrowObjectAsync(CancellationToken cancellationToken) { int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) .ConfigureAwait(false); @@ -189,7 +198,16 @@ await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) return result; } - // Note: When the message type is DictionaryBatch, this function adds data to _dictionaryMemo and returns null. + /// + /// Read a record batch or dictionary batch from Flatbuf.Message. + /// + /// + /// This method adds data to _dictionaryMemo and returns null when the message type is DictionaryBatch. + /// > + /// + /// The record batch when the message type is RecordBatch. + /// Null when the message type is DictionaryBatch. + /// private RecordBatch ReadArrowObject() { int messageLength = ReadMessageLength(throwOnFullRead: false); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 8ecc99696752..8f0d3ddaa63f 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -190,7 +190,8 @@ public void Visit(IArrowArray array) private readonly ArrowTypeFlatbufferBuilder _fieldTypeBuilder; - private protected readonly LazyCreator _lazyDictionaryMemo; + private DictionaryMemo _dictionaryMemo; + private DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo(); public ArrowStreamWriter(Stream baseStream, Schema schema) : this(baseStream, schema, leaveOpen: false) @@ -214,7 +215,6 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOp _fieldTypeBuilder = new ArrowTypeFlatbufferBuilder(Builder); _options = options ?? IpcOptions.Default; - _lazyDictionaryMemo = new LazyCreator(); } @@ -265,7 +265,7 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) if (!HasWrittenDictionaryBatch) { - DictionaryCollector.Collect(recordBatch, _lazyDictionaryMemo); + DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); WriteDictionaries(recordBatch); HasWrittenDictionaryBatch = true; } @@ -304,7 +304,7 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat if (!HasWrittenDictionaryBatch) { - DictionaryCollector.Collect(recordBatch, _lazyDictionaryMemo); + DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); await WriteDictionariesAsync(recordBatch, cancellationToken).ConfigureAwait(false); HasWrittenDictionaryBatch = true; } @@ -395,7 +395,7 @@ private async ValueTask WriteBufferDataAsync(IReadOnlyList PreparingWritingRecordBatch(RecordBatch recordBatch) { - return PreparingWritingRecordBatch(recordBatch.Schema.Fields, recordBatch._Arrays); + return PreparingWritingRecordBatch(recordBatch.Schema.Fields, recordBatch.ArrayList); } private Tuple PreparingWritingRecordBatch(IReadOnlyDictionary fields, IReadOnlyList arrays) @@ -471,7 +471,7 @@ private protected void WriteDictionary(Field field) WriteBufferData(recordBatchBuilder.Buffers); } - private protected async Task WriteDictionariesAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default) + private protected async Task WriteDictionariesAsync(RecordBatch recordBatch, CancellationToken cancellationToken) { foreach (Field field in recordBatch.Schema.Fields.Values) { @@ -479,7 +479,7 @@ private protected async Task WriteDictionariesAsync(RecordBatch recordBatch, Can } } - private protected async Task WriteDictionaryAsync(Field field, CancellationToken cancellationToken = default) + private protected async Task WriteDictionaryAsync(Field field, CancellationToken cancellationToken) { if (field.DataType.TypeId != ArrowTypeId.Dictionary) { @@ -505,8 +505,8 @@ await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch, private Tuple> CreateDictionaryBatchOffset(Field field) { Field dictionaryField = new Field("dummy", ((DictionaryType)field.DataType).ValueType, false); - long id = _lazyDictionaryMemo.Instance.GetId(field); - IArrowArray dictionary = _lazyDictionaryMemo.Instance.GetDictionary(id); + long id = DictionaryMemo.GetId(field); + IArrowArray dictionary = DictionaryMemo.GetDictionary(id); var fieldsDictionary = new Dictionary { { dictionaryField.Name, dictionaryField } }; @@ -666,7 +666,7 @@ private VectorOffset GetFieldMetadataOffset(Field field) return default; } - long id = _lazyDictionaryMemo.Instance.GetOrAssignId(field); + long id = DictionaryMemo.GetOrAssignId(field); var dicType = field.DataType as DictionaryType; var indexType = dicType.IndexType as NumberType; @@ -871,7 +871,7 @@ public virtual void Dispose() internal static class DictionaryCollector { - internal static void Collect(RecordBatch recordBatch, LazyCreator lazyDictionaryMemo) + internal static void Collect(RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo) { Schema schema = recordBatch.Schema; for (int i = 0; i < schema.Fields.Count; i++) @@ -880,28 +880,29 @@ internal static void Collect(RecordBatch recordBatch, LazyCreator lazyDictionaryMemo) + private static void CollectDictionary(Field field, IArrowArray array, ref DictionaryMemo dictionaryMemo) { if (field.DataType.TypeId == ArrowTypeId.Dictionary) { IArrowArray dictionary = (array as DictionaryArray).Dictionary; - long id = lazyDictionaryMemo.Instance.GetOrAssignId(field); + dictionaryMemo ??= new DictionaryMemo(); + long id = dictionaryMemo.GetOrAssignId(field); - lazyDictionaryMemo.Instance.AddOrReplaceDictionary(id, dictionary); - WalkChildren(dictionary, lazyDictionaryMemo); + dictionaryMemo.AddOrReplaceDictionary(id, dictionary); + WalkChildren(dictionary, ref dictionaryMemo); } else { - WalkChildren(array, lazyDictionaryMemo); + WalkChildren(array, ref dictionaryMemo); } } - private static void WalkChildren(IArrowArray array, LazyCreator lazyDictionaryMemo) + private static void WalkChildren(IArrowArray array, ref DictionaryMemo dictionaryMemo) { ArrayData[] children = array.Data.Children; @@ -921,7 +922,7 @@ private static void WalkChildren(IArrowArray array, LazyCreator ArrayData child = children[i]; IArrowArray childArray = ArrowArrayFactory.BuildArray(child); - CollectDictionary(childField, childArray, lazyDictionaryMemo); + CollectDictionary(childField, childArray, ref dictionaryMemo); } } } diff --git a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs index 82c837aa2aef..4952205298f7 100644 --- a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs +++ b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs @@ -100,11 +100,6 @@ public void AddOrReplaceDictionary(long id, IArrowArray dictionary) _idToDictionary[id] = dictionary; } - public void AddDictionaryDelta(long id, IArrowArray dictionary) - { - throw new NotImplementedException("Dictionary delta is not supported yet."); - } - public int GetFieldCount() { return _fieldToId.Count; diff --git a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs index db8589161033..f464895e4e11 100644 --- a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs +++ b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs @@ -53,13 +53,13 @@ public static Types.NumberType GetNumberType(int bitWidth, bool signed) $"{(signed ? "signed " : "unsigned")} integer."); } - internal static Schema GetSchema(Flatbuf.Schema schema, LazyCreator lazyDictionaryMemo) + internal static Schema GetSchema(Flatbuf.Schema schema, ref DictionaryMemo dictionaryMemo) { List fields = new List(); for (int i = 0; i < schema.FieldsLength; i++) { Flatbuf.Field field = schema.Fields(i).GetValueOrDefault(); - fields.Add(FieldFromFlatbuffer(field, lazyDictionaryMemo)); + fields.Add(FieldFromFlatbuffer(field, ref dictionaryMemo)); } Dictionary metadata = schema.CustomMetadataLength > 0 ? new Dictionary() : null; @@ -73,27 +73,27 @@ internal static Schema GetSchema(Flatbuf.Schema schema, LazyCreator lazyDictionaryMemo) + private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField, ref DictionaryMemo dictionaryMemo) { Field[] childFields = flatbufField.ChildrenLength > 0 ? new Field[flatbufField.ChildrenLength] : null; for (int i = 0; i < flatbufField.ChildrenLength; i++) { Flatbuf.Field? childFlatbufField = flatbufField.Children(i); - childFields[i] = FieldFromFlatbuffer(childFlatbufField.Value, lazyDictionaryMemo); + childFields[i] = FieldFromFlatbuffer(childFlatbufField.Value, ref dictionaryMemo); } - Flatbuf.DictionaryEncoding? de = flatbufField.Dictionary; + Flatbuf.DictionaryEncoding? dictionaryEncoding = flatbufField.Dictionary; IArrowType type = GetFieldArrowType(flatbufField, childFields); - if (de.HasValue) + if (dictionaryEncoding.HasValue) { - Flatbuf.Int? indexTypeAsInt = de.Value.IndexType; + Flatbuf.Int? indexTypeAsInt = dictionaryEncoding.Value.IndexType; if (!indexTypeAsInt.HasValue) { - throw new InvalidDataException("Dictionary type not defined"); + throw new InvalidDataException("Dictionary IndexType not defined"); } IArrowType indexType = GetNumberType(indexTypeAsInt.Value.BitWidth, indexTypeAsInt.Value.IsSigned); - type = new DictionaryType(indexType, type, de.Value.IsOrdered); + type = new DictionaryType(indexType, type, dictionaryEncoding.Value.IsOrdered); } Dictionary metadata = flatbufField.CustomMetadataLength > 0 ? new Dictionary() : null; @@ -106,9 +106,10 @@ private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField, LazyCreator var arrowField = new Field(flatbufField.Name, type, flatbufField.Nullable, metadata, copyCollections: false); - if (de.HasValue) + if (dictionaryEncoding.HasValue) { - lazyDictionaryMemo.Instance.AddField(de.Value.Id, arrowField); + dictionaryMemo ??= new DictionaryMemo(); + dictionaryMemo.AddField(dictionaryEncoding.Value.Id, arrowField); } return arrowField; diff --git a/csharp/src/Apache.Arrow/LazyCreator.cs b/csharp/src/Apache.Arrow/LazyCreator.cs deleted file mode 100644 index 127e2d8e40de..000000000000 --- a/csharp/src/Apache.Arrow/LazyCreator.cs +++ /dev/null @@ -1,45 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -using System; - -namespace Apache.Arrow -{ - public class LazyCreator where T : new() - { - private T _instance; - - private readonly Func _instanceCreator; - - public bool IsCreated => _instance != null; - - public T Instance { - get { - lock (this) { - _instance ??= _instanceCreator != null ? - _instanceCreator(): - new T(); - - return _instance; - } - } - } - - public LazyCreator(Func instanceCreator = null) - { - _instanceCreator = instanceCreator; - } - } -} diff --git a/csharp/src/Apache.Arrow/RecordBatch.cs b/csharp/src/Apache.Arrow/RecordBatch.cs index 971a1faeb1f8..6e97100686ee 100644 --- a/csharp/src/Apache.Arrow/RecordBatch.cs +++ b/csharp/src/Apache.Arrow/RecordBatch.cs @@ -28,10 +28,10 @@ public partial class RecordBatch : IDisposable public IEnumerable Arrays => _arrays; public int Length { get; } - internal IReadOnlyList _Arrays => (IReadOnlyList)_arrays; + internal IReadOnlyList ArrayList => _arrays; private readonly IMemoryOwner _memoryOwner; - private readonly IList _arrays; + private readonly List _arrays; public IArrowArray Column(int i) { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs index 4aad86296990..973fc6a0a0e5 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs @@ -50,11 +50,13 @@ public void Ctor_LeaveOpenTrue_StreamValidOnDispose() } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) + [InlineData(true, true, 2)] + [InlineData(true, false, 1)] + [InlineData(false, true, 2)] + [InlineData(false, false, 1)] + public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen, bool createDictionaryArray, int expectedAllocations) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); using (MemoryStream stream = new MemoryStream()) { @@ -68,7 +70,7 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) ArrowStreamReader reader = new ArrowStreamReader(stream, memoryPool, shouldLeaveOpen); reader.ReadNextRecordBatch(); - Assert.Equal(2, memoryPool.Statistics.Allocations); + Assert.Equal(expectedAllocations, memoryPool.Statistics.Allocations); Assert.True(memoryPool.Statistics.BytesAllocated > 0); reader.Dispose(); @@ -127,30 +129,34 @@ private static async Task TestReaderFromMemory( } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task ReadRecordBatch_Stream(bool writeEnd) + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task ReadRecordBatch_Stream(bool writeEnd, bool createDictionaryArray) { await TestReaderFromStream((reader, originalBatch) => { ArrowReaderVerifier.VerifyReader(reader, originalBatch); return Task.CompletedTask; - }, writeEnd); + }, writeEnd, createDictionaryArray); } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task ReadRecordBatchAsync_Stream(bool writeEnd) + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task ReadRecordBatchAsync_Stream(bool writeEnd, bool createDictionaryArray) { - await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync, writeEnd); + await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync, writeEnd, createDictionaryArray); } private static async Task TestReaderFromStream( Func verificationFunc, - bool writeEnd) + bool writeEnd, bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); using (MemoryStream stream = new MemoryStream()) { @@ -168,29 +174,33 @@ private static async Task TestReaderFromStream( } } - [Fact] - public async Task ReadRecordBatch_PartialReadStream() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadRecordBatch_PartialReadStream(bool createDictionaryArray) { await TestReaderFromPartialReadStream((reader, originalBatch) => { ArrowReaderVerifier.VerifyReader(reader, originalBatch); return Task.CompletedTask; - }); + }, createDictionaryArray); } - [Fact] - public async Task ReadRecordBatchAsync_PartialReadStream() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadRecordBatchAsync_PartialReadStream(bool createDictionaryArray) { - await TestReaderFromPartialReadStream(ArrowReaderVerifier.VerifyReaderAsync); + await TestReaderFromPartialReadStream(ArrowReaderVerifier.VerifyReaderAsync, createDictionaryArray); } /// /// Verifies that the stream reader reads multiple times when a stream /// only returns a subset of the data from each Read. /// - private static async Task TestReaderFromPartialReadStream(Func verificationFunc) + private static async Task TestReaderFromPartialReadStream(Func verificationFunc, bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); using (PartialReadStream stream = new PartialReadStream()) { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs index 6060c4e452f0..faea33cf8ff3 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs @@ -32,7 +32,7 @@ public class ArrowStreamWriterTests [Fact] public void Ctor_LeaveOpenDefault_StreamClosedOnDispose() { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); var stream = new MemoryStream(); new ArrowStreamWriter(stream, originalBatch.Schema).Dispose(); Assert.Throws(() => stream.Position); @@ -41,7 +41,7 @@ public void Ctor_LeaveOpenDefault_StreamClosedOnDispose() [Fact] public void Ctor_LeaveOpenFalse_StreamClosedOnDispose() { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); var stream = new MemoryStream(); new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: false).Dispose(); Assert.Throws(() => stream.Position); @@ -50,18 +50,19 @@ public void Ctor_LeaveOpenFalse_StreamClosedOnDispose() [Fact] public void Ctor_LeaveOpenTrue_StreamValidOnDispose() { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); var stream = new MemoryStream(); new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true).Dispose(); Assert.Equal(0, stream.Position); } - [Fact] - public void CanWriteToNetworkStream() + [Theory] + [InlineData(true, 32153)] + [InlineData(false, 32154)] + public void CanWriteToNetworkStream(bool createDictionaryArray, int port) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); - const int port = 32153; TcpListener listener = new TcpListener(IPAddress.Loopback, port); listener.Start(); @@ -90,12 +91,13 @@ public void CanWriteToNetworkStream() } } - [Fact] - public async Task CanWriteToNetworkStreamAsync() + [Theory] + [InlineData(true, 32155)] + [InlineData(false, 32156)] + public async Task CanWriteToNetworkStreamAsync(bool createDictionaryArray, int port) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); - const int port = 32154; TcpListener listener = new TcpListener(IPAddress.Loopback, port); listener.Start(); @@ -124,18 +126,22 @@ public async Task CanWriteToNetworkStreamAsync() } } - [Fact] - public void WriteEmptyBatch() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void WriteEmptyBatch(bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0, createDictionaryArray: createDictionaryArray); TestRoundTripRecordBatch(originalBatch); } - [Fact] - public async Task WriteEmptyBatchAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WriteEmptyBatchAsync(bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0, createDictionaryArray: createDictionaryArray); await TestRoundTripRecordBatchAsync(originalBatch); } @@ -392,27 +398,33 @@ public async Task WriteBatchWithCorrectPaddingAsync() } } - [Fact] - public void LegacyIpcFormatRoundTrips() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void LegacyIpcFormatRoundTrips(bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); TestRoundTripRecordBatch(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); } - [Fact] - public async Task LegacyIpcFormatRoundTripsAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task LegacyIpcFormatRoundTripsAsync(bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); await TestRoundTripRecordBatchAsync(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); } [Theory] - [InlineData(true)] - [InlineData(false)] - public void WriteLegacyIpcFormat(bool writeLegacyIpcFormat) + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public void WriteLegacyIpcFormat(bool writeLegacyIpcFormat, bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat }; using (MemoryStream stream = new MemoryStream()) @@ -448,11 +460,13 @@ public void WriteLegacyIpcFormat(bool writeLegacyIpcFormat) } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task WriteLegacyIpcFormatAsync(bool writeLegacyIpcFormat) + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task WriteLegacyIpcFormatAsync(bool writeLegacyIpcFormat, bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: true); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat }; using (MemoryStream stream = new MemoryStream()) diff --git a/csharp/test/Apache.Arrow.Tests/TestData.cs b/csharp/test/Apache.Arrow.Tests/TestData.cs index a66569ce4ad2..f15e857981ba 100644 --- a/csharp/test/Apache.Arrow.Tests/TestData.cs +++ b/csharp/test/Apache.Arrow.Tests/TestData.cs @@ -21,7 +21,6 @@ namespace Apache.Arrow.Tests { public static class TestData { - //TODO: Remove the createDictionaryArray argument after all writer/reader supports DictionaryType serialization public static RecordBatch CreateSampleRecordBatch(int length, bool createDictionaryArray = false) { return CreateSampleRecordBatch(length, columnSetCount: 1, createDictionaryArray); @@ -254,7 +253,7 @@ public void Visit(StructType type) { nullBitmap.Append(true); } - + Array = new StructArray(type, Length, childArrays, nullBitmap.Build()); } From fae6e28bb118b27366dc94692ba8568ee45a7846 Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Mon, 5 Jul 2021 10:25:01 +0900 Subject: [PATCH 08/15] ARROW-6870: Support Dictionary - change the place of the HasCreatedDictionaryMemo property --- csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs | 1 - csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index 7c8ce34b7bed..f9086f3e9300 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -32,7 +32,6 @@ internal abstract class ArrowReaderImplementation : IDisposable private protected DictionaryMemo _dictionaryMemo; private protected DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo(); - private protected bool HasCreatedDictionaryMemo => _dictionaryMemo != null; public ArrowReaderImplementation() { diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 2f6149854e01..d6b2abb503c4 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -28,6 +28,7 @@ internal class ArrowStreamReaderImplementation : ArrowReaderImplementation private readonly bool _leaveOpen; private readonly MemoryAllocator _allocator; private bool HasReadInitialDictionary { get; set; } + private bool HasCreatedDictionaryMemo => _dictionaryMemo != null; public ArrowStreamReaderImplementation(Stream stream, MemoryAllocator allocator, bool leaveOpen) { From 57d844d462175229d51eace4b09d97ab0bd2556a Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Thu, 8 Jul 2021 23:47:47 +0900 Subject: [PATCH 09/15] ARROW-6870: Support Dictionary - Remove needless extra lines --- csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs | 3 --- 1 file changed, 3 deletions(-) diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs index faea33cf8ff3..29089f7688c0 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs @@ -202,7 +202,6 @@ public async Task WriteBatchWithNullsAsync() await TestRoundTripRecordBatchAsync(originalBatch); } - private static void TestRoundTripRecordBatches(List originalBatches, IpcOptions options = null) { using (MemoryStream stream = new MemoryStream()) @@ -260,7 +259,6 @@ private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptio TestRoundTripRecordBatches(new List { originalBatch }, options); } - private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null) { await TestRoundTripRecordBatchesAsync(new List { originalBatch }, options); @@ -407,7 +405,6 @@ public void LegacyIpcFormatRoundTrips(bool createDictionaryArray) TestRoundTripRecordBatch(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); } - [Theory] [InlineData(true)] [InlineData(false)] From 08fc153ff03cbd14620ea662589adba6e722a901 Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Sun, 11 Jul 2021 09:32:25 +0900 Subject: [PATCH 10/15] ARROW-6870: Support Dictionary - Avoid a needless allocation - Remove a needless ctor --- .../Ipc/ArrowReaderImplementation.cs | 4 -- .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 46 ++++++++++--------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index f9086f3e9300..e7fbae535dbe 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -33,10 +33,6 @@ internal abstract class ArrowReaderImplementation : IDisposable private protected DictionaryMemo _dictionaryMemo; private protected DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo(); - public ArrowReaderImplementation() - { - } - public void Dispose() { Dispose(true); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 8f0d3ddaa63f..578b5df4ec79 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -876,53 +876,55 @@ internal static void Collect(RecordBatch recordBatch, ref DictionaryMemo diction Schema schema = recordBatch.Schema; for (int i = 0; i < schema.Fields.Count; i++) { - { - Field field = schema.GetFieldByIndex(i); - IArrowArray array = recordBatch.Column(i); + Field field = schema.GetFieldByIndex(i); + IArrowArray array = recordBatch.Column(i); - CollectDictionary(field, array, ref dictionaryMemo); - } + CollectDictionary(field, array.Data, ref dictionaryMemo); } } - private static void CollectDictionary(Field field, IArrowArray array, ref DictionaryMemo dictionaryMemo) + private static void CollectDictionary(Field field, ArrayData arrayData, ref DictionaryMemo dictionaryMemo) { - if (field.DataType.TypeId == ArrowTypeId.Dictionary) + if (field.DataType is DictionaryType dictionaryType) { - IArrowArray dictionary = (array as DictionaryArray).Dictionary; + if (arrayData.Dictionary == null) + { + throw new ArgumentException($"{nameof(arrayData.Dictionary)} must not be null"); + } + arrayData.Dictionary.EnsureDataType(dictionaryType.ValueType.TypeId); + + IArrowArray dictionary = ArrowArrayFactory.BuildArray(arrayData.Dictionary); + dictionaryMemo ??= new DictionaryMemo(); long id = dictionaryMemo.GetOrAssignId(field); dictionaryMemo.AddOrReplaceDictionary(id, dictionary); - WalkChildren(dictionary, ref dictionaryMemo); + WalkChildren(arrayData, ref dictionaryMemo); } else { - WalkChildren(array, ref dictionaryMemo); + WalkChildren(arrayData, ref dictionaryMemo); } } - private static void WalkChildren(IArrowArray array, ref DictionaryMemo dictionaryMemo) + private static void WalkChildren(ArrayData arrayData, ref DictionaryMemo dictionaryMemo) { - ArrayData[] children = array.Data.Children; + ArrayData[] children = arrayData.Children; if (children == null) { return; } - if (!(array.Data.DataType is NestedType nestedType)) + if (arrayData.DataType is NestedType nestedType) { - return; - } - - for (int i = 0; i < nestedType.Fields.Count; i++) - { - Field childField = nestedType.Fields[i]; - ArrayData child = children[i]; - IArrowArray childArray = ArrowArrayFactory.BuildArray(child); + for (int i = 0; i < nestedType.Fields.Count; i++) + { + Field childField = nestedType.Fields[i]; + ArrayData child = children[i]; - CollectDictionary(childField, childArray, ref dictionaryMemo); + CollectDictionary(childField, child, ref dictionaryMemo); + } } } } From f7de40326339ef597cad44d5902d9a7d20f445c9 Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Sun, 11 Jul 2021 09:50:26 +0900 Subject: [PATCH 11/15] ARROW-6870: Support Dictionary - Support reading the replacement dictionaries - Add tests for writing and reading dictionaries used in NestedType arrays - Fix a bug when reading dictionaries used in NestedType arrays --- .../Ipc/ArrowReaderImplementation.cs | 10 + .../Ipc/ArrowStreamReaderImplementation.cs | 199 ++++++------------ .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 5 +- csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs | 6 +- .../ArrowStreamWriterTests.cs | 44 +++- 5 files changed, 121 insertions(+), 143 deletions(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index e7fbae535dbe..2e3a965da0b1 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -82,6 +82,16 @@ private static bool MatchEnum(Flatbuf.MessageHeader messageHeader, Type flatBuff } } + /// + /// Create a record batch or dictionary batch from Flatbuf.Message. + /// + /// + /// This method adds data to _dictionaryMemo and returns null when the message type is DictionaryBatch. + /// > + /// + /// The record batch when the message type is RecordBatch. + /// Null when the message type is not RecordBatch. + /// protected RecordBatch CreateArrowObjectFromMessage( Flatbuf.Message message, ByteBuffer bodyByteBuffer, IMemoryOwner memoryOwner) { diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index d6b2abb503c4..6958c3774b34 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -27,8 +27,6 @@ internal class ArrowStreamReaderImplementation : ArrowReaderImplementation public Stream BaseStream { get; } private readonly bool _leaveOpen; private readonly MemoryAllocator _allocator; - private bool HasReadInitialDictionary { get; set; } - private bool HasCreatedDictionaryMemo => _dictionaryMemo != null; public ArrowStreamReaderImplementation(Stream stream, MemoryAllocator allocator, bool leaveOpen) { @@ -45,42 +43,6 @@ protected override void Dispose(bool disposing) } } - private void ReadInitialDictionaries() - { - if (HasReadInitialDictionary) - { - return; - } - - if (HasCreatedDictionaryMemo) - { - int fieldCount = DictionaryMemo.GetFieldCount(); - for (int i = 0; i < fieldCount; ++i) - { - ReadArrowObject(); - } - } - HasReadInitialDictionary = true; - } - - private async ValueTask ReadInitialDictionariesAsync(CancellationToken cancellationToken) - { - if (HasReadInitialDictionary) - { - return; - } - - if (HasCreatedDictionaryMemo) - { - int fieldCount = DictionaryMemo.GetFieldCount(); - for (int i = 0; i < fieldCount; ++i) - { - await ReadArrowObjectAsync(cancellationToken).ConfigureAwait(false); - } - } - HasReadInitialDictionary = true; - } - public override async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) { // TODO: Loop until a record batch is read. @@ -97,19 +59,83 @@ protected async ValueTask ReadRecordBatchAsync(CancellationToken ca { await ReadSchemaAsync().ConfigureAwait(false); - await ReadInitialDictionariesAsync(cancellationToken).ConfigureAwait(false); + RecordBatch result = default; + Flatbuf.MessageHeader messageHeaderType = Flatbuf.MessageHeader.NONE; + + do + { + int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) + .ConfigureAwait(false); - return await ReadArrowObjectAsync(cancellationToken).ConfigureAwait(false); - } + if (messageLength == 0) + { + // reached end + return null; + } + await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) => + { + int bytesRead = await BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken) + .ConfigureAwait(false); + EnsureFullRead(messageBuff, bytesRead); + + Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); + + int bodyLength = checked((int)message.BodyLength); + messageHeaderType = message.HeaderType; + + IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); + Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); + bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken) + .ConfigureAwait(false); + EnsureFullRead(bodyBuff, bytesRead); + + FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); + result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); + }).ConfigureAwait(false); + } while (messageHeaderType == Flatbuf.MessageHeader.DictionaryBatch); + + return result; + } protected RecordBatch ReadRecordBatch() { ReadSchema(); - ReadInitialDictionaries(); + RecordBatch result = default; + Flatbuf.MessageHeader messageHeaderType = Flatbuf.MessageHeader.NONE; - return ReadArrowObject(); + do + { + int messageLength = ReadMessageLength(throwOnFullRead: false); + + if (messageLength == 0) + { + // reached end + return null; + } + + ArrayPool.Shared.RentReturn(messageLength, messageBuff => + { + int bytesRead = BaseStream.ReadFullBuffer(messageBuff); + EnsureFullRead(messageBuff, bytesRead); + + Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); + + int bodyLength = checked((int)message.BodyLength); + messageHeaderType = message.HeaderType; + + IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); + Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); + bytesRead = BaseStream.ReadFullBuffer(bodyBuff); + EnsureFullRead(bodyBuff, bytesRead); + + FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); + result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); + }); + } while (messageHeaderType == Flatbuf.MessageHeader.DictionaryBatch); + + return result; } protected virtual async ValueTask ReadSchemaAsync() @@ -154,93 +180,6 @@ protected virtual void ReadSchema() }); } - /// - /// Read a record batch or dictionary batch from Flatbuf.Message. - /// - /// - /// This method adds data to _dictionaryMemo and returns null when the message type is DictionaryBatch. - /// > - /// - /// The record batch when the message type is RecordBatch. - /// Null when the message type is DictionaryBatch. - /// - private async ValueTask ReadArrowObjectAsync(CancellationToken cancellationToken) - { - int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) - .ConfigureAwait(false); - - if (messageLength == 0) - { - // reached end - return null; - } - - RecordBatch result = default; - await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) => - { - int bytesRead = await BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken) - .ConfigureAwait(false); - EnsureFullRead(messageBuff, bytesRead); - - Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); - - int bodyLength = checked((int)message.BodyLength); - - IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); - Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); - bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken) - .ConfigureAwait(false); - EnsureFullRead(bodyBuff, bytesRead); - - FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); - result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); - }).ConfigureAwait(false); - - return result; - } - - /// - /// Read a record batch or dictionary batch from Flatbuf.Message. - /// - /// - /// This method adds data to _dictionaryMemo and returns null when the message type is DictionaryBatch. - /// > - /// - /// The record batch when the message type is RecordBatch. - /// Null when the message type is DictionaryBatch. - /// - private RecordBatch ReadArrowObject() - { - int messageLength = ReadMessageLength(throwOnFullRead: false); - - if (messageLength == 0) - { - // reached end - return null; - } - - RecordBatch result = default; - ArrayPool.Shared.RentReturn(messageLength, messageBuff => - { - int bytesRead = BaseStream.ReadFullBuffer(messageBuff); - EnsureFullRead(messageBuff, bytesRead); - - Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); - - int bodyLength = checked((int)message.BodyLength); - - IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); - Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); - bytesRead = BaseStream.ReadFullBuffer(bodyBuff); - EnsureFullRead(bodyBuff, bytesRead); - - FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); - result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); - }); - - return result; - } - private async ValueTask ReadMessageLengthAsync(bool throwOnFullRead, CancellationToken cancellationToken = default) { int messageLength = 0; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 578b5df4ec79..9e13bd371a65 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -639,10 +639,11 @@ private VectorOffset GetChildrenFieldOffset(Field field) VectorOffset childFieldChildrenVectorOffset = GetChildrenFieldOffset(childField); VectorOffset childFieldMetadataVectorOffset = GetFieldMetadataOffset(childField); + Offset dictionaryOffset = GetDictionaryOffset(childField); children[i] = Flatbuf.Field.CreateField(Builder, childFieldNameOffset, childField.IsNullable, childFieldType.Type, childFieldType.Offset, - default, childFieldChildrenVectorOffset, childFieldMetadataVectorOffset); + dictionaryOffset, childFieldChildrenVectorOffset, childFieldMetadataVectorOffset); } return Builder.CreateVectorOfTables(children); @@ -899,7 +900,7 @@ private static void CollectDictionary(Field field, ArrayData arrayData, ref Dict long id = dictionaryMemo.GetOrAssignId(field); dictionaryMemo.AddOrReplaceDictionary(id, dictionary); - WalkChildren(arrayData, ref dictionaryMemo); + WalkChildren(dictionary.Data, ref dictionaryMemo); } else { diff --git a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs index 4952205298f7..83a8d88b8fc0 100644 --- a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs +++ b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs @@ -21,9 +21,9 @@ namespace Apache.Arrow.Ipc { class DictionaryMemo { - private Dictionary _idToDictionary; - private Dictionary _idToValueType; - private Dictionary _fieldToId; + private readonly Dictionary _idToDictionary; + private readonly Dictionary _idToValueType; + private readonly Dictionary _fieldToId; public DictionaryMemo() { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs index 29089f7688c0..ee690d73dc4e 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs @@ -548,7 +548,7 @@ private List CreateMultipleDictionaryArraysTestData() var dictionaryData = new List { "a", "b", "c" }; int length = dictionaryData.Count; - var indicesSchema = new Schema(new List { + var schemaForSimpleCase = new Schema(new List { new Field("int8", Int8Type.Default, true), new Field("uint8", UInt8Type.Default, true), new Field("int16", Int16Type.Default, true), @@ -560,26 +560,54 @@ private List CreateMultipleDictionaryArraysTestData() }, null); StringArray dictionary = new StringArray.Builder().AppendRange(dictionaryData).Build(); - IEnumerable indicesArrays = TestData.CreateArrays(indicesSchema, length); + IEnumerable indicesArraysForSimpleCase = TestData.CreateArrays(schemaForSimpleCase, length); - var fields = new List(capacity: length); - var dictionaryArrays = new List(capacity: length); + var fields = new List(capacity: length + 1); + var testTargetArrays = new List(capacity: length + 1); - foreach (IArrowArray indices in indicesArrays) + foreach (IArrowArray indices in indicesArraysForSimpleCase) { var dictionaryArray = new DictionaryArray( new DictionaryType(indices.Data.DataType, StringType.Default, false), indices, dictionary); - dictionaryArrays.Add(dictionaryArray); + testTargetArrays.Add(dictionaryArray); fields.Add(new Field($"dictionaryField_{indices.Data.DataType.Name}", dictionaryArray.Data.DataType, false)); } + (Field listField, ListArray listArray) = CreateDictionaryListArrayTestData(dictionary); + + fields.Add(listField); + testTargetArrays.Add(listArray); + var schema = new Schema(fields, null); return new List { - new RecordBatch(schema, dictionaryArrays, length), - new RecordBatch(schema, dictionaryArrays, length), + new RecordBatch(schema, testTargetArrays, length), + new RecordBatch(schema, testTargetArrays, length), }; } + + private Tuple CreateDictionaryListArrayTestData(StringArray dictionary) + { + List indices = Enumerable.Range(0, dictionary.Length).ToList(); + Int32Array indiceArray = new Int32Array.Builder().AppendRange(indices).Build(); + var dictionaryType = new DictionaryType(Int32Type.Default, StringType.Default, false); + var dictionaryArray = new DictionaryArray(dictionaryType, indiceArray, dictionary); + + var valueOffsetsBufferBuilder = new ArrowBuffer.Builder(); + var validityBufferBuilder = new ArrowBuffer.BitmapBuilder(); + + foreach (int i in Enumerable.Range(0, dictionary.Length + 1)) + { + valueOffsetsBufferBuilder.Append(i); + validityBufferBuilder.Append(true); + } + + var dictionaryField = new Field("dictionaryField_list", dictionaryType, false); + var listType = new ListType(dictionaryField); + var listArray = new ListArray(listType, valueOffsetsBufferBuilder.Length - 1, valueOffsetsBufferBuilder.Build(), dictionaryArray, valueOffsetsBufferBuilder.Build()); + + return Tuple.Create(new Field($"ListField_{listType.ValueDataType.Name}", listType, false), listArray); + } } } From 30c0cd9e4a44cc20b5e2a905bd54e0dd6a52c66b Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Mon, 12 Jul 2021 03:54:12 +0900 Subject: [PATCH 12/15] ARROW-6870: Support Dictionary - Regard indexType as signed int32 if it is null --- csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs index f464895e4e11..d156d3a625f9 100644 --- a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs +++ b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs @@ -88,11 +88,10 @@ private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField, ref Diction if (dictionaryEncoding.HasValue) { Flatbuf.Int? indexTypeAsInt = dictionaryEncoding.Value.IndexType; - if (!indexTypeAsInt.HasValue) - { - throw new InvalidDataException("Dictionary IndexType not defined"); - } - IArrowType indexType = GetNumberType(indexTypeAsInt.Value.BitWidth, indexTypeAsInt.Value.IsSigned); + IArrowType indexType = indexTypeAsInt.HasValue ? + GetNumberType(indexTypeAsInt.Value.BitWidth, indexTypeAsInt.Value.IsSigned) : + GetNumberType(Int32Type.Default.BitWidth, Int32Type.Default.IsSigned); + type = new DictionaryType(indexType, type, dictionaryEncoding.Value.IsOrdered); } From 75970291deeda1faf4c17ae91d0e2db910aa215c Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Tue, 13 Jul 2021 08:56:38 +0900 Subject: [PATCH 13/15] ARROW-6870: Support Dictionary - DictionaryArray with children - Add a test for this - Fix a serialization bug about this --- .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 6 ++- csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs | 5 --- .../ArrowStreamWriterTests.cs | 38 +++++++++++++++---- 3 files changed, 36 insertions(+), 13 deletions(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 9e13bd371a65..bfb6c05ca380 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -623,7 +623,11 @@ private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken ca private VectorOffset GetChildrenFieldOffset(Field field) { - if (!(field.DataType is NestedType type)) + IArrowType targetDataType = field.DataType is DictionaryType dictionaryType ? + dictionaryType.ValueType : + field.DataType; + + if (!(targetDataType is NestedType type)) { return default; } diff --git a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs index 83a8d88b8fc0..e069be8d9a97 100644 --- a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs +++ b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs @@ -99,10 +99,5 @@ public void AddOrReplaceDictionary(long id, IArrowArray dictionary) { _idToDictionary[id] = dictionary; } - - public int GetFieldCount() - { - return _fieldToId.Count; - } } } diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs index ee690d73dc4e..837fe68a0daf 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs @@ -574,10 +574,15 @@ private List CreateMultipleDictionaryArraysTestData() fields.Add(new Field($"dictionaryField_{indices.Data.DataType.Name}", dictionaryArray.Data.DataType, false)); } - (Field listField, ListArray listArray) = CreateDictionaryListArrayTestData(dictionary); + (Field dictionaryTypeListArrayField, ListArray dictionaryTypeListArray) = CreateDictionaryTypeListArrayTestData(dictionary); - fields.Add(listField); - testTargetArrays.Add(listArray); + fields.Add(dictionaryTypeListArrayField); + testTargetArrays.Add(dictionaryTypeListArray); + + (Field listTypeDictionaryArrayField, DictionaryArray listTypeDictionaryArray) = CreateListTypeDictionaryArrayTestData(dictionaryData); + + fields.Add(listTypeDictionaryArrayField); + testTargetArrays.Add(listTypeDictionaryArray); var schema = new Schema(fields, null); @@ -587,10 +592,11 @@ private List CreateMultipleDictionaryArraysTestData() }; } - private Tuple CreateDictionaryListArrayTestData(StringArray dictionary) + private Tuple CreateDictionaryTypeListArrayTestData(StringArray dictionary) { - List indices = Enumerable.Range(0, dictionary.Length).ToList(); - Int32Array indiceArray = new Int32Array.Builder().AppendRange(indices).Build(); + Int32Array indiceArray = new Int32Array.Builder().AppendRange(Enumerable.Range(0, dictionary.Length)).Build(); + + //DictionaryArray has no Builder for now, so creating ListArray directly. var dictionaryType = new DictionaryType(Int32Type.Default, StringType.Default, false); var dictionaryArray = new DictionaryArray(dictionaryType, indiceArray, dictionary); @@ -607,7 +613,25 @@ private Tuple CreateDictionaryListArrayTestData(StringArray di var listType = new ListType(dictionaryField); var listArray = new ListArray(listType, valueOffsetsBufferBuilder.Length - 1, valueOffsetsBufferBuilder.Build(), dictionaryArray, valueOffsetsBufferBuilder.Build()); - return Tuple.Create(new Field($"ListField_{listType.ValueDataType.Name}", listType, false), listArray); + return Tuple.Create(new Field($"listField_{listType.ValueDataType.Name}", listType, false), listArray); + } + + private Tuple CreateListTypeDictionaryArrayTestData(List dictionaryDataBase) + { + var listBuilder = new ListArray.Builder(StringType.Default); + var valueBuilder = listBuilder.ValueBuilder as StringArray.Builder; + + foreach(string data in dictionaryDataBase) { + listBuilder.Append(); + valueBuilder.Append(data); + } + + ListArray dictionary = listBuilder.Build(); + Int32Array indiceArray = new Int32Array.Builder().AppendRange(Enumerable.Range(0, dictionary.Length)).Build(); + var dictionaryArrayType = new DictionaryType(Int32Type.Default, dictionary.Data.DataType, false); + var dictionaryArray = new DictionaryArray(dictionaryArrayType, indiceArray, dictionary); + + return Tuple.Create(new Field($"dictionaryField_{dictionaryArray.Data.DataType.Name}", dictionaryArrayType, false), dictionaryArray); } } } From 487fc011b907ff83856650ae2b5891a8f7c5b344 Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Mon, 9 Aug 2021 10:51:12 +0900 Subject: [PATCH 14/15] ARROW-6870: Support Dictionary - Fix minor issues --- .../Ipc/ArrowStreamReaderImplementation.cs | 12 ++++-------- csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 6 +++--- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 6958c3774b34..0c12857ba0ae 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -59,8 +59,7 @@ protected async ValueTask ReadRecordBatchAsync(CancellationToken ca { await ReadSchemaAsync().ConfigureAwait(false); - RecordBatch result = default; - Flatbuf.MessageHeader messageHeaderType = Flatbuf.MessageHeader.NONE; + RecordBatch result = null; do { @@ -82,7 +81,6 @@ await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); int bodyLength = checked((int)message.BodyLength); - messageHeaderType = message.HeaderType; IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); @@ -93,7 +91,7 @@ await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); }).ConfigureAwait(false); - } while (messageHeaderType == Flatbuf.MessageHeader.DictionaryBatch); + } while (result == null); return result; } @@ -102,8 +100,7 @@ protected RecordBatch ReadRecordBatch() { ReadSchema(); - RecordBatch result = default; - Flatbuf.MessageHeader messageHeaderType = Flatbuf.MessageHeader.NONE; + RecordBatch result = null; do { @@ -123,7 +120,6 @@ protected RecordBatch ReadRecordBatch() Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); int bodyLength = checked((int)message.BodyLength); - messageHeaderType = message.HeaderType; IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); @@ -133,7 +129,7 @@ protected RecordBatch ReadRecordBatch() FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); }); - } while (messageHeaderType == Flatbuf.MessageHeader.DictionaryBatch); + } while (result == null); return result; } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index bfb6c05ca380..7ef9813265a8 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -131,8 +131,8 @@ public void Visit(StructArray array) public void Visit(DictionaryArray array) { - //Dictionary is serialized separately in Dictionary serialization. - //We are only interested in indexes at this context. + // Dictionary is serialized separately in Dictionary serialization. + // We are only interested in indices at this context. _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); _buffers.Add(CreateBuffer(array.IndicesBuffer)); @@ -523,7 +523,7 @@ await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch, fieldNodesVectorOffset, buffersVectorOffset); - //TODO: Support delta. + // TODO: Support delta. Offset dictionaryBatchOffset = Flatbuf.DictionaryBatch.CreateDictionaryBatch(Builder, id, recordBatchOffset, false); return Tuple.Create(recordBatchBuilder, dictionaryBatchOffset); } From 6bdd277658a39c38c5bca7bd67e1fd079a786c4d Mon Sep 17 00:00:00 2001 From: Takashi Hashida Date: Mon, 9 Aug 2021 19:19:48 +0900 Subject: [PATCH 15/15] ARROW-6870: Support Dictionary - do-while -> while --- .../Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 0c12857ba0ae..cc66c3873f37 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -61,7 +61,7 @@ protected async ValueTask ReadRecordBatchAsync(CancellationToken ca RecordBatch result = null; - do + while (result == null) { int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) .ConfigureAwait(false); @@ -91,7 +91,7 @@ await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); }).ConfigureAwait(false); - } while (result == null); + } return result; } @@ -102,7 +102,7 @@ protected RecordBatch ReadRecordBatch() RecordBatch result = null; - do + while (result == null) { int messageLength = ReadMessageLength(throwOnFullRead: false); @@ -129,7 +129,7 @@ protected RecordBatch ReadRecordBatch() FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); }); - } while (result == null); + } return result; }