diff --git a/vortex-array/src/dtype/arrow.rs b/vortex-array/src/dtype/arrow.rs index 8dd80082533..103c52daf93 100644 --- a/vortex-array/src/dtype/arrow.rs +++ b/vortex-array/src/dtype/arrow.rs @@ -450,6 +450,18 @@ mod test { ); } + #[rstest] + #[case(1, DataType::Decimal128(1, 0))] + #[case(38, DataType::Decimal128(38, 0))] + #[case(39, DataType::Decimal256(39, 0))] + #[case(76, DataType::Decimal256(76, 0))] + fn test_decimal_dtype_to_arrow(#[case] precision: u8, #[case] expected: DataType) { + use crate::dtype::DecimalDType; + + let dtype = DType::Decimal(DecimalDType::new(precision, 0), Nullability::NonNullable); + assert_eq!(dtype.to_arrow_dtype().unwrap(), expected); + } + #[test] fn test_variant_dtype_to_arrow_dtype_errors() { let err = DType::Variant(Nullability::NonNullable) diff --git a/vortex-cuda/kernels/src/decimal_cast.cu b/vortex-cuda/kernels/src/decimal_cast.cu new file mode 100644 index 00000000000..22a8a130729 --- /dev/null +++ b/vortex-cuda/kernels/src/decimal_cast.cu @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#include "config.cuh" +#include "types.cuh" +#include +#include + +// Arrow decimal schemas fix the physical values buffer width: +// - Decimal32: 4 bytes per value. +// - Decimal64: 8 bytes per value. +// - Decimal128: 16 bytes per value. +// - Decimal256: 32 bytes per value. +// +// Vortex storage width can differ, so export casts to the schema-implied width. +// Rust-side export rejects narrowing casts because detecting overflow on-device +// would require synchronizing an overflow flag back to the host. + +// Low 64-bit conversion for Decimal32/64 outputs. +template +__device__ __forceinline__ int64_t decimal_to_i64(Input value) { + if constexpr (std::is_same_v) { + return value.lo; + } else if constexpr (std::is_same_v) { + return value.parts[0]; + } else { + return static_cast(value); + } +} + +// 128-bit conversion for Decimal128 outputs. +template +__device__ __forceinline__ int128_t decimal_to_i128(Input value) { + if constexpr (std::is_same_v) { + return value; + } else if constexpr (std::is_same_v) { + return int128_t {value.parts[0], value.parts[1]}; + } else { + const int64_t lo = static_cast(value); + const int64_t hi = value < 0 ? -1 : 0; + return int128_t {lo, hi}; + } +} + +// Convert one value to the Arrow schema's physical width. +template +__device__ __forceinline__ Output decimal_cast_value(Input value) { + if constexpr (std::is_same_v) { + return static_cast(decimal_to_i64(value)); + } else if constexpr (std::is_same_v) { + return decimal_to_i64(value); + } else if constexpr (std::is_same_v) { + return decimal_to_i128(value); + } else { + static_assert(std::is_same_v); + if constexpr (std::is_same_v) { + return value; + } else { + const int128_t value128 = decimal_to_i128(value); + const int64_t sign = value128.hi < 0 ? -1 : 0; + return int256_t {{value128.lo, value128.hi, sign, sign}}; + } + } +} + +// Cast a contiguous values buffer to the Arrow schema's physical width. +template +__device__ void +decimal_cast_device(const Input *__restrict input, Output *__restrict output, uint64_t array_len) { + const uint64_t worker = blockIdx.x * blockDim.x + threadIdx.x; + const uint64_t startElem = start_elem(worker, array_len); + const uint64_t stopElem = stop_elem(worker, array_len); + + if (startElem >= array_len) { + return; + } + + for (uint64_t idx = startElem; idx < stopElem; idx++) { + output[idx] = decimal_cast_value(input[idx]); + } +} + +// Generate Decimal32/64/128/256 cast kernels for one input storage type. +#define GENERATE_DECIMAL_CAST_KERNELS(input_suffix, InputType) \ + extern "C" __global__ void decimal_cast_##input_suffix##_i32(const InputType *__restrict input, \ + int32_t *__restrict output, \ + uint64_t array_len) { \ + decimal_cast_device(input, output, array_len); \ + } \ + extern "C" __global__ void decimal_cast_##input_suffix##_i64(const InputType *__restrict input, \ + int64_t *__restrict output, \ + uint64_t array_len) { \ + decimal_cast_device(input, output, array_len); \ + } \ + extern "C" __global__ void decimal_cast_##input_suffix##_i128(const InputType *__restrict input, \ + int128_t *__restrict output, \ + uint64_t array_len) { \ + decimal_cast_device(input, output, array_len); \ + } \ + extern "C" __global__ void decimal_cast_##input_suffix##_i256(const InputType *__restrict input, \ + int256_t *__restrict output, \ + uint64_t array_len) { \ + decimal_cast_device(input, output, array_len); \ + } + +FOR_EACH_SIGNED_INT(GENERATE_DECIMAL_CAST_KERNELS) +FOR_EACH_LARGE_DECIMAL(GENERATE_DECIMAL_CAST_KERNELS) diff --git a/vortex-cuda/src/arrow/canonical.rs b/vortex-cuda/src/arrow/canonical.rs index b1bee7c018f..af180400675 100644 --- a/vortex-cuda/src/arrow/canonical.rs +++ b/vortex-cuda/src/arrow/canonical.rs @@ -3,11 +3,15 @@ use std::mem; use std::ptr; +use std::sync::Arc; use async_trait::async_trait; +use cudarc::driver::DeviceRepr; +use cudarc::driver::PushKernelArg; use futures::future::BoxFuture; use vortex::array::ArrayRef; use vortex::array::Canonical; +use vortex::array::arrays::DecimalArray; use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::ListArray; use vortex::array::arrays::PrimitiveArray; @@ -24,13 +28,16 @@ use vortex::array::arrays::struct_::StructDataParts; use vortex::array::arrays::varbinview::VarBinViewDataParts; use vortex::array::buffer::BufferHandle; use vortex::array::builtins::ArrayBuiltins; +use vortex::array::match_each_decimal_value_type; use vortex::array::validity::Validity; use vortex::buffer::Buffer; use vortex::buffer::ByteBuffer; use vortex::dtype::DType; use vortex::dtype::DecimalType; +use vortex::dtype::NativeDecimalType; use vortex::dtype::Nullability; use vortex::dtype::PType; +use vortex::dtype::i256; use vortex::error::VortexResult; use vortex::error::vortex_bail; use vortex::error::vortex_ensure; @@ -39,6 +46,8 @@ use vortex::extension::datetime::AnyTemporal; use vortex::mask::Mask; use super::list_view::export_device_list_view; +use crate::CudaBufferExt; +use crate::CudaDeviceBuffer; use crate::CudaExecutionCtx; use crate::arrow::ARROW_DEVICE_CUDA; use crate::arrow::ArrowArray; @@ -46,6 +55,7 @@ use crate::arrow::ArrowDeviceArray; use crate::arrow::ExportDeviceArray; use crate::arrow::PrivateData; use crate::arrow::SyncEvent; +use crate::arrow::cuda_decimal_value_type; use crate::executor::CudaArrayExt; /// An implementation of `ExportDeviceArray` that exports Vortex arrays to `ArrowDeviceArray` by @@ -106,27 +116,7 @@ fn export_canonical( // we don't need a sync event for Null since no data is copied. Ok((array, ptr::null_mut())) } - Canonical::Decimal(decimal) => { - let len = decimal.len(); - let DecimalDataParts { - values, - values_type, - validity, - .. - } = decimal.into_data_parts(); - - // TODO(aduffy): GPU kernel for upcasting. - vortex_ensure!( - values_type >= DecimalType::I32, - "cannot export DecimalArray with values type {values_type}. must be i32 or wider." - ); - - let (validity_buffer, null_count) = - export_arrow_validity_buffer(validity, len, 0, ctx).await?; - let buffer = ctx.ensure_on_device(values).await?; - - export_fixed_size(buffer, len, 0, validity_buffer, null_count, ctx) - } + Canonical::Decimal(decimal) => export_decimal(decimal, ctx).await, Canonical::Extension(extension) => { if !extension.ext_dtype().is::() { vortex_bail!("only support temporal extension types currently"); @@ -242,6 +232,115 @@ fn export_canonical( }) } +/// Exports decimals with value buffers cast to Arrow's Decimal32/64/128/256 layout. +/// +/// Decimal values are already decoded; this only adapts the physical buffer width. Storage-to-Arrow +/// narrowing is rejected instead of checked on-device to avoid a device-to-host synchronization +/// point. +async fn export_decimal( + decimal: DecimalArray, + ctx: &mut CudaExecutionCtx, +) -> VortexResult<(ArrowArray, SyncEvent)> { + let len = decimal.len(); + let DecimalDataParts { + decimal_dtype, + values, + values_type, + validity, + } = decimal.into_data_parts(); + + let (validity_buffer, null_count) = export_arrow_validity_buffer(validity, len, 0, ctx).await?; + let target_type = cuda_decimal_value_type(decimal_dtype); + let values = export_decimal_values(values, values_type, target_type, len, ctx).await?; + + export_fixed_size(values, len, 0, validity_buffer, null_count, ctx) +} + +/// Ensure the values buffer is on-device and has the Arrow-required decimal width. +/// +/// Storage wider than the precision-implied Arrow width is rejected. Callers that hit this +/// should narrow the storage via a decimal cast before exporting. +async fn export_decimal_values( + values: BufferHandle, + values_type: DecimalType, + target_type: DecimalType, + len: usize, + ctx: &mut CudaExecutionCtx, +) -> VortexResult { + if values_type.byte_width() > target_type.byte_width() { + vortex_bail!( + "cannot export decimal values from {values_type} storage to Arrow {target_type}: narrowing would require a device-to-host overflow check", + ); + } + let values = ctx.ensure_on_device(values).await?; + if values_type == target_type { + return Ok(values); + } + + match_each_decimal_value_type!(values_type, |S| { + export_decimal_values_from::(values, target_type, len, ctx).await + }) +} + +/// Dispatch from a concrete storage type `S` to the Arrow-required output width. +async fn export_decimal_values_from( + values: BufferHandle, + target_type: DecimalType, + len: usize, + ctx: &mut CudaExecutionCtx, +) -> VortexResult +where + S: NativeDecimalType + DeviceRepr, +{ + match target_type { + DecimalType::I32 => decimal_cast::(values, len, ctx).await, + DecimalType::I64 => decimal_cast::(values, len, ctx).await, + DecimalType::I128 => decimal_cast::(values, len, ctx).await, + DecimalType::I256 => decimal_cast::(values, len, ctx).await, + target_type => { + vortex_bail!("cannot export DecimalArray as Arrow decimal value type {target_type}") + } + } +} + +/// Launches the CUDA kernel that casts from Vortex storage type `S` to Arrow output type `D`. +/// +/// The caller must ensure this is not a narrowing cast. +async fn decimal_cast( + values: BufferHandle, + len: usize, + ctx: &mut CudaExecutionCtx, +) -> VortexResult +where + S: NativeDecimalType + DeviceRepr, + D: NativeDecimalType + DeviceRepr, +{ + if len == 0 { + return ctx + .ensure_on_device(BufferHandle::new_host( + Buffer::::empty().into_byte_buffer(), + )) + .await; + } + + let output_buffer = ctx.device_alloc::(len)?; + let output_device = CudaDeviceBuffer::new(output_buffer); + + let values_view = values.cuda_view::()?; + let output_view = output_device.as_view::(); + let len_u64 = len as u64; + let cuda_function = ctx.load_function_with_suffixes( + "decimal_cast", + &[&S::DECIMAL_TYPE.to_string(), &D::DECIMAL_TYPE.to_string()], + )?; + + ctx.launch_kernel(&cuda_function, len, |args| { + args.arg(&values_view).arg(&output_view).arg(&len_u64); + })?; + + Ok(BufferHandle::new_device(Arc::new(output_device))) +} + /// Export Vortex validity as an Arrow validity byte buffer. /// /// Returns `None` for the buffer when Arrow can omit validity because all rows are valid. @@ -537,16 +636,19 @@ mod tests { use vortex::array::arrays::VarBinViewArray; use vortex::array::arrays::primitive::PrimitiveArrayExt; use vortex::array::arrays::varbinview::BinaryView; + use vortex::array::buffer::BufferHandle; use vortex::array::validity::Validity; use vortex::buffer::Buffer; use vortex::buffer::ByteBuffer; use vortex::dtype::DType; use vortex::dtype::DecimalDType; use vortex::dtype::FieldNames; + use vortex::dtype::NativeDecimalType; use vortex::dtype::NativePType; use vortex::dtype::Nullability; use vortex::dtype::PType; use vortex::dtype::half::f16; + use vortex::dtype::i256; use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::error::vortex_bail; @@ -761,6 +863,14 @@ mod tests { .collect()) } + fn assert_exported_decimal_values( + value_buffer: &BufferHandle, + expected: &[T], + ) { + let values = Buffer::::from_byte_buffer(value_buffer.to_host_sync()); + assert_eq!(values.as_slice(), expected); + } + // Build a nested struct fixture with an out-of-line string-view value. fn nested_struct_array() -> ArrayRef { let nested = StructArray::new( @@ -855,25 +965,197 @@ mod tests { Ok(()) } + async fn assert_exported_decimal( + array: ArrayRef, + expected_data_type: DataType, + expected_values: Vec, + ) -> VortexResult<()> { + let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty()) + .vortex_expect("failed to create execution context"); + + let mut exported = array.export_device_array_with_schema(&mut ctx).await?; + + let field = Field::try_from(&exported.schema)?; + assert_eq!(field, Field::new("", expected_data_type, false)); + assert_eq!( + exported.array.array.length, + i64::try_from(expected_values.len())? + ); + assert_eq!(exported.array.array.null_count, 0); + assert_eq!(exported.array.array.n_buffers, 2); + assert_eq!(exported.array.array.n_children, 0); + assert!(exported.array.array.release.is_some()); + assert_eq!(exported.array.device_type, ARROW_DEVICE_CUDA); + + let private_data = unsafe { &*exported.array.array.private_data.cast::() }; + let value_buffer = private_data.buffers[1] + .as_ref() + .vortex_expect("value buffer should be present"); + assert_eq!(value_buffer.len(), expected_values.len() * size_of::()); + assert_exported_decimal_values(value_buffer, &expected_values); + + unsafe { release_exported_array(&raw mut exported.array.array) }; + Ok(()) + } + + #[rstest] + #[case::i8( + DecimalArray::from_iter([1i8, -2, 3], DecimalDType::new(2, 1)).into_array(), + DataType::Decimal32(2, 1), + vec![1i32, -2, 3] + )] + #[case::i16( + DecimalArray::from_iter([100i16, -200, 300], DecimalDType::new(4, 2)).into_array(), + DataType::Decimal32(4, 2), + vec![100i32, -200, 300] + )] + #[case::i32( + DecimalArray::from_iter([10_000i32, -20_000, 30_000], DecimalDType::new(9, 2)).into_array(), + DataType::Decimal32(9, 2), + vec![10_000i32, -20_000, 30_000] + )] #[crate::test] - async fn test_export_decimal() -> VortexResult<()> { + async fn test_export_decimal32( + #[case] array: ArrayRef, + #[case] expected_data_type: DataType, + #[case] expected_values: Vec, + ) -> VortexResult<()> { + assert_exported_decimal(array, expected_data_type, expected_values).await + } + + #[rstest] + #[case::i32( + DecimalArray::from_iter([1_000_000i32, -2_000_000, 3_000_000], DecimalDType::new(10, 2)).into_array(), + DataType::Decimal64(10, 2), + vec![1_000_000i64, -2_000_000, 3_000_000] + )] + #[case::i32_boundary( + DecimalArray::from_iter([i32::MIN, -1i32, 0, 1, i32::MAX], DecimalDType::new(10, 0)).into_array(), + DataType::Decimal64(10, 0), + vec![i32::MIN as i64, -1, 0, 1, i32::MAX as i64] + )] + #[case::i64( + DecimalArray::from_iter([1_000_000i64, -2_000_000, 3_000_000], DecimalDType::new(18, 2)).into_array(), + DataType::Decimal64(18, 2), + vec![1_000_000i64, -2_000_000, 3_000_000] + )] + #[crate::test] + async fn test_export_decimal64( + #[case] array: ArrayRef, + #[case] expected_data_type: DataType, + #[case] expected_values: Vec, + ) -> VortexResult<()> { + assert_exported_decimal(array, expected_data_type, expected_values).await + } + + #[rstest] + #[case::i64_boundary( + DecimalArray::from_iter([i64::MIN, -1i64, 0, 1, i64::MAX], DecimalDType::new(19, 0)).into_array(), + DataType::Decimal128(19, 0), + vec![i64::MIN as i128, -1, 0, 1, i64::MAX as i128] + )] + #[case::i128( + DecimalArray::from_iter([1i128, -2, 3], DecimalDType::new(38, 2)).into_array(), + DataType::Decimal128(38, 2), + vec![1i128, -2, 3] + )] + #[crate::test] + async fn test_export_decimal128( + #[case] array: ArrayRef, + #[case] expected_data_type: DataType, + #[case] expected_values: Vec, + ) -> VortexResult<()> { + assert_exported_decimal(array, expected_data_type, expected_values).await + } + + #[crate::test] + async fn test_export_empty_decimal() -> VortexResult<()> { + assert_exported_decimal( + DecimalArray::new( + Buffer::::empty(), + DecimalDType::new(9, 2), + Validity::NonNullable, + ) + .into_array(), + DataType::Decimal32(9, 2), + Vec::::new(), + ) + .await + } + + #[crate::test] + async fn test_export_empty_decimal_widening() -> VortexResult<()> { + assert_exported_decimal( + DecimalArray::new( + Buffer::::empty(), + DecimalDType::new(9, 2), + Validity::NonNullable, + ) + .into_array(), + DataType::Decimal32(9, 2), + Vec::::new(), + ) + .await + } + + #[crate::test] + async fn test_export_decimal_narrowing_errors() -> VortexResult<()> { let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty()) .vortex_expect("failed to create execution context"); + let array = DecimalArray::from_iter([i256::from_parts(0, 1)], DecimalDType::new(38, 0)) + .into_array(); - let array = DecimalArray::from_iter(0i128..5, DecimalDType::new(38, 2)).into_array(); - let mut device_array = array.export_device_array(&mut ctx).await?; + let err = array + .export_device_array_with_schema(&mut ctx) + .await + .unwrap_err(); + assert!(err.to_string().contains("narrowing would require")); + Ok(()) + } - assert_eq!(device_array.array.length, 5); - assert_eq!(device_array.array.null_count, 0); - assert_eq!(device_array.array.n_buffers, 2); - assert_eq!(device_array.array.n_children, 0); - assert!(device_array.array.release.is_some()); - assert_eq!(device_array.device_type, ARROW_DEVICE_CUDA); + #[crate::test] + async fn test_export_decimal_narrowing_from_arrow_import() -> VortexResult<()> { + let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty()) + .vortex_expect("failed to create execution context"); + let array = DecimalArray::from_iter([0i128, 1, -2], DecimalDType::new(10, 2)).into_array(); - unsafe { release_exported_array(&raw mut device_array.array) }; + let err = array + .export_device_array_with_schema(&mut ctx) + .await + .unwrap_err(); + assert!(err.to_string().contains("narrowing would require")); Ok(()) } + #[rstest] + #[case::i64( + DecimalArray::from_iter([i64::MIN, -1i64, 1, i64::MAX], DecimalDType::new(39, 0)).into_array(), + DataType::Decimal256(39, 0), + vec![i256::from_i128(i64::MIN as i128), i256::from_i128(-1), i256::from_i128(1), i256::from_i128(i64::MAX as i128)] + )] + #[case::i128( + DecimalArray::from_iter([1i128, -2, 3], DecimalDType::new(39, 2)).into_array(), + DataType::Decimal256(39, 2), + vec![i256::from_i128(1), i256::from_i128(-2), i256::from_i128(3)] + )] + #[case::i256( + DecimalArray::from_iter( + [i256::from_i128(10), i256::from_i128(-20), i256::from_i128(30)], + DecimalDType::new(76, 2), + ) + .into_array(), + DataType::Decimal256(76, 2), + vec![i256::from_i128(10), i256::from_i128(-20), i256::from_i128(30)] + )] + #[crate::test] + async fn test_export_decimal256( + #[case] array: ArrayRef, + #[case] expected_data_type: DataType, + #[case] expected_values: Vec, + ) -> VortexResult<()> { + assert_exported_decimal(array, expected_data_type, expected_values).await + } + #[crate::test] async fn test_export_temporal() -> VortexResult<()> { let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty()) @@ -1519,6 +1801,14 @@ mod tests { &mut ctx, ) .await?; + + let private_data = unsafe { &*decimal.array.private_data.cast::() }; + let value_buffer = private_data.buffers[1] + .as_ref() + .vortex_expect("value buffer should be present"); + assert_eq!(value_buffer.len(), 3 * size_of::()); + assert_exported_decimal_values(value_buffer, &[100i64, 0, 300]); + unsafe { release_exported_array(&raw mut decimal.array) }; Ok(()) @@ -1794,6 +2084,67 @@ mod tests { Ok(()) } + #[crate::test] + async fn test_export_nested_struct_decimal_with_schema() -> VortexResult<()> { + let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty()) + .vortex_expect("failed to create execution context"); + + let nested = StructArray::new( + FieldNames::from_iter(["amount"]), + vec![ + DecimalArray::from_iter([100i32, -200, 300], DecimalDType::new(9, 2)).into_array(), + ], + 3, + Validity::NonNullable, + ) + .into_array(); + let array = StructArray::new( + FieldNames::from_iter(["nested"]), + vec![nested], + 3, + Validity::NonNullable, + ) + .into_array(); + let mut exported = array.export_device_array_with_schema(&mut ctx).await?; + + let schema = Schema::try_from(&exported.schema)?; + assert_eq!( + schema, + Schema::new(vec![Field::new( + "nested", + DataType::Struct(Fields::from(vec![Field::new( + "amount", + DataType::Decimal32(9, 2), + false, + )])), + false, + )]) + ); + + let children = unsafe { + std::slice::from_raw_parts( + exported.array.array.children, + usize::try_from(exported.array.array.n_children)?, + ) + }; + let nested_child = unsafe { &*children[0] }; + let nested_children = unsafe { + std::slice::from_raw_parts( + nested_child.children, + usize::try_from(nested_child.n_children)?, + ) + }; + let decimal_child = unsafe { &*nested_children[0] }; + let private_data = unsafe { &*decimal_child.private_data.cast::() }; + let value_buffer = private_data.buffers[1] + .as_ref() + .vortex_expect("value buffer should be present"); + assert_eq!(value_buffer.len(), 3 * size_of::()); + + unsafe { release_exported_array(&raw mut exported.array.array) }; + Ok(()) + } + #[crate::test] async fn test_export_primitive_with_schema_is_column_shaped() -> VortexResult<()> { let mut ctx = CudaSession::create_execution_ctx(&VortexSession::empty()) diff --git a/vortex-cuda/src/arrow/mod.rs b/vortex-cuda/src/arrow/mod.rs index e21383d96e7..99fa14add3d 100644 --- a/vortex-cuda/src/arrow/mod.rs +++ b/vortex-cuda/src/arrow/mod.rs @@ -16,6 +16,9 @@ use std::fmt::Debug; use std::ptr; use std::sync::Arc; +use arrow_schema::DataType; +use arrow_schema::Field; +use arrow_schema::Schema; use arrow_schema::ffi::FFI_ArrowSchema; use async_trait::async_trait; pub(crate) use canonical::CanonicalDeviceArrayExport; @@ -26,6 +29,8 @@ use vortex::array::ArrayRef; use vortex::array::arrow::ArrowSessionExt; use vortex::array::buffer::BufferHandle; use vortex::dtype::DType; +use vortex::dtype::DecimalDType; +use vortex::dtype::DecimalType; use vortex::dtype::StructFields; use vortex::error::VortexResult; use vortex::error::vortex_err; @@ -159,6 +164,10 @@ pub trait DeviceArrayExt { /// /// The returned array owns any device buffers allocated during export. Call the embedded /// Arrow release callback when the consumer is done with the array. + /// + /// Arrow arrays are not self-describing, so callers that use this method directly must provide + /// a matching schema out-of-band. Prefer [`Self::export_device_array_with_schema`] unless a + /// consumer already has the CUDA export schema. async fn export_device_array( self, ctx: &mut CudaExecutionCtx, @@ -167,13 +176,15 @@ pub trait DeviceArrayExt { /// Export this array as an Arrow C Device array together with its matching Arrow C schema. /// /// Arrow arrays are not self-describing: consumers need both the [`ArrowDeviceArray`] and an - /// Arrow schema to interpret the buffer layout. This helper derives the schema from the - /// Vortex dtype using the session's Arrow conversion rules and returns it alongside the device - /// array. + /// Arrow schema to interpret the buffer layout. This helper derives the schema that matches the + /// CUDA device export layout and returns it alongside the device array. /// /// Top-level struct arrays are exported as table-like Arrow schemas and struct-shaped device /// arrays. Top-level non-struct arrays are exported as column-shaped field schemas and /// column-shaped device arrays; this method does not wrap them in a single-field struct. + /// + /// Decimal exports use the Arrow decimal width implied by precision; storage wider than that + /// width is rejected rather than narrowed on device. async fn export_device_array_with_schema( self, ctx: &mut CudaExecutionCtx, @@ -200,22 +211,75 @@ impl DeviceArrayExt for ArrayRef { } } -/// Build the Arrow C schema that describes the device array exported for `array`. -/// -/// Top-level Vortex structs are represented as Arrow schemas, which is the shape expected for -/// table-like consumers. Non-struct arrays are represented as a single Arrow field schema, matching -/// the column-shaped [`ArrowDeviceArray`] returned by [`DeviceArrayExt::export_device_array`]. +/// Build the Arrow C schema that describes the exported device array. fn arrow_schema_for_array( array: &ArrayRef, ctx: &mut CudaExecutionCtx, ) -> VortexResult { - let arrow = ctx.execution_ctx().session().arrow(); let dtype = arrow_device_export_dtype(array.dtype()); match &dtype { - DType::Struct(..) => Ok(FFI_ArrowSchema::try_from(arrow.to_arrow_schema(&dtype)?)?), - _ => Ok(FFI_ArrowSchema::try_from( - arrow.to_arrow_field("", &dtype)?, - )?), + DType::Struct(struct_dtype, _) => Ok(FFI_ArrowSchema::try_from(Schema::new( + cuda_arrow_struct_fields(struct_dtype, ctx)?, + ))?), + _ => Ok(FFI_ArrowSchema::try_from(cuda_arrow_field( + "", &dtype, ctx, + )?)?), + } +} + +fn cuda_arrow_struct_fields( + struct_dtype: &StructFields, + ctx: &mut CudaExecutionCtx, +) -> VortexResult> { + let mut fields = Vec::with_capacity(struct_dtype.nfields()); + for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) { + fields.push(cuda_arrow_field(field_name.as_ref(), &field_dtype, ctx)?); + } + Ok(fields) +} + +fn cuda_arrow_field( + name: impl AsRef, + dtype: &DType, + ctx: &mut CudaExecutionCtx, +) -> VortexResult { + let field = ctx + .execution_ctx() + .session() + .arrow() + .to_arrow_field(name.as_ref(), dtype)?; + + let data_type = match dtype { + DType::Decimal(decimal_dtype, _) => cuda_arrow_decimal_data_type(*decimal_dtype), + DType::Struct(struct_dtype, _) => { + DataType::Struct(cuda_arrow_struct_fields(struct_dtype, ctx)?.into()) + } + _ => return Ok(field), + }; + + Ok( + Field::new(field.name().clone(), data_type, field.is_nullable()) + .with_metadata(field.metadata().clone()), + ) +} + +fn cuda_arrow_decimal_data_type(decimal_dtype: DecimalDType) -> DataType { + match cuda_decimal_value_type(decimal_dtype) { + DecimalType::I32 => DataType::Decimal32(decimal_dtype.precision(), decimal_dtype.scale()), + DecimalType::I64 => DataType::Decimal64(decimal_dtype.precision(), decimal_dtype.scale()), + DecimalType::I128 => DataType::Decimal128(decimal_dtype.precision(), decimal_dtype.scale()), + DecimalType::I256 => DataType::Decimal256(decimal_dtype.precision(), decimal_dtype.scale()), + decimal_type => unreachable!("unsupported decimal value type {decimal_type}"), + } +} + +pub(crate) fn cuda_decimal_value_type(decimal_dtype: DecimalDType) -> DecimalType { + match decimal_dtype.precision() { + 1..=9 => DecimalType::I32, + 10..=18 => DecimalType::I64, + 19..=38 => DecimalType::I128, + 39..=76 => DecimalType::I256, + p => unreachable!("precision {p} is invalid for a DecimalDType"), } } diff --git a/vortex-test/e2e-cuda/src/lib.rs b/vortex-test/e2e-cuda/src/lib.rs index 11a540ff70e..04b84efb03d 100644 --- a/vortex-test/e2e-cuda/src/lib.rs +++ b/vortex-test/e2e-cuda/src/lib.rs @@ -14,6 +14,8 @@ use std::sync::LazyLock; use arrow_array::Array; use arrow_array::ArrayRef as ArrowArrayRef; use arrow_array::Date32Array; +use arrow_array::Decimal32Array; +use arrow_array::Decimal64Array; use arrow_array::Decimal128Array; use arrow_array::StringArray; use arrow_array::cast::AsArray; @@ -146,9 +148,19 @@ fn export_array_inner(schema_ptr: &mut FFI_ArrowSchema, array_ptr: &mut ArrowDev return 1; } }; - let decimal = DecimalArray::from_option_iter( - [Some(0i128), Some(1), None, Some(3), Some(4)], - DecimalDType::new(38, 2), + // cuDF supports Arrow decimal device imports through Decimal128. Decimal256 is intentionally + // not included here because cuDF has no DECIMAL256 type_id or Arrow interop mapping. + let decimal32 = DecimalArray::from_option_iter( + [Some(0i8), Some(1), None, Some(3), Some(4)], + DecimalDType::new(9, 2), + ); + let decimal64 = DecimalArray::from_option_iter( + [Some(0i32), Some(1), None, Some(3), Some(4)], + DecimalDType::new(10, 2), + ); + let decimal128 = DecimalArray::from_option_iter( + [Some(0i64), Some(1), None, Some(3), Some(4)], + DecimalDType::new(19, 2), ); let strings = VarBinViewArray::from_iter_nullable_str([ Some("one"), @@ -166,7 +178,9 @@ fn export_array_inner(schema_ptr: &mut FFI_ArrowSchema, array_ptr: &mut ArrowDev let array = StructArray::new( FieldNames::from_iter([ "prims", - "decimals", + "decimal32", + "decimal64", + "decimal128", "strings", "dates", "lists", @@ -174,7 +188,9 @@ fn export_array_inner(schema_ptr: &mut FFI_ArrowSchema, array_ptr: &mut ArrowDev ]), vec![ primitive, - decimal.into_array(), + decimal32.into_array(), + decimal64.into_array(), + decimal128.into_array(), strings.into_array(), dates.into_array(), list_array(), @@ -244,7 +260,14 @@ fn validate_array_inner(ffi_schema: &FFI_ArrowSchema, ffi_array: &mut FFI_ArrowA &mut SESSION.create_execution_ctx(), ) .expect("expected primitive Arrow array"); - let decimal = Decimal128Array::from_iter([Some(0i128), Some(1), None, Some(3), Some(4)]) + let decimal32 = Decimal32Array::from_iter([Some(0i32), Some(1), None, Some(3), Some(4)]) + // cuDF stores decimals using the maximum precision for the physical width and preserves scale. + .with_precision_and_scale(9, 2) + .expect("with_precision_and_scale"); + let decimal64 = Decimal64Array::from_iter([Some(0i64), Some(1), None, Some(3), Some(4)]) + .with_precision_and_scale(18, 2) + .expect("with_precision_and_scale"); + let decimal128 = Decimal128Array::from_iter([Some(0i128), Some(1), None, Some(3), Some(4)]) .with_precision_and_scale(38, 2) .expect("with_precision_and_scale"); let string = StringArray::from_iter([ @@ -270,7 +293,9 @@ fn validate_array_inner(ffi_schema: &FFI_ArrowSchema, ffi_array: &mut FFI_ArrowA let expected_fields = Fields::from_iter([ Field::new("prims", primitive.data_type().clone(), true), - Field::new("decimals", decimal.data_type().clone(), true), + Field::new("decimal32", decimal32.data_type().clone(), true), + Field::new("decimal64", decimal64.data_type().clone(), true), + Field::new("decimal128", decimal128.data_type().clone(), true), Field::new("strings", string.data_type().clone(), true), Field::new("dates", date.data_type().clone(), true), cudf_list_field("lists"), @@ -278,12 +303,16 @@ fn validate_array_inner(ffi_schema: &FFI_ArrowSchema, ffi_array: &mut FFI_ArrowA ]); if &expected_fields != struct_array.fields() { eprintln!("wrong fields for host array"); + eprintln!("expected fields: {}", format_fields(&expected_fields)); + eprintln!("actual fields: {}", format_fields(struct_array.fields())); return 1; } - let expected_arrays: [ArrowArrayRef; 4] = [ + let expected_arrays: [ArrowArrayRef; 6] = [ primitive, - Arc::new(decimal), + Arc::new(decimal32), + Arc::new(decimal64), + Arc::new(decimal128), Arc::new(string), Arc::new(date), ]; @@ -299,11 +328,11 @@ fn validate_array_inner(ffi_schema: &FFI_ArrowSchema, ffi_array: &mut FFI_ArrowA } } - if !list_values_eq(list.as_ref(), struct_array.column(4).as_ref()) { + if !list_values_eq(list.as_ref(), struct_array.column(6).as_ref()) { eprintln!("wrong values for lists column"); return 1; } - if !list_values_eq(fixed_size_list.as_ref(), struct_array.column(5).as_ref()) { + if !list_values_eq(fixed_size_list.as_ref(), struct_array.column(7).as_ref()) { eprintln!("wrong values for fixed_lists column"); return 1; } @@ -315,6 +344,21 @@ fn cudf_list_field(name: &str) -> Field { Field::new_list(name, Field::new("element", DataType::Int32, false), true) } +fn format_fields(fields: &Fields) -> String { + fields + .iter() + .map(|field| { + format!( + "{}: {}{}", + field.name(), + field.data_type(), + if field.is_nullable() { "?" } else { "" } + ) + }) + .collect::>() + .join(", ") +} + fn list_values_eq(expected: &dyn Array, actual: &dyn Array) -> bool { let expected = expected.as_list::(); let actual = actual.as_list::();