From cf784bacefe12b8db74fdaae6955bb888caa523d Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 15 Jan 2025 13:41:17 -0500 Subject: [PATCH 01/15] chore: all scalars support cast --- vortex-scalar/src/binary.rs | 11 +++++- vortex-scalar/src/bool.rs | 12 +++--- vortex-scalar/src/extension.rs | 13 ++++++- vortex-scalar/src/lib.rs | 71 ++++++++++++++++++++-------------- vortex-scalar/src/list.rs | 17 +++++++- vortex-scalar/src/primitive.rs | 19 +++++---- vortex-scalar/src/utf8.rs | 11 +++++- vortex-scalar/src/value.rs | 50 +++++++++++++++++++++++- 8 files changed, 147 insertions(+), 57 deletions(-) diff --git a/vortex-scalar/src/binary.rs b/vortex-scalar/src/binary.rs index f09f76d803c..b72e27fd08f 100644 --- a/vortex-scalar/src/binary.rs +++ b/vortex-scalar/src/binary.rs @@ -20,8 +20,15 @@ impl<'a> BinaryScalar<'a> { self.value.as_ref().cloned() } - pub fn cast(&self, _dtype: &DType) -> VortexResult { - todo!() + pub fn cast(&self, dtype: &DType) -> VortexResult { + Ok(match (&self.value, dtype) { + (Some(b), DType::Binary(_)) => Scalar::new( + dtype.clone(), + ScalarValue(InnerScalarValue::Buffer(b.clone())), + ), + (None, DType::Binary(Nullability::Nullable)) => Scalar::null(dtype.clone()), + (value, dtype) => vortex_bail!("Can't cast {:?} to {}", value, dtype), + }) } } diff --git a/vortex-scalar/src/bool.rs b/vortex-scalar/src/bool.rs index 7f85e22e19c..366b72e612e 100644 --- a/vortex-scalar/src/bool.rs +++ b/vortex-scalar/src/bool.rs @@ -21,13 +21,11 @@ impl<'a> BoolScalar<'a> { } pub fn cast(&self, dtype: &DType) -> VortexResult { - match dtype { - DType::Bool(_) => Ok(Scalar::bool( - self.value().ok_or_else(|| vortex_err!("not a bool"))?, - dtype.nullability(), - )), - _ => vortex_bail!("Can't cast {} to bool", dtype), - } + Ok(match (self.value, dtype) { + (Some(b), DType::Bool(_)) => Scalar::new(dtype.clone(), ScalarValue::from(b)), + (None, DType::Bool(Nullability::Nullable)) => Scalar::null(dtype.clone()), + (value, dtype) => vortex_bail!("Can't cast {:?} to {}", value, dtype), + }) } pub fn invert(self) -> BoolScalar<'a> { diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs index e1cc6920853..3e60a69d5b4 100644 --- a/vortex-scalar/src/extension.rs +++ b/vortex-scalar/src/extension.rs @@ -35,8 +35,17 @@ impl<'a> ExtScalar<'a> { Scalar::new(storage_dtype, self.value.clone()) } - pub fn cast(&self, _dtype: &DType) -> VortexResult { - todo!() + pub fn cast(&self, dtype: &DType) -> VortexResult { + if self.dtype().eq_ignore_nullability(dtype) { + if self.dtype.is_nullable() && dtype.is_nullable() && self.value.is_null() { + vortex_bail!("cannot cast null value to {}", dtype); + } + // ScalarValue::cast must reject casting _to_ an extension type because it does not know + // its own type. + return Ok(Scalar::new(dtype.clone(), self.value.clone())); + } + + Ok(Scalar::new(dtype.clone(), self.value.cast(dtype)?)) } } diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index e8297cfaf10..1a886f41e17 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -101,39 +101,16 @@ impl Scalar { } pub fn cast(&self, dtype: &DType) -> VortexResult { - if self.is_null() && !dtype.is_nullable() { - vortex_bail!("Can't cast null scalar to non-nullable type") - } - if self.dtype().eq_ignore_nullability(dtype) { - return Ok(Scalar { - dtype: dtype.clone(), - value: self.value.clone(), - }); - } - - match dtype { - DType::Null => vortex_bail!("Can't cast non-null to null"), - DType::Bool(_) => BoolScalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::Primitive(..) => PrimitiveScalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::Utf8(_) => Utf8Scalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::Binary(_) => BinaryScalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::Struct(..) => StructScalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::List(..) => ListScalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::Extension(ext_dtype) => { - if !self.value().is_instance_of(ext_dtype.storage_dtype()) { - vortex_bail!( - "Failed to cast scalar to extension dtype with storage type {:?}, found {:?}", - ext_dtype.storage_dtype(), - self.dtype() - ); - } - Ok(Scalar::extension( - ext_dtype.clone(), - self.cast(ext_dtype.storage_dtype())?, - )) + if self.dtype.is_nullable() && dtype.is_nullable() && self.value.is_null() { + vortex_bail!("cannot cast null value to {}", dtype); } + // ScalarValue::cast must reject casting _to_ an extension type because it does not know + // its own type. + return Ok(Scalar::new(dtype.clone(), self.value.clone())); } + + Ok(Scalar::new(dtype.clone(), self.value.cast(dtype)?)) } pub fn into_nullable(self) -> Self { @@ -296,3 +273,37 @@ from_vec_for_scalar!(f64); from_vec_for_scalar!(String); from_vec_for_scalar!(BufferString); from_vec_for_scalar!(ByteBuffer); + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use vortex_dtype::{DType, ExtDType, ExtID, Nullability}; + + use crate::{InnerScalarValue, Scalar, ScalarValue}; + + #[test] + fn cast_from_extension_types() { + let apples = ExtDType::new( + ExtID::new(Arc::from("apples")), + Arc::from(DType::Bool(Nullability::NonNullable)), + None, + ); + let scalar = Scalar::new( + DType::Extension(Arc::from(apples.clone())), + ScalarValue(InnerScalarValue::Bool(true)), + ); + + let inner = scalar.cast(scalar.dtype()).unwrap(); + assert_eq!(inner.dtype(), scalar.dtype()); + + let inner = scalar.cast(&scalar.dtype().as_nullable()).unwrap(); + assert_eq!(inner.dtype(), &scalar.dtype().as_nullable()); + + let inner = scalar.cast(apples.storage_dtype()).unwrap(); + assert_eq!(inner.dtype(), apples.storage_dtype()); + + let inner = scalar.cast(&apples.storage_dtype().as_nullable()).unwrap(); + assert_eq!(inner.dtype(), &apples.storage_dtype().as_nullable()); + } +} diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/list.rs index 9d493e4ea0e..5cc1bda6d91 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/list.rs @@ -1,6 +1,7 @@ use std::ops::Deref; use std::sync::Arc; +use itertools::Itertools as _; use vortex_dtype::{DType, Nullability}; use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexResult}; @@ -65,8 +66,20 @@ impl<'a> ListScalar<'a> { }) } - pub fn cast(&self, _dtype: &DType) -> VortexResult { - todo!() + pub fn cast(&self, dtype: &DType) -> VortexResult { + Ok(match (&self.elements, dtype) { + (Some(elements), DType::List(element_dtype, _)) => Scalar::new( + dtype.clone(), + ScalarValue(InnerScalarValue::List( + elements + .iter() + .map(|element| element.cast(element_dtype)) + .process_results(|iter| iter.collect())?, + )), + ), + (None, DType::List(_, Nullability::Nullable)) => Scalar::null(dtype.clone()), + (value, dtype) => vortex_bail!("Can't cast {:?} to {}", value, dtype), + }) } } diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index ad719d7f05b..25559a33bd7 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -1,7 +1,7 @@ use std::any::type_name; use std::fmt::{Debug, Display}; -use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, NumCast}; +use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive}; use vortex_dtype::half::f16; use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability, PType}; use vortex_error::{ @@ -68,15 +68,14 @@ impl<'a> PrimitiveScalar<'a> { } pub fn cast(&self, dtype: &DType) -> VortexResult { - let ptype = PType::try_from(dtype)?; - match_each_native_ptype!(ptype, |$Q| { - match_each_native_ptype!(self.ptype(), |$T| { - Ok(Scalar::primitive::<$Q>( - <$Q as NumCast>::from(self.typed_value::<$T>().expect("Invalid value")) - .ok_or_else(|| vortex_err!("Can't cast {} scalar {} to {}", self.ptype, self.typed_value::<$T>().expect("Invalid value"), dtype))?, - dtype.nullability(), - )) - }) + Ok(match (&self.pvalue, dtype) { + (Some(v), DType::Primitive(ptype, _)) => { + match_each_native_ptype!(ptype, |$Q| { + Scalar::primitive(v.as_primitive::<$Q>()?, dtype.nullability()) + }) + } + (None, DType::Primitive(_, Nullability::Nullable)) => Scalar::null(dtype.clone()), + (value, dtype) => vortex_bail!("Can't cast {:?} to {}", value, dtype), }) } diff --git a/vortex-scalar/src/utf8.rs b/vortex-scalar/src/utf8.rs index cc353db3c36..98def4137f6 100644 --- a/vortex-scalar/src/utf8.rs +++ b/vortex-scalar/src/utf8.rs @@ -21,8 +21,15 @@ impl<'a> Utf8Scalar<'a> { self.value.as_ref().cloned() } - pub fn cast(&self, _dtype: &DType) -> VortexResult { - todo!() + pub fn cast(&self, dtype: &DType) -> VortexResult { + Ok(match (&self.value, dtype) { + (Some(bufstr), DType::Utf8(_)) => Scalar::new( + dtype.clone(), + ScalarValue(InnerScalarValue::BufferString(bufstr.clone())), + ), + (None, DType::Utf8(Nullability::Nullable)) => Scalar::null(dtype.clone()), + (value, dtype) => vortex_bail!("Can't cast {:?} to {}", value, dtype), + }) } } diff --git a/vortex-scalar/src/value.rs b/vortex-scalar/src/value.rs index 57deabbcc7f..d510c10f7fa 100644 --- a/vortex-scalar/src/value.rs +++ b/vortex-scalar/src/value.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use itertools::Itertools; use vortex_buffer::{BufferString, ByteBuffer}; -use vortex_dtype::DType; -use vortex_error::{vortex_err, VortexResult}; +use vortex_dtype::{match_each_native_ptype, DType}; +use vortex_error::{vortex_bail, vortex_err, VortexResult}; use crate::pvalue::PValue; @@ -115,6 +115,52 @@ impl ScalarValue { pub(crate) fn as_list(&self) -> VortexResult>> { self.0.as_list() } + + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + Ok(match (&self.0, dtype) { + (InnerScalarValue::Bool(..), DType::Bool(..)) => self.clone(), + (InnerScalarValue::Primitive(pvalue), DType::Primitive(ptype, ..)) => { + let pvalue = match_each_native_ptype!(ptype, |$Q| { + PValue::from(pvalue.as_primitive::<$Q>()?) + }); + ScalarValue(InnerScalarValue::Primitive(pvalue)) + } + (InnerScalarValue::Buffer(..), DType::Binary(..)) => self.clone(), + (InnerScalarValue::BufferString(..), DType::Utf8(..)) => self.clone(), + (InnerScalarValue::List(fields), DType::Struct(sdtype, _)) => { + if fields.len() != sdtype.names().len() { + vortex_bail!( + "cannot cast from {} fields to {} fields", + fields.len(), + sdtype.names().len() + ); + } + ScalarValue(InnerScalarValue::List( + fields + .iter() + .zip_eq(sdtype.dtypes()) + .map(|(value, dtype)| value.cast(&dtype)) + .process_results(|iter| iter.collect())?, + )) + } + (InnerScalarValue::List(values), DType::List(dtype, _)) => { + ScalarValue(InnerScalarValue::List( + values + .iter() + .map(|value| value.cast(&dtype)) + .process_results(|iter| iter.collect())?, + )) + } + (InnerScalarValue::Null, dtype) => { + if !dtype.is_nullable() { + vortex_bail!("cannot cast null to non-nullable dtype: {}", dtype) + } + self.clone() + } + // (_, Extension(..)) we are never allowed to cast _to_ an extension type + _ => vortex_bail!("cannot cast {} to {}", self, dtype), + }) + } } impl InnerScalarValue { From 7bc264cd17098dc22b65a61c827d0a0d4279d37b Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 15 Jan 2025 15:40:24 -0500 Subject: [PATCH 02/15] we indeed may cast to an extension type --- vortex-scalar/src/value.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vortex-scalar/src/value.rs b/vortex-scalar/src/value.rs index d510c10f7fa..0c4262af2a8 100644 --- a/vortex-scalar/src/value.rs +++ b/vortex-scalar/src/value.rs @@ -147,7 +147,7 @@ impl ScalarValue { ScalarValue(InnerScalarValue::List( values .iter() - .map(|value| value.cast(&dtype)) + .map(|value| value.cast(dtype)) .process_results(|iter| iter.collect())?, )) } @@ -157,7 +157,7 @@ impl ScalarValue { } self.clone() } - // (_, Extension(..)) we are never allowed to cast _to_ an extension type + (_, DType::Extension(ext_dtype)) => self.cast(ext_dtype.storage_dtype())?, _ => vortex_bail!("cannot cast {} to {}", self, dtype), }) } From f0a0e99122c9ecfd79a40414758e313b2dfe8761 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 15 Jan 2025 15:45:11 -0500 Subject: [PATCH 03/15] the min/max of an array without valid values is not defined (and definitely not null) --- vortex-array/src/array/primitive/stats.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vortex-array/src/array/primitive/stats.rs b/vortex-array/src/array/primitive/stats.rs index ea9feebc64e..787c8609abf 100644 --- a/vortex-array/src/array/primitive/stats.rs +++ b/vortex-array/src/array/primitive/stats.rs @@ -403,7 +403,7 @@ mod test { let min: Option = arr.statistics().compute(Stat::Min); let max: Option = arr.statistics().compute(Stat::Max); let null_i32 = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)); - assert_eq!(min, Some(null_i32.clone())); - assert_eq!(max, Some(null_i32)); + assert_eq!(min, None); + assert_eq!(max, None); } } From 91c4b364f84692d35987d2e667b1bf9306136235 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 15 Jan 2025 15:47:20 -0500 Subject: [PATCH 04/15] unused variable --- vortex-array/src/array/primitive/stats.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/vortex-array/src/array/primitive/stats.rs b/vortex-array/src/array/primitive/stats.rs index 787c8609abf..8ac42a22a1c 100644 --- a/vortex-array/src/array/primitive/stats.rs +++ b/vortex-array/src/array/primitive/stats.rs @@ -402,7 +402,6 @@ mod test { let arr = PrimitiveArray::from_option_iter([Option::::None, None, None]); let min: Option = arr.statistics().compute(Stat::Min); let max: Option = arr.statistics().compute(Stat::Max); - let null_i32 = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)); assert_eq!(min, None); assert_eq!(max, None); } From 105c894c92112f095ce192716d4d66a4d96cb5d4 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 15 Jan 2025 15:47:56 -0500 Subject: [PATCH 05/15] clippy --- vortex-array/src/array/primitive/stats.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/vortex-array/src/array/primitive/stats.rs b/vortex-array/src/array/primitive/stats.rs index 8ac42a22a1c..9b74439b0c1 100644 --- a/vortex-array/src/array/primitive/stats.rs +++ b/vortex-array/src/array/primitive/stats.rs @@ -334,7 +334,6 @@ impl BitWidthAccumulator { #[cfg(test)] mod test { - use vortex_dtype::{DType, Nullability, PType}; use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; From 33a69b40e43b9995cf9206ee54169e62177c31c3 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 15 Jan 2025 17:31:37 -0500 Subject: [PATCH 06/15] works maybe? --- vortex-scalar/src/binary.rs | 24 +++--- vortex-scalar/src/bool.rs | 16 ++-- vortex-scalar/src/extension.rs | 33 ++++---- vortex-scalar/src/lib.rs | 134 +++++++++++++++++++++++++++++++-- vortex-scalar/src/list.rs | 42 ++++++----- vortex-scalar/src/primitive.rs | 26 ++++--- vortex-scalar/src/utf8.rs | 24 +++--- vortex-scalar/src/value.rs | 50 +----------- 8 files changed, 220 insertions(+), 129 deletions(-) diff --git a/vortex-scalar/src/binary.rs b/vortex-scalar/src/binary.rs index b72e27fd08f..4bb83b7bb6b 100644 --- a/vortex-scalar/src/binary.rs +++ b/vortex-scalar/src/binary.rs @@ -1,6 +1,6 @@ use vortex_buffer::ByteBuffer; use vortex_dtype::{DType, Nullability}; -use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; +use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult}; use crate::value::{InnerScalarValue, ScalarValue}; use crate::Scalar; @@ -20,15 +20,19 @@ impl<'a> BinaryScalar<'a> { self.value.as_ref().cloned() } - pub fn cast(&self, dtype: &DType) -> VortexResult { - Ok(match (&self.value, dtype) { - (Some(b), DType::Binary(_)) => Scalar::new( - dtype.clone(), - ScalarValue(InnerScalarValue::Buffer(b.clone())), - ), - (None, DType::Binary(Nullability::Nullable)) => Scalar::null(dtype.clone()), - (value, dtype) => vortex_bail!("Can't cast {:?} to {}", value, dtype), - }) + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + if !matches!(dtype, DType::Binary(..)) { + vortex_bail!("Can't cast binary to {}", dtype) + } + Ok(Scalar::new( + dtype.clone(), + ScalarValue(InnerScalarValue::Buffer( + self.value + .as_ref() + .vortex_expect("nullness handled in Scalar::cast") + .clone(), + )), + )) } } diff --git a/vortex-scalar/src/bool.rs b/vortex-scalar/src/bool.rs index 366b72e612e..0948854fd79 100644 --- a/vortex-scalar/src/bool.rs +++ b/vortex-scalar/src/bool.rs @@ -1,6 +1,6 @@ use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, Nullability}; -use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; +use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult}; use crate::value::ScalarValue; use crate::{InnerScalarValue, Scalar}; @@ -20,12 +20,14 @@ impl<'a> BoolScalar<'a> { self.value } - pub fn cast(&self, dtype: &DType) -> VortexResult { - Ok(match (self.value, dtype) { - (Some(b), DType::Bool(_)) => Scalar::new(dtype.clone(), ScalarValue::from(b)), - (None, DType::Bool(Nullability::Nullable)) => Scalar::null(dtype.clone()), - (value, dtype) => vortex_bail!("Can't cast {:?} to {}", value, dtype), - }) + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + if !matches!(dtype, DType::Bool(..)) { + vortex_bail!("Can't cast bool to {}", dtype) + } + Ok(Scalar::bool( + self.value.vortex_expect("nullness handled in Scalar::cast"), + dtype.nullability(), + )) } pub fn invert(self) -> BoolScalar<'a> { diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs index 3e60a69d5b4..d761a7028a5 100644 --- a/vortex-scalar/src/extension.rs +++ b/vortex-scalar/src/extension.rs @@ -1,23 +1,28 @@ use std::sync::Arc; use vortex_dtype::{DType, ExtDType}; -use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexResult}; +use vortex_error::{vortex_bail, VortexError, VortexResult}; use crate::value::ScalarValue; use crate::Scalar; pub struct ExtScalar<'a> { dtype: &'a DType, + ext_dtype: &'a ExtDType, value: &'a ScalarValue, } impl<'a> ExtScalar<'a> { pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult { - if !matches!(dtype, DType::Extension(..)) { + let DType::Extension(ext_dtype) = dtype else { vortex_bail!("Expected extension scalar, found {}", dtype) - } + }; - Ok(Self { dtype, value }) + Ok(Self { + dtype, + ext_dtype, + value, + }) } #[inline] @@ -27,25 +32,17 @@ impl<'a> ExtScalar<'a> { /// Returns the storage scalar of the extension scalar. pub fn storage(&self) -> Scalar { - let storage_dtype = if let DType::Extension(ext_dtype) = self.dtype() { - ext_dtype.storage_dtype().clone() - } else { - vortex_panic!("Expected extension DType: {}", self.dtype()); - }; - Scalar::new(storage_dtype, self.value.clone()) + Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone()) } - pub fn cast(&self, dtype: &DType) -> VortexResult { - if self.dtype().eq_ignore_nullability(dtype) { - if self.dtype.is_nullable() && dtype.is_nullable() && self.value.is_null() { - vortex_bail!("cannot cast null value to {}", dtype); - } - // ScalarValue::cast must reject casting _to_ an extension type because it does not know - // its own type. + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + if self.dtype().eq_ignore_nullability(dtype) + || self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) + { return Ok(Scalar::new(dtype.clone(), self.value.clone())); } - Ok(Scalar::new(dtype.clone(), self.value.cast(dtype)?)) + vortex_bail!("cannot cast {} to {}", self.dtype(), dtype); } } diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 1a886f41e17..1114f19083e 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -101,16 +101,31 @@ impl Scalar { } pub fn cast(&self, dtype: &DType) -> VortexResult { + if self.is_null() && !dtype.is_nullable() { + vortex_bail!("Can't cast null scalar to non-nullable type") + } + + if self.is_null() && dtype.is_nullable() { + return Ok(Scalar::new(dtype.clone(), self.value.clone())); + } + if self.dtype().eq_ignore_nullability(dtype) { - if self.dtype.is_nullable() && dtype.is_nullable() && self.value.is_null() { - vortex_bail!("cannot cast null value to {}", dtype); - } - // ScalarValue::cast must reject casting _to_ an extension type because it does not know - // its own type. return Ok(Scalar::new(dtype.clone(), self.value.clone())); } - Ok(Scalar::new(dtype.clone(), self.value.cast(dtype)?)) + match &self.dtype { + DType::Null => { + assert!(dtype.is_nullable()); + Ok(Scalar::new(dtype.clone(), self.value.clone())) + } + DType::Bool(_) => self.as_bool().cast(dtype), + DType::Primitive(..) => self.as_primitive().cast(dtype), + DType::Utf8(_) => self.as_utf8().cast(dtype), + DType::Binary(_) => self.as_binary().cast(dtype), + DType::Struct(..) => self.as_struct().cast(dtype), + DType::List(..) => self.as_list().cast(dtype), + DType::Extension(..) => self.as_extension().cast(dtype), + } } pub fn into_nullable(self) -> Self { @@ -278,9 +293,112 @@ from_vec_for_scalar!(ByteBuffer); mod test { use std::sync::Arc; - use vortex_dtype::{DType, ExtDType, ExtID, Nullability}; + use rstest::rstest; + use vortex_buffer::{BufferString, ByteBuffer}; + use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType, StructDType}; + + use crate::{InnerScalarValue, PValue, Scalar, ScalarValue}; + + #[rstest] + #[case(Scalar::null(DType::Null))] + // + #[case(Scalar::from(true))] + #[case(Scalar::from(Some(true)))] + #[case(Scalar::null(DType::Bool(Nullability::Nullable)))] + // + #[case(Scalar::from(1u8))] + #[case(Scalar::from(-1i8))] + #[case(Scalar::from(Some(1u8)))] + #[case(Scalar::from(Some(-1i8)))] + #[case(Scalar::null(DType::Primitive(PType::U8, Nullability::Nullable)))] + #[case(Scalar::null(DType::Primitive(PType::I8, Nullability::Nullable)))] + // + #[case(Scalar::from(1u64))] + #[case(Scalar::from(-1i64))] + #[case(Scalar::from(Some(1u64)))] + #[case(Scalar::from(Some(-1i64)))] + #[case(Scalar::null(DType::Primitive(PType::U64, Nullability::Nullable)))] + #[case(Scalar::null(DType::Primitive(PType::I64, Nullability::Nullable)))] + // + #[case(Scalar::from("hello"))] + #[case(Scalar::from(Some(BufferString::from("hello"))))] + #[case(Scalar::null(DType::Utf8(Nullability::Nullable)))] + #[case(Scalar::from(ByteBuffer::from(vec![0u8, 1, 2])))] + #[case(Scalar::from(Some(ByteBuffer::from(vec![0u8, 1, 2]))))] + #[case(Scalar::null(DType::Binary(Nullability::Nullable)))] + // + #[case(Scalar::new(DType::Struct(StructDType::new(Arc::from(["a".into()]), + vec![DType::Bool(Nullability::Nullable)]), + Nullability::Nullable), + ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(InnerScalarValue::Bool(true))])))))] + // + #[case(Scalar::new(DType::List(Arc::from(DType::Bool(Nullability::Nullable)), Nullability::Nullable), + ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(InnerScalarValue::Bool(true))])))))] + fn no_op_cast(#[case] scalar: Scalar) { + scalar.cast(&scalar.dtype()).unwrap(); + } + + #[test] + fn list_casts() { + let list = Scalar::new( + DType::List( + Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)), + Nullability::Nullable, + ), + ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue( + InnerScalarValue::Primitive(PValue::U16(6)), + )]))), + ); - use crate::{InnerScalarValue, Scalar, ScalarValue}; + let target_u32 = DType::List( + Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)), + Nullability::Nullable, + ); + assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32); + + let target_u32_nonnull = DType::List( + Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)), + Nullability::Nullable, + ); + assert_eq!( + list.cast(&target_u32_nonnull).unwrap().dtype(), + &target_u32_nonnull + ); + + let target_nonnull = DType::List( + Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)), + Nullability::NonNullable, + ); + assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull); + + let target_u8 = DType::List( + Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)), + Nullability::Nullable, + ); + assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8); + + let list_with_null = Scalar::new( + DType::List( + Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)), + Nullability::Nullable, + ), + ScalarValue(InnerScalarValue::List(Arc::from([ + ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))), + ScalarValue(InnerScalarValue::Null), + ]))), + ); + let target_u8 = DType::List( + Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)), + Nullability::Nullable, + ); + assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8); + + let target_u32_nonnull = DType::List( + Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)), + Nullability::Nullable, + ); + assert!(list_with_null.cast(&target_u32_nonnull).is_err()); + } #[test] fn cast_from_extension_types() { diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/list.rs index 5cc1bda6d91..7a32c4ef329 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/list.rs @@ -3,13 +3,14 @@ use std::sync::Arc; use itertools::Itertools as _; use vortex_dtype::{DType, Nullability}; -use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexResult}; +use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexExpect as _, VortexResult}; use crate::value::{InnerScalarValue, ScalarValue}; use crate::Scalar; pub struct ListScalar<'a> { dtype: &'a DType, + element_dtype: &'a Arc, elements: Option>, } @@ -66,20 +67,26 @@ impl<'a> ListScalar<'a> { }) } - pub fn cast(&self, dtype: &DType) -> VortexResult { - Ok(match (&self.elements, dtype) { - (Some(elements), DType::List(element_dtype, _)) => Scalar::new( - dtype.clone(), - ScalarValue(InnerScalarValue::List( - elements - .iter() - .map(|element| element.cast(element_dtype)) - .process_results(|iter| iter.collect())?, - )), - ), - (None, DType::List(_, Nullability::Nullable)) => Scalar::null(dtype.clone()), - (value, dtype) => vortex_bail!("Can't cast {:?} to {}", value, dtype), - }) + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + let DType::List(element_dtype, ..) = dtype else { + vortex_bail!("Can't cast {:?} to {}", self.dtype(), dtype) + }; + + Ok(Scalar::new( + dtype.clone(), + ScalarValue(InnerScalarValue::List( + self.elements + .as_ref() + .vortex_expect("nullness handled in Scalar::cast") + .iter() + .map(|element| { + Scalar::new(DType::clone(self.element_dtype), element.clone()) + .cast(element_dtype) + .map(|x| x.value().clone()) + }) + .process_results(|iter| iter.collect())?, + )), + )) } } @@ -118,12 +125,13 @@ impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> { type Error = VortexError; fn try_from(value: &'a Scalar) -> Result { - if !matches!(value.dtype(), DType::List(..)) { + let DType::List(element_dtype, ..) = value.dtype() else { vortex_bail!("Expected list scalar, found {}", value.dtype()) - } + }; Ok(Self { dtype: value.dtype(), + element_dtype, elements: value.value.as_list()?.cloned(), }) } diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index 25559a33bd7..bf002725d4b 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -5,7 +5,8 @@ use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive}; use vortex_dtype::half::f16; use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability, PType}; use vortex_error::{ - vortex_bail, vortex_err, vortex_panic, VortexError, VortexResult, VortexUnwrap, + vortex_bail, vortex_err, vortex_panic, VortexError, VortexExpect as _, VortexResult, + VortexUnwrap, }; use crate::pvalue::PValue; @@ -67,16 +68,19 @@ impl<'a> PrimitiveScalar<'a> { self.pvalue.map(|pv| pv.as_primitive::().vortex_unwrap()) } - pub fn cast(&self, dtype: &DType) -> VortexResult { - Ok(match (&self.pvalue, dtype) { - (Some(v), DType::Primitive(ptype, _)) => { - match_each_native_ptype!(ptype, |$Q| { - Scalar::primitive(v.as_primitive::<$Q>()?, dtype.nullability()) - }) - } - (None, DType::Primitive(_, Nullability::Nullable)) => Scalar::null(dtype.clone()), - (value, dtype) => vortex_bail!("Can't cast {:?} to {}", value, dtype), - }) + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + let ptype = PType::try_from(dtype)?; + let pvalue = self + .pvalue + .vortex_expect("nullness handled in Scalar::cast"); + Ok(match_each_native_ptype!(ptype, |$Q| { + Scalar::primitive( + pvalue + .as_primitive::<$Q>() + .map_err(|err| vortex_err!("Can't cast {} scalar {} to {} (cause: {})", self.ptype, pvalue, dtype, err))?, + dtype.nullability() + ) + })) } /// Attempt to extract the primitive value as the given type. diff --git a/vortex-scalar/src/utf8.rs b/vortex-scalar/src/utf8.rs index 98def4137f6..4c60b49460d 100644 --- a/vortex-scalar/src/utf8.rs +++ b/vortex-scalar/src/utf8.rs @@ -1,7 +1,7 @@ use vortex_buffer::BufferString; use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, Nullability}; -use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; +use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult}; use crate::value::ScalarValue; use crate::{InnerScalarValue, Scalar}; @@ -21,15 +21,19 @@ impl<'a> Utf8Scalar<'a> { self.value.as_ref().cloned() } - pub fn cast(&self, dtype: &DType) -> VortexResult { - Ok(match (&self.value, dtype) { - (Some(bufstr), DType::Utf8(_)) => Scalar::new( - dtype.clone(), - ScalarValue(InnerScalarValue::BufferString(bufstr.clone())), - ), - (None, DType::Utf8(Nullability::Nullable)) => Scalar::null(dtype.clone()), - (value, dtype) => vortex_bail!("Can't cast {:?} to {}", value, dtype), - }) + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + if !matches!(dtype, DType::Bool(..)) { + vortex_bail!("Can't cast bool to {}", dtype) + } + Ok(Scalar::new( + dtype.clone(), + ScalarValue(InnerScalarValue::BufferString( + self.value + .as_ref() + .vortex_expect("nullness handled in Scalar::cast") + .clone(), + )), + )) } } diff --git a/vortex-scalar/src/value.rs b/vortex-scalar/src/value.rs index 0c4262af2a8..57deabbcc7f 100644 --- a/vortex-scalar/src/value.rs +++ b/vortex-scalar/src/value.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use itertools::Itertools; use vortex_buffer::{BufferString, ByteBuffer}; -use vortex_dtype::{match_each_native_ptype, DType}; -use vortex_error::{vortex_bail, vortex_err, VortexResult}; +use vortex_dtype::DType; +use vortex_error::{vortex_err, VortexResult}; use crate::pvalue::PValue; @@ -115,52 +115,6 @@ impl ScalarValue { pub(crate) fn as_list(&self) -> VortexResult>> { self.0.as_list() } - - pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { - Ok(match (&self.0, dtype) { - (InnerScalarValue::Bool(..), DType::Bool(..)) => self.clone(), - (InnerScalarValue::Primitive(pvalue), DType::Primitive(ptype, ..)) => { - let pvalue = match_each_native_ptype!(ptype, |$Q| { - PValue::from(pvalue.as_primitive::<$Q>()?) - }); - ScalarValue(InnerScalarValue::Primitive(pvalue)) - } - (InnerScalarValue::Buffer(..), DType::Binary(..)) => self.clone(), - (InnerScalarValue::BufferString(..), DType::Utf8(..)) => self.clone(), - (InnerScalarValue::List(fields), DType::Struct(sdtype, _)) => { - if fields.len() != sdtype.names().len() { - vortex_bail!( - "cannot cast from {} fields to {} fields", - fields.len(), - sdtype.names().len() - ); - } - ScalarValue(InnerScalarValue::List( - fields - .iter() - .zip_eq(sdtype.dtypes()) - .map(|(value, dtype)| value.cast(&dtype)) - .process_results(|iter| iter.collect())?, - )) - } - (InnerScalarValue::List(values), DType::List(dtype, _)) => { - ScalarValue(InnerScalarValue::List( - values - .iter() - .map(|value| value.cast(dtype)) - .process_results(|iter| iter.collect())?, - )) - } - (InnerScalarValue::Null, dtype) => { - if !dtype.is_nullable() { - vortex_bail!("cannot cast null to non-nullable dtype: {}", dtype) - } - self.clone() - } - (_, DType::Extension(ext_dtype)) => self.cast(ext_dtype.storage_dtype())?, - _ => vortex_bail!("cannot cast {} to {}", self, dtype), - }) - } } impl InnerScalarValue { From 950f1542a39fbbfbb503651f89b0fe55a665bdea Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 15 Jan 2025 17:32:55 -0500 Subject: [PATCH 07/15] simplify --- vortex-scalar/src/lib.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 1114f19083e..5abb3049018 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -101,12 +101,12 @@ impl Scalar { } pub fn cast(&self, dtype: &DType) -> VortexResult { - if self.is_null() && !dtype.is_nullable() { - vortex_bail!("Can't cast null scalar to non-nullable type") - } - - if self.is_null() && dtype.is_nullable() { - return Ok(Scalar::new(dtype.clone(), self.value.clone())); + if self.is_null() { + if dtype.is_nullable() { + return Ok(Scalar::new(dtype.clone(), self.value.clone())); + } else { + vortex_bail!("Can't cast null scalar to non-nullable type") + } } if self.dtype().eq_ignore_nullability(dtype) { From a852c2ffc1e73f3df5afdfd81fde2485ed6bff4e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 16 Jan 2025 11:01:09 -0500 Subject: [PATCH 08/15] wip --- vortex-array/src/array/bool/stats.rs | 14 +- vortex-array/src/data/statistics.rs | 4 +- vortex-array/src/stats/statsset.rs | 8 +- vortex-scalar/src/lib.rs | 192 +++++++++++++++++---------- 4 files changed, 134 insertions(+), 84 deletions(-) diff --git a/vortex-array/src/array/bool/stats.rs b/vortex-array/src/array/bool/stats.rs index e2139c69bc6..e8408847cdf 100644 --- a/vortex-array/src/array/bool/stats.rs +++ b/vortex-array/src/array/bool/stats.rs @@ -169,9 +169,7 @@ impl BoolStatsAccumulator { #[cfg(test)] mod test { use arrow_buffer::BooleanBuffer; - use vortex_dtype::Nullability::Nullable; - use vortex_dtype::{DType, Nullability}; - use vortex_scalar::Scalar; + use vortex_dtype::Nullability; use crate::array::BoolArray; use crate::stats::{ArrayStatistics, Stat}; @@ -278,14 +276,8 @@ mod test { assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_constant().unwrap()); - assert_eq!( - bool_arr.statistics().compute(Stat::Min).unwrap(), - Scalar::null(DType::Bool(Nullable)) - ); - assert_eq!( - bool_arr.statistics().compute(Stat::Max).unwrap(), - Scalar::null(DType::Bool(Nullable)) - ); + assert_eq!(bool_arr.statistics().compute(Stat::Min), None); + assert_eq!(bool_arr.statistics().compute(Stat::Max), None); assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 1); assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 0); assert_eq!(bool_arr.statistics().compute_null_count().unwrap(), 5); diff --git a/vortex-array/src/data/statistics.rs b/vortex-array/src/data/statistics.rs index 4e0a139756e..720d8e5383a 100644 --- a/vortex-array/src/data/statistics.rs +++ b/vortex-array/src/data/statistics.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use enum_iterator::all; use itertools::Itertools; use vortex_dtype::{DType, Nullability, PType}; -use vortex_error::vortex_panic; +use vortex_error::{vortex_panic, VortexExpect as _}; use vortex_scalar::{Scalar, ScalarValue}; use crate::data::InnerArrayData; @@ -118,7 +118,7 @@ impl Statistics for ArrayData { let s = self .encoding() .compute_statistics(self, stat) - .ok()? + .vortex_expect("compute_statistics must not fail") .get(stat) .cloned(); diff --git a/vortex-array/src/stats/statsset.rs b/vortex-array/src/stats/statsset.rs index 0babda26e20..aeb6cc1e616 100644 --- a/vortex-array/src/stats/statsset.rs +++ b/vortex-array/src/stats/statsset.rs @@ -33,8 +33,6 @@ impl StatsSet { /// an array consisting entirely of [null](vortex_dtype::DType::Null) values. pub fn nulls(len: usize, dtype: &DType) -> Self { let mut stats = Self::new_unchecked(vec![ - (Stat::Min, Scalar::null(dtype.clone())), - (Stat::Max, Scalar::null(dtype.clone())), (Stat::RunCount, 1.into()), (Stat::NullCount, len.into()), ]); @@ -86,8 +84,10 @@ impl StatsSet { stats.set(Stat::TrueCount, true_count); } - stats.set(Stat::Min, scalar.clone()); - stats.set(Stat::Max, scalar.clone()); + if !scalar.is_null() { + stats.set(Stat::Min, scalar.clone()); + stats.set(Stat::Max, scalar.clone()); + } stats } diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 5abb3049018..f00c2a20c75 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -100,31 +100,41 @@ impl Scalar { } } - pub fn cast(&self, dtype: &DType) -> VortexResult { + pub fn cast(&self, target: &DType) -> VortexResult { + if let DType::Extension(ext_dtype) = target { + let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?; + Ok(Scalar::extension(ext_dtype.clone(), storage_scalar)) + } else { + self.cast_to_non_extension(target) + } + } + + pub fn cast_to_non_extension(&self, target: &DType) -> VortexResult { + assert!(!matches!(target, DType::Extension(..))); if self.is_null() { - if dtype.is_nullable() { - return Ok(Scalar::new(dtype.clone(), self.value.clone())); + if target.is_nullable() { + return Ok(Scalar::new(target.clone(), self.value.clone())); } else { vortex_bail!("Can't cast null scalar to non-nullable type") } } - if self.dtype().eq_ignore_nullability(dtype) { - return Ok(Scalar::new(dtype.clone(), self.value.clone())); + if self.dtype().eq_ignore_nullability(target) { + return Ok(Scalar::new(target.clone(), self.value.clone())); } match &self.dtype { DType::Null => { - assert!(dtype.is_nullable()); - Ok(Scalar::new(dtype.clone(), self.value.clone())) + assert!(target.is_nullable()); + Ok(Scalar::new(target.clone(), self.value.clone())) } - DType::Bool(_) => self.as_bool().cast(dtype), - DType::Primitive(..) => self.as_primitive().cast(dtype), - DType::Utf8(_) => self.as_utf8().cast(dtype), - DType::Binary(_) => self.as_binary().cast(dtype), - DType::Struct(..) => self.as_struct().cast(dtype), - DType::List(..) => self.as_list().cast(dtype), - DType::Extension(..) => self.as_extension().cast(dtype), + DType::Bool(_) => self.as_bool().cast(target), + DType::Primitive(..) => self.as_primitive().cast(target), + DType::Utf8(_) => self.as_utf8().cast(target), + DType::Binary(_) => self.as_binary().cast(target), + DType::Struct(..) => self.as_struct().cast(target), + DType::List(..) => self.as_list().cast(target), + DType::Extension(..) => self.as_extension().cast(target), } } @@ -294,48 +304,52 @@ mod test { use std::sync::Arc; use rstest::rstest; - use vortex_buffer::{BufferString, ByteBuffer}; - use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType, StructDType}; + use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType}; use crate::{InnerScalarValue, PValue, Scalar, ScalarValue}; #[rstest] - #[case(Scalar::null(DType::Null))] - // - #[case(Scalar::from(true))] - #[case(Scalar::from(Some(true)))] - #[case(Scalar::null(DType::Bool(Nullability::Nullable)))] - // - #[case(Scalar::from(1u8))] - #[case(Scalar::from(-1i8))] - #[case(Scalar::from(Some(1u8)))] - #[case(Scalar::from(Some(-1i8)))] - #[case(Scalar::null(DType::Primitive(PType::U8, Nullability::Nullable)))] - #[case(Scalar::null(DType::Primitive(PType::I8, Nullability::Nullable)))] - // - #[case(Scalar::from(1u64))] - #[case(Scalar::from(-1i64))] - #[case(Scalar::from(Some(1u64)))] - #[case(Scalar::from(Some(-1i64)))] - #[case(Scalar::null(DType::Primitive(PType::U64, Nullability::Nullable)))] - #[case(Scalar::null(DType::Primitive(PType::I64, Nullability::Nullable)))] - // - #[case(Scalar::from("hello"))] - #[case(Scalar::from(Some(BufferString::from("hello"))))] - #[case(Scalar::null(DType::Utf8(Nullability::Nullable)))] - #[case(Scalar::from(ByteBuffer::from(vec![0u8, 1, 2])))] - #[case(Scalar::from(Some(ByteBuffer::from(vec![0u8, 1, 2]))))] - #[case(Scalar::null(DType::Binary(Nullability::Nullable)))] - // - #[case(Scalar::new(DType::Struct(StructDType::new(Arc::from(["a".into()]), - vec![DType::Bool(Nullability::Nullable)]), - Nullability::Nullable), - ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(InnerScalarValue::Bool(true))])))))] - // - #[case(Scalar::new(DType::List(Arc::from(DType::Bool(Nullability::Nullable)), Nullability::Nullable), - ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue(InnerScalarValue::Bool(true))])))))] - fn no_op_cast(#[case] scalar: Scalar) { - scalar.cast(&scalar.dtype()).unwrap(); + fn null_can_cast_to_anything_nullable( + #[values( + DType::Null, + DType::Bool(Nullability::Nullable), + DType::Primitive(PType::I32, Nullability::Nullable), + DType::Extension(Arc::from(ExtDType::new( + ExtID::from("a"), + Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)), + None, + ))), + DType::Extension(Arc::from(ExtDType::new( + ExtID::from("b"), + Arc::from(DType::Utf8(Nullability::Nullable)), + None, + ))) + )] + source_dtype: DType, + #[values( + DType::Null, + DType::Bool(Nullability::Nullable), + DType::Primitive(PType::I32, Nullability::Nullable), + DType::Extension(Arc::from(ExtDType::new( + ExtID::from("a"), + Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)), + None, + ))), + DType::Extension(Arc::from(ExtDType::new( + ExtID::from("b"), + Arc::from(DType::Utf8(Nullability::Nullable)), + None, + ))) + )] + target_dtype: DType, + ) { + assert_eq!( + Scalar::null(source_dtype) + .cast(&target_dtype) + .unwrap() + .dtype(), + &target_dtype + ); } #[test] @@ -401,27 +415,71 @@ mod test { } #[test] - fn cast_from_extension_types() { + fn cast_to_from_extension_types() { let apples = ExtDType::new( ExtID::new(Arc::from("apples")), - Arc::from(DType::Bool(Nullability::NonNullable)), + Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)), None, ); - let scalar = Scalar::new( - DType::Extension(Arc::from(apples.clone())), - ScalarValue(InnerScalarValue::Bool(true)), + let ext_dtype = DType::Extension(Arc::from(apples.clone())); + let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true))); + let storage_scalar = Scalar::new( + DType::clone(apples.storage_dtype()), + ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))), ); - let inner = scalar.cast(scalar.dtype()).unwrap(); - assert_eq!(inner.dtype(), scalar.dtype()); - - let inner = scalar.cast(&scalar.dtype().as_nullable()).unwrap(); - assert_eq!(inner.dtype(), &scalar.dtype().as_nullable()); - - let inner = scalar.cast(apples.storage_dtype()).unwrap(); - assert_eq!(inner.dtype(), apples.storage_dtype()); + // to self + let expected_dtype = &ext_dtype; + let actual = ext_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // to nullable self + let expected_dtype = &ext_dtype.as_nullable(); + let actual = ext_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // cast to the storage type + let expected_dtype = apples.storage_dtype(); + let actual = ext_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // cast to the storage type, nullable + let expected_dtype = &apples.storage_dtype().as_nullable(); + let actual = ext_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // cast from storage type to extension + let expected_dtype = &ext_dtype; + let actual = storage_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // cast from storage type to extension, nullable + let expected_dtype = &ext_dtype.as_nullable(); + let actual = storage_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // cast from *compatible* storage type to extension + let storage_scalar_u64 = Scalar::new( + DType::clone(apples.storage_dtype()), + ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))), + ); + let expected_dtype = &ext_dtype; + let actual = storage_scalar_u64.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); - let inner = scalar.cast(&apples.storage_dtype().as_nullable()).unwrap(); - assert_eq!(inner.dtype(), &apples.storage_dtype().as_nullable()); + // cast from *incompatible* storage type to extension + let apples_u8 = ExtDType::new( + ExtID::new(Arc::from("apples")), + Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)), + None, + ); + let expected_dtype = &DType::Extension(Arc::from(apples_u8)); + let result = storage_scalar.cast(expected_dtype); + assert!( + result.is_err_and(|err| { + err + .to_string() + .contains("Can't cast u16 scalar 1000_u16 to u8 (cause: Cannot read primitive value U16(1000) as u8") + })); } } From 16b1951caf62eca892436d7c48714bfc5ed660da Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 16 Jan 2025 11:05:36 -0500 Subject: [PATCH 09/15] fix simple case --- vortex-scalar/src/lib.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index f00c2a20c75..f606ede89e7 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -124,10 +124,7 @@ impl Scalar { } match &self.dtype { - DType::Null => { - assert!(target.is_nullable()); - Ok(Scalar::new(target.clone(), self.value.clone())) - } + DType::Null => unreachable!(), // handled by if is_null case DType::Bool(_) => self.as_bool().cast(target), DType::Primitive(..) => self.as_primitive().cast(target), DType::Utf8(_) => self.as_utf8().cast(target), From 89f0870665cee7e26c87bb6e4f97e4f32363f760 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 16 Jan 2025 11:07:26 -0500 Subject: [PATCH 10/15] typo --- vortex-scalar/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index f606ede89e7..31d9ab20ae1 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -476,7 +476,7 @@ mod test { result.is_err_and(|err| { err .to_string() - .contains("Can't cast u16 scalar 1000_u16 to u8 (cause: Cannot read primitive value U16(1000) as u8") + .contains("Can't cast u16 scalar 1000_u16 to u8 (cause: Cannot read primitive value U16(1000) as u8)") })); } } From f68b0b8d4eccd2eb744073e2e4a014f4cf8d3081 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 16 Jan 2025 11:14:19 -0500 Subject: [PATCH 11/15] another bad test --- encodings/bytebool/src/stats.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/encodings/bytebool/src/stats.rs b/encodings/bytebool/src/stats.rs index 1acaf40beff..794e5ab4c06 100644 --- a/encodings/bytebool/src/stats.rs +++ b/encodings/bytebool/src/stats.rs @@ -23,8 +23,6 @@ impl StatisticsVTable for ByteBoolEncoding { #[cfg(test)] mod tests { use vortex_array::stats::ArrayStatistics; - use vortex_dtype::{DType, Nullability}; - use vortex_scalar::Scalar; use super::*; @@ -90,14 +88,8 @@ mod tests { assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_constant().unwrap()); - assert_eq!( - bool_arr.statistics().compute(Stat::Min).unwrap(), - Scalar::null(DType::Bool(Nullability::Nullable)) - ); - assert_eq!( - bool_arr.statistics().compute(Stat::Max).unwrap(), - Scalar::null(DType::Bool(Nullability::Nullable)) - ); + assert_eq!(bool_arr.statistics().compute(Stat::Min), None); + assert_eq!(bool_arr.statistics().compute(Stat::Max), None); assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 1); assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 0); } From 8b1815bdd81e5f81c54ba663de3cb1395685213c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 16 Jan 2025 11:18:44 -0500 Subject: [PATCH 12/15] nope --- vortex-scalar/src/lib.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 31d9ab20ae1..9ed2e291096 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -473,10 +473,13 @@ mod test { let expected_dtype = &DType::Extension(Arc::from(apples_u8)); let result = storage_scalar.cast(expected_dtype); assert!( - result.is_err_and(|err| { + result.as_ref().is_err_and(|err| { err .to_string() - .contains("Can't cast u16 scalar 1000_u16 to u8 (cause: Cannot read primitive value U16(1000) as u8)") - })); + .contains("Can't cast u16 scalar 1000_u16 to u8 (cause: Cannot read primitive value U16(1000) as u8") + }), + "{:?}", + result + ); } } From b07b343cb63be394a79fa85cdbe13f03d3d77197 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 16 Jan 2025 12:27:01 -0500 Subject: [PATCH 13/15] fix utf8 cast method --- vortex-scalar/src/utf8.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vortex-scalar/src/utf8.rs b/vortex-scalar/src/utf8.rs index 4c60b49460d..ef6e534dd27 100644 --- a/vortex-scalar/src/utf8.rs +++ b/vortex-scalar/src/utf8.rs @@ -22,8 +22,8 @@ impl<'a> Utf8Scalar<'a> { } pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { - if !matches!(dtype, DType::Bool(..)) { - vortex_bail!("Can't cast bool to {}", dtype) + if !matches!(dtype, DType::Utf8(..)) { + vortex_bail!("Can't cast utf8 to {}", dtype) } Ok(Scalar::new( dtype.clone(), From 6e4371e00c0b8f6de7eaa737b59d260a3320da58 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 21 Jan 2025 11:14:21 +0000 Subject: [PATCH 14/15] remove DType borrow from ExtensionScalar --- vortex-scalar/src/extension.rs | 39 +++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs index d761a7028a5..78264b5927b 100644 --- a/vortex-scalar/src/extension.rs +++ b/vortex-scalar/src/extension.rs @@ -7,7 +7,6 @@ use crate::value::ScalarValue; use crate::Scalar; pub struct ExtScalar<'a> { - dtype: &'a DType, ext_dtype: &'a ExtDType, value: &'a ScalarValue, } @@ -18,16 +17,7 @@ impl<'a> ExtScalar<'a> { vortex_bail!("Expected extension scalar, found {}", dtype) }; - Ok(Self { - dtype, - ext_dtype, - value, - }) - } - - #[inline] - pub fn dtype(&self) -> &'a DType { - self.dtype + Ok(Self { ext_dtype, value }) } /// Returns the storage scalar of the extension scalar. @@ -36,13 +26,32 @@ impl<'a> ExtScalar<'a> { } pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { - if self.dtype().eq_ignore_nullability(dtype) - || self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) - { + if self.value.is_null() && !dtype.is_nullable() { + vortex_bail!( + "cannot cast extension dtype with id {} and storage type {} to {}", + self.ext_dtype.id(), + self.ext_dtype.storage_dtype(), + dtype + ); + } + + if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) { + // Casting from an extension type to the underlying storage type is OK. return Ok(Scalar::new(dtype.clone(), self.value.clone())); } - vortex_bail!("cannot cast {} to {}", self.dtype(), dtype); + if let DType::Extension(ext_dtype) = dtype { + if self.ext_dtype.eq_ignore_nullability(ext_dtype) { + return Ok(Scalar::new(dtype.clone(), self.value.clone())); + } + } + + vortex_bail!( + "cannot cast extension dtype with id {} and storage type {} to {}", + self.ext_dtype.id(), + self.ext_dtype.storage_dtype(), + dtype + ); } } From 48466c9139406862a74a0c0bfcac9952cecc88fd Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 21 Jan 2025 11:18:25 +0000 Subject: [PATCH 15/15] include target type in error --- vortex-scalar/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 9ed2e291096..683b2df1d8d 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -115,7 +115,7 @@ impl Scalar { if target.is_nullable() { return Ok(Scalar::new(target.clone(), self.value.clone())); } else { - vortex_bail!("Can't cast null scalar to non-nullable type") + vortex_bail!("Can't cast null scalar to non-nullable type {}", target) } }