diff --git a/encodings/bytebool/src/stats.rs b/encodings/bytebool/src/stats.rs index 1acaf40beff..794e5ab4c06 100644 --- a/encodings/bytebool/src/stats.rs +++ b/encodings/bytebool/src/stats.rs @@ -23,8 +23,6 @@ impl StatisticsVTable for ByteBoolEncoding { #[cfg(test)] mod tests { use vortex_array::stats::ArrayStatistics; - use vortex_dtype::{DType, Nullability}; - use vortex_scalar::Scalar; use super::*; @@ -90,14 +88,8 @@ mod tests { assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_constant().unwrap()); - assert_eq!( - bool_arr.statistics().compute(Stat::Min).unwrap(), - Scalar::null(DType::Bool(Nullability::Nullable)) - ); - assert_eq!( - bool_arr.statistics().compute(Stat::Max).unwrap(), - Scalar::null(DType::Bool(Nullability::Nullable)) - ); + assert_eq!(bool_arr.statistics().compute(Stat::Min), None); + assert_eq!(bool_arr.statistics().compute(Stat::Max), None); assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 1); assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 0); } diff --git a/vortex-array/src/array/bool/stats.rs b/vortex-array/src/array/bool/stats.rs index e2139c69bc6..e8408847cdf 100644 --- a/vortex-array/src/array/bool/stats.rs +++ b/vortex-array/src/array/bool/stats.rs @@ -169,9 +169,7 @@ impl BoolStatsAccumulator { #[cfg(test)] mod test { use arrow_buffer::BooleanBuffer; - use vortex_dtype::Nullability::Nullable; - use vortex_dtype::{DType, Nullability}; - use vortex_scalar::Scalar; + use vortex_dtype::Nullability; use crate::array::BoolArray; use crate::stats::{ArrayStatistics, Stat}; @@ -278,14 +276,8 @@ mod test { assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_constant().unwrap()); - assert_eq!( - bool_arr.statistics().compute(Stat::Min).unwrap(), - Scalar::null(DType::Bool(Nullable)) - ); - assert_eq!( - bool_arr.statistics().compute(Stat::Max).unwrap(), - Scalar::null(DType::Bool(Nullable)) - ); + assert_eq!(bool_arr.statistics().compute(Stat::Min), None); + assert_eq!(bool_arr.statistics().compute(Stat::Max), None); assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 1); assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 0); assert_eq!(bool_arr.statistics().compute_null_count().unwrap(), 5); diff --git a/vortex-array/src/array/primitive/stats.rs b/vortex-array/src/array/primitive/stats.rs index ea9feebc64e..9b74439b0c1 100644 --- a/vortex-array/src/array/primitive/stats.rs +++ b/vortex-array/src/array/primitive/stats.rs @@ -334,7 +334,6 @@ impl BitWidthAccumulator { #[cfg(test)] mod test { - use vortex_dtype::{DType, Nullability, PType}; use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; @@ -402,8 +401,7 @@ mod test { let arr = PrimitiveArray::from_option_iter([Option::::None, None, None]); let min: Option = arr.statistics().compute(Stat::Min); let max: Option = arr.statistics().compute(Stat::Max); - let null_i32 = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)); - assert_eq!(min, Some(null_i32.clone())); - assert_eq!(max, Some(null_i32)); + assert_eq!(min, None); + assert_eq!(max, None); } } diff --git a/vortex-array/src/data/statistics.rs b/vortex-array/src/data/statistics.rs index 4e0a139756e..720d8e5383a 100644 --- a/vortex-array/src/data/statistics.rs +++ b/vortex-array/src/data/statistics.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use enum_iterator::all; use itertools::Itertools; use vortex_dtype::{DType, Nullability, PType}; -use vortex_error::vortex_panic; +use vortex_error::{vortex_panic, VortexExpect as _}; use vortex_scalar::{Scalar, ScalarValue}; use crate::data::InnerArrayData; @@ -118,7 +118,7 @@ impl Statistics for ArrayData { let s = self .encoding() .compute_statistics(self, stat) - .ok()? + .vortex_expect("compute_statistics must not fail") .get(stat) .cloned(); diff --git a/vortex-array/src/stats/statsset.rs b/vortex-array/src/stats/statsset.rs index 0babda26e20..aeb6cc1e616 100644 --- a/vortex-array/src/stats/statsset.rs +++ b/vortex-array/src/stats/statsset.rs @@ -33,8 +33,6 @@ impl StatsSet { /// an array consisting entirely of [null](vortex_dtype::DType::Null) values. pub fn nulls(len: usize, dtype: &DType) -> Self { let mut stats = Self::new_unchecked(vec![ - (Stat::Min, Scalar::null(dtype.clone())), - (Stat::Max, Scalar::null(dtype.clone())), (Stat::RunCount, 1.into()), (Stat::NullCount, len.into()), ]); @@ -86,8 +84,10 @@ impl StatsSet { stats.set(Stat::TrueCount, true_count); } - stats.set(Stat::Min, scalar.clone()); - stats.set(Stat::Max, scalar.clone()); + if !scalar.is_null() { + stats.set(Stat::Min, scalar.clone()); + stats.set(Stat::Max, scalar.clone()); + } stats } diff --git a/vortex-scalar/src/binary.rs b/vortex-scalar/src/binary.rs index f09f76d803c..4bb83b7bb6b 100644 --- a/vortex-scalar/src/binary.rs +++ b/vortex-scalar/src/binary.rs @@ -1,6 +1,6 @@ use vortex_buffer::ByteBuffer; use vortex_dtype::{DType, Nullability}; -use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; +use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult}; use crate::value::{InnerScalarValue, ScalarValue}; use crate::Scalar; @@ -20,8 +20,19 @@ impl<'a> BinaryScalar<'a> { self.value.as_ref().cloned() } - pub fn cast(&self, _dtype: &DType) -> VortexResult { - todo!() + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + if !matches!(dtype, DType::Binary(..)) { + vortex_bail!("Can't cast binary to {}", dtype) + } + Ok(Scalar::new( + dtype.clone(), + ScalarValue(InnerScalarValue::Buffer( + self.value + .as_ref() + .vortex_expect("nullness handled in Scalar::cast") + .clone(), + )), + )) } } diff --git a/vortex-scalar/src/bool.rs b/vortex-scalar/src/bool.rs index 7f85e22e19c..0948854fd79 100644 --- a/vortex-scalar/src/bool.rs +++ b/vortex-scalar/src/bool.rs @@ -1,6 +1,6 @@ use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, Nullability}; -use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; +use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult}; use crate::value::ScalarValue; use crate::{InnerScalarValue, Scalar}; @@ -20,14 +20,14 @@ impl<'a> BoolScalar<'a> { self.value } - pub fn cast(&self, dtype: &DType) -> VortexResult { - match dtype { - DType::Bool(_) => Ok(Scalar::bool( - self.value().ok_or_else(|| vortex_err!("not a bool"))?, - dtype.nullability(), - )), - _ => vortex_bail!("Can't cast {} to bool", dtype), + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + if !matches!(dtype, DType::Bool(..)) { + vortex_bail!("Can't cast bool to {}", dtype) } + Ok(Scalar::bool( + self.value.vortex_expect("nullness handled in Scalar::cast"), + dtype.nullability(), + )) } pub fn invert(self) -> BoolScalar<'a> { diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs index e1cc6920853..78264b5927b 100644 --- a/vortex-scalar/src/extension.rs +++ b/vortex-scalar/src/extension.rs @@ -1,42 +1,57 @@ use std::sync::Arc; use vortex_dtype::{DType, ExtDType}; -use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexResult}; +use vortex_error::{vortex_bail, VortexError, VortexResult}; use crate::value::ScalarValue; use crate::Scalar; pub struct ExtScalar<'a> { - dtype: &'a DType, + ext_dtype: &'a ExtDType, value: &'a ScalarValue, } impl<'a> ExtScalar<'a> { pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult { - if !matches!(dtype, DType::Extension(..)) { + let DType::Extension(ext_dtype) = dtype else { vortex_bail!("Expected extension scalar, found {}", dtype) - } - - Ok(Self { dtype, value }) - } + }; - #[inline] - pub fn dtype(&self) -> &'a DType { - self.dtype + Ok(Self { ext_dtype, value }) } /// Returns the storage scalar of the extension scalar. pub fn storage(&self) -> Scalar { - let storage_dtype = if let DType::Extension(ext_dtype) = self.dtype() { - ext_dtype.storage_dtype().clone() - } else { - vortex_panic!("Expected extension DType: {}", self.dtype()); - }; - Scalar::new(storage_dtype, self.value.clone()) + Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone()) } - pub fn cast(&self, _dtype: &DType) -> VortexResult { - todo!() + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + if self.value.is_null() && !dtype.is_nullable() { + vortex_bail!( + "cannot cast extension dtype with id {} and storage type {} to {}", + self.ext_dtype.id(), + self.ext_dtype.storage_dtype(), + dtype + ); + } + + if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) { + // Casting from an extension type to the underlying storage type is OK. + return Ok(Scalar::new(dtype.clone(), self.value.clone())); + } + + if let DType::Extension(ext_dtype) = dtype { + if self.ext_dtype.eq_ignore_nullability(ext_dtype) { + return Ok(Scalar::new(dtype.clone(), self.value.clone())); + } + } + + vortex_bail!( + "cannot cast extension dtype with id {} and storage type {} to {}", + self.ext_dtype.id(), + self.ext_dtype.storage_dtype(), + dtype + ); } } diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index e8297cfaf10..683b2df1d8d 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -100,39 +100,38 @@ impl Scalar { } } - pub fn cast(&self, dtype: &DType) -> VortexResult { - if self.is_null() && !dtype.is_nullable() { - vortex_bail!("Can't cast null scalar to non-nullable type") + pub fn cast(&self, target: &DType) -> VortexResult { + if let DType::Extension(ext_dtype) = target { + let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?; + Ok(Scalar::extension(ext_dtype.clone(), storage_scalar)) + } else { + self.cast_to_non_extension(target) } + } - if self.dtype().eq_ignore_nullability(dtype) { - return Ok(Scalar { - dtype: dtype.clone(), - value: self.value.clone(), - }); + pub fn cast_to_non_extension(&self, target: &DType) -> VortexResult { + assert!(!matches!(target, DType::Extension(..))); + if self.is_null() { + if target.is_nullable() { + return Ok(Scalar::new(target.clone(), self.value.clone())); + } else { + vortex_bail!("Can't cast null scalar to non-nullable type {}", target) + } } - match dtype { - DType::Null => vortex_bail!("Can't cast non-null to null"), - DType::Bool(_) => BoolScalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::Primitive(..) => PrimitiveScalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::Utf8(_) => Utf8Scalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::Binary(_) => BinaryScalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::Struct(..) => StructScalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::List(..) => ListScalar::try_from(self).and_then(|s| s.cast(dtype)), - DType::Extension(ext_dtype) => { - if !self.value().is_instance_of(ext_dtype.storage_dtype()) { - vortex_bail!( - "Failed to cast scalar to extension dtype with storage type {:?}, found {:?}", - ext_dtype.storage_dtype(), - self.dtype() - ); - } - Ok(Scalar::extension( - ext_dtype.clone(), - self.cast(ext_dtype.storage_dtype())?, - )) - } + if self.dtype().eq_ignore_nullability(target) { + return Ok(Scalar::new(target.clone(), self.value.clone())); + } + + match &self.dtype { + DType::Null => unreachable!(), // handled by if is_null case + DType::Bool(_) => self.as_bool().cast(target), + DType::Primitive(..) => self.as_primitive().cast(target), + DType::Utf8(_) => self.as_utf8().cast(target), + DType::Binary(_) => self.as_binary().cast(target), + DType::Struct(..) => self.as_struct().cast(target), + DType::List(..) => self.as_list().cast(target), + DType::Extension(..) => self.as_extension().cast(target), } } @@ -296,3 +295,191 @@ from_vec_for_scalar!(f64); from_vec_for_scalar!(String); from_vec_for_scalar!(BufferString); from_vec_for_scalar!(ByteBuffer); + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use rstest::rstest; + use vortex_dtype::{DType, ExtDType, ExtID, Nullability, PType}; + + use crate::{InnerScalarValue, PValue, Scalar, ScalarValue}; + + #[rstest] + fn null_can_cast_to_anything_nullable( + #[values( + DType::Null, + DType::Bool(Nullability::Nullable), + DType::Primitive(PType::I32, Nullability::Nullable), + DType::Extension(Arc::from(ExtDType::new( + ExtID::from("a"), + Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)), + None, + ))), + DType::Extension(Arc::from(ExtDType::new( + ExtID::from("b"), + Arc::from(DType::Utf8(Nullability::Nullable)), + None, + ))) + )] + source_dtype: DType, + #[values( + DType::Null, + DType::Bool(Nullability::Nullable), + DType::Primitive(PType::I32, Nullability::Nullable), + DType::Extension(Arc::from(ExtDType::new( + ExtID::from("a"), + Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)), + None, + ))), + DType::Extension(Arc::from(ExtDType::new( + ExtID::from("b"), + Arc::from(DType::Utf8(Nullability::Nullable)), + None, + ))) + )] + target_dtype: DType, + ) { + assert_eq!( + Scalar::null(source_dtype) + .cast(&target_dtype) + .unwrap() + .dtype(), + &target_dtype + ); + } + + #[test] + fn list_casts() { + let list = Scalar::new( + DType::List( + Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)), + Nullability::Nullable, + ), + ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue( + InnerScalarValue::Primitive(PValue::U16(6)), + )]))), + ); + + let target_u32 = DType::List( + Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)), + Nullability::Nullable, + ); + assert_eq!(list.cast(&target_u32).unwrap().dtype(), &target_u32); + + let target_u32_nonnull = DType::List( + Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)), + Nullability::Nullable, + ); + assert_eq!( + list.cast(&target_u32_nonnull).unwrap().dtype(), + &target_u32_nonnull + ); + + let target_nonnull = DType::List( + Arc::from(DType::Primitive(PType::U32, Nullability::Nullable)), + Nullability::NonNullable, + ); + assert_eq!(list.cast(&target_nonnull).unwrap().dtype(), &target_nonnull); + + let target_u8 = DType::List( + Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)), + Nullability::Nullable, + ); + assert_eq!(list.cast(&target_u8).unwrap().dtype(), &target_u8); + + let list_with_null = Scalar::new( + DType::List( + Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)), + Nullability::Nullable, + ), + ScalarValue(InnerScalarValue::List(Arc::from([ + ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))), + ScalarValue(InnerScalarValue::Null), + ]))), + ); + let target_u8 = DType::List( + Arc::from(DType::Primitive(PType::U8, Nullability::Nullable)), + Nullability::Nullable, + ); + assert_eq!(list_with_null.cast(&target_u8).unwrap().dtype(), &target_u8); + + let target_u32_nonnull = DType::List( + Arc::from(DType::Primitive(PType::U32, Nullability::NonNullable)), + Nullability::Nullable, + ); + assert!(list_with_null.cast(&target_u32_nonnull).is_err()); + } + + #[test] + fn cast_to_from_extension_types() { + let apples = ExtDType::new( + ExtID::new(Arc::from("apples")), + Arc::from(DType::Primitive(PType::U16, Nullability::NonNullable)), + None, + ); + let ext_dtype = DType::Extension(Arc::from(apples.clone())); + let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true))); + let storage_scalar = Scalar::new( + DType::clone(apples.storage_dtype()), + ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))), + ); + + // to self + let expected_dtype = &ext_dtype; + let actual = ext_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // to nullable self + let expected_dtype = &ext_dtype.as_nullable(); + let actual = ext_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // cast to the storage type + let expected_dtype = apples.storage_dtype(); + let actual = ext_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // cast to the storage type, nullable + let expected_dtype = &apples.storage_dtype().as_nullable(); + let actual = ext_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // cast from storage type to extension + let expected_dtype = &ext_dtype; + let actual = storage_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // cast from storage type to extension, nullable + let expected_dtype = &ext_dtype.as_nullable(); + let actual = storage_scalar.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // cast from *compatible* storage type to extension + let storage_scalar_u64 = Scalar::new( + DType::clone(apples.storage_dtype()), + ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))), + ); + let expected_dtype = &ext_dtype; + let actual = storage_scalar_u64.cast(expected_dtype).unwrap(); + assert_eq!(actual.dtype(), expected_dtype); + + // cast from *incompatible* storage type to extension + let apples_u8 = ExtDType::new( + ExtID::new(Arc::from("apples")), + Arc::from(DType::Primitive(PType::U8, Nullability::NonNullable)), + None, + ); + let expected_dtype = &DType::Extension(Arc::from(apples_u8)); + let result = storage_scalar.cast(expected_dtype); + assert!( + result.as_ref().is_err_and(|err| { + err + .to_string() + .contains("Can't cast u16 scalar 1000_u16 to u8 (cause: Cannot read primitive value U16(1000) as u8") + }), + "{:?}", + result + ); + } +} diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/list.rs index 9d493e4ea0e..7a32c4ef329 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/list.rs @@ -1,14 +1,16 @@ use std::ops::Deref; use std::sync::Arc; +use itertools::Itertools as _; use vortex_dtype::{DType, Nullability}; -use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexResult}; +use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexExpect as _, VortexResult}; use crate::value::{InnerScalarValue, ScalarValue}; use crate::Scalar; pub struct ListScalar<'a> { dtype: &'a DType, + element_dtype: &'a Arc, elements: Option>, } @@ -65,8 +67,26 @@ impl<'a> ListScalar<'a> { }) } - pub fn cast(&self, _dtype: &DType) -> VortexResult { - todo!() + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + let DType::List(element_dtype, ..) = dtype else { + vortex_bail!("Can't cast {:?} to {}", self.dtype(), dtype) + }; + + Ok(Scalar::new( + dtype.clone(), + ScalarValue(InnerScalarValue::List( + self.elements + .as_ref() + .vortex_expect("nullness handled in Scalar::cast") + .iter() + .map(|element| { + Scalar::new(DType::clone(self.element_dtype), element.clone()) + .cast(element_dtype) + .map(|x| x.value().clone()) + }) + .process_results(|iter| iter.collect())?, + )), + )) } } @@ -105,12 +125,13 @@ impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> { type Error = VortexError; fn try_from(value: &'a Scalar) -> Result { - if !matches!(value.dtype(), DType::List(..)) { + let DType::List(element_dtype, ..) = value.dtype() else { vortex_bail!("Expected list scalar, found {}", value.dtype()) - } + }; Ok(Self { dtype: value.dtype(), + element_dtype, elements: value.value.as_list()?.cloned(), }) } diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index ad719d7f05b..bf002725d4b 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -1,11 +1,12 @@ use std::any::type_name; use std::fmt::{Debug, Display}; -use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, NumCast}; +use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive}; use vortex_dtype::half::f16; use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability, PType}; use vortex_error::{ - vortex_bail, vortex_err, vortex_panic, VortexError, VortexResult, VortexUnwrap, + vortex_bail, vortex_err, vortex_panic, VortexError, VortexExpect as _, VortexResult, + VortexUnwrap, }; use crate::pvalue::PValue; @@ -67,17 +68,19 @@ impl<'a> PrimitiveScalar<'a> { self.pvalue.map(|pv| pv.as_primitive::().vortex_unwrap()) } - pub fn cast(&self, dtype: &DType) -> VortexResult { + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { let ptype = PType::try_from(dtype)?; - match_each_native_ptype!(ptype, |$Q| { - match_each_native_ptype!(self.ptype(), |$T| { - Ok(Scalar::primitive::<$Q>( - <$Q as NumCast>::from(self.typed_value::<$T>().expect("Invalid value")) - .ok_or_else(|| vortex_err!("Can't cast {} scalar {} to {}", self.ptype, self.typed_value::<$T>().expect("Invalid value"), dtype))?, - dtype.nullability(), - )) - }) - }) + let pvalue = self + .pvalue + .vortex_expect("nullness handled in Scalar::cast"); + Ok(match_each_native_ptype!(ptype, |$Q| { + Scalar::primitive( + pvalue + .as_primitive::<$Q>() + .map_err(|err| vortex_err!("Can't cast {} scalar {} to {} (cause: {})", self.ptype, pvalue, dtype, err))?, + dtype.nullability() + ) + })) } /// Attempt to extract the primitive value as the given type. diff --git a/vortex-scalar/src/utf8.rs b/vortex-scalar/src/utf8.rs index cc353db3c36..ef6e534dd27 100644 --- a/vortex-scalar/src/utf8.rs +++ b/vortex-scalar/src/utf8.rs @@ -1,7 +1,7 @@ use vortex_buffer::BufferString; use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, Nullability}; -use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; +use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult}; use crate::value::ScalarValue; use crate::{InnerScalarValue, Scalar}; @@ -21,8 +21,19 @@ impl<'a> Utf8Scalar<'a> { self.value.as_ref().cloned() } - pub fn cast(&self, _dtype: &DType) -> VortexResult { - todo!() + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + if !matches!(dtype, DType::Utf8(..)) { + vortex_bail!("Can't cast utf8 to {}", dtype) + } + Ok(Scalar::new( + dtype.clone(), + ScalarValue(InnerScalarValue::BufferString( + self.value + .as_ref() + .vortex_expect("nullness handled in Scalar::cast") + .clone(), + )), + )) } }