diff --git a/Cargo.lock b/Cargo.lock index 452078df739..5a124fedd65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8659,9 +8659,13 @@ dependencies = [ name = "vortex-compute" version = "0.1.0" dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-schema", "num-traits", "vortex-buffer", "vortex-dtype", + "vortex-error", "vortex-mask", "vortex-vector", ] diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index 204f22a52d0..c565a6711a9 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -13,7 +13,7 @@ use log::debug; use num_enum::{IntoPrimitive, TryFromPrimitive}; pub use stats_set::*; use vortex_dtype::Nullability::{NonNullable, Nullable}; -use vortex_dtype::{DECIMAL256_MAX_PRECISION, DType, DecimalDType, PType}; +use vortex_dtype::{DType, DecimalDType, NativeDecimalType, PType, i256}; mod array; mod bound; @@ -210,7 +210,7 @@ impl Stat { // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188 let precision = - u8::min(DECIMAL256_MAX_PRECISION, decimal_dtype.precision() + 10); + u8::min(i256::MAX_PRECISION, decimal_dtype.precision() + 10); DType::Decimal( DecimalDType::new(precision, decimal_dtype.scale()), Nullable, diff --git a/vortex-compute/Cargo.toml b/vortex-compute/Cargo.toml index b538cd1bdc6..7fc6ca6ceec 100644 --- a/vortex-compute/Cargo.toml +++ b/vortex-compute/Cargo.toml @@ -22,15 +22,20 @@ workspace = true [dependencies] vortex-buffer = { workspace = true } vortex-dtype = { workspace = true } +vortex-error = { workspace = true } vortex-mask = { workspace = true } vortex-vector = { workspace = true } +arrow-array = { workspace = true, optional = true } +arrow-buffer = { workspace = true, optional = true } +arrow-schema = { workspace = true, optional = true } num-traits = { workspace = true } [features] -default = ["arithmetic", "comparison", "filter", "logical", "mask"] +default = ["arithmetic", "arrow", "comparison", "filter", "logical", "mask"] arithmetic = [] +arrow = ["dep:arrow-array", "dep:arrow-buffer", "dep:arrow-schema"] comparison = [] filter = [] logical = [] diff --git a/vortex-compute/src/arrow/bool.rs b/vortex-compute/src/arrow/bool.rs new file mode 100644 index 00000000000..ba35fcbf675 --- /dev/null +++ b/vortex-compute/src/arrow/bool.rs @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use arrow_array::{ArrayRef, BooleanArray}; +use vortex_error::VortexResult; +use vortex_vector::BoolVector; + +use crate::arrow::IntoArrow; + +impl IntoArrow for BoolVector { + fn into_arrow(self) -> VortexResult { + let (bits, validity) = self.into_parts(); + Ok(Arc::new(BooleanArray::new( + bits.into(), + validity.into_arrow()?, + ))) + } +} diff --git a/vortex-compute/src/arrow/decimal.rs b/vortex-compute/src/arrow/decimal.rs new file mode 100644 index 00000000000..5b1d3d020da --- /dev/null +++ b/vortex-compute/src/arrow/decimal.rs @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use arrow_array::types::{Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type}; +use arrow_array::{ArrayRef, PrimitiveArray}; +use vortex_buffer::Buffer; +use vortex_dtype::i256; +use vortex_error::VortexResult; +use vortex_vector::{DVector, DecimalVector}; + +use crate::arrow::IntoArrow; + +impl IntoArrow for DecimalVector { + fn into_arrow(self) -> VortexResult { + match self { + DecimalVector::D8(v) => v.into_arrow(), + DecimalVector::D16(v) => v.into_arrow(), + DecimalVector::D32(v) => v.into_arrow(), + DecimalVector::D64(v) => v.into_arrow(), + DecimalVector::D128(v) => v.into_arrow(), + DecimalVector::D256(v) => v.into_arrow(), + } + } +} + +macro_rules! impl_decimal_upcast_i32 { + ($T:ty) => { + impl IntoArrow for DVector<$T> { + fn into_arrow(self) -> VortexResult { + let (_, elements, validity) = self.into_parts(); + // Upcast the DVector to Arrow's smallest decimal type (Decimal32) + let elements = + Buffer::::from_trusted_len_iter(elements.iter().map(|i| *i as i32)); + Ok(Arc::new(PrimitiveArray::::new( + elements.into_arrow_scalar_buffer(), + validity.into_arrow()?, + ))) + } + } + }; +} + +impl_decimal_upcast_i32!(i8); +impl_decimal_upcast_i32!(i16); + +/// Direct Arrow conversion for vectors that map directly to Arrow decimal types. +macro_rules! impl_decimal { + ($T:ty, $A:ty) => { + impl IntoArrow for DVector<$T> { + fn into_arrow(self) -> VortexResult { + let (_, elements, validity) = self.into_parts(); + Ok(Arc::new(PrimitiveArray::<$A>::new( + elements.into_arrow_scalar_buffer(), + validity.into_arrow()?, + ))) + } + } + }; +} + +impl_decimal!(i32, Decimal32Type); +impl_decimal!(i64, Decimal64Type); +impl_decimal!(i128, Decimal128Type); + +impl IntoArrow for DVector { + fn into_arrow(self) -> VortexResult { + let (_, elements, validity) = self.into_parts(); + + // Transmute the elements from our i256 to Arrow's. + // SAFETY: we use Arrow's type internally for our layout. + let elements = + unsafe { std::mem::transmute::, Buffer>(elements) }; + + Ok(Arc::new(PrimitiveArray::::new( + elements.into_arrow_scalar_buffer(), + validity.into_arrow()?, + ))) + } +} diff --git a/vortex-compute/src/arrow/mask.rs b/vortex-compute/src/arrow/mask.rs new file mode 100644 index 00000000000..16c83e808b4 --- /dev/null +++ b/vortex-compute/src/arrow/mask.rs @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use arrow_buffer::NullBuffer; +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::arrow::IntoArrow; + +impl IntoArrow> for Mask { + fn into_arrow(self) -> VortexResult> { + Ok(match self { + Mask::AllTrue(_) => None, + Mask::AllFalse(len) => Some(NullBuffer::new_null(len)), + Mask::Values(values) => { + // SAFETY: we maintain our own validated true count. + Some(unsafe { + NullBuffer::new_unchecked( + values.bit_buffer().clone().into(), + values.len() - values.true_count(), + ) + }) + } + }) + } +} diff --git a/vortex-compute/src/arrow/mod.rs b/vortex-compute/src/arrow/mod.rs new file mode 100644 index 00000000000..f8c3bdcda1d --- /dev/null +++ b/vortex-compute/src/arrow/mod.rs @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Conversion logic from Vortex vector types to Arrow types. + +use vortex_error::VortexResult; + +mod bool; +mod decimal; +mod mask; +mod null; +mod primitive; +mod struct_; +mod varbin; +mod vector; + +/// Trait for converting Vortex vector types into Arrow types. +pub trait IntoArrow { + /// Convert the Vortex type into an Arrow type. + fn into_arrow(self) -> VortexResult; +} diff --git a/vortex-compute/src/arrow/null.rs b/vortex-compute/src/arrow/null.rs new file mode 100644 index 00000000000..2c934e29a94 --- /dev/null +++ b/vortex-compute/src/arrow/null.rs @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use arrow_array::{ArrayRef, NullArray}; +use vortex_error::VortexResult; +use vortex_vector::{NullVector, VectorOps}; + +use crate::arrow::IntoArrow; + +impl IntoArrow for NullVector { + fn into_arrow(self) -> VortexResult { + Ok(Arc::new(NullArray::new(self.len()))) + } +} diff --git a/vortex-compute/src/arrow/primitive.rs b/vortex-compute/src/arrow/primitive.rs new file mode 100644 index 00000000000..4d03e7c5927 --- /dev/null +++ b/vortex-compute/src/arrow/primitive.rs @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use arrow_array::types::{ + Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, +}; +use arrow_array::{ArrayRef, PrimitiveArray}; +use vortex_dtype::half::f16; +use vortex_error::VortexResult; +use vortex_vector::{PVector, PrimitiveVector, match_each_pvector}; + +use crate::arrow::IntoArrow; + +impl IntoArrow for PrimitiveVector { + fn into_arrow(self) -> VortexResult { + match_each_pvector!(self, |v| { v.into_arrow() }) + } +} + +macro_rules! impl_primitive { + ($T:ty, $A:ty) => { + impl IntoArrow for PVector<$T> { + fn into_arrow(self) -> VortexResult { + let (elements, validity) = self.into_parts(); + Ok(Arc::new(PrimitiveArray::<$A>::new( + elements.into_arrow_scalar_buffer(), + validity.into_arrow()?, + ))) + } + } + }; +} + +impl_primitive!(u8, UInt8Type); +impl_primitive!(u16, UInt16Type); +impl_primitive!(u32, UInt32Type); +impl_primitive!(u64, UInt64Type); +impl_primitive!(i8, Int8Type); +impl_primitive!(i16, Int16Type); +impl_primitive!(i32, Int32Type); +impl_primitive!(i64, Int64Type); +impl_primitive!(f16, Float16Type); +impl_primitive!(f32, Float32Type); +impl_primitive!(f64, Float64Type); diff --git a/vortex-compute/src/arrow/struct_.rs b/vortex-compute/src/arrow/struct_.rs new file mode 100644 index 00000000000..a41426f43d5 --- /dev/null +++ b/vortex-compute/src/arrow/struct_.rs @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use arrow_array::{ArrayRef, StructArray}; +use arrow_schema::{Field, Fields}; +use vortex_error::VortexResult; +use vortex_vector::StructVector; + +use crate::arrow::IntoArrow; + +impl IntoArrow for StructVector { + fn into_arrow(self) -> VortexResult { + let (fields, validity) = self.into_parts(); + let arrow_fields = fields + .iter() + .map(|field| field.clone().into_arrow()) + .collect::>>()?; + + // We need to make up the field names since vectors are unnamed. + let fields = Fields::from( + (0..arrow_fields.len()) + .map(|i| Field::new(i.to_string(), arrow_fields[i].data_type().clone(), true)) + .collect::>(), + ); + + Ok(Arc::new(StructArray::new( + fields, + arrow_fields, + validity.into_arrow()?, + ))) + } +} diff --git a/vortex-compute/src/arrow/varbin.rs b/vortex-compute/src/arrow/varbin.rs new file mode 100644 index 00000000000..e38e8dc33f8 --- /dev/null +++ b/vortex-compute/src/arrow/varbin.rs @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use arrow_array::{ArrayRef, GenericByteViewArray}; +use vortex_buffer::Buffer; +use vortex_error::VortexResult; +use vortex_vector::{BinaryType, StringType, VarBinVector}; + +use crate::arrow::IntoArrow; + +macro_rules! impl_varbin { + ($T:ty, $A:ty) => { + impl IntoArrow for VarBinVector<$T> { + fn into_arrow(self) -> VortexResult { + let (views, buffers, validity) = self.into_parts(); + + let views = Buffer::::from_byte_buffer(views.into_byte_buffer()) + .into_arrow_scalar_buffer(); + let buffers: Vec<_> = buffers + .iter() + .cloned() + .map(|b| b.into_arrow_buffer()) + .collect(); + + // SAFETY: our own guarantees are the same as Arrow's guarantees for BinaryViewArray + let array = unsafe { + GenericByteViewArray::<$A>::new_unchecked( + views, + buffers, + validity.into_arrow()?, + ) + }; + Ok(Arc::new(array)) + } + } + }; +} + +impl_varbin!(BinaryType, arrow_array::types::BinaryViewType); +impl_varbin!(StringType, arrow_array::types::StringViewType); diff --git a/vortex-compute/src/arrow/vector.rs b/vortex-compute/src/arrow/vector.rs new file mode 100644 index 00000000000..e9e888c1dba --- /dev/null +++ b/vortex-compute/src/arrow/vector.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use arrow_array::ArrayRef; +use vortex_error::VortexResult; +use vortex_vector::{Vector, match_each_vector}; + +use crate::arrow::IntoArrow; + +impl IntoArrow for Vector { + fn into_arrow(self) -> VortexResult { + match_each_vector!(self, |v| { v.into_arrow() }) + } +} diff --git a/vortex-compute/src/lib.rs b/vortex-compute/src/lib.rs index 512024d07c8..e4c8ea0455f 100644 --- a/vortex-compute/src/lib.rs +++ b/vortex-compute/src/lib.rs @@ -9,6 +9,8 @@ #[cfg(feature = "arithmetic")] pub mod arithmetic; +#[cfg(feature = "arrow")] +pub mod arrow; #[cfg(feature = "comparison")] pub mod comparison; #[cfg(feature = "filter")] diff --git a/vortex-compute/src/mask/mod.rs b/vortex-compute/src/mask/mod.rs index a88d81a52ce..125c081b367 100644 --- a/vortex-compute/src/mask/mod.rs +++ b/vortex-compute/src/mask/mod.rs @@ -5,11 +5,11 @@ use std::ops::BitAnd; -use vortex_dtype::NativePType; +use vortex_dtype::{NativeDecimalType, NativePType}; use vortex_mask::Mask; use vortex_vector::{ - BoolVector, NullVector, PVector, PrimitiveVector, StructVector, VarBinType, VarBinVector, - Vector, match_each_pvector, match_each_vector, + BoolVector, DVector, DecimalVector, NullVector, PVector, PrimitiveVector, StructVector, + VarBinType, VarBinVector, Vector, match_each_dvector, match_each_pvector, match_each_vector, }; /// Trait for masking the validity of an array or vector. @@ -42,6 +42,20 @@ impl MaskValidity for BoolVector { } } +impl MaskValidity for DecimalVector { + fn mask_validity(self, mask: &Mask) -> Self { + match_each_dvector!(self, |v| { MaskValidity::mask_validity(v, mask).into() }) + } +} + +impl MaskValidity for DVector { + fn mask_validity(self, mask: &Mask) -> Self { + let (ps, elements, validity) = self.into_parts(); + // SAFETY: we are preserving the original elements buffer and only modifying the validity. + unsafe { Self::new_unchecked(ps, elements, validity.bitand(mask)) } + } +} + impl MaskValidity for PrimitiveVector { fn mask_validity(self, mask: &Mask) -> Self { match_each_pvector!(self, |v| { MaskValidity::mask_validity(v, mask).into() }) diff --git a/vortex-datafusion/src/convert/scalars.rs b/vortex-datafusion/src/convert/scalars.rs index 8558167c758..930741d9542 100644 --- a/vortex-datafusion/src/convert/scalars.rs +++ b/vortex-datafusion/src/convert/scalars.rs @@ -8,7 +8,7 @@ use vortex::buffer::ByteBuffer; use vortex::dtype::datetime::arrow::make_temporal_ext_dtype; use vortex::dtype::datetime::{TemporalMetadata, TimeUnit, is_temporal_ext_type}; use vortex::dtype::half::f16; -use vortex::dtype::{DECIMAL128_MAX_PRECISION, DType, DecimalDType, Nullability, PType}; +use vortex::dtype::{DType, DecimalDType, NativeDecimalType, Nullability, PType}; use vortex::error::{VortexResult, vortex_bail}; use vortex::scalar::{DecimalValue, Scalar, i256}; @@ -40,7 +40,7 @@ impl TryToDataFusion for Scalar { let precision = decimal_type.precision(); let scale = decimal_type.scale(); - if precision <= DECIMAL128_MAX_PRECISION { + if precision <= i128::MAX_PRECISION { match dscalar.decimal_value() { None => ScalarValue::Decimal128(None, precision, scale), Some(DecimalValue::I128(v128)) => { diff --git a/vortex-dtype/src/arbitrary/mod.rs b/vortex-dtype/src/arbitrary/mod.rs index 4302aae0b03..4da6a03ac0a 100644 --- a/vortex-dtype/src/arbitrary/mod.rs +++ b/vortex-dtype/src/arbitrary/mod.rs @@ -7,8 +7,8 @@ use arbitrary::{Arbitrary, Result, Unstructured}; use vortex_error::VortexExpect; use crate::{ - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DType, DecimalDType, FieldName, FieldNames, - Nullability, PType, StructFields, + DType, DecimalDType, FieldName, FieldNames, NativeDecimalType, Nullability, PType, + StructFields, i256, }; mod decimal; @@ -94,8 +94,8 @@ impl<'a> Arbitrary<'a> for DecimalDType { )] fn arbitrary(u: &mut Unstructured<'a>) -> Result { // Get a random integer for the scale - let precision = u.int_in_range(1..=DECIMAL256_MAX_PRECISION)?; - let scale = u.int_in_range(-DECIMAL256_MAX_SCALE..=(precision as i8))?; + let precision = u.int_in_range(1..=i256::MAX_PRECISION)?; + let scale = u.int_in_range(-i256::MAX_SCALE..=(precision as i8))?; Ok(Self::new(precision, scale)) } } diff --git a/vortex-scalar/src/arbitrary/decimal.rs b/vortex-dtype/src/decimal/max_precision.rs similarity index 89% rename from vortex-scalar/src/arbitrary/decimal.rs rename to vortex-dtype/src/decimal/max_precision.rs index d88df4b8889..4313e125396 100644 --- a/vortex-scalar/src/arbitrary/decimal.rs +++ b/vortex-dtype/src/decimal/max_precision.rs @@ -1,117 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use arbitrary::{Result, Unstructured}; -use vortex_dtype::{DECIMAL128_MAX_PRECISION, DecimalDType, i256}; +//! Lookup tables for minumum/maximum i256 decimal values for each precision. +//! We cannot perform const computations for i256, so we precompute these values. -use crate::scalar_value::InnerScalarValue; -use crate::{DecimalValue, ScalarValue}; +use crate::i256; -/// Generate an arbitrary decimal scalar confined to the given bounds of precision and scale. -pub fn random_decimal(u: &mut Unstructured, decimal_type: &DecimalDType) -> Result { - let precision = decimal_type.precision(); - if precision <= DECIMAL128_MAX_PRECISION { - Ok(ScalarValue(InnerScalarValue::Decimal(DecimalValue::I128( - u.int_in_range( - MIN_DECIMAL128_FOR_EACH_PRECISION[precision as usize] - ..=MAX_DECIMAL128_FOR_EACH_PRECISION[precision as usize], - )?, - )))) - } else { - Ok(ScalarValue(InnerScalarValue::Decimal(DecimalValue::I256( - u.int_in_range( - MIN_DECIMAL256_FOR_EACH_PRECISION[precision as usize] - ..=MAX_DECIMAL256_FOR_EACH_PRECISION[precision as usize], - )?, - )))) - } -} - -const MAX_DECIMAL128_FOR_EACH_PRECISION: [i128; 39] = [ - 0, // unused first element - 9, - 99, - 999, - 9999, - 99999, - 999999, - 9999999, - 99999999, - 999999999, - 9999999999, - 99999999999, - 999999999999, - 9999999999999, - 99999999999999, - 999999999999999, - 9999999999999999, - 99999999999999999, - 999999999999999999, - 9999999999999999999, - 99999999999999999999, - 999999999999999999999, - 9999999999999999999999, - 99999999999999999999999, - 999999999999999999999999, - 9999999999999999999999999, - 99999999999999999999999999, - 999999999999999999999999999, - 9999999999999999999999999999, - 99999999999999999999999999999, - 999999999999999999999999999999, - 9999999999999999999999999999999, - 99999999999999999999999999999999, - 999999999999999999999999999999999, - 9999999999999999999999999999999999, - 99999999999999999999999999999999999, - 999999999999999999999999999999999999, - 9999999999999999999999999999999999999, - 99999999999999999999999999999999999999, -]; - -const MIN_DECIMAL128_FOR_EACH_PRECISION: [i128; 39] = [ - 0, // unused first element - -9, - -99, - -999, - -9999, - -99999, - -999999, - -9999999, - -99999999, - -999999999, - -9999999999, - -99999999999, - -999999999999, - -9999999999999, - -99999999999999, - -999999999999999, - -9999999999999999, - -99999999999999999, - -999999999999999999, - -9999999999999999999, - -99999999999999999999, - -999999999999999999999, - -9999999999999999999999, - -99999999999999999999999, - -999999999999999999999999, - -9999999999999999999999999, - -99999999999999999999999999, - -999999999999999999999999999, - -9999999999999999999999999999, - -99999999999999999999999999999, - -999999999999999999999999999999, - -9999999999999999999999999999999, - -99999999999999999999999999999999, - -999999999999999999999999999999999, - -9999999999999999999999999999999999, - -99999999999999999999999999999999999, - -999999999999999999999999999999999999, - -9999999999999999999999999999999999999, - -99999999999999999999999999999999999999, -]; - -const MAX_DECIMAL256_FOR_EACH_PRECISION: [i256; 77] = [ +pub(super) const MAX_DECIMAL256_FOR_EACH_PRECISION: [i256; 77] = [ i256::ZERO, // unused first element i256::from_le_bytes([ 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -419,7 +314,7 @@ const MAX_DECIMAL256_FOR_EACH_PRECISION: [i256; 77] = [ ]), ]; -const MIN_DECIMAL256_FOR_EACH_PRECISION: [i256; 77] = [ +pub(super) const MIN_DECIMAL256_FOR_EACH_PRECISION: [i256; 77] = [ i256::ZERO, // unused first element i256::from_le_bytes([ 247, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, diff --git a/vortex-dtype/src/decimal/mod.rs b/vortex-dtype/src/decimal/mod.rs index 723082c9d2e..cce3f667780 100644 --- a/vortex-dtype/src/decimal/mod.rs +++ b/vortex-dtype/src/decimal/mod.rs @@ -1,29 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +mod max_precision; +mod precision; mod types; + use std::fmt::{Display, Formatter}; use num_traits::ToPrimitive; +pub use precision::*; pub use types::*; use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_panic}; -use crate::DType; - -/// Maximum precision for a Decimal128 type from Arrow -pub const DECIMAL128_MAX_PRECISION: u8 = 38; - -/// Maximum precision for a Decimal256 type from Arrow -pub const DECIMAL256_MAX_PRECISION: u8 = 76; +use crate::{DType, i256}; -/// Maximum scale for a Decimal128 type from Arrow -pub const DECIMAL128_MAX_SCALE: i8 = 38; - -/// Maximum scale for a Decimal256 type from Arrow -pub const DECIMAL256_MAX_SCALE: i8 = 76; - -const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION; -const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE; +const MAX_PRECISION: u8 = ::MAX_PRECISION; +const MAX_SCALE: i8 = ::MAX_SCALE; /// Parameters that define the precision and scale of a decimal type. /// @@ -253,20 +245,6 @@ mod tests { assert_eq!(decimal.scale(), -5); } - #[test] - fn test_decimal128_boundaries() { - let decimal = DecimalDType::new(DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE); - assert_eq!(decimal.precision(), 38); - assert_eq!(decimal.scale(), 38); - } - - #[test] - fn test_decimal256_boundaries() { - let decimal = DecimalDType::new(DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE); - assert_eq!(decimal.precision(), 76); - assert_eq!(decimal.scale(), 76); - } - #[test] fn test_required_bit_width() { // Test common decimal precisions diff --git a/vortex-dtype/src/decimal/precision.rs b/vortex-dtype/src/decimal/precision.rs new file mode 100644 index 00000000000..7fe44b201b3 --- /dev/null +++ b/vortex-dtype/src/decimal/precision.rs @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::any::type_name; +use std::fmt::Display; +use std::marker::PhantomData; + +use vortex_error::{VortexExpect, VortexResult, vortex_bail}; + +use crate::{DecimalDType, NativeDecimalType}; + +/// A struct representing the precision and scale of a decimal type, to be represented +/// by the native type `D`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PrecisionScale { + precision: u8, + scale: i8, + phantom: PhantomData, +} + +impl Display for PrecisionScale { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "decimal({}, p={}, s={})", + type_name::(), + self.precision, + self.scale + ) + } +} + +impl PrecisionScale { + /// Create a new [`PrecisionScale`] with the given precision and scale. + /// + /// # Panics + /// + /// Panics if the precision/scale are invalid. + pub fn new(precision: u8, scale: i8) -> Self { + Self::try_new(precision, scale).vortex_expect("Failed to create `PrecisionScale`") + } + + /// Try to create a new [`PrecisionScale`] with the given precision and scale. + pub fn try_new(precision: u8, scale: i8) -> VortexResult { + if precision == 0 { + vortex_bail!( + "precision cannot be 0, has to be between [1, {}]", + D::MAX_PRECISION + ); + } + if precision > D::MAX_PRECISION { + vortex_bail!( + "Precision {} is greater than max {}", + precision, + D::MAX_PRECISION + ); + } + if scale > D::MAX_SCALE { + vortex_bail!("Scale {} is greater than max {}", scale, D::MAX_SCALE); + } + if scale > 0 && scale as u8 > precision { + vortex_bail!("Scale {} is greater than precision {}", scale, precision); + } + Ok(Self { + precision, + scale, + phantom: Default::default(), + }) + } + + /// Create a new [`PrecisionScale`] with the given precision and scale without validation. + /// + /// # Safety + /// + /// The caller must ensure that the precision and scale are valid. + pub unsafe fn new_unchecked(precision: u8, scale: i8) -> Self { + if cfg!(debug_assertions) { + Self::new(precision, scale) + } else { + Self { + precision, + scale, + phantom: Default::default(), + } + } + } + + /// The precision is the number of significant figures that the decimal tracks. + #[inline(always)] + pub fn precision(&self) -> u8 { + self.precision + } + + /// The scale is the maximum number of digits relative to the decimal point. + #[inline(always)] + pub fn scale(&self) -> i8 { + self.scale + } + + /// Validate whether a given value of type `D` fits within the precision and scale. + #[inline] + pub fn is_valid(&self, value: D) -> bool { + self.precision <= D::MAX_PRECISION + && value >= D::MIN_BY_PRECISION[self.precision as usize] + && value <= D::MAX_BY_PRECISION[self.precision as usize] + } +} + +impl From> for DecimalDType { + fn from(value: PrecisionScale) -> Self { + DecimalDType { + precision: value.precision, + scale: value.scale, + } + } +} + +impl TryFrom<&DecimalDType> for PrecisionScale { + type Error = vortex_error::VortexError; + + fn try_from(value: &DecimalDType) -> VortexResult { + PrecisionScale::try_new(value.precision, value.scale) + } +} diff --git a/vortex-dtype/src/decimal/types.rs b/vortex-dtype/src/decimal/types.rs index e0ae8de2862..f21a6e230e7 100644 --- a/vortex-dtype/src/decimal/types.rs +++ b/vortex-dtype/src/decimal/types.rs @@ -4,8 +4,12 @@ use std::fmt::{Debug, Display}; use std::panic::RefUnwindSafe; +use num_traits::{ConstOne, ConstZero}; use paste::paste; +use crate::decimal::max_precision::{ + MAX_DECIMAL256_FOR_EACH_PRECISION, MIN_DECIMAL256_FOR_EACH_PRECISION, +}; use crate::{BigCast, i256}; /// Type of the decimal values. @@ -51,6 +55,16 @@ pub trait NativeDecimalType: /// The decimal value type corresponding to this native type. const DECIMAL_TYPE: DecimalType; + /// The maximum precision supported by this decimal type. + const MAX_PRECISION: u8; + /// The maximum scale supported by this decimal type. + const MAX_SCALE: i8; + + /// The minimum value for each precision supported by this decimal type. + const MIN_BY_PRECISION: &'static [Self]; + /// The maximum value for each precision supported by this decimal type. + const MAX_BY_PRECISION: &'static [Self]; + /// Downcast the provided object to a type-specific instance. fn downcast(visitor: V) -> V::Output; @@ -102,6 +116,40 @@ macro_rules! impl_decimal { impl NativeDecimalType for $T { const DECIMAL_TYPE: DecimalType = DecimalType::$UPPER; + const MAX_PRECISION: u8 = match DecimalType::$UPPER { + DecimalType::I8 => 2, + DecimalType::I16 => 4, + DecimalType::I32 => 9, + DecimalType::I64 => 18, + DecimalType::I128 => 38, + DecimalType::I256 => 76, + }; + const MAX_SCALE: i8 = Self::MAX_PRECISION as i8; + + const MIN_BY_PRECISION: &'static [Self] = &{ + let mut mins = [$T::ZERO; Self::MAX_PRECISION as usize]; + let mut p = $T::ONE; + let mut i = 0; + while i < Self::MAX_PRECISION as usize { + p = p * 10; + mins[i] = -(p - 1); + i += 1; + } + mins + }; + + const MAX_BY_PRECISION: &'static [Self] = &{ + let mut maxs = [$T::ZERO; Self::MAX_PRECISION as usize]; + let mut p = $T::ONE; + let mut i = 0; + while i < Self::MAX_PRECISION as usize { + p = p * 10; + maxs[i] = p - 1; + i += 1; + } + maxs + }; + #[inline] fn downcast(visitor: V) -> V::Output { paste::paste! { visitor.[]() } @@ -121,4 +169,19 @@ impl_decimal!(i16, I16); impl_decimal!(i32, I32); impl_decimal!(i64, I64); impl_decimal!(i128, I128); -impl_decimal!(i256, I256); + +impl NativeDecimalType for i256 { + const DECIMAL_TYPE: DecimalType = DecimalType::I256; + const MAX_PRECISION: u8 = 76; + const MAX_SCALE: i8 = 76; + const MIN_BY_PRECISION: &'static [Self] = &MIN_DECIMAL256_FOR_EACH_PRECISION; + const MAX_BY_PRECISION: &'static [Self] = &MAX_DECIMAL256_FOR_EACH_PRECISION; + + fn downcast(visitor: V) -> V::Output { + visitor.into_i256() + } + + fn upcast(input: V::Input) -> V { + V::from_i256(input) + } +} diff --git a/vortex-scalar/src/arbitrary/mod.rs b/vortex-scalar/src/arbitrary.rs similarity index 76% rename from vortex-scalar/src/arbitrary/mod.rs rename to vortex-scalar/src/arbitrary.rs index e66bfe415ce..0fed180c04a 100644 --- a/vortex-scalar/src/arbitrary/mod.rs +++ b/vortex-scalar/src/arbitrary.rs @@ -6,18 +6,15 @@ //! This module provides functions to generate arbitrary scalar values of various data types. //! It is used by the fuzzer to test the correctness of the scalar value implementation. -mod decimal; - use std::iter; use std::sync::Arc; use arbitrary::{Result, Unstructured}; -pub use decimal::random_decimal; use vortex_buffer::{BufferString, ByteBuffer}; use vortex_dtype::half::f16; -use vortex_dtype::{DType, PType}; +use vortex_dtype::{DType, DecimalDType, NativeDecimalType, PType, i256}; -use crate::{InnerScalarValue, PValue, Scalar, ScalarValue}; +use crate::{DecimalValue, InnerScalarValue, PValue, Scalar, ScalarValue}; /// Generate an arbitrary scalar value of the given data type. pub fn random_scalar(u: &mut Unstructured, dtype: &DType) -> Result { @@ -81,3 +78,23 @@ fn random_pvalue(u: &mut Unstructured, ptype: &PType) -> Result { PType::F64 => PValue::F64(u.arbitrary()?), }) } + +/// Generate an arbitrary decimal scalar confined to the given bounds of precision and scale. +pub fn random_decimal(u: &mut Unstructured, decimal_type: &DecimalDType) -> Result { + let precision = decimal_type.precision(); + if precision <= i128::MAX_PRECISION { + Ok(ScalarValue(InnerScalarValue::Decimal(DecimalValue::I128( + u.int_in_range( + i128::MIN_BY_PRECISION[precision as usize] + ..=i128::MAX_BY_PRECISION[precision as usize], + )?, + )))) + } else { + Ok(ScalarValue(InnerScalarValue::Decimal(DecimalValue::I256( + u.int_in_range( + i256::MIN_BY_PRECISION[precision as usize] + ..=i256::MAX_BY_PRECISION[precision as usize], + )?, + )))) + } +} diff --git a/vortex-scalar/src/scalar.rs b/vortex-scalar/src/scalar.rs index 2ac9790fdab..bfc3b27432a 100644 --- a/vortex-scalar/src/scalar.rs +++ b/vortex-scalar/src/scalar.rs @@ -6,7 +6,7 @@ use std::hash::Hash; use std::sync::Arc; use vortex_buffer::Buffer; -use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, NativeDType, Nullability, i256}; +use vortex_dtype::{DType, NativeDType, NativeDecimalType, Nullability, i256}; use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err}; use super::*; @@ -163,7 +163,7 @@ impl Scalar { DType::Bool(_) => 1, DType::Primitive(ptype, _) => ptype.byte_width(), DType::Decimal(dt, _) => { - if dt.precision() <= DECIMAL128_MAX_PRECISION { + if dt.precision() <= i128::MAX_PRECISION { size_of::() } else { size_of::() diff --git a/vortex-scalar/src/tests/primitives.rs b/vortex-scalar/src/tests/primitives.rs index fade750debf..722669d4562 100644 --- a/vortex-scalar/src/tests/primitives.rs +++ b/vortex-scalar/src/tests/primitives.rs @@ -8,7 +8,7 @@ mod tests { use std::sync::Arc; use vortex_buffer::ByteBuffer; - use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType}; + use vortex_dtype::{DType, ExtDType, ExtID, NativeDecimalType, Nullability, PType}; use vortex_utils::aliases::hash_set::HashSet; use crate::{InnerScalarValue, PValue, Scalar, ScalarValue}; @@ -131,14 +131,14 @@ mod tests { #[test] fn test_decimal_nbytes() { - use vortex_dtype::{DECIMAL128_MAX_PRECISION, DecimalDType}; + use vortex_dtype::DecimalDType; use crate::decimal::DecimalValue; // Test decimal with precision <= 38 (should use i128 = 16 bytes) let decimal_low_precision = Scalar::decimal( DecimalValue::I128(123456789), - DecimalDType::new(DECIMAL128_MAX_PRECISION, 2), // precision 38 + DecimalDType::new(i128::MAX_PRECISION, 2), // precision 38 Nullability::NonNullable, ); assert_eq!( @@ -150,7 +150,7 @@ mod tests { // Test decimal with precision > 38 (should use i256 = 32 bytes) let decimal_high_precision = Scalar::decimal( DecimalValue::I128(123456789), - DecimalDType::new(DECIMAL128_MAX_PRECISION + 1, 2), // precision 39 + DecimalDType::new(i128::MAX_PRECISION + 1, 2), // precision 39 Nullability::NonNullable, ); assert_eq!( diff --git a/vortex-vector/src/decimal/generic.rs b/vortex-vector/src/decimal/generic.rs new file mode 100644 index 00000000000..f00302bca08 --- /dev/null +++ b/vortex-vector/src/decimal/generic.rs @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_buffer::Buffer; +use vortex_dtype::{NativeDecimalType, PrecisionScale}; +use vortex_error::{VortexExpect, VortexResult, vortex_bail}; +use vortex_mask::Mask; + +use crate::{DVectorMut, VectorOps}; + +/// A specifically typed decimal vector. +#[derive(Debug, Clone)] +pub struct DVector { + pub(super) ps: PrecisionScale, + pub(super) elements: Buffer, + pub(super) validity: Mask, +} + +impl DVector { + /// Try to create a new decimal vector from the given elements and validity. + /// + /// # Errors + /// + /// Returns an error if the precision/scale is invalid, the lengths of the elements + /// and validity do not match, or any of the elements are out of bounds for the given + /// precision/scale. + pub fn try_new( + ps: PrecisionScale, + elements: Buffer, + validity: Mask, + ) -> VortexResult { + if elements.len() != validity.len() { + vortex_bail!( + "Elements length {} does not match validity length {}", + elements.len(), + validity.len() + ); + } + + // We assert that each element is within bounds for the given precision/scale. + if !elements.iter().all(|e| ps.is_valid(*e)) { + vortex_bail!( + "One or more elements are out of bounds for precision {} and scale {}", + ps.precision(), + ps.scale() + ); + } + + Ok(Self { + ps, + elements, + validity, + }) + } + + /// Create a new decimal vector from the given elements and validity without validation. + /// + /// # Safety + /// + /// The caller must ensure that the precision/scale is valid, the lengths of the elements + /// and validity match, and all the elements are within bounds for the given precision/scale. + pub unsafe fn new_unchecked( + ps: PrecisionScale, + elements: Buffer, + validity: Mask, + ) -> Self { + if cfg!(debug_assertions) { + Self::try_new(ps, elements, validity).vortex_expect("Failed to create `DVector`") + } else { + Self { + ps, + elements, + validity, + } + } + } + + /// Get the precision/scale of the decimal vector. + pub fn precision_scale(&self) -> PrecisionScale { + self.ps + } + + /// Decomposes the decimal vector into its constituent parts (precision/scale, buffer and validity). + pub fn into_parts(self) -> (PrecisionScale, Buffer, Mask) { + (self.ps, self.elements, self.validity) + } +} + +impl VectorOps for DVector { + type Mutable = DVectorMut; + + fn len(&self) -> usize { + self.elements.len() + } + + fn validity(&self) -> &Mask { + &self.validity + } + + fn try_into_mut(self) -> Result + where + Self: Sized, + { + let elements = match self.elements.try_into_mut() { + Ok(elements) => elements, + Err(elements) => { + return Err(DVector { + ps: self.ps, + elements, + validity: self.validity, + }); + } + }; + + match self.validity.try_into_mut() { + Ok(validity_mut) => Ok(DVectorMut { + ps: self.ps, + elements, + validity: validity_mut, + }), + Err(validity) => Err(DVector { + ps: self.ps, + elements: elements.freeze(), + validity, + }), + } + } +} diff --git a/vortex-vector/src/decimal/generic_mut.rs b/vortex-vector/src/decimal/generic_mut.rs new file mode 100644 index 00000000000..b550f0f2b86 --- /dev/null +++ b/vortex-vector/src/decimal/generic_mut.rs @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_buffer::BufferMut; +use vortex_dtype::{NativeDecimalType, PrecisionScale}; +use vortex_mask::MaskMut; + +use crate::{DVector, VectorMutOps, VectorOps}; + +/// A specifically typed mutable decimal vector. +#[derive(Debug, Clone)] +pub struct DVectorMut { + pub(super) ps: PrecisionScale, + pub(super) elements: BufferMut, + pub(super) validity: MaskMut, +} + +impl VectorMutOps for DVectorMut { + type Immutable = DVector; + + fn len(&self) -> usize { + self.elements.len() + } + + fn capacity(&self) -> usize { + self.elements.capacity() + } + + fn reserve(&mut self, additional: usize) { + self.elements.reserve(additional); + self.validity.reserve(additional); + } + + fn extend_from_vector(&mut self, other: &Self::Immutable) { + self.elements.extend_from_slice(&other.elements); + self.validity.append_mask(other.validity()); + } + + fn append_nulls(&mut self, n: usize) { + self.elements.extend((0..n).map(|_| D::default())); + self.validity.append_n(false, n); + } + + fn freeze(self) -> Self::Immutable { + DVector { + ps: self.ps, + elements: self.elements.freeze(), + validity: self.validity.freeze(), + } + } + + fn split_off(&mut self, at: usize) -> Self { + DVectorMut { + ps: self.ps, + elements: self.elements.split_off(at), + validity: self.validity.split_off(at), + } + } + + fn unsplit(&mut self, other: Self) { + self.elements.unsplit(other.elements); + self.validity.unsplit(other.validity); + } +} diff --git a/vortex-vector/src/decimal/generic_mut_impl.rs b/vortex-vector/src/decimal/generic_mut_impl.rs new file mode 100644 index 00000000000..ff787c892c8 --- /dev/null +++ b/vortex-vector/src/decimal/generic_mut_impl.rs @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_dtype::{NativeDecimalType, PrecisionScale}; +use vortex_error::{VortexResult, vortex_bail}; + +use crate::DVectorMut; + +impl DVectorMut { + /// Get the precision/scale of the decimal vector. + pub fn precision_scale(&self) -> PrecisionScale { + self.ps + } + + /// Get a nullable element at the given index. + pub fn get(&self, index: usize) -> Option<&D> { + self.validity.value(index).then(|| &self.elements[index]) + } + + /// Appends a new element to the end of the vector. + /// + /// # Errors + /// + /// Returns an error if the value is out of bounds for the vector's precision/scale. + pub fn try_push(&mut self, value: D) -> VortexResult<()> { + if !self.ps.is_valid(value) { + vortex_bail!("Value {:?} is out of bounds for {}", value, self.ps,); + } + self.elements.push(value); + self.validity.append_n(true, 1); + Ok(()) + } + + /// Returns a mutable reference to the underlying elements buffer. + /// + /// # SAFETY + /// + /// Modifying the elements buffer directly may violate the precision/scale constraints. + /// The caller must ensure that any modifications maintain these invariants. + pub unsafe fn elements_mut(&mut self) -> &mut [D] { + &mut self.elements + } +} + +impl AsRef<[D]> for DVectorMut { + fn as_ref(&self) -> &[D] { + &self.elements + } +} diff --git a/vortex-vector/src/decimal/macros.rs b/vortex-vector/src/decimal/macros.rs new file mode 100644 index 00000000000..f95070df7e5 --- /dev/null +++ b/vortex-vector/src/decimal/macros.rs @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Helper macros for working with the different variants of [`crate::DecimalVector`] and +//! [`crate::DecimalVectorMut`]. + +/// Matches on all decimal type variants of [`crate::DecimalVector`] and executes the same code for +/// each variant branch. +/// +/// This macro eliminates repetitive match statements when implementing operations that need to work +/// uniformly across all decimal type variants (`D8`, `D16`, `D32`, `D64`, `D128`, `D256`). +#[macro_export] +macro_rules! match_each_dvector { + ($self:expr, | $vec:ident | $body:block) => {{ + match $self { + $crate::DecimalVector::D8($vec) => $body, + $crate::DecimalVector::D16($vec) => $body, + $crate::DecimalVector::D32($vec) => $body, + $crate::DecimalVector::D64($vec) => $body, + $crate::DecimalVector::D128($vec) => $body, + $crate::DecimalVector::D256($vec) => $body, + } + }}; +} + +/// Matches on all decimal type variants of [`crate::DecimalVectorMut`] and executes the same code +/// for each variant branch. +/// +/// This macro eliminates repetitive match statements when implementing mutable operations that need +/// to work uniformly across all decimal type variants (`D8`, `D16`, `D32`, `D64`, `D128`, `D256`). +#[macro_export] +macro_rules! match_each_dvector_mut { + ($self:expr, | $vec:ident | $body:block) => {{ + match $self { + $crate::DecimalVectorMut::D8($vec) => $body, + $crate::DecimalVectorMut::D16($vec) => $body, + $crate::DecimalVectorMut::D32($vec) => $body, + $crate::DecimalVectorMut::D64($vec) => $body, + $crate::DecimalVectorMut::D128($vec) => $body, + $crate::DecimalVectorMut::D256($vec) => $body, + } + }}; +} diff --git a/vortex-vector/src/decimal/mod.rs b/vortex-vector/src/decimal/mod.rs new file mode 100644 index 00000000000..c5ea20920c2 --- /dev/null +++ b/vortex-vector/src/decimal/mod.rs @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod generic; +mod generic_mut; +mod generic_mut_impl; +mod macros; +mod vector; +mod vector_mut; + +pub use generic::*; +pub use generic_mut::*; +pub use vector::*; +pub use vector_mut::*; +use vortex_dtype::NativeDecimalType; + +use crate::{Vector, VectorMut}; + +impl From for Vector { + fn from(v: DecimalVector) -> Self { + Self::Decimal(v) + } +} + +impl From> for Vector { + fn from(v: DVector) -> Self { + Self::Decimal(DecimalVector::from(v)) + } +} + +impl From> for DecimalVector { + fn from(value: DVector) -> Self { + D::upcast(value) + } +} + +impl From for VectorMut { + fn from(v: DecimalVectorMut) -> Self { + Self::Decimal(v) + } +} + +impl From> for DecimalVectorMut { + fn from(val: DVectorMut) -> Self { + D::upcast(val) + } +} + +impl From> for VectorMut { + fn from(val: DVectorMut) -> Self { + Self::Decimal(DecimalVectorMut::from(val)) + } +} diff --git a/vortex-vector/src/decimal/vector.rs b/vortex-vector/src/decimal/vector.rs new file mode 100644 index 00000000000..e95a4785e97 --- /dev/null +++ b/vortex-vector/src/decimal/vector.rs @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_dtype::{DecimalTypeDowncast, DecimalTypeUpcast, NativeDecimalType, i256}; +use vortex_error::vortex_panic; +use vortex_mask::Mask; + +use crate::decimal::DVector; +use crate::{DecimalVectorMut, VectorOps, match_each_dvector}; + +/// An enum over all supported decimal mutable vector types. +#[derive(Clone, Debug)] +pub enum DecimalVector { + /// A decimal vector with 8-bit integer representation. + D8(DVector), + /// A decimal vector with 16-bit integer representation. + D16(DVector), + /// A decimal vector with 32-bit integer representation. + D32(DVector), + /// A decimal vector with 64-bit integer representation. + D64(DVector), + /// A decimal vector with 128-bit integer representation. + D128(DVector), + /// A decimal vector with 256-bit integer representation. + D256(DVector), +} + +impl VectorOps for DecimalVector { + type Mutable = DecimalVectorMut; + + fn len(&self) -> usize { + match_each_dvector!(self, |v| { v.len() }) + } + + fn validity(&self) -> &Mask { + match_each_dvector!(self, |v| { v.validity() }) + } + + fn try_into_mut(self) -> Result + where + Self: Sized, + { + match_each_dvector!(self, |v| { + v.try_into_mut() + .map(DecimalVectorMut::from) + .map_err(DecimalVector::from) + }) + } +} + +impl DecimalTypeDowncast for DecimalVector { + type Output = DVector; + + fn into_i8(self) -> Self::Output { + if let DecimalVector::D8(vec) = self { + return vec; + } + vortex_panic!("DecimalVector is not of type D8"); + } + + fn into_i16(self) -> Self::Output { + if let DecimalVector::D16(vec) = self { + return vec; + } + vortex_panic!("DecimalVector is not of type D16"); + } + + fn into_i32(self) -> Self::Output { + if let DecimalVector::D32(vec) = self { + return vec; + } + vortex_panic!("DecimalVector is not of type D32"); + } + + fn into_i64(self) -> Self::Output { + if let DecimalVector::D64(vec) = self { + return vec; + } + vortex_panic!("DecimalVector is not of type D64"); + } + + fn into_i128(self) -> Self::Output { + if let DecimalVector::D128(vec) = self { + return vec; + } + vortex_panic!("DecimalVector is not of type D128"); + } + + fn into_i256(self) -> Self::Output { + if let DecimalVector::D256(vec) = self { + return vec; + } + vortex_panic!("DecimalVector is not of type D256"); + } +} + +impl DecimalTypeUpcast for DecimalVector { + type Input = DVector; + + fn from_i8(input: Self::Input) -> Self { + DecimalVector::D8(input) + } + + fn from_i16(input: Self::Input) -> Self { + DecimalVector::D16(input) + } + + fn from_i32(input: Self::Input) -> Self { + DecimalVector::D32(input) + } + + fn from_i64(input: Self::Input) -> Self { + DecimalVector::D64(input) + } + + fn from_i128(input: Self::Input) -> Self { + DecimalVector::D128(input) + } + + fn from_i256(input: Self::Input) -> Self { + DecimalVector::D256(input) + } +} diff --git a/vortex-vector/src/decimal/vector_mut.rs b/vortex-vector/src/decimal/vector_mut.rs new file mode 100644 index 00000000000..4c283012312 --- /dev/null +++ b/vortex-vector/src/decimal/vector_mut.rs @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_dtype::{DecimalTypeDowncast, DecimalTypeUpcast, NativeDecimalType, i256}; +use vortex_error::vortex_panic; + +use crate::decimal::DVectorMut; +use crate::{DecimalVector, VectorMutOps, match_each_dvector_mut}; + +/// An enum over all supported decimal mutable vector types. +#[derive(Clone, Debug)] +pub enum DecimalVectorMut { + /// A decimal vector with 8-bit integer representation. + D8(DVectorMut), + /// A decimal vector with 16-bit integer representation. + D16(DVectorMut), + /// A decimal vector with 32-bit integer representation. + D32(DVectorMut), + /// A decimal vector with 64-bit integer representation. + D64(DVectorMut), + /// A decimal vector with 128-bit integer representation. + D128(DVectorMut), + /// A decimal vector with 256-bit integer representation. + D256(DVectorMut), +} + +impl VectorMutOps for DecimalVectorMut { + type Immutable = DecimalVector; + + fn len(&self) -> usize { + match_each_dvector_mut!(self, |d| { d.len() }) + } + + fn capacity(&self) -> usize { + match_each_dvector_mut!(self, |d| { d.capacity() }) + } + + fn reserve(&mut self, additional: usize) { + match_each_dvector_mut!(self, |d| { d.reserve(additional) }) + } + + fn extend_from_vector(&mut self, other: &Self::Immutable) { + match (self, other) { + (DecimalVectorMut::D8(s), DecimalVector::D8(o)) => s.extend_from_vector(o), + (DecimalVectorMut::D16(s), DecimalVector::D16(o)) => s.extend_from_vector(o), + (DecimalVectorMut::D32(s), DecimalVector::D32(o)) => s.extend_from_vector(o), + (DecimalVectorMut::D64(s), DecimalVector::D64(o)) => s.extend_from_vector(o), + (DecimalVectorMut::D128(s), DecimalVector::D128(o)) => s.extend_from_vector(o), + (DecimalVectorMut::D256(s), DecimalVector::D256(o)) => s.extend_from_vector(o), + _ => vortex_panic!("Mismatched decimal vector types in extend_from_vector"), + } + } + + fn append_nulls(&mut self, n: usize) { + match_each_dvector_mut!(self, |d| { d.append_nulls(n) }) + } + + fn freeze(self) -> Self::Immutable { + match_each_dvector_mut!(self, |d| { d.freeze().into() }) + } + + fn split_off(&mut self, at: usize) -> Self { + match_each_dvector_mut!(self, |d| { d.split_off(at).into() }) + } + + fn unsplit(&mut self, other: Self) { + match (self, other) { + (DecimalVectorMut::D8(s), DecimalVectorMut::D8(o)) => s.unsplit(o), + (DecimalVectorMut::D16(s), DecimalVectorMut::D16(o)) => s.unsplit(o), + (DecimalVectorMut::D32(s), DecimalVectorMut::D32(o)) => s.unsplit(o), + (DecimalVectorMut::D64(s), DecimalVectorMut::D64(o)) => s.unsplit(o), + (DecimalVectorMut::D128(s), DecimalVectorMut::D128(o)) => s.unsplit(o), + (DecimalVectorMut::D256(s), DecimalVectorMut::D256(o)) => s.unsplit(o), + _ => vortex_panic!("Mismatched decimal vector types in unsplit"), + } + } +} + +impl DecimalTypeDowncast for DecimalVectorMut { + type Output = DVectorMut; + + fn into_i8(self) -> Self::Output { + if let DecimalVectorMut::D8(vec) = self { + return vec; + } + vortex_panic!("DecimalVectorMut is not of type D8"); + } + + fn into_i16(self) -> Self::Output { + if let DecimalVectorMut::D16(vec) = self { + return vec; + } + vortex_panic!("DecimalVectorMut is not of type D16"); + } + + fn into_i32(self) -> Self::Output { + if let DecimalVectorMut::D32(vec) = self { + return vec; + } + vortex_panic!("DecimalVectorMut is not of type D32"); + } + + fn into_i64(self) -> Self::Output { + if let DecimalVectorMut::D64(vec) = self { + return vec; + } + vortex_panic!("DecimalVectorMut is not of type D64"); + } + + fn into_i128(self) -> Self::Output { + if let DecimalVectorMut::D128(vec) = self { + return vec; + } + vortex_panic!("DecimalVectorMut is not of type D128"); + } + + fn into_i256(self) -> Self::Output { + if let DecimalVectorMut::D256(vec) = self { + return vec; + } + vortex_panic!("DecimalVectorMut is not of type D256"); + } +} + +impl DecimalTypeUpcast for DecimalVectorMut { + type Input = DVectorMut; + + fn from_i8(input: Self::Input) -> Self { + DecimalVectorMut::D8(input) + } + + fn from_i16(input: Self::Input) -> Self { + DecimalVectorMut::D16(input) + } + + fn from_i32(input: Self::Input) -> Self { + DecimalVectorMut::D32(input) + } + + fn from_i64(input: Self::Input) -> Self { + DecimalVectorMut::D64(input) + } + + fn from_i128(input: Self::Input) -> Self { + DecimalVectorMut::D128(input) + } + + fn from_i256(input: Self::Input) -> Self { + DecimalVectorMut::D256(input) + } +} diff --git a/vortex-vector/src/lib.rs b/vortex-vector/src/lib.rs index aa61f2dc919..4a94540e788 100644 --- a/vortex-vector/src/lib.rs +++ b/vortex-vector/src/lib.rs @@ -11,12 +11,14 @@ #![deny(clippy::missing_safety_doc)] mod bool; +mod decimal; mod null; mod primitive; mod struct_; mod varbin; pub use bool::*; +pub use decimal::*; pub use null::*; pub use primitive::*; pub use struct_::*; diff --git a/vortex-vector/src/macros.rs b/vortex-vector/src/macros.rs index 59f11c89b20..7aeb66c87df 100644 --- a/vortex-vector/src/macros.rs +++ b/vortex-vector/src/macros.rs @@ -41,6 +41,7 @@ macro_rules! match_each_vector { match $self { $crate::Vector::Null($vec) => $body, $crate::Vector::Bool($vec) => $body, + $crate::Vector::Decimal($vec) => $body, $crate::Vector::Primitive($vec) => $body, $crate::Vector::String($vec) => $body, $crate::Vector::Binary($vec) => $body, @@ -86,6 +87,7 @@ macro_rules! match_each_vector_mut { match $self { $crate::VectorMut::Null($vec) => $body, $crate::VectorMut::Bool($vec) => $body, + $crate::VectorMut::Decimal($vec) => $body, $crate::VectorMut::Primitive($vec) => $body, $crate::VectorMut::String($vec) => $body, $crate::VectorMut::Binary($vec) => $body, diff --git a/vortex-vector/src/primitive/pvector_impl.rs b/vortex-vector/src/primitive/generic_mut_impl.rs similarity index 100% rename from vortex-vector/src/primitive/pvector_impl.rs rename to vortex-vector/src/primitive/generic_mut_impl.rs diff --git a/vortex-vector/src/primitive/mod.rs b/vortex-vector/src/primitive/mod.rs index a411d5f213c..ef26a2147b0 100644 --- a/vortex-vector/src/primitive/mod.rs +++ b/vortex-vector/src/primitive/mod.rs @@ -23,8 +23,8 @@ pub use generic_mut::PVectorMut; mod vector; pub use vector::PrimitiveVector; +mod generic_mut_impl; mod iter; -mod pvector_impl; mod vector_mut; pub use vector_mut::PrimitiveVectorMut; @@ -40,6 +40,12 @@ impl From for Vector { } } +impl From> for PrimitiveVector { + fn from(v: PVector) -> Self { + T::upcast(v) + } +} + impl From> for Vector { fn from(v: PVector) -> Self { Self::Primitive(PrimitiveVector::from(v)) @@ -52,6 +58,12 @@ impl From for VectorMut { } } +impl From> for PrimitiveVectorMut { + fn from(v: PVectorMut) -> Self { + T::upcast(v) + } +} + impl From> for VectorMut { fn from(val: PVectorMut) -> Self { Self::Primitive(PrimitiveVectorMut::from(val)) diff --git a/vortex-vector/src/primitive/vector.rs b/vortex-vector/src/primitive/vector.rs index 835adaa8754..fed3798ab42 100644 --- a/vortex-vector/src/primitive/vector.rs +++ b/vortex-vector/src/primitive/vector.rs @@ -85,12 +85,6 @@ impl VectorOps for PrimitiveVector { } } -impl From> for PrimitiveVector { - fn from(v: PVector) -> Self { - T::upcast(v) - } -} - impl PTypeUpcast for PrimitiveVector { type Input = PVector; diff --git a/vortex-vector/src/primitive/vector_mut.rs b/vortex-vector/src/primitive/vector_mut.rs index a77bba9c954..2590bc8cde5 100644 --- a/vortex-vector/src/primitive/vector_mut.rs +++ b/vortex-vector/src/primitive/vector_mut.rs @@ -143,12 +143,6 @@ impl VectorMutOps for PrimitiveVectorMut { } } -impl From> for PrimitiveVectorMut { - fn from(v: PVectorMut) -> Self { - T::upcast(v) - } -} - impl PTypeUpcast for PrimitiveVectorMut { type Input = PVectorMut; diff --git a/vortex-vector/src/private.rs b/vortex-vector/src/private.rs index 22c668e98d2..fcd9fb66a2f 100644 --- a/vortex-vector/src/private.rs +++ b/vortex-vector/src/private.rs @@ -8,7 +8,7 @@ //! usage, which gives us the freedom to add new trait methods in the future without breaking //! backward compatibility. -use vortex_dtype::NativePType; +use vortex_dtype::{NativeDecimalType, NativePType}; use crate::*; @@ -24,6 +24,11 @@ impl Sealed for NullVectorMut {} impl Sealed for BoolVector {} impl Sealed for BoolVectorMut {} +impl Sealed for DecimalVector {} +impl Sealed for DecimalVectorMut {} +impl Sealed for DVector {} +impl Sealed for DVectorMut {} + impl Sealed for PrimitiveVector {} impl Sealed for PrimitiveVectorMut {} impl Sealed for PVector {} diff --git a/vortex-vector/src/vector.rs b/vortex-vector/src/vector.rs index 5d6700a209d..69183b80155 100644 --- a/vortex-vector/src/vector.rs +++ b/vortex-vector/src/vector.rs @@ -8,9 +8,9 @@ use vortex_error::vortex_panic; -use crate::varbin::{BinaryVector, StringVector}; use crate::{ - BoolVector, NullVector, PrimitiveVector, StructVector, VectorMut, VectorOps, match_each_vector, + BinaryVector, BoolVector, DecimalVector, NullVector, PrimitiveVector, StringVector, + StructVector, VectorMut, VectorOps, match_each_vector, }; /// An enum over all kinds of immutable vectors, which represent fully decompressed (canonical) @@ -28,13 +28,13 @@ pub enum Vector { Null(NullVector), /// Boolean vectors. Bool(BoolVector), + /// Decimal + Decimal(DecimalVector), /// Primitive vectors. /// /// Note that [`PrimitiveVector`] is an enum over the different possible (generic) /// [`PVector`](crate::PVector)s. See the documentation for more information. Primitive(PrimitiveVector), - // Decimal - // Decimal(DecimalVector), /// String vectors String(StringVector), /// Binary vectors diff --git a/vortex-vector/src/vector_mut.rs b/vortex-vector/src/vector_mut.rs index 6dbe0fa5519..f52029765be 100644 --- a/vortex-vector/src/vector_mut.rs +++ b/vortex-vector/src/vector_mut.rs @@ -11,8 +11,8 @@ use vortex_error::vortex_panic; use crate::varbin::{BinaryVectorMut, StringVectorMut}; use crate::{ - BoolVectorMut, NullVectorMut, PrimitiveVectorMut, StructVectorMut, Vector, VectorMutOps, - match_each_vector_mut, match_vector_pair, + BoolVectorMut, DecimalVectorMut, NullVectorMut, PrimitiveVectorMut, StructVectorMut, Vector, + VectorMutOps, match_each_vector_mut, match_vector_pair, }; /// An enum over all kinds of mutable vectors, which represent fully decompressed (canonical) array @@ -30,6 +30,8 @@ pub enum VectorMut { Null(NullVectorMut), /// Mutable Boolean vectors. Bool(BoolVectorMut), + /// Mutable Decimal vectors + Decimal(DecimalVectorMut), /// Mutable Primitive vectors. /// /// Note that [`PrimitiveVectorMut`] is an enum over the different possible (generic)