diff --git a/Cargo.lock b/Cargo.lock index a3c9caff8e6..0b5c91ede3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5000,6 +5000,7 @@ dependencies = [ name = "vortex-datetime-parts" version = "0.21.1" dependencies = [ + "rstest", "serde", "vortex-array", "vortex-buffer", diff --git a/encodings/datetime-parts/Cargo.toml b/encodings/datetime-parts/Cargo.toml index 04530e0482a..f45d57e07ca 100644 --- a/encodings/datetime-parts/Cargo.toml +++ b/encodings/datetime-parts/Cargo.toml @@ -26,4 +26,5 @@ vortex-error = { workspace = true } vortex-scalar = { workspace = true } [dev-dependencies] +rstest = { workspace = true } vortex-array = { workspace = true, features = ["test-harness"] } diff --git a/encodings/datetime-parts/src/array.rs b/encodings/datetime-parts/src/array.rs index 9ee15a1caa6..735b03fa8d0 100644 --- a/encodings/datetime-parts/src/array.rs +++ b/encodings/datetime-parts/src/array.rs @@ -8,15 +8,10 @@ use vortex_array::stats::StatsSet; use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity, ValidityVTable}; use vortex_array::variants::{ExtensionArrayTrait, VariantsVTable}; use vortex_array::visitor::{ArrayVisitor, VisitorVTable}; -use vortex_array::{ - impl_encoding, ArrayDType, ArrayData, ArrayLen, ArrayTrait, Canonical, IntoArrayData, - IntoCanonical, -}; +use vortex_array::{impl_encoding, ArrayDType, ArrayData, ArrayLen, ArrayTrait, IntoArrayData}; use vortex_dtype::{DType, PType}; use vortex_error::{vortex_bail, VortexExpect as _, VortexResult, VortexUnwrap}; -use crate::compute::decode_to_temporal; - impl_encoding!("vortex.datetimeparts", ids::DATE_TIME_PARTS, DateTimeParts); #[derive(Clone, Debug, Serialize, Deserialize)] @@ -136,12 +131,6 @@ impl ExtensionArrayTrait for DateTimePartsArray { } } -impl IntoCanonical for DateTimePartsArray { - fn into_canonical(self) -> VortexResult { - Ok(Canonical::Extension(decode_to_temporal(&self)?.into())) - } -} - impl ValidityVTable for DateTimePartsEncoding { fn is_valid(&self, array: &DateTimePartsArray, index: usize) -> bool { array.validity().is_valid(index) diff --git a/encodings/datetime-parts/src/canonical.rs b/encodings/datetime-parts/src/canonical.rs new file mode 100644 index 00000000000..4f594cf8e70 --- /dev/null +++ b/encodings/datetime-parts/src/canonical.rs @@ -0,0 +1,146 @@ +use vortex_array::array::{PrimitiveArray, TemporalArray}; +use vortex_array::compute::try_cast; +use vortex_array::{ + ArrayDType, Canonical, IntoArrayData as _, IntoArrayVariant as _, IntoCanonical, +}; +use vortex_buffer::BufferMut; +use vortex_datetime_dtype::{TemporalMetadata, TimeUnit}; +use vortex_dtype::Nullability::NonNullable; +use vortex_dtype::{DType, PType}; +use vortex_error::{vortex_bail, VortexExpect as _, VortexResult}; +use vortex_scalar::PrimitiveScalar; + +use crate::DateTimePartsArray; + +impl IntoCanonical for DateTimePartsArray { + fn into_canonical(self) -> VortexResult { + Ok(Canonical::Extension(decode_to_temporal(&self)?.into())) + } +} + +/// Decode an [ArrayData] into a [TemporalArray]. +/// +/// Enforces that the passed array is actually a [DateTimePartsArray] with proper metadata. +pub fn decode_to_temporal(array: &DateTimePartsArray) -> VortexResult { + let DType::Extension(ext) = array.dtype().clone() else { + vortex_bail!(ComputeError: "expected dtype to be DType::Extension variant") + }; + + let Ok(temporal_metadata) = TemporalMetadata::try_from(ext.as_ref()) else { + vortex_bail!(ComputeError: "must decode TemporalMetadata from extension metadata"); + }; + + let divisor = match temporal_metadata.time_unit() { + TimeUnit::Ns => 1_000_000_000, + TimeUnit::Us => 1_000_000, + TimeUnit::Ms => 1_000, + TimeUnit::S => 1, + TimeUnit::D => vortex_bail!(InvalidArgument: "cannot decode into TimeUnit::D"), + }; + + let days_buf = try_cast( + array.days(), + &DType::Primitive(PType::I64, array.dtype().nullability()), + )? + .into_primitive()?; + + // We start with the days component, which is always present. + // And then add the seconds and subseconds components. + // We split this into separate passes because often the seconds and/org subsecond components + // are constant. + let mut values: BufferMut = days_buf + .into_buffer_mut::() + .map_each(|d| d * 86_400 * divisor); + + if let Some(seconds) = array.seconds().as_constant() { + let seconds = + PrimitiveScalar::try_from(&seconds.cast(&DType::Primitive(PType::I64, NonNullable))?)? + .typed_value::() + .vortex_expect("non-nullable"); + let seconds = seconds * divisor; + for v in values.iter_mut() { + *v += seconds; + } + } else { + let seconds_buf = try_cast(array.seconds(), &DType::Primitive(PType::U32, NonNullable))? + .into_primitive()?; + for (v, second) in values.iter_mut().zip(seconds_buf.as_slice::()) { + *v += (*second as i64) * divisor; + } + } + + if let Some(subseconds) = array.subsecond().as_constant() { + let subseconds = PrimitiveScalar::try_from( + &subseconds.cast(&DType::Primitive(PType::I64, NonNullable))?, + )? + .typed_value::() + .vortex_expect("non-nullable"); + for v in values.iter_mut() { + *v += subseconds; + } + } else { + let subsecond_buf = try_cast( + array.subsecond(), + &DType::Primitive(PType::I64, NonNullable), + )? + .into_primitive()?; + for (v, subsecond) in values.iter_mut().zip(subsecond_buf.as_slice::()) { + *v += *subsecond; + } + } + + Ok(TemporalArray::new_timestamp( + PrimitiveArray::new(values.freeze(), array.validity()).into_array(), + temporal_metadata.time_unit(), + temporal_metadata.time_zone().map(ToString::to_string), + )) +} + +#[cfg(test)] +mod test { + use rstest::rstest; + use vortex_array::array::{PrimitiveArray, TemporalArray}; + use vortex_array::validity::Validity; + use vortex_array::{IntoArrayData as _, IntoArrayVariant}; + use vortex_buffer::buffer; + use vortex_datetime_dtype::TimeUnit; + + use crate::canonical::decode_to_temporal; + use crate::DateTimePartsArray; + + #[rstest] + #[case(Validity::NonNullable)] + #[case(Validity::AllValid)] + #[case(Validity::AllInvalid)] + #[case(Validity::from_iter([true, false, true]))] + fn test_decode_to_temporal(#[case] validity: Validity) { + let milliseconds = PrimitiveArray::new( + buffer![ + 86_400i64, // element with only day component + 86_400i64 + 1000, // element with day + second components + 86_400i64 + 1000 + 1, // element with day + second + sub-second components + ], + validity.clone(), + ); + let date_times = DateTimePartsArray::try_from(TemporalArray::new_timestamp( + milliseconds.clone().into_array(), + TimeUnit::Ms, + Some("UTC".to_string()), + )) + .unwrap(); + + assert_eq!(date_times.validity(), validity); + + let primitive_values = decode_to_temporal(&date_times) + .unwrap() + .temporal_values() + .into_primitive() + .unwrap(); + + assert_eq!( + primitive_values.as_slice::(), + milliseconds.as_slice::() + ); + assert_eq!(primitive_values.validity(), validity); + } +} diff --git a/encodings/datetime-parts/src/compress.rs b/encodings/datetime-parts/src/compress.rs index aa01912524e..24076cf78f1 100644 --- a/encodings/datetime-parts/src/compress.rs +++ b/encodings/datetime-parts/src/compress.rs @@ -4,7 +4,9 @@ use vortex_array::{ArrayDType as _, ArrayData, ArrayLen, IntoArrayData, IntoArra use vortex_buffer::BufferMut; use vortex_datetime_dtype::TimeUnit; use vortex_dtype::{DType, PType}; -use vortex_error::{vortex_bail, VortexResult}; +use vortex_error::{vortex_bail, VortexError, VortexResult}; + +use crate::DateTimePartsArray; pub struct TemporalParts { pub days: ArrayData, @@ -52,3 +54,62 @@ pub fn split_temporal(array: TemporalArray) -> VortexResult { subseconds: subsecond.into_array(), }) } + +impl TryFrom for DateTimePartsArray { + type Error = VortexError; + + fn try_from(array: TemporalArray) -> Result { + let ext_dtype = array.ext_dtype(); + let TemporalParts { + days, + seconds, + subseconds, + } = split_temporal(array)?; + DateTimePartsArray::try_new(DType::Extension(ext_dtype), days, seconds, subseconds) + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_array::array::{PrimitiveArray, TemporalArray}; + use vortex_array::validity::Validity; + use vortex_array::{IntoArrayData as _, IntoArrayVariant as _}; + use vortex_buffer::buffer; + use vortex_datetime_dtype::TimeUnit; + + use crate::{split_temporal, TemporalParts}; + + #[rstest] + #[case(Validity::NonNullable)] + #[case(Validity::AllValid)] + #[case(Validity::AllInvalid)] + #[case(Validity::from_iter([true, false, true]))] + fn test_split_temporal(#[case] validity: Validity) { + let milliseconds = PrimitiveArray::new( + buffer![ + 86_400i64, // element with only day component + 86_400i64 + 1000, // element with day + second components + 86_400i64 + 1000 + 1, // element with day + second + sub-second components + ], + validity.clone(), + ) + .into_array(); + let temporal_array = + TemporalArray::new_timestamp(milliseconds, TimeUnit::Ms, Some("UTC".to_string())); + let TemporalParts { + days, + seconds, + subseconds, + } = split_temporal(temporal_array).unwrap(); + assert_eq!(days.into_primitive().unwrap().validity(), validity); + assert_eq!( + seconds.into_primitive().unwrap().validity(), + Validity::NonNullable + ); + assert_eq!( + subseconds.into_primitive().unwrap().validity(), + Validity::NonNullable + ); + } +} diff --git a/encodings/datetime-parts/src/compute/cast.rs b/encodings/datetime-parts/src/compute/cast.rs new file mode 100644 index 00000000000..b8f45521060 --- /dev/null +++ b/encodings/datetime-parts/src/compute/cast.rs @@ -0,0 +1,103 @@ +use vortex_array::compute::{try_cast, CastFn}; +use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::{DateTimePartsArray, DateTimePartsEncoding}; + +impl CastFn for DateTimePartsEncoding { + fn cast(&self, array: &DateTimePartsArray, dtype: &DType) -> VortexResult { + if !array.dtype().eq_ignore_nullability(dtype) { + vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype); + }; + + Ok(DateTimePartsArray::try_new( + dtype.clone(), + try_cast( + array.days().as_ref(), + &array.days().dtype().with_nullability(dtype.nullability()), + )?, + array.seconds(), + array.subsecond(), + )? + .into_array()) + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_array::array::{PrimitiveArray, TemporalArray}; + use vortex_array::compute::try_cast; + use vortex_array::validity::Validity; + use vortex_array::{ArrayDType as _, ArrayData, IntoArrayData as _}; + use vortex_buffer::buffer; + use vortex_datetime_dtype::TimeUnit; + use vortex_dtype::{DType, Nullability}; + + use crate::DateTimePartsArray; + + fn date_time_array(validity: Validity) -> ArrayData { + DateTimePartsArray::try_from(TemporalArray::new_timestamp( + PrimitiveArray::new( + buffer![ + 86_400i64, // element with only day component + 86_400i64 + 1000, // element with day + second components + 86_400i64 + 1000 + 1, // element with day + second + sub-second components + ], + validity, + ) + .into_array(), + TimeUnit::Ms, + Some("UTC".to_string()), + )) + .unwrap() + .into_array() + } + + #[rstest] + #[case(Validity::NonNullable, Nullability::Nullable)] + #[case(Validity::AllValid, Nullability::Nullable)] + #[case(Validity::AllInvalid, Nullability::Nullable)] + #[case(Validity::from_iter([true, false, true]), Nullability::Nullable)] + #[case(Validity::NonNullable, Nullability::NonNullable)] + #[case(Validity::AllValid, Nullability::NonNullable)] + #[case(Validity::from_iter([true, true, true]), Nullability::Nullable)] + fn test_cast_to_compatibile_nullability( + #[case] validity: Validity, + #[case] cast_to_nullability: Nullability, + ) { + let array = date_time_array(validity); + let new_dtype = array.dtype().with_nullability(cast_to_nullability); + let result = try_cast(&array, &new_dtype); + assert!(result.is_ok(), "{:?}", result); + assert_eq!(result.unwrap().dtype(), &new_dtype); + } + + #[rstest] + #[case(Validity::AllInvalid)] + #[case(Validity::from_iter([true, false, true]))] + fn test_bad_cast_fails(#[case] validity: Validity) { + let array = date_time_array(validity); + let result = try_cast(&array, &DType::Bool(Nullability::NonNullable)); + assert!( + result + .as_ref() + .is_err_and(|err| err.to_string().contains("cannot cast from")), + "{:?}", + result + ); + + let result = try_cast( + &array, + &array.dtype().with_nullability(Nullability::NonNullable), + ); + assert!( + result.as_ref().is_err_and(|err| err + .to_string() + .contains("invalid cast from nullable to non-nullable")), + "{:?}", + result + ); + } +} diff --git a/encodings/datetime-parts/src/compute/mod.rs b/encodings/datetime-parts/src/compute/mod.rs index de3937c1e8a..1f971315b31 100644 --- a/encodings/datetime-parts/src/compute/mod.rs +++ b/encodings/datetime-parts/src/compute/mod.rs @@ -1,22 +1,25 @@ +mod cast; mod filter; mod take; -use vortex_array::array::{PrimitiveArray, TemporalArray}; use vortex_array::compute::{ - scalar_at, slice, try_cast, ComputeVTable, FilterFn, ScalarAtFn, SliceFn, TakeFn, + scalar_at, slice, CastFn, ComputeVTable, FilterFn, ScalarAtFn, SliceFn, TakeFn, }; use vortex_array::validity::ArrayValidity; -use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; -use vortex_buffer::BufferMut; +use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; use vortex_datetime_dtype::{TemporalMetadata, TimeUnit}; use vortex_dtype::Nullability::{NonNullable, Nullable}; use vortex_dtype::{DType, PType}; -use vortex_error::{vortex_bail, VortexExpect, VortexResult}; -use vortex_scalar::{PrimitiveScalar, Scalar}; +use vortex_error::{vortex_bail, VortexResult}; +use vortex_scalar::Scalar; use crate::{DateTimePartsArray, DateTimePartsEncoding}; impl ComputeVTable for DateTimePartsEncoding { + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } @@ -92,161 +95,3 @@ impl ScalarAtFn for DateTimePartsEncoding { Ok(Scalar::extension(ext, Scalar::from(scalar))) } } - -/// Decode an [ArrayData] into a [TemporalArray]. -/// -/// Enforces that the passed array is actually a [DateTimePartsArray] with proper metadata. -pub fn decode_to_temporal(array: &DateTimePartsArray) -> VortexResult { - let DType::Extension(ext) = array.dtype().clone() else { - vortex_bail!(ComputeError: "expected dtype to be DType::Extension variant") - }; - - let Ok(temporal_metadata) = TemporalMetadata::try_from(ext.as_ref()) else { - vortex_bail!(ComputeError: "must decode TemporalMetadata from extension metadata"); - }; - - let divisor = match temporal_metadata.time_unit() { - TimeUnit::Ns => 1_000_000_000, - TimeUnit::Us => 1_000_000, - TimeUnit::Ms => 1_000, - TimeUnit::S => 1, - TimeUnit::D => vortex_bail!(InvalidArgument: "cannot decode into TimeUnit::D"), - }; - - let days_buf = try_cast( - array.days(), - &DType::Primitive(PType::I64, array.dtype().nullability()), - )? - .into_primitive()?; - - // We start with the days component, which is always present. - // And then add the seconds and subseconds components. - // We split this into separate passes because often the seconds and/org subsecond components - // are constant. - let mut values: BufferMut = days_buf - .into_buffer_mut::() - .map_each(|d| d * 86_400 * divisor); - - if let Some(seconds) = array.seconds().as_constant() { - let seconds = - PrimitiveScalar::try_from(&seconds.cast(&DType::Primitive(PType::I64, NonNullable))?)? - .typed_value::() - .vortex_expect("non-nullable"); - let seconds = seconds * divisor; - for v in values.iter_mut() { - *v += seconds; - } - } else { - let seconds_buf = try_cast(array.seconds(), &DType::Primitive(PType::U32, NonNullable))? - .into_primitive()?; - for (v, second) in values.iter_mut().zip(seconds_buf.as_slice::()) { - *v += (*second as i64) * divisor; - } - } - - if let Some(subseconds) = array.subsecond().as_constant() { - let subseconds = PrimitiveScalar::try_from( - &subseconds.cast(&DType::Primitive(PType::I64, NonNullable))?, - )? - .typed_value::() - .vortex_expect("non-nullable"); - for v in values.iter_mut() { - *v += subseconds; - } - } else { - let subsecond_buf = try_cast( - array.subsecond(), - &DType::Primitive(PType::I64, NonNullable), - )? - .into_primitive()?; - for (v, subsecond) in values.iter_mut().zip(subsecond_buf.as_slice::()) { - *v += *subsecond; - } - } - - Ok(TemporalArray::new_timestamp( - PrimitiveArray::new(values.freeze(), array.validity()).into_array(), - temporal_metadata.time_unit(), - temporal_metadata.time_zone().map(ToString::to_string), - )) -} - -#[cfg(test)] -mod test { - use vortex_array::array::{PrimitiveArray, TemporalArray}; - use vortex_array::validity::Validity; - use vortex_array::{IntoArrayVariant, ToArrayData}; - use vortex_buffer::Buffer; - use vortex_datetime_dtype::TimeUnit; - use vortex_dtype::DType; - - use crate::compute::decode_to_temporal; - use crate::{split_temporal, DateTimePartsArray, TemporalParts}; - - #[test] - fn test_roundtrip_datetimeparts() { - let raw_values = vec![ - 86_400i64, // element with only day component - 86_400i64 + 1000, // element with day + second components - 86_400i64 + 1000 + 1, // element with day + second + sub-second components - ]; - - do_roundtrip_test(&raw_values, Validity::NonNullable); - do_roundtrip_test(&raw_values, Validity::AllValid); - do_roundtrip_test(&raw_values, Validity::AllInvalid); - do_roundtrip_test(&raw_values, Validity::from_iter([true, false, true])); - } - - fn do_roundtrip_test(raw_values: &[i64], validity: Validity) { - let raw_millis = PrimitiveArray::new(Buffer::copy_from(raw_values), validity.clone()); - assert_eq!(raw_millis.validity(), validity); - - let temporal_array = TemporalArray::new_timestamp( - raw_millis.to_array(), - TimeUnit::Ms, - Some("UTC".to_string()), - ); - assert_eq!( - temporal_array - .temporal_values() - .into_primitive() - .unwrap() - .validity(), - validity - ); - - let TemporalParts { - days, - seconds, - subseconds, - } = split_temporal(temporal_array.clone()).unwrap(); - assert_eq!(days.clone().into_primitive().unwrap().validity(), validity); - assert_eq!( - seconds.clone().into_primitive().unwrap().validity(), - Validity::NonNullable - ); - assert_eq!( - subseconds.clone().into_primitive().unwrap().validity(), - Validity::NonNullable - ); - assert_eq!(validity, raw_millis.validity()); - - let date_times = DateTimePartsArray::try_new( - DType::Extension(temporal_array.ext_dtype()), - days, - seconds, - subseconds, - ) - .unwrap(); - assert_eq!(date_times.validity(), validity); - - let primitive_values = decode_to_temporal(&date_times) - .unwrap() - .temporal_values() - .into_primitive() - .unwrap(); - - assert_eq!(primitive_values.as_slice::(), raw_values); - assert_eq!(primitive_values.validity(), validity); - } -} diff --git a/encodings/datetime-parts/src/lib.rs b/encodings/datetime-parts/src/lib.rs index ef9a8711c46..648db210467 100644 --- a/encodings/datetime-parts/src/lib.rs +++ b/encodings/datetime-parts/src/lib.rs @@ -2,6 +2,7 @@ pub use array::*; pub use compress::*; mod array; +mod canonical; mod compress; mod compute; mod stats; diff --git a/vortex-array/src/array/datetime/test.rs b/vortex-array/src/array/datetime/test.rs index b7f0510cbac..18ad51f2b23 100644 --- a/vortex-array/src/array/datetime/test.rs +++ b/vortex-array/src/array/datetime/test.rs @@ -1,7 +1,9 @@ +use rstest::rstest; use vortex_buffer::buffer; use vortex_datetime_dtype::{TemporalMetadata, TimeUnit}; use crate::array::{PrimitiveArray, TemporalArray}; +use crate::validity::Validity; use crate::{IntoArrayData, IntoArrayVariant}; macro_rules! test_temporal_roundtrip { @@ -138,3 +140,30 @@ fn test_timestamp_fails_i32() { let _ = TemporalArray::new_timestamp(ts_array, TimeUnit::S, None); } + +#[rstest] +#[case(Validity::NonNullable)] +#[case(Validity::AllValid)] +#[case(Validity::AllInvalid)] +#[case(Validity::from_iter([true, false, true]))] +fn test_validity_preservation(#[case] validity: Validity) { + let milliseconds = PrimitiveArray::new( + buffer![ + 86_400i64, // element with only day component + 86_400i64 + 1000, // element with day + second components + 86_400i64 + 1000 + 1, // element with day + second + sub-second components + ], + validity.clone(), + ) + .into_array(); + let temporal_array = + TemporalArray::new_timestamp(milliseconds, TimeUnit::Ms, Some("UTC".to_string())); + assert_eq!( + temporal_array + .temporal_values() + .into_primitive() + .unwrap() + .validity(), + validity + ); +}