diff --git a/Apache.Arrow.sln b/Apache.Arrow.sln index 0dd6853a..3481f174 100644 --- a/Apache.Arrow.sln +++ b/Apache.Arrow.sln @@ -29,6 +29,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Integra EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.IntegrationTest", "test\Apache.Arrow.IntegrationTest\Apache.Arrow.IntegrationTest.csproj", "{E8264B7F-B680-4A55-939B-85DB628164BB}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Operations", "src\Apache.Arrow.Operations\Apache.Arrow.Operations.csproj", "{BA6B2B0D-EAAE-4183-8A39-1B9CF571F71F}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Operations.Tests", "test\Apache.Arrow.Operations.Tests\Apache.Arrow.Operations.Tests.csproj", "{BA6B2B0D-EAAE-4183-8A39-1B9CF571F71F}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU diff --git a/src/Apache.Arrow.Operations/Apache.Arrow.Operations.csproj b/src/Apache.Arrow.Operations/Apache.Arrow.Operations.csproj new file mode 100644 index 00000000..51796055 --- /dev/null +++ b/src/Apache.Arrow.Operations/Apache.Arrow.Operations.csproj @@ -0,0 +1,12 @@ + + + net8.0 + enable + enable + + + + + + + diff --git a/src/Apache.Arrow.Operations/Bitops.cs b/src/Apache.Arrow.Operations/Bitops.cs new file mode 100644 index 00000000..12cf82ee --- /dev/null +++ b/src/Apache.Arrow.Operations/Bitops.cs @@ -0,0 +1,305 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Runtime.Intrinsics; + +using Apache.Arrow.Memory; + +namespace Apache.Arrow.Operations; + +internal static class BitVectorOps +{ + internal static ArrowBuffer AllOnes(int numBytes, MemoryAllocator? allocator = default) + { + var zeros = AllZeros(numBytes, allocator); + return OnesComplement(zeros, allocator); + } + + internal static ArrowBuffer AllZeros(int numBytes, MemoryAllocator? allocator = default) + { + // Exploit that this uses new byte[...] to allocate the memory which necessarily + // zeros out everything. + var builder = new ArrowBuffer.BitmapBuilder(numBytes * 8); + builder.Set(numBytes * 8 - 1, false); + return builder.Build(allocator); + } + + internal static ArrowBuffer OnesComplement(ArrowBuffer buffer, MemoryAllocator? allocator = default) + { + var builder = new ArrowBuffer.BitmapBuilder(buffer.Length * 8); + var store = builder.Span; + int offset = 0; + int size = buffer.Span.Length; + + if (Vector512.IsHardwareAccelerated) + { + while ((size - offset) >= 64) + { + var part = buffer.Span.Slice(offset, 64); + Vector512 vector = Vector512.Create(part); + vector = Vector512.OnesComplement(vector); + vector.CopyTo(store.Slice(offset, 64)); + offset += 64; + } + } + if (Vector256.IsHardwareAccelerated) + { + while ((size - offset) >= 32) + { + var part = buffer.Span.Slice(offset, 32); + Vector256 vector = Vector256.Create(part); + vector = Vector256.OnesComplement(vector); + vector.CopyTo(store.Slice(offset, 32)); + offset += 32; + } + } + while ((size - offset) >= 16) + { + var part = buffer.Span.Slice(offset, 16); + Vector128 vector = Vector128.Create(part); + vector = Vector128.OnesComplement(vector); + vector.CopyTo(store.Slice(offset, 16)); + offset += 16; + } + while ((size - offset) >= 8) + { + var part = buffer.Span.Slice(offset, 8); + Vector64 vector = Vector64.Create(part); + vector = Vector64.OnesComplement(vector); + vector.CopyTo(store.Slice(offset, 8)); + offset += 8; + } + for (var i = offset; i < size; i++) + { + store[i] = (byte)~buffer.Span[i]; + } + return builder.Build(allocator); + } + + internal static ArrowBuffer And(ArrowBuffer lhs, ArrowBuffer rhs, MemoryAllocator? allocator = default) + { + if (lhs.IsEmpty) + { + if (rhs.IsEmpty) + { + return ArrowBuffer.Empty; + } + else + { + return rhs; + } + } + else if (rhs.IsEmpty) return lhs; + + var builder = new ArrowBuffer.BitmapBuilder(lhs.Length * 8); + var store = builder.Span; + int offset = 0; + int size = lhs.Span.Length; + + if (Vector512.IsHardwareAccelerated) + { + while ((size - offset) >= 64) + { + var part = lhs.Span.Slice(offset, 64); + Vector512 vlhs = Vector512.Create(part); + part = rhs.Span.Slice(offset, 64); + Vector512 vrhs = Vector512.Create(part); + vlhs = vlhs & vrhs; + vlhs.CopyTo(store.Slice(offset, 64)); + offset += 64; + } + } + if (Vector256.IsHardwareAccelerated) + { + while ((size - offset) >= 32) + { + var part = lhs.Span.Slice(offset, 32); + Vector256 vlhs = Vector256.Create(part); + part = rhs.Span.Slice(offset, 32); + Vector256 vrhs = Vector256.Create(part); + vlhs = vlhs & vrhs; + vlhs.CopyTo(store.Slice(offset, 32)); + offset += 32; + } + } + while ((size - offset) >= 16) + { + var part = lhs.Span.Slice(offset, 16); + Vector128 vlhs = Vector128.Create(part); + part = rhs.Span.Slice(offset, 16); + Vector128 vrhs = Vector128.Create(part); + vlhs = vlhs & vrhs; + vlhs.CopyTo(store.Slice(offset, 16)); + offset += 16; + } + while ((size - offset) >= 8) + { + var part = lhs.Span.Slice(offset, 8); + Vector64 vlhs = Vector64.Create(part); + part = rhs.Span.Slice(offset, 8); + Vector64 vrhs = Vector64.Create(part); + vlhs = vlhs & vrhs; + vlhs.CopyTo(store.Slice(offset, 8)); + offset += 8; + } + for (var i = offset; i < size; i++) + { + store[i] = (byte)(lhs.Span[i] & rhs.Span[i]); + } + return builder.Build(allocator); + } + + internal static ArrowBuffer Or(ArrowBuffer lhs, ArrowBuffer rhs, MemoryAllocator? allocator = default) + { + if (lhs.IsEmpty) + { + return lhs; + } + else if (rhs.IsEmpty) return rhs; + + var builder = new ArrowBuffer.BitmapBuilder(lhs.Length * 8); + var store = builder.Span; + int offset = 0; + int size = lhs.Span.Length; + + if (Vector512.IsHardwareAccelerated) + { + while ((size - offset) >= 64) + { + var part = lhs.Span.Slice(offset, 64); + Vector512 vlhs = Vector512.Create(part); + part = rhs.Span.Slice(offset, 64); + Vector512 vrhs = Vector512.Create(part); + vlhs = vlhs | vrhs; + vlhs.CopyTo(store.Slice(offset, 64)); + offset += 64; + } + } + if (Vector256.IsHardwareAccelerated) + { + while ((size - offset) >= 32) + { + var part = lhs.Span.Slice(offset, 32); + Vector256 vlhs = Vector256.Create(part); + part = rhs.Span.Slice(offset, 32); + Vector256 vrhs = Vector256.Create(part); + vlhs = vlhs | vrhs; + vlhs.CopyTo(store.Slice(offset, 32)); + offset += 32; + } + } + while ((size - offset) >= 16) + { + var part = lhs.Span.Slice(offset, 16); + Vector128 vlhs = Vector128.Create(part); + part = rhs.Span.Slice(offset, 16); + Vector128 vrhs = Vector128.Create(part); + vlhs = vlhs | vrhs; + vlhs.CopyTo(store.Slice(offset, 16)); + offset += 16; + } + while ((size - offset) >= 8) + { + var part = lhs.Span.Slice(offset, 8); + Vector64 vlhs = Vector64.Create(part); + part = rhs.Span.Slice(offset, 8); + Vector64 vrhs = Vector64.Create(part); + vlhs = vlhs | vrhs; + vlhs.CopyTo(store.Slice(offset, 8)); + offset += 8; + } + for (var i = offset; i < size; i++) + { + store[i] = (byte)(lhs.Span[i] | rhs.Span[i]); + } + return builder.Build(allocator); + } + + internal static ArrowBuffer Xor(ArrowBuffer lhs, ArrowBuffer rhs, MemoryAllocator? allocator = default) + { + if (lhs.IsEmpty) + { + if (rhs.IsEmpty) + { + return ArrowBuffer.Empty; + } + else + { + return OnesComplement(rhs, allocator); + } + } + else if (rhs.IsEmpty) + { + return OnesComplement(lhs, allocator); + } + var builder = new ArrowBuffer.BitmapBuilder(lhs.Length * 8); + var store = builder.Span; + int offset = 0; + int size = lhs.Span.Length; + + if (Vector512.IsHardwareAccelerated) + { + while ((size - offset) >= 64) + { + var part = lhs.Span.Slice(offset, 64); + Vector512 vlhs = Vector512.Create(part); + part = rhs.Span.Slice(offset, 64); + Vector512 vrhs = Vector512.Create(part); + vlhs = vlhs ^ vrhs; + vlhs.CopyTo(store.Slice(offset, 64)); + offset += 64; + } + } + if (Vector256.IsHardwareAccelerated) + { + while ((size - offset) >= 32) + { + var part = lhs.Span.Slice(offset, 32); + Vector256 vlhs = Vector256.Create(part); + part = rhs.Span.Slice(offset, 32); + Vector256 vrhs = Vector256.Create(part); + vlhs = vlhs ^ vrhs; + vlhs.CopyTo(store.Slice(offset, 32)); + offset += 32; + } + } + while ((size - offset) >= 16) + { + var part = lhs.Span.Slice(offset, 16); + Vector128 vlhs = Vector128.Create(part); + part = rhs.Span.Slice(offset, 16); + Vector128 vrhs = Vector128.Create(part); + vlhs = vlhs ^ vrhs; + vlhs.CopyTo(store.Slice(offset, 16)); + offset += 16; + } + while ((size - offset) >= 8) + { + var part = lhs.Span.Slice(offset, 8); + Vector64 vlhs = Vector64.Create(part); + part = rhs.Span.Slice(offset, 8); + Vector64 vrhs = Vector64.Create(part); + vlhs = vlhs ^ vrhs; + vlhs.CopyTo(store.Slice(offset, 8)); + offset += 8; + } + + for (var i = offset; i < size; i++) + { + store[i] = (byte)(lhs.Span[i] ^ rhs.Span[i]); + } + return builder.Build(allocator); + } +} diff --git a/src/Apache.Arrow.Operations/Comparison.cs b/src/Apache.Arrow.Operations/Comparison.cs new file mode 100644 index 00000000..0de4d80f --- /dev/null +++ b/src/Apache.Arrow.Operations/Comparison.cs @@ -0,0 +1,474 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Numerics; + +using Apache.Arrow.Memory; +using Apache.Arrow.Types; + +namespace Apache.Arrow.Operations; + + +/// +/// Specifies how null values should be handled in comparison operations. +/// +public enum ComparisonNullHandling +{ + /// + /// If both values are null, they are equal. This is the default behavior in C# + /// + Equality, + + /// + /// Propagate null: if any value in the comparison is null, return null, as in SQL. + /// + Propagate, +} + +public static class Comparison +{ + /// + /// Negate a boolean array, flipping true to false, false to true. Nulls remain null + /// + /// + /// + /// + public static BooleanArray Invert(BooleanArray mask, MemoryAllocator? allocator = null) + { + var inverted = BitVectorOps.OnesComplement(mask.ValueBuffer, allocator); + var invertedmask = new BooleanArray(inverted, mask.NullBitmapBuffer.Clone(), mask.Length, mask.NullCount, 0); + return invertedmask; + } + + /// + /// An alias for that is idiomatic. + /// + /// + /// + /// + public static BooleanArray OnesComplement(BooleanArray mask, MemoryAllocator? allocator = null) => Invert(mask, allocator); + + /// + /// Perform a pairwise boolean AND operation. + /// + /// + /// + /// + /// + /// + public static BooleanArray And(BooleanArray lhs, BooleanArray rhs, MemoryAllocator? allocator = null) + { + if (lhs.Length != rhs.Length) throw new ArgumentException("Arrays must have the same length"); + var combined = BitVectorOps.And(lhs.ValueBuffer, rhs.ValueBuffer, allocator); + var combinedMask = BitVectorOps.And(lhs.NullBitmapBuffer, rhs.NullBitmapBuffer, allocator); + var nullCount = BitUtility.CountBits(combinedMask.Span); + return new BooleanArray(combined, combinedMask, lhs.Length, nullCount, 0); + } + + /// + /// Perform a pairwise boolean OR operation. + /// + /// + /// + /// + /// + /// + public static BooleanArray Or(BooleanArray lhs, BooleanArray rhs, MemoryAllocator? allocator = null) + { + if (lhs.Length != rhs.Length) throw new ArgumentException("Arrays must have the same length"); + var combined = BitVectorOps.Or(lhs.ValueBuffer, rhs.ValueBuffer, allocator); + var combinedMask = BitVectorOps.And(lhs.NullBitmapBuffer, rhs.NullBitmapBuffer, allocator); + var nullCount = BitUtility.CountBits(combinedMask.Span); + return new BooleanArray(combined, combinedMask, lhs.Length, nullCount, 0); + } + + /// + /// Perform a pairwise boolean equality operation. + /// + /// + /// + /// + /// + /// + public static BooleanArray Equals(BooleanArray lhs, BooleanArray rhs, MemoryAllocator? allocator = null) + { + if (lhs.Length != rhs.Length) throw new ArgumentException("Arrays must have the same length"); + var combined = BitVectorOps.OnesComplement(BitVectorOps.Xor(lhs.ValueBuffer, rhs.ValueBuffer, allocator)); + var combinedMask = BitVectorOps.And(lhs.NullBitmapBuffer, rhs.NullBitmapBuffer, allocator); + var nullCount = BitUtility.CountBits(combinedMask.Span); + return new BooleanArray(combined, combinedMask, lhs.Length, nullCount, 0); + } + + /// + /// Perform a pairwise boolean XOR operation. + /// + /// + /// + /// + /// + /// + public static BooleanArray Xor(BooleanArray lhs, BooleanArray rhs, MemoryAllocator? allocator = null) + { + if (lhs.Length != rhs.Length) throw new ArgumentException("Arrays must have the same length"); + var combined = BitVectorOps.Xor(lhs.ValueBuffer, rhs.ValueBuffer, allocator); + var combinedMask = BitVectorOps.And(lhs.NullBitmapBuffer, rhs.NullBitmapBuffer, allocator); + var nullCount = BitUtility.CountBits(combinedMask.Span); + return new BooleanArray(combined, combinedMask, lhs.Length, nullCount, 0); + } + + /// + /// Compare each value in `lhs` to a scalar `rhs`, returning boolean mask + /// + /// + /// + /// + /// + /// + /// + public static BooleanArray Equal(PrimitiveArray lhs, T? rhs, MemoryAllocator? allocator = null) where T : struct, INumber + { + if (rhs == null) + { + return new BooleanArray(lhs.NullBitmapBuffer.Clone(), ArrowBuffer.Empty, lhs.Length, 0, 0); + } + var cmp = new BooleanArray.Builder(lhs.Length); + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetValue(i); + var flag = a == rhs; + cmp.Append(flag); + } + return cmp.Build(allocator); + } + + /// + /// Perform a pairwise comparison between each position in `lhs` and `rhs`, returning a boolean mask + /// + /// + /// + /// + /// + /// + /// + /// + public static BooleanArray Equal(PrimitiveArray lhs, PrimitiveArray rhs, MemoryAllocator? allocator = null, ComparisonNullHandling nullHandling = ComparisonNullHandling.Equality) where T : struct, INumber + { + if (lhs.Length != rhs.Length) throw new ArgumentException("Arrays must have the same length"); + var cmp = new BooleanArray.Builder(lhs.Length); + switch (nullHandling) + { + case ComparisonNullHandling.Equality: + { + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetValue(i); + var b = rhs.GetValue(i); + var flag = a == b; + cmp.Append(flag); + } + break; + } + case ComparisonNullHandling.Propagate: + { + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetValue(i); + var b = rhs.GetValue(i); + if (a == null || b == null) + cmp.AppendNull(); + else + cmp.Append(a == b); + } + break; + } + default: + throw new NotImplementedException($"{nullHandling}"); + } + return cmp.Build(allocator); + } + + /// + /// Compare each value in `lhs` to a scalar `rhs`, returning boolean mask + /// + /// + /// + /// + /// + /// + public static BooleanArray Equal(StringArray lhs, string? rhs, MemoryAllocator? allocator = null, ComparisonNullHandling nullHandling = ComparisonNullHandling.Equality) + { + if (rhs == null) + { + if (nullHandling == ComparisonNullHandling.Equality) + return new BooleanArray(lhs.NullBitmapBuffer.Clone(), ArrowBuffer.Empty, lhs.Length, 0, 0); + else if (nullHandling == ComparisonNullHandling.Propagate) + return new BooleanArray(lhs.NullBitmapBuffer.Clone(), lhs.NullBitmapBuffer.Clone(), lhs.Length, lhs.NullCount, 0); + } + var cmp = new BooleanArray.Builder(lhs.Length); + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetString(i); + var flag = a == rhs; + cmp.Append(flag); + } + return cmp.Build(allocator); + } + + /// + /// Perform a pairwise comparison between each position in `lhs` and `rhs`, returning a boolean mask + /// + /// + /// + /// + /// + /// + /// + public static BooleanArray Equal(StringArray lhs, StringArray rhs, MemoryAllocator? allocator = null, ComparisonNullHandling nullHandling = ComparisonNullHandling.Equality) + { + if (lhs.Length != rhs.Length) throw new ArgumentException("Arrays must have the same length"); + var cmp = new BooleanArray.Builder(lhs.Length); + switch (nullHandling) + { + case ComparisonNullHandling.Equality: + { + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetString(i); + var b = rhs.GetString(i); + var flag = a == b; + cmp.Append(flag); + } + break; + } + case ComparisonNullHandling.Propagate: + { + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetString(i); + var b = rhs.GetString(i); + if (a == null || b == null) + cmp.AppendNull(); + else + cmp.Append(a == b); + } + break; + } + default: + throw new NotImplementedException($"{nullHandling}"); + } + return cmp.Build(allocator); + } + + /// + /// Compare each value in `lhs` to a scalar `rhs`, returning boolean mask + /// + /// + /// + /// + /// + /// + public static BooleanArray Equal(LargeStringArray lhs, string? rhs, MemoryAllocator? allocator = null, ComparisonNullHandling nullHandling = ComparisonNullHandling.Equality) + { + if (rhs == null) + { + if (nullHandling == ComparisonNullHandling.Equality) + return new BooleanArray(lhs.NullBitmapBuffer.Clone(), ArrowBuffer.Empty, lhs.Length, 0, 0); + else if (nullHandling == ComparisonNullHandling.Propagate) + return new BooleanArray(lhs.NullBitmapBuffer.Clone(), lhs.NullBitmapBuffer.Clone(), lhs.Length, lhs.NullCount, 0); + } + var cmp = new BooleanArray.Builder(lhs.Length); + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetString(i); + var flag = a == rhs; + cmp.Append(flag); + } + return cmp.Build(allocator); + } + + /// + /// Perform a pairwise comparison between each position in `lhs` and `rhs`, returning a boolean mask + /// + /// + /// + /// + /// + /// + /// + public static BooleanArray Equal(LargeStringArray lhs, LargeStringArray rhs, MemoryAllocator? allocator = null, ComparisonNullHandling nullHandling = ComparisonNullHandling.Equality) + { + if (lhs.Length != rhs.Length) throw new ArgumentException("Arrays must have the same length"); + var cmp = new BooleanArray.Builder(lhs.Length); + switch (nullHandling) + { + case ComparisonNullHandling.Equality: + { + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetString(i); + var b = rhs.GetString(i); + var flag = a == b; + cmp.Append(flag); + } + break; + } + case ComparisonNullHandling.Propagate: + { + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetString(i); + var b = rhs.GetString(i); + if (a == null || b == null) + cmp.AppendNull(); + else + cmp.Append(a == b); + } + break; + } + default: + throw new NotImplementedException($"{nullHandling}"); + } + return cmp.Build(allocator); + } + + /// + /// A dispatching comparison between a string array and a single string. If the `lhs` is not some flavor + /// of string array, an exception is thrown. + /// + /// + /// + /// + /// + /// + public static BooleanArray Equal(IArrowArray lhs, string? rhs, MemoryAllocator? allocator = null) + { + switch (lhs.Data.DataType.TypeId) + { + case ArrowTypeId.String: + return Equal((StringArray)lhs, rhs, allocator); + case ArrowTypeId.LargeString: + return Equal((LargeStringArray)lhs, rhs, allocator); + default: + throw new InvalidDataException("Unsupported data type " + lhs.Data.DataType.Name); + } + } + + public static BooleanArray GreaterThan(PrimitiveArray lhs, T? rhs, MemoryAllocator? allocator = null) where T : struct, INumber + { + var cmp = new BooleanArray.Builder(lhs.Length); + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetValue(i); + var flag = a > rhs; + cmp.Append(flag); + } + return cmp.Build(allocator); + } + + public static BooleanArray GreaterThan(PrimitiveArray lhs, PrimitiveArray rhs, MemoryAllocator? allocator = null) where T : struct, INumber + { + if (lhs.Length != rhs.Length) throw new ArgumentException("Arrays must have the same length"); + var cmp = new BooleanArray.Builder(lhs.Length); + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetValue(i); + var b = rhs.GetValue(i); + var flag = a > b; + cmp.Append(flag); + } + return cmp.Build(allocator); + } + + public static BooleanArray LessThan(PrimitiveArray lhs, T? rhs, MemoryAllocator? allocator = null) where T : struct, INumber + { + var cmp = new BooleanArray.Builder(lhs.Length); + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetValue(i); + var flag = a < rhs; + cmp.Append(flag); + } + return cmp.Build(allocator); + } + + public static BooleanArray LessThan(PrimitiveArray lhs, PrimitiveArray rhs, MemoryAllocator? allocator = null) where T : struct, INumber + { + if (lhs.Length != rhs.Length) throw new ArgumentException("Arrays must have the same length"); + var cmp = new BooleanArray.Builder(lhs.Length); + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetValue(i); + var b = rhs.GetValue(i); + var flag = a < b; + cmp.Append(flag); + } + return cmp.Build(allocator); + } + + public static BooleanArray GreaterThanOrEqual(PrimitiveArray lhs, T? rhs, MemoryAllocator? allocator = null) where T : struct, INumber + { + var cmp = new BooleanArray.Builder(lhs.Length); + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetValue(i); + var flag = a >= rhs; + cmp.Append(flag); + } + return cmp.Build(allocator); + } + + public static BooleanArray GreaterThanOrEqual(PrimitiveArray lhs, PrimitiveArray rhs, MemoryAllocator? allocator = null) where T : struct, INumber + { + if (lhs.Length != rhs.Length) throw new ArgumentException("Arrays must have the same length"); + var cmp = new BooleanArray.Builder(lhs.Length); + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetValue(i); + var b = rhs.GetValue(i); + var flag = a >= b; + cmp.Append(flag); + } + return cmp.Build(allocator); + } + + public static BooleanArray LessThanOrEqual(PrimitiveArray lhs, T? rhs, MemoryAllocator? allocator = null) where T : struct, INumber + { + var cmp = new BooleanArray.Builder(lhs.Length); + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetValue(i); + var flag = a <= rhs; + cmp.Append(flag); + } + return cmp.Build(allocator); + } + + public static BooleanArray LessThanOrEqual(PrimitiveArray lhs, PrimitiveArray rhs, MemoryAllocator? allocator = null) where T : struct, INumber + { + if (lhs.Length != rhs.Length) throw new ArgumentException("Arrays must have the same length"); + var cmp = new BooleanArray.Builder(lhs.Length); + for (int i = 0; i < lhs.Length; i++) + { + var a = lhs.GetValue(i); + var b = rhs.GetValue(i); + var flag = a <= b; + cmp.Append(flag); + } + return cmp.Build(allocator); + } + +} + diff --git a/src/Apache.Arrow.Operations/Conversion.cs b/src/Apache.Arrow.Operations/Conversion.cs new file mode 100644 index 00000000..1a15c91e --- /dev/null +++ b/src/Apache.Arrow.Operations/Conversion.cs @@ -0,0 +1,633 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +using System.Numerics; + +using Apache.Arrow.Memory; +using Apache.Arrow.Types; + +namespace Apache.Arrow.Operations; + + +/// +/// Copy primitive arraays between types to explicitly known numerical types. When the type already +/// matches, no copy is performed. +/// +public static partial class Conversion +{ + static void NullToZero(PrimitiveArray array, IArrowArrayBuilder, TBuilder> accumulator) + where T : struct, INumber where TBuilder : IArrowArrayBuilder> + { + accumulator.Reserve(array.Length); + foreach (var value in array) + { + accumulator.Append(value == null ? T.Zero : (T)value); + } + } + + public static Array NullToZero(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + { + var builder = new DoubleArray.Builder(); + NullToZero((DoubleArray)(IArrowArray)array, builder); + return builder.Build(allocator); + } + case ArrowTypeId.Float: + { + var builder = new FloatArray.Builder(); + NullToZero((FloatArray)(IArrowArray)array, builder); + return builder.Build(allocator); + } + case ArrowTypeId.Int32: + { + var builder = new Int32Array.Builder(); + NullToZero((Int32Array)(IArrowArray)array, builder); + return builder.Build(allocator); + } + case ArrowTypeId.Int64: + { + var builder = new Int64Array.Builder(); + NullToZero((Int64Array)(IArrowArray)array, builder); + return builder.Build(allocator); + } + case ArrowTypeId.UInt32: + { + var builder = new UInt32Array.Builder(); + NullToZero((UInt32Array)(IArrowArray)array, builder); + return builder.Build(allocator); + } + case ArrowTypeId.UInt64: + { + var builder = new UInt64Array.Builder(); + NullToZero((UInt64Array)(IArrowArray)array, builder); + return builder.Build(allocator); + } + case ArrowTypeId.Int16: + { + var builder = new Int16Array.Builder(); + NullToZero((Int16Array)(IArrowArray)array, builder); + return builder.Build(allocator); + } + case ArrowTypeId.Int8: + { + var builder = new Int8Array.Builder(); + NullToZero((Int8Array)(IArrowArray)array, builder); + return builder.Build(allocator); + } + case ArrowTypeId.UInt16: + { + var builder = new UInt16Array.Builder(); + NullToZero((UInt16Array)(IArrowArray)array, builder); + return builder.Build(allocator); + } + case ArrowTypeId.UInt8: + { + var builder = new UInt8Array.Builder(); + NullToZero((UInt8Array)(IArrowArray)array, builder); + return builder.Build(allocator); + } + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + public static DoubleArray CastDouble(IList array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new DoubleArray.Builder(); + builder.Reserve(array.Count); + foreach (var val in array) + builder.Append(double.CreateChecked(val)); + return builder.Build(allocator); + } + + public static FloatArray CastFloat(IList array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new FloatArray.Builder(); + builder.Reserve(array.Count); + foreach (var val in array) + builder.Append(float.CreateChecked(val)); + return builder.Build(allocator); + } + + public static Int32Array CastInt32(IList array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new Int32Array.Builder(); + builder.Reserve(array.Count); + foreach (var val in array) + builder.Append(int.CreateChecked(val)); + return builder.Build(allocator); + } + + public static Int64Array CastInt64(IList array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new Int64Array.Builder(); + builder.Reserve(array.Count); + foreach (var val in array) + builder.Append(long.CreateChecked(val)); + return builder.Build(allocator); + } + + public static UInt16Array CastUInt16(IList array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new UInt16Array.Builder(); + builder.Reserve(array.Count); + foreach (var val in array) + builder.Append(ushort.CreateChecked(val)); + return builder.Build(allocator); + } + + public static Int16Array CastInt16(IList array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new Int16Array.Builder(); + builder.Reserve(array.Count); + foreach (var val in array) + builder.Append(short.CreateChecked(val)); + return builder.Build(allocator); + } + + public static UInt8Array CastUInt8(IList array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new UInt8Array.Builder(); + builder.Reserve(array.Count); + foreach (var val in array) + builder.Append(byte.CreateChecked(val)); + return builder.Build(allocator); + } + + public static Int8Array CastInt8(IList array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new Int8Array.Builder(); + builder.Reserve(array.Count); + foreach (var val in array) + builder.Append(sbyte.CreateChecked(val)); + return builder.Build(allocator); + } + + public static BooleanArray CastBool(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new BooleanArray.Builder(); + builder.Reserve(array.Length); + foreach (var val in array) + { + if (val != null) builder.Append(val.Value != T.Zero); + else builder.AppendNull(); + } + return builder.Build(allocator); + } + + public static Int64Array CastInt64(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new Int64Array.Builder(); + builder.Reserve(array.Length); + foreach (var val in array) + { + if (val != null) builder.Append(long.CreateChecked((T)val)); + else builder.AppendNull(); + } + return builder.Build(allocator); + } + + public static Int32Array CastInt32(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new Int32Array.Builder(); + builder.Reserve(array.Length); + foreach (var val in array) + { + if (val != null) builder.Append(int.CreateChecked((T)val)); + else builder.AppendNull(); + } + return builder.Build(allocator); + } + + public static Int16Array CastInt16(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new Int16Array.Builder(); + builder.Reserve(array.Length); + foreach (var val in array) + { + if (val != null) builder.Append(short.CreateChecked((T)val)); + else builder.AppendNull(); + } + return builder.Build(allocator); + } + + public static Int8Array CastInt8(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new Int8Array.Builder(); + builder.Reserve(array.Length); + foreach (var val in array) + { + if (val != null) builder.Append(sbyte.CreateChecked((T)val)); + else builder.AppendNull(); + } + return builder.Build(allocator); + } + + public static UInt64Array CastUInt64(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new UInt64Array.Builder(); + builder.Reserve(array.Length); + foreach (var val in array) + { + if (val != null) builder.Append(ulong.CreateChecked((T)val)); + else builder.AppendNull(); + } + return builder.Build(allocator); + } + + public static UInt32Array CastUInt32(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new UInt32Array.Builder(); + builder.Reserve(array.Length); + foreach (var val in array) + { + if (val != null) builder.Append(uint.CreateChecked((T)val)); + else builder.AppendNull(); + } + return builder.Build(allocator); + } + + public static UInt16Array CastUInt16(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new UInt16Array.Builder(); + builder.Reserve(array.Length); + foreach (var val in array) + { + if (val != null) builder.Append(ushort.CreateChecked((T)val)); + else builder.AppendNull(); + } + return builder.Build(allocator); + } + + public static UInt8Array CastUInt8(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new UInt8Array.Builder(); + builder.Reserve(array.Length); + foreach (var val in array) + { + if (val != null) builder.Append(byte.CreateChecked((T)val)); + else builder.AppendNull(); + } + return builder.Build(allocator); + } + + public static FloatArray CastFloat(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new FloatArray.Builder(); + builder.Reserve(array.Length); + foreach (var val in array) + { + if (val != null) builder.Append(float.CreateChecked((T)val)); + else builder.AppendNull(); + } + return builder.Build(allocator); + } + + public static DoubleArray CastDouble(PrimitiveArray array, MemoryAllocator? allocator = null) where T : struct, INumber + { + var builder = new DoubleArray.Builder(); + builder.Reserve(array.Length); + foreach (var val in array) + { + if (val != null) builder.Append(double.CreateChecked((T)val)); + else builder.AppendNull(); + } + return builder.Build(allocator); + } + + public static Int64Array CastInt64(IArrowArray array, MemoryAllocator? allocator = null) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return CastInt64((DoubleArray)array, allocator); + case ArrowTypeId.Float: + return CastInt64((FloatArray)array, allocator); + case ArrowTypeId.Int32: + return CastInt64((Int32Array)array, allocator); + case ArrowTypeId.Int64: + return (Int64Array)array; + case ArrowTypeId.UInt32: + return CastInt64((UInt32Array)array, allocator); + case ArrowTypeId.UInt64: + return CastInt64((UInt64Array)array, allocator); + case ArrowTypeId.Int16: + return CastInt64((Int16Array)array, allocator); + case ArrowTypeId.Int8: + return CastInt64((Int8Array)array, allocator); + case ArrowTypeId.UInt16: + return CastInt64((UInt16Array)array, allocator); + case ArrowTypeId.UInt8: + return CastInt64((UInt8Array)array, allocator); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + public static Int32Array CastInt32(IArrowArray array, MemoryAllocator? allocator = null) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return CastInt32((DoubleArray)array, allocator); + case ArrowTypeId.Float: + return CastInt32((FloatArray)array, allocator); + case ArrowTypeId.Int32: + return (Int32Array)array; + case ArrowTypeId.Int64: + return CastInt32((Int64Array)array, allocator); + case ArrowTypeId.UInt32: + return CastInt32((UInt32Array)array, allocator); + case ArrowTypeId.UInt64: + return CastInt32((UInt64Array)array, allocator); + case ArrowTypeId.Int16: + return CastInt32((Int16Array)array, allocator); + case ArrowTypeId.Int8: + return CastInt32((Int8Array)array, allocator); + case ArrowTypeId.UInt16: + return CastInt32((UInt16Array)array, allocator); + case ArrowTypeId.UInt8: + return CastInt32((UInt8Array)array, allocator); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + public static Int16Array CastInt16(IArrowArray array, MemoryAllocator? allocator = null) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return CastInt16((DoubleArray)array, allocator); + case ArrowTypeId.Float: + return CastInt16((FloatArray)array, allocator); + case ArrowTypeId.Int32: + return CastInt16((Int32Array)array, allocator); + case ArrowTypeId.Int64: + return CastInt16((Int64Array)array, allocator); + case ArrowTypeId.UInt32: + return CastInt16((UInt32Array)array, allocator); + case ArrowTypeId.UInt64: + return CastInt16((UInt64Array)array, allocator); + case ArrowTypeId.Int16: + return CastInt16((Int16Array)array, allocator); + case ArrowTypeId.Int8: + return CastInt16((Int8Array)array, allocator); + case ArrowTypeId.UInt16: + return CastInt16((UInt16Array)array, allocator); + case ArrowTypeId.UInt8: + return CastInt16((UInt8Array)array, allocator); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + public static Int8Array CastInt8(IArrowArray array, MemoryAllocator? allocator = null) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return CastInt8((DoubleArray)array, allocator); + case ArrowTypeId.Float: + return CastInt8((FloatArray)array, allocator); + case ArrowTypeId.Int32: + return CastInt8((Int32Array)array, allocator); + case ArrowTypeId.Int64: + return CastInt8((Int64Array)array, allocator); + case ArrowTypeId.UInt32: + return CastInt8((UInt32Array)array, allocator); + case ArrowTypeId.UInt64: + return CastInt8((UInt64Array)array, allocator); + case ArrowTypeId.Int16: + return CastInt8((Int16Array)array, allocator); + case ArrowTypeId.Int8: + return CastInt8((Int8Array)array, allocator); + case ArrowTypeId.UInt16: + return CastInt8((UInt16Array)array, allocator); + case ArrowTypeId.UInt8: + return CastInt8((UInt8Array)array, allocator); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + public static FloatArray CastFloat(IArrowArray array, MemoryAllocator? allocator = null) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return CastFloat((DoubleArray)array, allocator); + case ArrowTypeId.Float: + return (FloatArray)array; + case ArrowTypeId.Int32: + return CastFloat((Int32Array)array, allocator); + case ArrowTypeId.Int64: + return CastFloat((Int64Array)array, allocator); + case ArrowTypeId.UInt32: + return CastFloat((UInt32Array)array, allocator); + case ArrowTypeId.UInt64: + return CastFloat((UInt64Array)array, allocator); + case ArrowTypeId.Int16: + return CastFloat((Int16Array)array, allocator); + case ArrowTypeId.Int8: + return CastFloat((Int8Array)array, allocator); + case ArrowTypeId.UInt16: + return CastFloat((UInt16Array)array, allocator); + case ArrowTypeId.UInt8: + return CastFloat((UInt8Array)array, allocator); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + public static DoubleArray CastDouble(IArrowArray array, MemoryAllocator? allocator = null) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return (DoubleArray)array; + case ArrowTypeId.Float: + return CastDouble((FloatArray)array, allocator); + case ArrowTypeId.Int32: + return CastDouble((Int32Array)array, allocator); + case ArrowTypeId.Int64: + return CastDouble((Int64Array)array, allocator); + case ArrowTypeId.UInt32: + return CastDouble((UInt32Array)array, allocator); + case ArrowTypeId.UInt64: + return CastDouble((UInt64Array)array, allocator); + case ArrowTypeId.Int16: + return CastDouble((Int16Array)array, allocator); + case ArrowTypeId.Int8: + return CastDouble((Int8Array)array, allocator); + case ArrowTypeId.UInt16: + return CastDouble((UInt16Array)array, allocator); + case ArrowTypeId.UInt8: + return CastDouble((UInt8Array)array, allocator); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + public static UInt64Array CastUInt64(IArrowArray array, MemoryAllocator? allocator = null) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return CastUInt64((DoubleArray)array, allocator); + case ArrowTypeId.Float: + return CastUInt64((FloatArray)array, allocator); + case ArrowTypeId.Int32: + return CastUInt64((Int32Array)array, allocator); + case ArrowTypeId.Int64: + return CastUInt64((Int64Array)array, allocator); + case ArrowTypeId.UInt32: + return CastUInt64((UInt32Array)array, allocator); + case ArrowTypeId.UInt64: + return CastUInt64((UInt64Array)array, allocator); + case ArrowTypeId.Int16: + return CastUInt64((Int16Array)array, allocator); + case ArrowTypeId.Int8: + return CastUInt64((Int8Array)array, allocator); + case ArrowTypeId.UInt16: + return CastUInt64((UInt16Array)array, allocator); + case ArrowTypeId.UInt8: + return CastUInt64((UInt8Array)array, allocator); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + public static UInt32Array CastUInt32(IArrowArray array, MemoryAllocator? allocator = null) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return CastUInt32((DoubleArray)array, allocator); + case ArrowTypeId.Float: + return CastUInt32((FloatArray)array, allocator); + case ArrowTypeId.Int32: + return CastUInt32((Int32Array)array, allocator); + case ArrowTypeId.Int64: + return CastUInt32((Int64Array)array, allocator); + case ArrowTypeId.UInt32: + return CastUInt32((UInt32Array)array, allocator); + case ArrowTypeId.UInt64: + return CastUInt32((UInt64Array)array, allocator); + case ArrowTypeId.Int16: + return CastUInt32((Int16Array)array, allocator); + case ArrowTypeId.Int8: + return CastUInt32((Int8Array)array, allocator); + case ArrowTypeId.UInt16: + return CastUInt32((UInt16Array)array, allocator); + case ArrowTypeId.UInt8: + return CastUInt32((UInt8Array)array, allocator); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + public static UInt16Array CastUInt16(IArrowArray array, MemoryAllocator? allocator = null) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return CastUInt16((DoubleArray)array, allocator); + case ArrowTypeId.Float: + return CastUInt16((FloatArray)array, allocator); + case ArrowTypeId.Int32: + return CastUInt16((Int32Array)array, allocator); + case ArrowTypeId.Int64: + return CastUInt16((Int64Array)array, allocator); + case ArrowTypeId.UInt32: + return CastUInt16((UInt32Array)array, allocator); + case ArrowTypeId.UInt64: + return CastUInt16((UInt64Array)array, allocator); + case ArrowTypeId.Int16: + return CastUInt16((Int16Array)array, allocator); + case ArrowTypeId.Int8: + return CastUInt16((Int8Array)array, allocator); + case ArrowTypeId.UInt16: + return CastUInt16((UInt16Array)array, allocator); + case ArrowTypeId.UInt8: + return CastUInt16((UInt8Array)array, allocator); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + public static UInt8Array CastUInt8(IArrowArray array, MemoryAllocator? allocator = null) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return CastUInt8((DoubleArray)array, allocator); + case ArrowTypeId.Float: + return CastUInt8((FloatArray)array, allocator); + case ArrowTypeId.Int32: + return CastUInt8((Int32Array)array, allocator); + case ArrowTypeId.Int64: + return CastUInt8((Int64Array)array, allocator); + case ArrowTypeId.UInt32: + return CastUInt8((UInt32Array)array, allocator); + case ArrowTypeId.UInt64: + return CastUInt8((UInt64Array)array, allocator); + case ArrowTypeId.Int16: + return CastUInt8((Int16Array)array, allocator); + case ArrowTypeId.Int8: + return CastUInt8((Int8Array)array, allocator); + case ArrowTypeId.UInt16: + return CastUInt8((UInt16Array)array, allocator); + case ArrowTypeId.UInt8: + return CastUInt8((UInt8Array)array, allocator); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + public static BooleanArray CastBool(IArrowArray array, MemoryAllocator? allocator = null) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return CastBool((DoubleArray)array, allocator); + case ArrowTypeId.Float: + return CastBool((FloatArray)array, allocator); + case ArrowTypeId.Int32: + return CastBool((Int32Array)array, allocator); + case ArrowTypeId.Int64: + return CastBool((Int64Array)array, allocator); + case ArrowTypeId.UInt32: + return CastBool((UInt32Array)array, allocator); + case ArrowTypeId.UInt64: + return CastBool((UInt64Array)array, allocator); + case ArrowTypeId.Int16: + return CastBool((Int16Array)array, allocator); + case ArrowTypeId.Int8: + return CastBool((Int8Array)array, allocator); + case ArrowTypeId.UInt16: + return CastBool((UInt16Array)array, allocator); + case ArrowTypeId.UInt8: + return CastBool((UInt8Array)array, allocator); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } +} + diff --git a/src/Apache.Arrow.Operations/Select.cs b/src/Apache.Arrow.Operations/Select.cs new file mode 100644 index 00000000..e707da41 --- /dev/null +++ b/src/Apache.Arrow.Operations/Select.cs @@ -0,0 +1,722 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +using System.Numerics; + +using Apache.Arrow; +using Apache.Arrow.Memory; +using Apache.Arrow.Types; + +namespace Apache.Arrow.Operations; + + +public static class Select +{ + /// + /// Returns a copy of the positions in the array where the mask is true. All other values in the array will be + /// excluded. + /// + /// This internally reduces to building a true-value run index map and calling `Take` + /// + /// The array to select from + /// The mask defining which values to keep or exclude + /// The memory allocator to build the new array from + /// + /// If the mask and the array are not of equal size + public static Array Filter(Array array, BooleanArray mask, MemoryAllocator? allocator = null) + { + if (array.Length != mask.Length) throw new ArgumentException("Array and mask must have the same length"); + List<(int, int)> spans = new(); + int? start = null; + for (int i = 0; i < mask.Length; i++) + { + var v = mask.GetValue(i); + if (v != null && (bool)v) + { + if (start != null) { } + else start = i; + } + else if (v != null && !(bool)v) + { + if (start != null) + { + // Slices in Take include the trailing index + spans.Add(((int)start, i - 1)); + start = null; + } + else { } + } + } + if (start != null) + { + spans.Add(((int)start, mask.Length - 1)); + } + return Take(array, spans, allocator); + } + + /// + /// Returns a copy of the positions in the array included in the provided start-end spans. All other values in the array will be + /// excluded. + /// + /// The array to select from + /// The index ranges to select + /// The memory allocator to build the new array from + /// + /// + public static Array Take(Array array, IList<(int, int)> spans, MemoryAllocator? allocator = null) + { + if (spans.Count == 0) + { + return array.Slice(0, 0); + } + List chunks = new(); + foreach (var (start, end) in spans) + { + if (end < start || end < 0 || start < 0) throw new InvalidOperationException(string.Format("Invalid span: {0} {1}", start, end)); + chunks.Add(array.Slice(start, end - start + 1)); + } + return (Array)ArrowArrayConcatenator.Concatenate(chunks, allocator); + } + + /// + /// Returns a copy of the positions in the array included in the provided indices list. All other values in the array will be + /// excluded. + /// + /// The array to select from + /// The indices to select + /// The memory allocator to build the new array from + /// + public static Array Take(Array array, IList indices, MemoryAllocator? allocator = null) + { + if (indices.Count == 0) + { + return array.Slice(0, 0); + } + List chunks = new(); + for (var i = 0; i < indices.Count; i++) + { + chunks.Add(array.Slice(indices[i], 1)); + } + return (Array)ArrowArrayConcatenator.Concatenate(chunks, allocator); + } + + /// + /// Apply `Take` to each array in `batch` using the same `indices` + /// + /// + /// + /// + /// + public static List Take(List batch, IList indices, MemoryAllocator? allocator = null) + { + return batch.Select(arr => Take(arr, indices, allocator)).ToList(); + } + + /// + /// Apply `Filter` to each array in `batch` using the same `mask` + /// + /// + /// + /// + /// + public static List Filter(List batch, BooleanArray mask, MemoryAllocator? allocator = null) + { + return batch.Select(arr => Filter(arr, mask, allocator)).ToList(); + } + + /// + /// Apply `Take` to each array in `batch` using the same `indices` + /// + /// + /// + /// + /// + public static Dictionary Take(Dictionary batch, IList indices, MemoryAllocator? allocator = null) where T : notnull + { + Dictionary result = new(); + foreach (var kv in batch) + { + result[kv.Key] = Take(kv.Value, indices, allocator); + } + return result; + } + + /// + /// Apply `Filter` to each array in `batch` using the same `mask` + /// + /// + /// + /// + /// + public static Dictionary Filter(Dictionary batch, BooleanArray mask, MemoryAllocator? allocator = null) where T : notnull + { + Dictionary result = new(); + foreach (var kv in batch) + { + result[kv.Key] = Filter(kv.Value, mask, allocator); + } + return result; + } + + /// + /// Apply `Filter` to each array in `batch` using the same `mask` + /// + /// + /// + /// + /// + public static RecordBatch Filter(RecordBatch batch, BooleanArray mask, MemoryAllocator? allocator = null) + { + if (batch.Length != mask.Length) throw new ArgumentException("Array and mask must have the same length"); + List<(int, int)> spans = new(); + int? start = null; + for (int i = 0; i < mask.Length; i++) + { + var v = mask.GetValue(i); + if (v != null && (bool)v) + { + if (start != null) { } + else start = i; + } + else if (v != null && !(bool)v) + { + if (start != null) + { + // Slices in Take include the trailing index + spans.Add(((int)start, i - 1)); + start = null; + } + else { } + } + } + if (start != null) + { + spans.Add(((int)start, mask.Length - 1)); + } + return Take(batch, spans, allocator); + } + + /// + /// Apply `Take` to each array in `batch` using the same `indices` + /// + /// + /// + /// + /// + public static RecordBatch Take(RecordBatch batch, IList<(int, int)> spans, MemoryAllocator? allocator = null) + { + if (spans.Count == 0) + { + return batch.Slice(0, 0); + } + List columns = new(); + var size = 0; + foreach (var col in batch.Arrays) + { + columns.Add(Take((Array)col, spans, allocator)); + size = columns.Last().Length; + } + return new RecordBatch(batch.Schema, columns, size); + } + + /// + /// Apply `Take` to each array in `batch` using the same `indices` + /// + /// + /// + /// + /// + public static RecordBatch Take(RecordBatch batch, IList indices, MemoryAllocator? allocator = null) + { + var spans = IndicesToSpans(indices); + return Take(batch, spans, allocator); + } + + /// + /// Convert a list of indices into a list of index start-end spans for ease-of selection + /// + /// + /// + /// + public static List<(T, T)> IndicesToSpans(IList indices) where T : struct, INumber + { + List<(T, T)> acc = new(); + T? start = null; + T? last = null; + foreach (var i in indices) + { + if (last == null) + { + start = i; + last = i; + } + else + { + if (i - last == T.One) + { + last = i; + } + else if (start != null) + { + acc.Add(((T)start, (T)last)); + start = i; + last = i; + } + } + } + if (start != null && last != null) + { + acc.Add(((T)start, indices.Last())); + } + return acc; + } +} + + + +/// +/// Specifies how null values should be handled in aggregate computations. +/// +public enum AggregateNullHandling +{ + /// + /// Skip null values when computing the result. + /// Returns null only if the array is empty or all values are null. + /// + Skip, + + /// + /// Propagate null: if any value in the array is null, return null. + /// + Propagate +} + + +public static class Aggregate +{ + + /// + /// Returns the minimum value in the array. + /// + /// The numeric type of array elements. + /// The input array. + /// How to handle null values. + /// The minimum value, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static T? Min(PrimitiveArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + where T : struct, INumber + { + if (array.Length == 0) + return null; + + T? min = null; + for (int i = 0; i < array.Length; i++) + { + var value = array.GetValue(i); + if (value == null) + { + if (nullHandling == AggregateNullHandling.Propagate) + return null; + continue; + } + + if (min == null || (T)value < min) + min = value; + } + return min; + } + + /// + /// Returns the minimum value in the array. + /// + /// The input array. + /// How to handle null values. + /// The minimum value, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static double? Min(IArrowArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return Min((DoubleArray)array, nullHandling); + case ArrowTypeId.Float: + return Min((FloatArray)array, nullHandling); + case ArrowTypeId.Int32: + return Min((Int32Array)array, nullHandling); + case ArrowTypeId.Int64: + return Min((Int64Array)array, nullHandling); + case ArrowTypeId.UInt32: + return Min((UInt32Array)array, nullHandling); + case ArrowTypeId.UInt64: + return Min((UInt64Array)array, nullHandling); + case ArrowTypeId.Int16: + return Min((Int16Array)array, nullHandling); + case ArrowTypeId.Int8: + return Min((Int8Array)array, nullHandling); + case ArrowTypeId.UInt16: + return Min((UInt16Array)array, nullHandling); + case ArrowTypeId.UInt8: + return Min((UInt8Array)array, nullHandling); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + /// + /// Returns the maximum value in the array. + /// + /// The numeric type of array elements. + /// The input array. + /// How to handle null values. + /// The maximum value, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static T? Max(PrimitiveArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + where T : struct, INumber + { + if (array.Length == 0) + return null; + + T? max = null; + for (int i = 0; i < array.Length; i++) + { + var value = array.GetValue(i); + if (value == null) + { + if (nullHandling == AggregateNullHandling.Propagate) + return null; + continue; + } + + if (max == null || (T)value > max) + max = value; + } + return max; + } + + /// + /// Returns the maximum value in the array. + /// + /// The input array. + /// How to handle null values. + /// The maximum value, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static double? Max(IArrowArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return Max((DoubleArray)array, nullHandling); + case ArrowTypeId.Float: + return Max((FloatArray)array, nullHandling); + case ArrowTypeId.Int32: + return Max((Int32Array)array, nullHandling); + case ArrowTypeId.Int64: + return Max((Int64Array)array, nullHandling); + case ArrowTypeId.UInt32: + return Max((UInt32Array)array, nullHandling); + case ArrowTypeId.UInt64: + return Max((UInt64Array)array, nullHandling); + case ArrowTypeId.Int16: + return Max((Int16Array)array, nullHandling); + case ArrowTypeId.Int8: + return Max((Int8Array)array, nullHandling); + case ArrowTypeId.UInt16: + return Max((UInt16Array)array, nullHandling); + case ArrowTypeId.UInt8: + return Max((UInt8Array)array, nullHandling); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + /// + /// Returns the index of the minimum value in the array (first occurrence). + /// + /// The numeric type of array elements. + /// The input array. + /// How to handle null values. + /// The index of the minimum value, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static int? ArgMin(PrimitiveArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + where T : struct, INumber + { + if (array.Length == 0) + return null; + + T? min = null; + int? minIndex = null; + for (int i = 0; i < array.Length; i++) + { + var value = array.GetValue(i); + if (value == null) + { + if (nullHandling == AggregateNullHandling.Propagate) + return null; + continue; + } + + if (min == null || (T)value < min) + { + min = value; + minIndex = i; + } + } + return minIndex; + } + + /// + /// Returns the index of the minimum value in the array (first occurrence). + /// + /// The input array. + /// How to handle null values. + /// The index of the minimum value, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static int? ArgMin(IArrowArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return ArgMin((DoubleArray)array, nullHandling); + case ArrowTypeId.Float: + return ArgMin((FloatArray)array, nullHandling); + case ArrowTypeId.Int32: + return ArgMin((Int32Array)array, nullHandling); + case ArrowTypeId.Int64: + return ArgMin((Int64Array)array, nullHandling); + case ArrowTypeId.UInt32: + return ArgMin((UInt32Array)array, nullHandling); + case ArrowTypeId.UInt64: + return ArgMin((UInt64Array)array, nullHandling); + case ArrowTypeId.Int16: + return ArgMin((Int16Array)array, nullHandling); + case ArrowTypeId.Int8: + return ArgMin((Int8Array)array, nullHandling); + case ArrowTypeId.UInt16: + return ArgMin((UInt16Array)array, nullHandling); + case ArrowTypeId.UInt8: + return ArgMin((UInt8Array)array, nullHandling); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + /// + /// Returns the index of the maximum value in the array (first occurrence). + /// + /// The numeric type of array elements. + /// The input array. + /// How to handle null values. + /// The index of the maximum value, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static int? ArgMax(PrimitiveArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + where T : struct, INumber + { + if (array.Length == 0) + return null; + + T? max = null; + int? maxIndex = null; + for (int i = 0; i < array.Length; i++) + { + var value = array.GetValue(i); + if (value == null) + { + if (nullHandling == AggregateNullHandling.Propagate) + return null; + continue; + } + + if (max == null || (T)value > max) + { + max = value; + maxIndex = i; + } + } + return maxIndex; + } + + /// + /// Returns the index of the maximum value in the array (first occurrence). + /// + /// The input array. + /// How to handle null values. + /// The index of the maximum value, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static int? ArgMax(IArrowArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return ArgMax((DoubleArray)array, nullHandling); + case ArrowTypeId.Float: + return ArgMax((FloatArray)array, nullHandling); + case ArrowTypeId.Int32: + return ArgMax((Int32Array)array, nullHandling); + case ArrowTypeId.Int64: + return ArgMax((Int64Array)array, nullHandling); + case ArrowTypeId.UInt32: + return ArgMax((UInt32Array)array, nullHandling); + case ArrowTypeId.UInt64: + return ArgMax((UInt64Array)array, nullHandling); + case ArrowTypeId.Int16: + return ArgMax((Int16Array)array, nullHandling); + case ArrowTypeId.Int8: + return ArgMax((Int8Array)array, nullHandling); + case ArrowTypeId.UInt16: + return ArgMax((UInt16Array)array, nullHandling); + case ArrowTypeId.UInt8: + return ArgMax((UInt8Array)array, nullHandling); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + /// + /// Returns the sum of all values in the array. + /// + /// The numeric type of array elements. + /// The input array. + /// How to handle null values. + /// The sum of values, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static T? Sum(PrimitiveArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + where T : struct, INumber + { + if (array.Length == 0) + return null; + + T sum = T.Zero; + bool hasValue = false; + for (int i = 0; i < array.Length; i++) + { + var value = array.GetValue(i); + if (value == null) + { + if (nullHandling == AggregateNullHandling.Propagate) + return null; + continue; + } + + sum += (T)value; + hasValue = true; + } + return hasValue ? sum : null; + } + + /// + /// Returns the sum of all values in the array. + /// + /// The input array. + /// How to handle null values. + /// The sum of values, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static double? Sum(IArrowArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return Sum((DoubleArray)array, nullHandling); + case ArrowTypeId.Float: + return Sum((FloatArray)array, nullHandling); + case ArrowTypeId.Int32: + return Sum((Int32Array)array, nullHandling); + case ArrowTypeId.Int64: + return Sum((Int64Array)array, nullHandling); + case ArrowTypeId.UInt32: + return Sum((UInt32Array)array, nullHandling); + case ArrowTypeId.UInt64: + return Sum((UInt64Array)array, nullHandling); + case ArrowTypeId.Int16: + return Sum((Int16Array)array, nullHandling); + case ArrowTypeId.Int8: + return Sum((Int8Array)array, nullHandling); + case ArrowTypeId.UInt16: + return Sum((UInt16Array)array, nullHandling); + case ArrowTypeId.UInt8: + return Sum((UInt8Array)array, nullHandling); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } + + /// + /// Returns the arithmetic mean of all values in the array. + /// + /// The numeric type of array elements. + /// The input array. + /// How to handle null values. + /// The mean as a double, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static double? Mean(PrimitiveArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + where T : struct, INumber + { + if (array.Length == 0) + return null; + + T sum = T.Zero; + long count = 0; + for (int i = 0; i < array.Length; i++) + { + var value = array.GetValue(i); + if (value == null) + { + if (nullHandling == AggregateNullHandling.Propagate) + return null; + continue; + } + + sum += (T)value; + count++; + } + return count > 0 ? double.CreateChecked(sum) / count : null; + } + + /// + /// Returns the arithmetic mean of all values in the array. + /// + /// The input array. + /// How to handle null values. + /// The mean as a double, or null if the array is empty, all values are null, + /// or nullHandling is Propagate and any null exists. + public static double? Mean(IArrowArray array, AggregateNullHandling nullHandling = AggregateNullHandling.Skip) + { + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Double: + return Mean((DoubleArray)array, nullHandling); + case ArrowTypeId.Float: + return Mean((FloatArray)array, nullHandling); + case ArrowTypeId.Int32: + return Mean((Int32Array)array, nullHandling); + case ArrowTypeId.Int64: + return Mean((Int64Array)array, nullHandling); + case ArrowTypeId.UInt32: + return Mean((UInt32Array)array, nullHandling); + case ArrowTypeId.UInt64: + return Mean((UInt64Array)array, nullHandling); + case ArrowTypeId.Int16: + return Mean((Int16Array)array, nullHandling); + case ArrowTypeId.Int8: + return Mean((Int8Array)array, nullHandling); + case ArrowTypeId.UInt16: + return Mean((UInt16Array)array, nullHandling); + case ArrowTypeId.UInt8: + return Mean((UInt8Array)array, nullHandling); + default: + throw new InvalidDataException("Unsupported data type " + array.Data.DataType.Name); + } + } +} diff --git a/src/Apache.Arrow.Operations/Text.cs b/src/Apache.Arrow.Operations/Text.cs new file mode 100644 index 00000000..5c7075ce --- /dev/null +++ b/src/Apache.Arrow.Operations/Text.cs @@ -0,0 +1,236 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +using Apache.Arrow.Types; + +namespace Apache.Arrow.Operations; + +/// +/// Pretty printing utilities +/// +public static class Format +{ + /// + /// Recursively pretty print format and write `array` into `stream`, indenting as nesting increases. + /// + /// + /// + /// + /// + /// + public static void PrettyPrintFormat(IArrowArray array, StreamWriter stream, int indent = 0, string indenter = " ") + { + + List indenting = Enumerable.Repeat(indenter, indent).ToList(); + string indentString = string.Concat(indenting); + + stream.WriteLine($"{indentString}[ {array.Length} elements"); + var pad = indentString + indenter; + switch (array.Data.DataType.TypeId) + { + case ArrowTypeId.Float: + { + var valArray = (FloatArray)array; + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.Double: + { + var valArray = (DoubleArray)array; + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.Int32: + { + var valArray = (Int32Array)array; + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.Int64: + { + var valArray = (Int64Array)array; + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.Int16: + { + var valArray = (Int16Array)array; + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.Int8: + { + var valArray = (Int8Array)array; + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.UInt8: + { + var valArray = (UInt8Array)array; + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.UInt16: + { + var valArray = (UInt16Array)array; + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.UInt32: + { + var valArray = (UInt32Array)array; + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.UInt64: + { + var valArray = (UInt64Array)array; + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.Boolean: + { + var valArray = (BooleanArray)array; + + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.HalfFloat: + { + var valArray = (HalfFloatArray)array; + foreach (var v in valArray) + { + stream.WriteLine($"{pad}{(v == null ? "null" : v)}"); + } + break; + } + case ArrowTypeId.List: + { + var valArray = (ListArray)array; + for (var i = 0; i < valArray.Length; i++) + { + if (valArray.IsNull(i)) + { + stream.WriteLine($"{pad}null"); + } + else + { + var slc = valArray.GetSlicedValues(i); + PrettyPrintFormat(slc, stream, indent + 1, indenter); + } + } + break; + } + case ArrowTypeId.String: + { + var valArray = (StringArray)array; + for (var i = 0; i < valArray.Length; i++) + { + if (valArray.IsNull(i)) + { + stream.WriteLine($"{pad}null"); + } + else + { + var slc = valArray.GetString(i); + stream.WriteLine($"{pad}{slc}"); + } + } + break; + } + case ArrowTypeId.Struct: + { + var dtype = (StructType)array.Data.DataType; + var valArray = (StructArray)array; + foreach (var (f, col) in dtype.Fields.Zip(valArray.Fields)) + { + stream.WriteLine($"{indentString}{f.Name}: {f.DataType.Name}"); + PrettyPrintFormat(col, stream, indent + 1, indenter); + } + break; + } + default: throw new NotImplementedException($"{array.Data.DataType.Name}"); + } + stream.WriteLine($"{indentString}]"); + } + + /// + /// Recursively pretty print format and write `array` into a string, indenting as nesting increases. + /// + /// + /// + /// + /// + public static string PrettyPrintFormat(IArrowArray array, int indent = 0, string indenter = " ") + { + using (var bufferStream = new MemoryStream()) + { + var writer = new StreamWriter(bufferStream); + PrettyPrintFormat(array, writer, indent, indenter); + writer.Flush(); + bufferStream.Seek(0, SeekOrigin.Begin); + var reader = new StreamReader(bufferStream); + var buff = reader.ReadToEnd(); + return buff; + } + } + + /// + /// Pretty print `array` via `Console.WriteLine`. Prefer `PrettyPrintFormat` to control where the + /// writing happens. + /// + /// + /// + /// + public static void PrettyPrint(IArrowArray array, int indent = 0, string indenter = " ") + { + var text = PrettyPrintFormat(array, indent, indenter); + Console.WriteLine(text); + } +} diff --git a/src/Apache.Arrow/Arrays/BooleanArray.cs b/src/Apache.Arrow/Arrays/BooleanArray.cs index f87c2ec7..8cd018c7 100644 --- a/src/Apache.Arrow/Arrays/BooleanArray.cs +++ b/src/Apache.Arrow/Arrays/BooleanArray.cs @@ -38,6 +38,12 @@ public Builder() ValidityBuffer = new ArrowBuffer.BitmapBuilder(); } + public Builder(int capacity) + { + ValueBuffer = new ArrowBuffer.BitmapBuilder(capacity); + ValidityBuffer = new ArrowBuffer.BitmapBuilder(capacity); + } + public Builder Append(bool value) { return NullableAppend(value); diff --git a/src/Apache.Arrow/Arrays/FixedSizeListArray.cs b/src/Apache.Arrow/Arrays/FixedSizeListArray.cs index f60daedb..9a81eedf 100644 --- a/src/Apache.Arrow/Arrays/FixedSizeListArray.cs +++ b/src/Apache.Arrow/Arrays/FixedSizeListArray.cs @@ -14,12 +14,14 @@ // limitations under the License. using System; +using System.Collections; +using System.Collections.Generic; using Apache.Arrow.Memory; using Apache.Arrow.Types; namespace Apache.Arrow { - public class FixedSizeListArray : Array + public class FixedSizeListArray : Array, IEnumerable { public class Builder : IArrowArrayBuilder { @@ -186,5 +188,15 @@ protected override void Dispose(bool disposing) } base.Dispose(disposing); } + + IEnumerator IEnumerable.GetEnumerator() + { + for (int index = 0; index < Length; index++) + { + yield return GetSlicedValues(index); + } + } + + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); } } diff --git a/src/Apache.Arrow/Arrays/LargeListArray.cs b/src/Apache.Arrow/Arrays/LargeListArray.cs index 6e37aa4c..df2d90ca 100644 --- a/src/Apache.Arrow/Arrays/LargeListArray.cs +++ b/src/Apache.Arrow/Arrays/LargeListArray.cs @@ -14,11 +14,13 @@ // limitations under the License. using System; +using System.Collections; +using System.Collections.Generic; using Apache.Arrow.Types; namespace Apache.Arrow { - public class LargeListArray : Array + public class LargeListArray : Array, IEnumerable { public IArrowArray Values { get; } @@ -93,5 +95,16 @@ protected override void Dispose(bool disposing) } base.Dispose(disposing); } + + IEnumerator IEnumerable.GetEnumerator() + { + for (int index = 0; index < Length; index++) + { + yield return GetSlicedValues(index); + } + } + + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); } + } diff --git a/src/Apache.Arrow/Arrays/ListArray.cs b/src/Apache.Arrow/Arrays/ListArray.cs index 4d2ff96a..a9d9148c 100644 --- a/src/Apache.Arrow/Arrays/ListArray.cs +++ b/src/Apache.Arrow/Arrays/ListArray.cs @@ -14,12 +14,14 @@ // limitations under the License. using System; +using System.Collections; +using System.Collections.Generic; using Apache.Arrow.Memory; using Apache.Arrow.Types; namespace Apache.Arrow { - public class ListArray : Array + public class ListArray : Array, IEnumerable { public class Builder : IArrowArrayBuilder { @@ -204,5 +206,15 @@ protected override void Dispose(bool disposing) } base.Dispose(disposing); } + + IEnumerator IEnumerable.GetEnumerator() + { + for (int index = 0; index < Length; index++) + { + yield return GetSlicedValues(index); + } + } + + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); } } diff --git a/src/Apache.Arrow/Arrays/ListViewArray.cs b/src/Apache.Arrow/Arrays/ListViewArray.cs index 081385d9..9c6ad4b3 100644 --- a/src/Apache.Arrow/Arrays/ListViewArray.cs +++ b/src/Apache.Arrow/Arrays/ListViewArray.cs @@ -14,12 +14,14 @@ // limitations under the License. using System; +using System.Collections; +using System.Collections.Generic; using Apache.Arrow.Memory; using Apache.Arrow.Types; namespace Apache.Arrow { - public class ListViewArray : Array + public class ListViewArray : Array, IEnumerable { public class Builder : IArrowArrayBuilder { @@ -213,5 +215,17 @@ protected override void Dispose(bool disposing) } base.Dispose(disposing); } + + + IEnumerator IEnumerable.GetEnumerator() + { + for (int index = 0; index < Length; index++) + { + if (IsNull(index)) yield return null; + else yield return GetSlicedValues(index); + } + } + + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); } } diff --git a/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs b/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs index dc6fba2b..4b83378a 100644 --- a/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs +++ b/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs @@ -116,6 +116,12 @@ public PrimitiveArrayBuilder() ValidityBuffer = new ArrowBuffer.BitmapBuilder(); } + public PrimitiveArrayBuilder(int capacity) + { + ValueBuffer = new ArrowBuffer.Builder(capacity); + ValidityBuffer = new ArrowBuffer.BitmapBuilder(capacity); + } + public TBuilder Resize(int length) { ValueBuffer.Resize(length); diff --git a/test/Apache.Arrow.Operations.Tests/Apache.Arrow.Operations.Tests.csproj b/test/Apache.Arrow.Operations.Tests/Apache.Arrow.Operations.Tests.csproj new file mode 100644 index 00000000..4cc754eb --- /dev/null +++ b/test/Apache.Arrow.Operations.Tests/Apache.Arrow.Operations.Tests.csproj @@ -0,0 +1,25 @@ + + + + net8.0 + + + + + + + + + + + all + runtime; build; native; contentfiles; analyzers + + + + + + + + + diff --git a/test/Apache.Arrow.Operations.Tests/TestOperations.cs b/test/Apache.Arrow.Operations.Tests/TestOperations.cs new file mode 100644 index 00000000..cb138cd4 --- /dev/null +++ b/test/Apache.Arrow.Operations.Tests/TestOperations.cs @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Linq; +using Xunit; + + +namespace Apache.Arrow.Operations.Tests; + +public class ArrowOperationsTests +{ + + [Fact] + public void TestConversion() + { + var vals = Conversion.CastDouble([50L, 52L, 510L]); + Assert.Equal(vals.GetValue(0), 50.0); + Assert.Equal(vals.GetValue(1), 52.0); + Assert.Equal(vals.GetValue(2), 510.0); + + var valsF = Conversion.CastFloat(vals); + Assert.Equal(valsF.GetValue(0), 50.0f); + Assert.Equal(valsF.GetValue(1), 52.0f); + Assert.Equal(valsF.GetValue(2), 510.0f); + + var valsI = Conversion.CastInt32(vals); + Assert.Equal(valsI.GetValue(0), 50); + Assert.Equal(valsI.GetValue(1), 52); + Assert.Equal(valsI.GetValue(2), 510); + } + + [Fact] + public void TestSelectionTakeIndex() + { + var vals = Conversion.CastInt64([50L, 52L, 510L]); + var items = (Int64Array)Select.Take(vals, [1]); + Assert.Equal(52, items.GetValue(0)); + } + + [Fact] + public void TestSelectionFilterMask() + { + var vals = Conversion.CastInt64([50L, 52L, 510L]); + var mask = Comparison.Equal(vals, 52L); + var items = (Int64Array)Select.Filter(vals, mask); + Assert.Equal(52, items.GetValue(0)); + } +} + + +public class ArrowBooleanOperationsTests +{ + [Fact] + public void TestInvert() + { + var vals = Enumerable.Repeat(true, 5000); + var builder = new BooleanArray.Builder(5000); + builder.AppendRange(vals); + var array = builder.Build(); + Assert.True(array.All(v => v ?? false)); + + var inverted = BitVectorOps.OnesComplement(array.ValueBuffer); + var invertedArray = new BooleanArray(inverted, array.NullBitmapBuffer.Clone(), array.Length, array.NullCount, 0); + Assert.Equal(array.Length, invertedArray.Length); + Assert.False(invertedArray.All(v => v ?? false)); + } + + [Fact] + public void TesAnd() + { + var vals = Enumerable.Repeat(true, 5000); + var builder = new BooleanArray.Builder(5000); + builder.AppendRange(vals); + var array = builder.Build(); + Assert.True(array.All(v => v ?? false)); + + var result = Comparison.And(array, array); + Assert.True(result.All(v => v ?? false)); + + vals = Enumerable.Repeat(false, 5000); + builder = new BooleanArray.Builder(5000); + builder.AppendRange(vals); + var inverted = builder.Build(); + + result = Comparison.And(array, inverted); + Assert.Equal(result.Length, inverted.Length); + Assert.False(result.All(v => v ?? false)); + } + + [Fact] + public void TestOr() + { + var vals = Enumerable.Repeat(true, 5000); + var builder = new BooleanArray.Builder(5000); + builder.AppendRange(vals); + var array = builder.Build(); + Assert.True(array.All(v => v ?? false)); + + var result = Comparison.Or(array, array); + Assert.True(result.All(v => v ?? false)); + + vals = Enumerable.Repeat(false, 5000); + builder = new BooleanArray.Builder(5000); + builder.AppendRange(vals); + var inverted = builder.Build(); + + result = Comparison.Or(array, inverted); + Assert.Equal(result.Length, inverted.Length); + Assert.True(result.All(v => v ?? false)); + } + + [Fact] + public void TestXor() + { + var vals = Enumerable.Repeat(true, 2500); + var builder = new BooleanArray.Builder(5000); + builder.AppendRange(vals); + vals = Enumerable.Repeat(false, 2500); + builder.AppendRange(vals); + var array = builder.Build(); + + Assert.Equal(2500, array.Count(s => s ?? false)); + + builder = new BooleanArray.Builder(5000); + vals = Enumerable.Repeat(true, 2500); + builder.AppendRange(vals); + vals = Enumerable.Repeat(true, 2500); + builder.AppendRange(vals); + var array2 = builder.Build(); + + var result = Comparison.Xor(array, array2); + Assert.Equal(2500, result.Count(s => s ?? false)); + Assert.Equal(0, ((BooleanArray)result.Slice(0, 2500)).Count(s => s ?? false)); + } +}