diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs index 0ac2d19b2971..36b13a63d3c0 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs @@ -44,13 +44,17 @@ public static Schema DecodeSchema(ReadOnlyMemory buffer) } ByteBuffer schemaBuffer = ArrowReaderImplementation.CreateByteBuffer(buffer.Slice(bufferPosition)); - var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer)); + //DictionaryBatch not supported for now + DictionaryMemo dictionaryMemo = null; + var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer), ref dictionaryMemo); return schema; } internal static Schema DecodeSchema(ByteBuffer schemaBuffer) { - var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer)); + //DictionaryBatch not supported for now + DictionaryMemo dictionaryMemo = null; + var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer), ref dictionaryMemo); return schema; } } diff --git a/csharp/src/Apache.Arrow/Arrays/ArrayData.cs b/csharp/src/Apache.Arrow/Arrays/ArrayData.cs index 93bb5ccf6d8e..5bae443e1120 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrayData.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrayData.cs @@ -30,11 +30,30 @@ public sealed class ArrayData : IDisposable public readonly int Offset; public readonly ArrowBuffer[] Buffers; public readonly ArrayData[] Children; + public readonly ArrayData Dictionary; //Only used for dictionary type + + //This is left for compatibility with lower version binaries + //before the dictionary type was supported. + public ArrayData( + IArrowType dataType, + int length, int nullCount, int offset, + IEnumerable buffers, IEnumerable children) : + this(dataType, length, nullCount, offset, buffers, children, null) + { } + + //This is left for compatibility with lower version binaries + //before the dictionary type was supported. + public ArrayData( + IArrowType dataType, + int length, int nullCount, int offset, + ArrowBuffer[] buffers, ArrayData[] children) : + this(dataType, length, nullCount, offset, buffers, children, null) + { } public ArrayData( IArrowType dataType, int length, int nullCount = 0, int offset = 0, - IEnumerable buffers = null, IEnumerable children = null) + IEnumerable buffers = null, IEnumerable children = null, ArrayData dictionary = null) { DataType = dataType ?? NullType.Default; Length = length; @@ -42,12 +61,13 @@ public ArrayData( Offset = offset; Buffers = buffers?.ToArray(); Children = children?.ToArray(); + Dictionary = dictionary; } public ArrayData( IArrowType dataType, int length, int nullCount = 0, int offset = 0, - ArrowBuffer[] buffers = null, ArrayData[] children = null) + ArrowBuffer[] buffers = null, ArrayData[] children = null, ArrayData dictionary = null) { DataType = dataType ?? NullType.Default; Length = length; @@ -55,6 +75,7 @@ public ArrayData( Offset = offset; Buffers = buffers; Children = children; + Dictionary = dictionary; } public void Dispose() @@ -74,6 +95,8 @@ public void Dispose() child?.Dispose(); } } + + Dictionary?.Dispose(); } public ArrayData Slice(int offset, int length) @@ -86,7 +109,7 @@ public ArrayData Slice(int offset, int length) length = Math.Min(Length - offset, length); offset += Offset; - return new ArrayData(DataType, length, RecalculateNullCount, offset, Buffers, Children); + return new ArrayData(DataType, length, RecalculateNullCount, offset, Buffers, Children, Dictionary); } } } diff --git a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs index c3429230cc65..c8c0b2487f6a 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs @@ -67,6 +67,7 @@ public static IArrowArray BuildArray(ArrayData data) case ArrowTypeId.Decimal256: return new Decimal256Array(data); case ArrowTypeId.Dictionary: + return new DictionaryArray(data); case ArrowTypeId.FixedSizedBinary: case ArrowTypeId.HalfFloat: case ArrowTypeId.Interval: diff --git a/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs b/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs new file mode 100644 index 000000000000..29c0f5c84c75 --- /dev/null +++ b/csharp/src/Apache.Arrow/Arrays/DictionaryArray.cs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.IO; +using Apache.Arrow.Types; + +namespace Apache.Arrow +{ + public class DictionaryArray : Array + { + public IArrowArray Dictionary { get; } + public IArrowArray Indices { get; } + public ArrowBuffer IndicesBuffer => Data.Buffers[1]; + + public DictionaryArray(ArrayData data) : base(data) + { + data.EnsureBufferCount(2); + data.EnsureDataType(ArrowTypeId.Dictionary); + + if (data.Dictionary == null) + { + throw new ArgumentException($"{nameof(data.Dictionary)} must not be null"); + } + + var dicType = (DictionaryType)data.DataType; + data.Dictionary.EnsureDataType(dicType.ValueType.TypeId); + + var indicesData = new ArrayData(dicType.IndexType, data.Length, data.NullCount, data.Offset, data.Buffers, data.Children); + + Indices = ArrowArrayFactory.BuildArray(indicesData); + Dictionary = ArrowArrayFactory.BuildArray(data.Dictionary); + } + + public DictionaryArray(DictionaryType dataType, IArrowArray indicesArray, IArrowArray dictionary) : + base(new ArrayData(dataType, indicesArray.Length, indicesArray.Data.NullCount, indicesArray.Data.Offset, indicesArray.Data.Buffers, indicesArray.Data.Children, dictionary.Data)) + { + Data.EnsureBufferCount(2); + + indicesArray.Data.EnsureDataType(dataType.IndexType.TypeId); + dictionary.Data.EnsureDataType(dataType.ValueType.TypeId); + + Indices = indicesArray; + Dictionary = dictionary; + } + + public override void Accept(IArrowArrayVisitor visitor) => Accept(this, visitor); + } +} diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs index f2bf21a0c4f3..36cd4ddf930c 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs @@ -135,7 +135,7 @@ private static int ReadFooterLength(Memory buffer) private void ReadSchema(Memory buffer) { // Deserialize the footer from the footer flatbuffer - _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer))); + _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)), ref _dictionaryMemo); Schema = _footer.Schema; } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs index aa7d7952d3f0..db269ae019b5 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs @@ -61,8 +61,8 @@ public ArrowFooter(Schema schema, IEnumerable dictionaries, IEnumerable _buffer; private int _bufferPosition; - public ArrowMemoryReaderImplementation(ReadOnlyMemory buffer) + public ArrowMemoryReaderImplementation(ReadOnlyMemory buffer) : base() { _buffer = buffer; } @@ -111,7 +111,7 @@ private void ReadSchema() } ByteBuffer schemaBuffer = CreateByteBuffer(_buffer.Slice(_bufferPosition)); - Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer)); + Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), ref _dictionaryMemo); _bufferPosition += schemaMessageLength; } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index 3279f7030557..2e3a965da0b1 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -30,6 +30,9 @@ internal abstract class ArrowReaderImplementation : IDisposable public Schema Schema { get; protected set; } protected bool HasReadSchema => Schema != null; + private protected DictionaryMemo _dictionaryMemo; + private protected DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo(); + public void Dispose() { Dispose(true); @@ -79,6 +82,16 @@ private static bool MatchEnum(Flatbuf.MessageHeader messageHeader, Type flatBuff } } + /// + /// Create a record batch or dictionary batch from Flatbuf.Message. + /// + /// + /// This method adds data to _dictionaryMemo and returns null when the message type is DictionaryBatch. + /// > + /// + /// The record batch when the message type is RecordBatch. + /// Null when the message type is not RecordBatch. + /// protected RecordBatch CreateArrowObjectFromMessage( Flatbuf.Message message, ByteBuffer bodyByteBuffer, IMemoryOwner memoryOwner) { @@ -88,8 +101,8 @@ protected RecordBatch CreateArrowObjectFromMessage( // TODO: Read schema and verify equality? break; case Flatbuf.MessageHeader.DictionaryBatch: - // TODO: not supported currently - Debug.WriteLine("Dictionaries are not yet supported."); + Flatbuf.DictionaryBatch dictionaryBatch = message.Header().Value; + ReadDictionaryBatch(dictionaryBatch, bodyByteBuffer, memoryOwner); break; case Flatbuf.MessageHeader.RecordBatch: Flatbuf.RecordBatch rb = message.Header().Value; @@ -109,6 +122,36 @@ internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory buffer) return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer), 0); } + private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBuffer bodyByteBuffer, IMemoryOwner memoryOwner) + { + long id = dictionaryBatch.Id; + IArrowType valueType = DictionaryMemo.GetDictionaryType(id); + Flatbuf.RecordBatch? recordBatch = dictionaryBatch.Data; + + if (!recordBatch.HasValue) + { + throw new InvalidDataException("Dictionary must contain RecordBatch"); + } + + Field valueField = new Field("dummy", valueType, true); + var schema = new Schema(new[] { valueField }, default); + IList arrays = BuildArrays(schema, bodyByteBuffer, recordBatch.Value); + + if (arrays.Count != 1) + { + throw new InvalidDataException("Dictionary record batch must contain only one field"); + } + + if (dictionaryBatch.IsDelta) + { + throw new NotImplementedException("Dictionary delta is not supported yet"); + } + else + { + DictionaryMemo.AddOrReplaceDictionary(id, arrays[0]); + } + } + private List BuildArrays( Schema schema, ByteBuffer messageBuffer, @@ -179,7 +222,14 @@ private ArrayData LoadPrimitiveField( ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData); - return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children); + IArrowArray dictionary = null; + if (field.DataType.TypeId == ArrowTypeId.Dictionary) + { + long id = DictionaryMemo.GetId(field); + dictionary = DictionaryMemo.GetDictionary(id); + } + + return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); } private ArrayData LoadVariableField( @@ -218,7 +268,14 @@ private ArrayData LoadVariableField( ArrowBuffer[] arrowBuff = new[] { nullArrowBuffer, offsetArrowBuffer, valueArrowBuffer }; ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData); - return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children); + IArrowArray dictionary = null; + if (field.DataType.TypeId == ArrowTypeId.Dictionary) + { + long id = DictionaryMemo.GetId(field); + dictionary = DictionaryMemo.GetDictionary(id); + } + + return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, children, dictionary?.Data); } private ArrayData[] GetChildren( diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 1fd320903bd0..cc66c3873f37 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -59,35 +59,39 @@ protected async ValueTask ReadRecordBatchAsync(CancellationToken ca { await ReadSchemaAsync().ConfigureAwait(false); - int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) - .ConfigureAwait(false); - - if (messageLength == 0) - { - // reached end - return null; - } - RecordBatch result = null; - await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) => + + while (result == null) { - int bytesRead = await BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken) + int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken) .ConfigureAwait(false); - EnsureFullRead(messageBuff, bytesRead); - Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); + if (messageLength == 0) + { + // reached end + return null; + } + + await ArrayPool.Shared.RentReturnAsync(messageLength, async (messageBuff) => + { + int bytesRead = await BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken) + .ConfigureAwait(false); + EnsureFullRead(messageBuff, bytesRead); - int bodyLength = checked((int)message.BodyLength); + Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); - IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); - Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); - bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken) - .ConfigureAwait(false); - EnsureFullRead(bodyBuff, bytesRead); + int bodyLength = checked((int)message.BodyLength); - FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); - result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); - }).ConfigureAwait(false); + IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); + Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); + bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken) + .ConfigureAwait(false); + EnsureFullRead(bodyBuff, bytesRead); + + FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); + result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); + }).ConfigureAwait(false); + } return result; } @@ -96,32 +100,36 @@ protected RecordBatch ReadRecordBatch() { ReadSchema(); - int messageLength = ReadMessageLength(throwOnFullRead: false); + RecordBatch result = null; - if (messageLength == 0) + while (result == null) { - // reached end - return null; - } + int messageLength = ReadMessageLength(throwOnFullRead: false); - RecordBatch result = null; - ArrayPool.Shared.RentReturn(messageLength, messageBuff => - { - int bytesRead = BaseStream.ReadFullBuffer(messageBuff); - EnsureFullRead(messageBuff, bytesRead); + if (messageLength == 0) + { + // reached end + return null; + } + + ArrayPool.Shared.RentReturn(messageLength, messageBuff => + { + int bytesRead = BaseStream.ReadFullBuffer(messageBuff); + EnsureFullRead(messageBuff, bytesRead); - Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); + Flatbuf.Message message = Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff)); - int bodyLength = checked((int)message.BodyLength); + int bodyLength = checked((int)message.BodyLength); - IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); - Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); - bytesRead = BaseStream.ReadFullBuffer(bodyBuff); - EnsureFullRead(bodyBuff, bytesRead); + IMemoryOwner bodyBuffOwner = _allocator.Allocate(bodyLength); + Memory bodyBuff = bodyBuffOwner.Memory.Slice(0, bodyLength); + bytesRead = BaseStream.ReadFullBuffer(bodyBuff); + EnsureFullRead(bodyBuff, bytesRead); - FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); - result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); - }); + FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff); + result = CreateArrowObjectFromMessage(message, bodybb, bodyBuffOwner); + }); + } return result; } @@ -144,7 +152,7 @@ await ArrayPool.Shared.RentReturnAsync(schemaMessageLength, async (buff) = EnsureFullRead(buff, bytesRead); FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb)); + Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); }).ConfigureAwait(false); } @@ -164,7 +172,7 @@ protected virtual void ReadSchema() EnsureFullRead(buff, bytesRead); FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb)); + Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); }); } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 5f0d16f83068..7ef9813265a8 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -48,7 +48,8 @@ internal class ArrowRecordBatchFlatBufferBuilder : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, - IArrowArrayVisitor + IArrowArrayVisitor, + IArrowArrayVisitor { public readonly struct Buffer { @@ -128,6 +129,15 @@ public void Visit(StructArray array) } } + public void Visit(DictionaryArray array) + { + // Dictionary is serialized separately in Dictionary serialization. + // We are only interested in indices at this context. + + _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); + _buffers.Add(CreateBuffer(array.IndicesBuffer)); + } + private void CreateBuffers(BooleanArray array) { _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); @@ -165,6 +175,8 @@ public void Visit(IArrowArray array) protected bool HasWrittenSchema { get; set; } + private bool HasWrittenDictionaryBatch { get; set; } + private bool HasWrittenEnd { get; set; } protected Schema Schema { get; } @@ -178,6 +190,9 @@ public void Visit(IArrowArray array) private readonly ArrowTypeFlatbufferBuilder _fieldTypeBuilder; + private DictionaryMemo _dictionaryMemo; + private DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo(); + public ArrowStreamWriter(Stream baseStream, Schema schema) : this(baseStream, schema, leaveOpen: false) { @@ -216,17 +231,17 @@ private void CreateSelfAndChildrenFieldNodes(ArrayData data) Flatbuf.FieldNode.CreateFieldNode(Builder, data.Length, data.NullCount); } - private int CountAllNodes() + private static int CountAllNodes(IReadOnlyDictionary fields) { int count = 0; - foreach (Field arrowArray in Schema.Fields.Values) + foreach (Field arrowArray in fields.Values) { CountSelfAndChildrenNodes(arrowArray.DataType, ref count); } return count; } - private void CountSelfAndChildrenNodes(IArrowType type, ref int count) + private static void CountSelfAndChildrenNodes(IArrowType type, ref int count) { if (type is NestedType nestedType) { @@ -248,6 +263,13 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) HasWrittenSchema = true; } + if (!HasWrittenDictionaryBatch) + { + DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); + WriteDictionaries(recordBatch); + HasWrittenDictionaryBatch = true; + } + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) = PreparingWritingRecordBatch(recordBatch); @@ -264,37 +286,9 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) long metadataLength = WriteMessage(Flatbuf.MessageHeader.RecordBatch, recordBatchOffset, recordBatchBuilder.TotalLength); - // Write buffer data + long bufferLength = WriteBufferData(recordBatchBuilder.Buffers); - IReadOnlyList buffers = recordBatchBuilder.Buffers; - - long bodyLength = 0; - - for (int i = 0; i < buffers.Count; i++) - { - ArrowBuffer buffer = buffers[i].DataBuffer; - if (buffer.IsEmpty) - continue; - - WriteBuffer(buffer); - - int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length)); - int padding = paddedLength - buffer.Length; - if (padding > 0) - { - WritePadding(padding); - } - - bodyLength += paddedLength; - } - - // Write padding so the record batch message body length is a multiple of 8 bytes - - int bodyPaddingLength = CalculatePadding(bodyLength); - - WritePadding(bodyPaddingLength); - - FinishedWritingRecordBatch(bodyLength + bodyPaddingLength, metadataLength); + FinishedWritingRecordBatch(bufferLength, metadataLength); } private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch, @@ -308,6 +302,13 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat HasWrittenSchema = true; } + if (!HasWrittenDictionaryBatch) + { + DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); + await WriteDictionariesAsync(recordBatch, cancellationToken).ConfigureAwait(false); + HasWrittenDictionaryBatch = true; + } + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) = PreparingWritingRecordBatch(recordBatch); @@ -325,10 +326,44 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat recordBatchOffset, recordBatchBuilder.TotalLength, cancellationToken).ConfigureAwait(false); - // Write buffer data + long bufferLength = await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); - IReadOnlyList buffers = recordBatchBuilder.Buffers; + FinishedWritingRecordBatch(bufferLength, metadataLength); + } + private long WriteBufferData(IReadOnlyList buffers) + { + long bodyLength = 0; + + for (int i = 0; i < buffers.Count; i++) + { + ArrowBuffer buffer = buffers[i].DataBuffer; + if (buffer.IsEmpty) + continue; + + WriteBuffer(buffer); + + int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length)); + int padding = paddedLength - buffer.Length; + if (padding > 0) + { + WritePadding(padding); + } + + bodyLength += paddedLength; + } + + // Write padding so the record batch message body length is a multiple of 8 bytes + + int bodyPaddingLength = CalculatePadding(bodyLength); + + WritePadding(bodyPaddingLength); + + return bodyLength + bodyPaddingLength; + } + + private async ValueTask WriteBufferDataAsync(IReadOnlyList buffers, CancellationToken cancellationToken = default) + { long bodyLength = 0; for (int i = 0; i < buffers.Count; i++) @@ -355,23 +390,28 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat await WritePaddingAsync(bodyPaddingLength).ConfigureAwait(false); - FinishedWritingRecordBatch(bodyLength + bodyPaddingLength, metadataLength); + return bodyLength + bodyPaddingLength; } private Tuple PreparingWritingRecordBatch(RecordBatch recordBatch) + { + return PreparingWritingRecordBatch(recordBatch.Schema.Fields, recordBatch.ArrayList); + } + + private Tuple PreparingWritingRecordBatch(IReadOnlyDictionary fields, IReadOnlyList arrays) { Builder.Clear(); // Serialize field nodes - int fieldCount = Schema.Fields.Count; + int fieldCount = fields.Count; - Flatbuf.RecordBatch.StartNodesVector(Builder, CountAllNodes()); + Flatbuf.RecordBatch.StartNodesVector(Builder, CountAllNodes(fields)); // flatbuffer struct vectors have to be created in reverse order for (int i = fieldCount - 1; i >= 0; i--) { - CreateSelfAndChildrenFieldNodes(recordBatch.Column(i).Data); + CreateSelfAndChildrenFieldNodes(arrays[i].Data); } VectorOffset fieldNodesVectorOffset = Builder.EndVector(); @@ -381,7 +421,7 @@ private Tuple PreparingWritingR var recordBatchBuilder = new ArrowRecordBatchFlatBufferBuilder(); for (int i = 0; i < fieldCount; i++) { - IArrowArray fieldArray = recordBatch.Column(i); + IArrowArray fieldArray = arrays[i]; fieldArray.Accept(recordBatchBuilder); } @@ -399,6 +439,95 @@ private Tuple PreparingWritingR return Tuple.Create(recordBatchBuilder, fieldNodesVectorOffset); } + + private protected void WriteDictionaries(RecordBatch recordBatch) + { + foreach (Field field in recordBatch.Schema.Fields.Values) + { + WriteDictionary(field); + } + } + + private protected void WriteDictionary(Field field) + { + if (field.DataType.TypeId != ArrowTypeId.Dictionary) + { + if (field.DataType is NestedType nestedType) + { + foreach (Field child in nestedType.Fields) + { + WriteDictionary(child); + } + } + return; + } + + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, Offset dictionaryBatchOffset) = + CreateDictionaryBatchOffset(field); + + WriteMessage(Flatbuf.MessageHeader.DictionaryBatch, + dictionaryBatchOffset, recordBatchBuilder.TotalLength); + + WriteBufferData(recordBatchBuilder.Buffers); + } + + private protected async Task WriteDictionariesAsync(RecordBatch recordBatch, CancellationToken cancellationToken) + { + foreach (Field field in recordBatch.Schema.Fields.Values) + { + await WriteDictionaryAsync(field, cancellationToken).ConfigureAwait(false); + } + } + + private protected async Task WriteDictionaryAsync(Field field, CancellationToken cancellationToken) + { + if (field.DataType.TypeId != ArrowTypeId.Dictionary) + { + if (field.DataType is NestedType nestedType) + { + foreach (Field child in nestedType.Fields) + { + await WriteDictionaryAsync(child, cancellationToken).ConfigureAwait(false); + } + } + return; + } + + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, Offset dictionaryBatchOffset) = + CreateDictionaryBatchOffset(field); + + await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch, + dictionaryBatchOffset, recordBatchBuilder.TotalLength, cancellationToken).ConfigureAwait(false); + + await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false); + } + + private Tuple> CreateDictionaryBatchOffset(Field field) + { + Field dictionaryField = new Field("dummy", ((DictionaryType)field.DataType).ValueType, false); + long id = DictionaryMemo.GetId(field); + IArrowArray dictionary = DictionaryMemo.GetDictionary(id); + + var fieldsDictionary = new Dictionary { + { dictionaryField.Name, dictionaryField } }; + + var arrays = new List { dictionary }; + + (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) = + PreparingWritingRecordBatch(fieldsDictionary, arrays); + + VectorOffset buffersVectorOffset = Builder.EndVector(); + + // Serialize record batch + Offset recordBatchOffset = Flatbuf.RecordBatch.CreateRecordBatch(Builder, dictionary.Length, + fieldNodesVectorOffset, + buffersVectorOffset); + + // TODO: Support delta. + Offset dictionaryBatchOffset = Flatbuf.DictionaryBatch.CreateDictionaryBatch(Builder, id, recordBatchOffset, false); + return Tuple.Create(recordBatchBuilder, dictionaryBatchOffset); + } + private protected virtual void WriteEndInternal() { WriteIpcMessageLength(length: 0); @@ -475,10 +604,11 @@ private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken ca VectorOffset fieldChildrenVectorOffset = GetChildrenFieldOffset(field); VectorOffset fieldMetadataVectorOffset = GetFieldMetadataOffset(field); + Offset dictionaryOffset = GetDictionaryOffset(field); fieldOffsets[i] = Flatbuf.Field.CreateField(Builder, fieldNameOffset, field.IsNullable, fieldType.Type, fieldType.Offset, - default, fieldChildrenVectorOffset, fieldMetadataVectorOffset); + dictionaryOffset, fieldChildrenVectorOffset, fieldMetadataVectorOffset); } VectorOffset fieldsVectorOffset = Flatbuf.Schema.CreateFieldsVector(Builder, fieldOffsets); @@ -493,7 +623,11 @@ private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken ca private VectorOffset GetChildrenFieldOffset(Field field) { - if (!(field.DataType is NestedType type)) + IArrowType targetDataType = field.DataType is DictionaryType dictionaryType ? + dictionaryType.ValueType : + field.DataType; + + if (!(targetDataType is NestedType type)) { return default; } @@ -509,10 +643,11 @@ private VectorOffset GetChildrenFieldOffset(Field field) VectorOffset childFieldChildrenVectorOffset = GetChildrenFieldOffset(childField); VectorOffset childFieldMetadataVectorOffset = GetFieldMetadataOffset(childField); + Offset dictionaryOffset = GetDictionaryOffset(childField); children[i] = Flatbuf.Field.CreateField(Builder, childFieldNameOffset, childField.IsNullable, childFieldType.Type, childFieldType.Offset, - default, childFieldChildrenVectorOffset, childFieldMetadataVectorOffset); + dictionaryOffset, childFieldChildrenVectorOffset, childFieldMetadataVectorOffset); } return Builder.CreateVectorOfTables(children); @@ -529,6 +664,21 @@ private VectorOffset GetFieldMetadataOffset(Field field) return Flatbuf.Field.CreateCustomMetadataVector(Builder, metadataOffsets); } + private Offset GetDictionaryOffset(Field field) + { + if (field.DataType.TypeId != ArrowTypeId.Dictionary) + { + return default; + } + + long id = DictionaryMemo.GetOrAssignId(field); + var dicType = field.DataType as DictionaryType; + var indexType = dicType.IndexType as NumberType; + + Offset indexOffset = Flatbuf.Int.CreateInt(Builder, indexType.BitWidth, indexType.IsSigned); + return Flatbuf.DictionaryEncoding.CreateDictionaryEncoding(Builder, id, indexOffset, dicType.Ordered); + } + private Offset[] GetMetadataOffsets(IReadOnlyDictionary metadata) { Debug.Assert(metadata != null); @@ -723,4 +873,64 @@ public virtual void Dispose() } } } + + internal static class DictionaryCollector + { + internal static void Collect(RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo) + { + Schema schema = recordBatch.Schema; + for (int i = 0; i < schema.Fields.Count; i++) + { + Field field = schema.GetFieldByIndex(i); + IArrowArray array = recordBatch.Column(i); + + CollectDictionary(field, array.Data, ref dictionaryMemo); + } + } + + private static void CollectDictionary(Field field, ArrayData arrayData, ref DictionaryMemo dictionaryMemo) + { + if (field.DataType is DictionaryType dictionaryType) + { + if (arrayData.Dictionary == null) + { + throw new ArgumentException($"{nameof(arrayData.Dictionary)} must not be null"); + } + arrayData.Dictionary.EnsureDataType(dictionaryType.ValueType.TypeId); + + IArrowArray dictionary = ArrowArrayFactory.BuildArray(arrayData.Dictionary); + + dictionaryMemo ??= new DictionaryMemo(); + long id = dictionaryMemo.GetOrAssignId(field); + + dictionaryMemo.AddOrReplaceDictionary(id, dictionary); + WalkChildren(dictionary.Data, ref dictionaryMemo); + } + else + { + WalkChildren(arrayData, ref dictionaryMemo); + } + } + + private static void WalkChildren(ArrayData arrayData, ref DictionaryMemo dictionaryMemo) + { + ArrayData[] children = arrayData.Children; + + if (children == null) + { + return; + } + + if (arrayData.DataType is NestedType nestedType) + { + for (int i = 0; i < nestedType.Fields.Count; i++) + { + Field childField = nestedType.Fields[i]; + ArrayData child = children[i]; + + CollectDictionary(childField, child, ref dictionaryMemo); + } + } + } + } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs index d0d2b74e701b..23f5b3f3d9c8 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs @@ -63,7 +63,8 @@ class TypeVisitor : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private FlatBufferBuilder Builder { get; } @@ -201,6 +202,14 @@ private void CreateIntType(NumberType type) Flatbuf.Int.CreateInt(Builder, type.BitWidth, type.IsSigned)); } + public void Visit(DictionaryType type) + { + // In this library, the dictionary "type" is a logical construct. Here we + // pass through to the value type, as we've already captured the index + // type in the DictionaryEncoding metadata in the parent field + type.ValueType.Accept(this); + } + public void Visit(IArrowType type) { throw new NotImplementedException(); diff --git a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs new file mode 100644 index 000000000000..e069be8d9a97 --- /dev/null +++ b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using Apache.Arrow.Types; + +namespace Apache.Arrow.Ipc +{ + class DictionaryMemo + { + private readonly Dictionary _idToDictionary; + private readonly Dictionary _idToValueType; + private readonly Dictionary _fieldToId; + + public DictionaryMemo() + { + _idToDictionary = new Dictionary(); + _idToValueType = new Dictionary(); + _fieldToId = new Dictionary(); + } + + public IArrowType GetDictionaryType(long id) + { + if (!_idToValueType.TryGetValue(id, out IArrowType type)) + { + throw new ArgumentException($"Dictionary with id {id} not found"); + } + return type; + } + + public IArrowArray GetDictionary(long id) + { + if (!_idToDictionary.TryGetValue(id, out IArrowArray dictionary)) + { + throw new ArgumentException($"Dictionary with id {id} not found"); + } + return dictionary; + } + + public void AddField(long id, Field field) + { + if (_fieldToId.ContainsKey(field)) + { + throw new ArgumentException($"Field {field.Name} is already in Memo"); + } + + if (field.DataType.TypeId != ArrowTypeId.Dictionary) + { + throw new ArgumentException($"Field type is not DictionaryType: Name={field.Name}, {field.DataType.Name}"); + } + + IArrowType valueType = ((DictionaryType)field.DataType).ValueType; + + if (_idToValueType.TryGetValue(id, out IArrowType valueTypeInDic)) + { + if (valueType != valueTypeInDic) + { + throw new ArgumentException($"Field type {field.DataType.Name} does not match the existing type {valueTypeInDic})"); + } + } + + _fieldToId.Add(field, id); + _idToValueType.Add(id, valueType); + } + + public long GetId(Field field) + { + if (!_fieldToId.TryGetValue(field, out long id)) + { + throw new ArgumentException($"Field with name {field.Name} not found"); + } + return id; + } + + public long GetOrAssignId(Field field) + { + if (!_fieldToId.TryGetValue(field, out long id)) + { + id = _fieldToId.Count + 1; + AddField(id, field); + } + return id; + } + + public void AddOrReplaceDictionary(long id, IArrowArray dictionary) + { + _idToDictionary[id] = dictionary; + } + } +} diff --git a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs index a4e766089245..d156d3a625f9 100644 --- a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs +++ b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs @@ -17,6 +17,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; +using Apache.Arrow.Types; namespace Apache.Arrow.Ipc { @@ -52,14 +53,13 @@ public static Types.NumberType GetNumberType(int bitWidth, bool signed) $"{(signed ? "signed " : "unsigned")} integer."); } - internal static Schema GetSchema(Flatbuf.Schema schema) + internal static Schema GetSchema(Flatbuf.Schema schema, ref DictionaryMemo dictionaryMemo) { List fields = new List(); for (int i = 0; i < schema.FieldsLength; i++) { Flatbuf.Field field = schema.Fields(i).GetValueOrDefault(); - - fields.Add(FieldFromFlatbuffer(field)); + fields.Add(FieldFromFlatbuffer(field, ref dictionaryMemo)); } Dictionary metadata = schema.CustomMetadataLength > 0 ? new Dictionary() : null; @@ -73,13 +73,26 @@ internal static Schema GetSchema(Flatbuf.Schema schema) return new Schema(fields, metadata, copyCollections: false); } - private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField) + private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField, ref DictionaryMemo dictionaryMemo) { Field[] childFields = flatbufField.ChildrenLength > 0 ? new Field[flatbufField.ChildrenLength] : null; for (int i = 0; i < flatbufField.ChildrenLength; i++) { Flatbuf.Field? childFlatbufField = flatbufField.Children(i); - childFields[i] = FieldFromFlatbuffer(childFlatbufField.Value); + childFields[i] = FieldFromFlatbuffer(childFlatbufField.Value, ref dictionaryMemo); + } + + Flatbuf.DictionaryEncoding? dictionaryEncoding = flatbufField.Dictionary; + IArrowType type = GetFieldArrowType(flatbufField, childFields); + + if (dictionaryEncoding.HasValue) + { + Flatbuf.Int? indexTypeAsInt = dictionaryEncoding.Value.IndexType; + IArrowType indexType = indexTypeAsInt.HasValue ? + GetNumberType(indexTypeAsInt.Value.BitWidth, indexTypeAsInt.Value.IsSigned) : + GetNumberType(Int32Type.Default.BitWidth, Int32Type.Default.IsSigned); + + type = new DictionaryType(indexType, type, dictionaryEncoding.Value.IsOrdered); } Dictionary metadata = flatbufField.CustomMetadataLength > 0 ? new Dictionary() : null; @@ -90,7 +103,15 @@ private static Field FieldFromFlatbuffer(Flatbuf.Field flatbufField) metadata[keyValue.Key] = keyValue.Value; } - return new Field(flatbufField.Name, GetFieldArrowType(flatbufField, childFields), flatbufField.Nullable, metadata, copyCollections: false); + var arrowField = new Field(flatbufField.Name, type, flatbufField.Nullable, metadata, copyCollections: false); + + if (dictionaryEncoding.HasValue) + { + dictionaryMemo ??= new DictionaryMemo(); + dictionaryMemo.AddField(dictionaryEncoding.Value.Id, arrowField); + } + + return arrowField; } private static Types.IArrowType GetFieldArrowType(Flatbuf.Field field, Field[] childFields = null) diff --git a/csharp/src/Apache.Arrow/RecordBatch.cs b/csharp/src/Apache.Arrow/RecordBatch.cs index 3cffabee32be..6e97100686ee 100644 --- a/csharp/src/Apache.Arrow/RecordBatch.cs +++ b/csharp/src/Apache.Arrow/RecordBatch.cs @@ -28,8 +28,10 @@ public partial class RecordBatch : IDisposable public IEnumerable Arrays => _arrays; public int Length { get; } + internal IReadOnlyList ArrayList => _arrays; + private readonly IMemoryOwner _memoryOwner; - private readonly IList _arrays; + private readonly List _arrays; public IArrowArray Column(int i) { diff --git a/csharp/src/Apache.Arrow/Types/DictionaryType.cs b/csharp/src/Apache.Arrow/Types/DictionaryType.cs new file mode 100644 index 000000000000..5c1dd4095eb1 --- /dev/null +++ b/csharp/src/Apache.Arrow/Types/DictionaryType.cs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +using System; + +namespace Apache.Arrow.Types +{ + public sealed class DictionaryType : FixedWidthType + { + public static readonly DictionaryType Default = new DictionaryType(Int64Type.Default, Int64Type.Default, false); + + public DictionaryType(IArrowType indexType, IArrowType valueType, bool ordered) + { + if (!(indexType is IntegerType)) + { + throw new ArgumentException($"{nameof(indexType)} must be integer"); + } + + IndexType = indexType; + ValueType = valueType; + Ordered = ordered; + } + + public override ArrowTypeId TypeId => ArrowTypeId.Dictionary; + public override string Name => "dictionary"; + public override int BitWidth => 64; + public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); + + public IArrowType IndexType { get; private set; } + public IArrowType ValueType { get; private set; } + public bool Ordered { get; private set; } + } +} diff --git a/csharp/src/Apache.Arrow/Types/Int16Type.cs b/csharp/src/Apache.Arrow/Types/Int16Type.cs index f1d6868ba8ae..564ae069206f 100644 --- a/csharp/src/Apache.Arrow/Types/Int16Type.cs +++ b/csharp/src/Apache.Arrow/Types/Int16Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class Int16Type : NumberType + public sealed class Int16Type : IntegerType { public static readonly Int16Type Default = new Int16Type(); @@ -26,4 +26,4 @@ public sealed class Int16Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/Int32Type.cs b/csharp/src/Apache.Arrow/Types/Int32Type.cs index a32c88462983..bc2ad32e4a10 100644 --- a/csharp/src/Apache.Arrow/Types/Int32Type.cs +++ b/csharp/src/Apache.Arrow/Types/Int32Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class Int32Type : NumberType + public sealed class Int32Type : IntegerType { public static readonly Int32Type Default = new Int32Type(); @@ -26,4 +26,4 @@ public sealed class Int32Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/Int64Type.cs b/csharp/src/Apache.Arrow/Types/Int64Type.cs index f45523cfb330..9be7f2161ee8 100644 --- a/csharp/src/Apache.Arrow/Types/Int64Type.cs +++ b/csharp/src/Apache.Arrow/Types/Int64Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class Int64Type : NumberType + public sealed class Int64Type : IntegerType { public static readonly Int64Type Default = new Int64Type(); @@ -26,4 +26,4 @@ public sealed class Int64Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/Int8Type.cs b/csharp/src/Apache.Arrow/Types/Int8Type.cs index 9b3f5b5b4fc9..fd6e471155f5 100644 --- a/csharp/src/Apache.Arrow/Types/Int8Type.cs +++ b/csharp/src/Apache.Arrow/Types/Int8Type.cs @@ -16,7 +16,7 @@ namespace Apache.Arrow.Types { - public sealed class Int8Type : NumberType + public sealed class Int8Type : IntegerType { public static readonly Int8Type Default = new Int8Type(); diff --git a/csharp/src/Apache.Arrow/Types/IntegerType.cs b/csharp/src/Apache.Arrow/Types/IntegerType.cs new file mode 100644 index 000000000000..7a5057c46686 --- /dev/null +++ b/csharp/src/Apache.Arrow/Types/IntegerType.cs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +namespace Apache.Arrow.Types +{ + public abstract class IntegerType: NumberType + { + } +} diff --git a/csharp/src/Apache.Arrow/Types/UInt16Type.cs b/csharp/src/Apache.Arrow/Types/UInt16Type.cs index 1925ffb86b79..7e020d37e756 100644 --- a/csharp/src/Apache.Arrow/Types/UInt16Type.cs +++ b/csharp/src/Apache.Arrow/Types/UInt16Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class UInt16Type : NumberType + public sealed class UInt16Type : IntegerType { public static readonly UInt16Type Default = new UInt16Type(); @@ -26,4 +26,4 @@ public sealed class UInt16Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/UInt32Type.cs b/csharp/src/Apache.Arrow/Types/UInt32Type.cs index 8007025f3061..9015f118b3a5 100644 --- a/csharp/src/Apache.Arrow/Types/UInt32Type.cs +++ b/csharp/src/Apache.Arrow/Types/UInt32Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class UInt32Type : NumberType + public sealed class UInt32Type : IntegerType { public static readonly UInt32Type Default = new UInt32Type(); @@ -26,4 +26,4 @@ public sealed class UInt32Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/UInt64Type.cs b/csharp/src/Apache.Arrow/Types/UInt64Type.cs index 20b51ad44f54..a414e701687b 100644 --- a/csharp/src/Apache.Arrow/Types/UInt64Type.cs +++ b/csharp/src/Apache.Arrow/Types/UInt64Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class UInt64Type : NumberType + public sealed class UInt64Type : IntegerType { public static readonly UInt64Type Default = new UInt64Type(); @@ -26,4 +26,4 @@ public sealed class UInt64Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Types/UInt8Type.cs b/csharp/src/Apache.Arrow/Types/UInt8Type.cs index e2e53657200e..31121b4e059f 100644 --- a/csharp/src/Apache.Arrow/Types/UInt8Type.cs +++ b/csharp/src/Apache.Arrow/Types/UInt8Type.cs @@ -15,7 +15,7 @@ namespace Apache.Arrow.Types { - public sealed class UInt8Type : NumberType + public sealed class UInt8Type : IntegerType { public static readonly UInt8Type Default = new UInt8Type(); @@ -26,4 +26,4 @@ public sealed class UInt8Type : NumberType public override void Accept(IArrowTypeVisitor visitor) => Accept(this, visitor); } -} \ No newline at end of file +} diff --git a/csharp/test/Apache.Arrow.Benchmarks/ArrowWriterBenchmark.cs b/csharp/test/Apache.Arrow.Benchmarks/ArrowWriterBenchmark.cs index f35c2a5d78d7..c791c9969356 100644 --- a/csharp/test/Apache.Arrow.Benchmarks/ArrowWriterBenchmark.cs +++ b/csharp/test/Apache.Arrow.Benchmarks/ArrowWriterBenchmark.cs @@ -38,7 +38,7 @@ public class ArrowWriterBenchmark [GlobalSetup] public void GlobalSetup() { - _batch = TestData.CreateSampleRecordBatch(BatchLength, ColumnSetCount); + _batch = TestData.CreateSampleRecordBatch(BatchLength, ColumnSetCount, false); _memoryStream = new MemoryStream(); } diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs index 78f51a7459c0..071560fe6ad8 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs @@ -80,7 +80,8 @@ private class ArrayComparer : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, - IArrowArrayVisitor + IArrowArrayVisitor, + IArrowArrayVisitor { private readonly IArrowArray _expectedArray; private readonly ArrayTypeComparer _arrayTypeComparer; @@ -129,6 +130,16 @@ public void Visit(StructArray array) } } + public void Visit(DictionaryArray array) + { + Assert.IsAssignableFrom(_expectedArray); + DictionaryArray expectedArray = (DictionaryArray)_expectedArray; + var indicesComparer = new ArrayComparer(expectedArray.Indices); + var dictionaryComparer = new ArrayComparer(expectedArray.Dictionary); + array.Indices.Accept(indicesComparer); + array.Dictionary.Accept(dictionaryComparer); + } + public void Visit(FixedSizeBinaryType array) => throw new NotImplementedException(); public void Visit(IArrowArray array) => throw new NotImplementedException(); diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs index a74a27941881..973fc6a0a0e5 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs @@ -50,11 +50,13 @@ public void Ctor_LeaveOpenTrue_StreamValidOnDispose() } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) + [InlineData(true, true, 2)] + [InlineData(true, false, 1)] + [InlineData(false, true, 2)] + [InlineData(false, false, 1)] + public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen, bool createDictionaryArray, int expectedAllocations) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); using (MemoryStream stream = new MemoryStream()) { @@ -68,7 +70,7 @@ public async Task Ctor_MemoryPool_AllocatesFromPool(bool shouldLeaveOpen) ArrowStreamReader reader = new ArrowStreamReader(stream, memoryPool, shouldLeaveOpen); reader.ReadNextRecordBatch(); - Assert.Equal(1, memoryPool.Statistics.Allocations); + Assert.Equal(expectedAllocations, memoryPool.Statistics.Allocations); Assert.True(memoryPool.Statistics.BytesAllocated > 0); reader.Dispose(); @@ -127,30 +129,34 @@ private static async Task TestReaderFromMemory( } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task ReadRecordBatch_Stream(bool writeEnd) + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task ReadRecordBatch_Stream(bool writeEnd, bool createDictionaryArray) { await TestReaderFromStream((reader, originalBatch) => { ArrowReaderVerifier.VerifyReader(reader, originalBatch); return Task.CompletedTask; - }, writeEnd); + }, writeEnd, createDictionaryArray); } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task ReadRecordBatchAsync_Stream(bool writeEnd) + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task ReadRecordBatchAsync_Stream(bool writeEnd, bool createDictionaryArray) { - await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync, writeEnd); + await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync, writeEnd, createDictionaryArray); } private static async Task TestReaderFromStream( Func verificationFunc, - bool writeEnd) + bool writeEnd, bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); using (MemoryStream stream = new MemoryStream()) { @@ -168,29 +174,33 @@ private static async Task TestReaderFromStream( } } - [Fact] - public async Task ReadRecordBatch_PartialReadStream() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadRecordBatch_PartialReadStream(bool createDictionaryArray) { await TestReaderFromPartialReadStream((reader, originalBatch) => { ArrowReaderVerifier.VerifyReader(reader, originalBatch); return Task.CompletedTask; - }); + }, createDictionaryArray); } - [Fact] - public async Task ReadRecordBatchAsync_PartialReadStream() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadRecordBatchAsync_PartialReadStream(bool createDictionaryArray) { - await TestReaderFromPartialReadStream(ArrowReaderVerifier.VerifyReaderAsync); + await TestReaderFromPartialReadStream(ArrowReaderVerifier.VerifyReaderAsync, createDictionaryArray); } /// /// Verifies that the stream reader reads multiple times when a stream /// only returns a subset of the data from each Read. /// - private static async Task TestReaderFromPartialReadStream(Func verificationFunc) + private static async Task TestReaderFromPartialReadStream(Func verificationFunc, bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); using (PartialReadStream stream = new PartialReadStream()) { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs index 44546f11d7ab..837fe68a0daf 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs @@ -56,12 +56,13 @@ public void Ctor_LeaveOpenTrue_StreamValidOnDispose() Assert.Equal(0, stream.Position); } - [Fact] - public void CanWriteToNetworkStream() + [Theory] + [InlineData(true, 32153)] + [InlineData(false, 32154)] + public void CanWriteToNetworkStream(bool createDictionaryArray, int port) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); - const int port = 32153; TcpListener listener = new TcpListener(IPAddress.Loopback, port); listener.Start(); @@ -90,12 +91,13 @@ public void CanWriteToNetworkStream() } } - [Fact] - public async Task CanWriteToNetworkStreamAsync() + [Theory] + [InlineData(true, 32155)] + [InlineData(false, 32156)] + public async Task CanWriteToNetworkStreamAsync(bool createDictionaryArray, int port) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); - const int port = 32154; TcpListener listener = new TcpListener(IPAddress.Loopback, port); listener.Start(); @@ -124,18 +126,22 @@ public async Task CanWriteToNetworkStreamAsync() } } - [Fact] - public void WriteEmptyBatch() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void WriteEmptyBatch(bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0, createDictionaryArray: createDictionaryArray); TestRoundTripRecordBatch(originalBatch); } - [Fact] - public async Task WriteEmptyBatchAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task WriteEmptyBatchAsync(bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0, createDictionaryArray: createDictionaryArray); await TestRoundTripRecordBatchAsync(originalBatch); } @@ -196,13 +202,16 @@ public async Task WriteBatchWithNullsAsync() await TestRoundTripRecordBatchAsync(originalBatch); } - private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null) + private static void TestRoundTripRecordBatches(List originalBatches, IpcOptions options = null) { using (MemoryStream stream = new MemoryStream()) { - using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) + using (var writer = new ArrowStreamWriter(stream, originalBatches[0].Schema, leaveOpen: true, options)) { - writer.WriteRecordBatch(originalBatch); + foreach (RecordBatch originalBatch in originalBatches) + { + writer.WriteRecordBatch(originalBatch); + } writer.WriteEnd(); } @@ -210,20 +219,25 @@ private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptio using (var reader = new ArrowStreamReader(stream)) { - RecordBatch newBatch = reader.ReadNextRecordBatch(); - ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + foreach (RecordBatch originalBatch in originalBatches) + { + RecordBatch newBatch = reader.ReadNextRecordBatch(); + ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + } } } } - - private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null) + private static async Task TestRoundTripRecordBatchesAsync(List originalBatches, IpcOptions options = null) { using (MemoryStream stream = new MemoryStream()) { - using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) + using (var writer = new ArrowStreamWriter(stream, originalBatches[0].Schema, leaveOpen: true, options)) { - await writer.WriteRecordBatchAsync(originalBatch); + foreach (RecordBatch originalBatch in originalBatches) + { + await writer.WriteRecordBatchAsync(originalBatch); + } await writer.WriteEndAsync(); } @@ -231,12 +245,25 @@ private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatc using (var reader = new ArrowStreamReader(stream)) { - RecordBatch newBatch = reader.ReadNextRecordBatch(); - ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + foreach (RecordBatch originalBatch in originalBatches) + { + RecordBatch newBatch = reader.ReadNextRecordBatch(); + ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + } } } } + private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null) + { + TestRoundTripRecordBatches(new List { originalBatch }, options); + } + + private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null) + { + await TestRoundTripRecordBatchesAsync(new List { originalBatch }, options); + } + [Fact] public void WriteBatchWithCorrectPadding() { @@ -369,27 +396,32 @@ public async Task WriteBatchWithCorrectPaddingAsync() } } - [Fact] - public void LegacyIpcFormatRoundTrips() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void LegacyIpcFormatRoundTrips(bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); TestRoundTripRecordBatch(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); } - - [Fact] - public async Task LegacyIpcFormatRoundTripsAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task LegacyIpcFormatRoundTripsAsync(bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); await TestRoundTripRecordBatchAsync(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); } [Theory] - [InlineData(true)] - [InlineData(false)] - public void WriteLegacyIpcFormat(bool writeLegacyIpcFormat) + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public void WriteLegacyIpcFormat(bool writeLegacyIpcFormat, bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat }; using (MemoryStream stream = new MemoryStream()) @@ -425,11 +457,13 @@ public void WriteLegacyIpcFormat(bool writeLegacyIpcFormat) } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task WriteLegacyIpcFormatAsync(bool writeLegacyIpcFormat) + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task WriteLegacyIpcFormatAsync(bool writeLegacyIpcFormat, bool createDictionaryArray) { - RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); + RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: createDictionaryArray); var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat }; using (MemoryStream stream = new MemoryStream()) @@ -494,5 +528,110 @@ public void WritesMetadataCorrectly() TestRoundTripRecordBatch(originalBatch); } + + [Fact] + public async Task WriteMultipleDictionaryArraysAsync() + { + List originalRecordBatches = CreateMultipleDictionaryArraysTestData(); + await TestRoundTripRecordBatchesAsync(originalRecordBatches); + } + + [Fact] + public void WriteMultipleDictionaryArrays() + { + List originalRecordBatches = CreateMultipleDictionaryArraysTestData(); + TestRoundTripRecordBatches(originalRecordBatches); + } + + private List CreateMultipleDictionaryArraysTestData() + { + var dictionaryData = new List { "a", "b", "c" }; + int length = dictionaryData.Count; + + var schemaForSimpleCase = new Schema(new List { + new Field("int8", Int8Type.Default, true), + new Field("uint8", UInt8Type.Default, true), + new Field("int16", Int16Type.Default, true), + new Field("uint16", UInt16Type.Default, true), + new Field("int32", Int32Type.Default, true), + new Field("uint32", UInt32Type.Default, true), + new Field("int64", Int64Type.Default, true), + new Field("uint64", UInt64Type.Default, true) + }, null); + + StringArray dictionary = new StringArray.Builder().AppendRange(dictionaryData).Build(); + IEnumerable indicesArraysForSimpleCase = TestData.CreateArrays(schemaForSimpleCase, length); + + var fields = new List(capacity: length + 1); + var testTargetArrays = new List(capacity: length + 1); + + foreach (IArrowArray indices in indicesArraysForSimpleCase) + { + var dictionaryArray = new DictionaryArray( + new DictionaryType(indices.Data.DataType, StringType.Default, false), + indices, dictionary); + testTargetArrays.Add(dictionaryArray); + fields.Add(new Field($"dictionaryField_{indices.Data.DataType.Name}", dictionaryArray.Data.DataType, false)); + } + + (Field dictionaryTypeListArrayField, ListArray dictionaryTypeListArray) = CreateDictionaryTypeListArrayTestData(dictionary); + + fields.Add(dictionaryTypeListArrayField); + testTargetArrays.Add(dictionaryTypeListArray); + + (Field listTypeDictionaryArrayField, DictionaryArray listTypeDictionaryArray) = CreateListTypeDictionaryArrayTestData(dictionaryData); + + fields.Add(listTypeDictionaryArrayField); + testTargetArrays.Add(listTypeDictionaryArray); + + var schema = new Schema(fields, null); + + return new List { + new RecordBatch(schema, testTargetArrays, length), + new RecordBatch(schema, testTargetArrays, length), + }; + } + + private Tuple CreateDictionaryTypeListArrayTestData(StringArray dictionary) + { + Int32Array indiceArray = new Int32Array.Builder().AppendRange(Enumerable.Range(0, dictionary.Length)).Build(); + + //DictionaryArray has no Builder for now, so creating ListArray directly. + var dictionaryType = new DictionaryType(Int32Type.Default, StringType.Default, false); + var dictionaryArray = new DictionaryArray(dictionaryType, indiceArray, dictionary); + + var valueOffsetsBufferBuilder = new ArrowBuffer.Builder(); + var validityBufferBuilder = new ArrowBuffer.BitmapBuilder(); + + foreach (int i in Enumerable.Range(0, dictionary.Length + 1)) + { + valueOffsetsBufferBuilder.Append(i); + validityBufferBuilder.Append(true); + } + + var dictionaryField = new Field("dictionaryField_list", dictionaryType, false); + var listType = new ListType(dictionaryField); + var listArray = new ListArray(listType, valueOffsetsBufferBuilder.Length - 1, valueOffsetsBufferBuilder.Build(), dictionaryArray, valueOffsetsBufferBuilder.Build()); + + return Tuple.Create(new Field($"listField_{listType.ValueDataType.Name}", listType, false), listArray); + } + + private Tuple CreateListTypeDictionaryArrayTestData(List dictionaryDataBase) + { + var listBuilder = new ListArray.Builder(StringType.Default); + var valueBuilder = listBuilder.ValueBuilder as StringArray.Builder; + + foreach(string data in dictionaryDataBase) { + listBuilder.Append(); + valueBuilder.Append(data); + } + + ListArray dictionary = listBuilder.Build(); + Int32Array indiceArray = new Int32Array.Builder().AppendRange(Enumerable.Range(0, dictionary.Length)).Build(); + var dictionaryArrayType = new DictionaryType(Int32Type.Default, dictionary.Data.DataType, false); + var dictionaryArray = new DictionaryArray(dictionaryArrayType, indiceArray, dictionary); + + return Tuple.Create(new Field($"dictionaryField_{dictionaryArray.Data.DataType.Name}", dictionaryArrayType, false), dictionaryArray); + } } } diff --git a/csharp/test/Apache.Arrow.Tests/DictionaryArrayTests.cs b/csharp/test/Apache.Arrow.Tests/DictionaryArrayTests.cs new file mode 100644 index 000000000000..da678563c3e6 --- /dev/null +++ b/csharp/test/Apache.Arrow.Tests/DictionaryArrayTests.cs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using Apache.Arrow.Types; +using Xunit; + +namespace Apache.Arrow.Tests +{ + public class DictionaryArrayTests + { + [Fact] + public void CreateTest() + { + (StringArray originalDictionary, Int32Array originalIndicesArray, DictionaryArray dictionaryArray) = + CreateSimpleTestData(); + + Assert.Equal(dictionaryArray.Dictionary, originalDictionary); + Assert.Equal(dictionaryArray.Indices, originalIndicesArray); + } + + [Fact] + public void SliceTest() + { + (StringArray originalDictionary, Int32Array originalIndicesArray, DictionaryArray dictionaryArray) = + CreateSimpleTestData(); + + int batchLength = originalIndicesArray.Length; + for (int offset = 0; offset < batchLength; offset++) + { + for (int length = 1; offset + length <= batchLength; length++) + { + var sliced = dictionaryArray.Slice(offset, length) as DictionaryArray; + var actualSlicedDictionary = sliced.Dictionary as StringArray; + var actualSlicedIndicesArray = sliced.Indices as Int32Array; + + var expectedSlicedIndicesArray = originalIndicesArray.Slice(offset, length) as Int32Array; + + //Dictionary is not sliced. + Assert.Equal(originalDictionary.Data, actualSlicedDictionary.Data); + Assert.Equal(expectedSlicedIndicesArray.ToList(), actualSlicedIndicesArray.ToList()); + } + } + } + + private Tuple CreateSimpleTestData() + { + StringArray originalDictionary = new StringArray.Builder().AppendRange(new[] { "a", "b", "c" }).Build(); + Int32Array originalIndicesArray = new Int32Array.Builder().AppendRange(new[] { 0, 0, 1, 1, 2, 2 }).Build(); + var dictionaryArray = new DictionaryArray(new DictionaryType(Int32Type.Default, StringType.Default, false), originalIndicesArray, originalDictionary); + + return Tuple.Create(originalDictionary, originalIndicesArray, dictionaryArray); + } + } +} diff --git a/csharp/test/Apache.Arrow.Tests/TestData.cs b/csharp/test/Apache.Arrow.Tests/TestData.cs index 8db2241ff891..f15e857981ba 100644 --- a/csharp/test/Apache.Arrow.Tests/TestData.cs +++ b/csharp/test/Apache.Arrow.Tests/TestData.cs @@ -21,12 +21,12 @@ namespace Apache.Arrow.Tests { public static class TestData { - public static RecordBatch CreateSampleRecordBatch(int length) + public static RecordBatch CreateSampleRecordBatch(int length, bool createDictionaryArray = false) { - return CreateSampleRecordBatch(length, columnSetCount: 1); + return CreateSampleRecordBatch(length, columnSetCount: 1, createDictionaryArray); } - public static RecordBatch CreateSampleRecordBatch(int length, int columnSetCount) + public static RecordBatch CreateSampleRecordBatch(int length, int columnSetCount, bool createDictionaryArray) { Schema.Builder builder = new Schema.Builder(); for (int i = 0; i < columnSetCount; i++) @@ -50,6 +50,12 @@ public static RecordBatch CreateSampleRecordBatch(int length, int columnSetCount builder.Field(CreateField(new StructType(new List { CreateField(StringType.Default, i), CreateField(Int32Type.Default, i) }), i)); builder.Field(CreateField(new Decimal128Type(10, 6), i)); builder.Field(CreateField(new Decimal256Type(16, 8), i)); + + if (createDictionaryArray) + { + builder.Field(CreateField(new DictionaryType(Int32Type.Default, StringType.Default, false), i)); + } + //builder.Field(CreateField(new FixedSizeBinaryType(16), i)); //builder.Field(CreateField(HalfFloatType.Default)); //builder.Field(CreateField(StringType.Default)); @@ -74,7 +80,7 @@ private static Field CreateField(ArrowType type, int iteration) return new Field(type.Name + iteration, type, nullable: false); } - private static IEnumerable CreateArrays(Schema schema, int length) + public static IEnumerable CreateArrays(Schema schema, int length) { int fieldCount = schema.Fields.Count; List arrays = new List(fieldCount); @@ -114,7 +120,8 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private int Length { get; } public IArrowArray Array { get; private set; } @@ -246,10 +253,24 @@ public void Visit(StructType type) { nullBitmap.Append(true); } - + Array = new StructArray(type, Length, childArrays, nullBitmap.Build()); } + public void Visit(DictionaryType type) + { + Int32Array.Builder indicesBuilder = new Int32Array.Builder().Reserve(Length); + StringArray.Builder valueBuilder = new StringArray.Builder().Reserve(Length); + + for (int i = 0; i < Length; i++) + { + indicesBuilder.Append(i); + valueBuilder.Append($"{i}"); + } + + Array = new DictionaryArray(type, indicesBuilder.Build(), valueBuilder.Build()); + } + private void GenerateArray(IArrowArrayBuilder builder, Func generator) where TArrayBuilder : IArrowArrayBuilder where TArray : IArrowArray