From 6b3b101103bafa11bc67c90d3268dc6aefc0b916 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 6 Jan 2025 18:13:12 -0500 Subject: [PATCH 1/8] feat: mask Mask sets entries of an array to null. I like the analogy to light: the array is a sequence of lights (each value might be a different wavelength). Null is represented by the absence of light. Placing a mask (i.e. a piece of plastic with slits) over the array causes those values where the mask is present (i.e. "on", "true") to be dark. An example in pseudo-code: ```rust a = [1, 2, 3, 4, 5] a_mask = [t, f, f, t, f] mask(a, a_mask) == [null, 2, 3, null, 5] ``` Specializations --------------- I only fallback to Arrow for two of the core arrays: - Sparse. I was skeptical that I could do better than decompressing and applying it. - Constant. If the mask is sparse, SparseArray might be a good choice. I didn't investigate. For the non-core arrays, I'm missing the following. I'm not clear that I can beat decompression for run end. The others are easy enough but some amount of typing and testing. - fastlanes - fsst - roaring - runend - runend-bool - zigzag Naming ------ Pandas also calls this operation [`mask`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.mask.html) but accepts an optional second argument which is an array of values to use instead of null (which makes Pandas' mask more like an `if_else`). Arrow-rs calls this [`nullif`](https://arrow.apache.org/rust/arrow/compute/fn.nullif.html). Arrow-cpp has [`if_else(condition, consequent, alternate)`](https://arrow.apache.org/docs/cpp/compute.html#cpp-compute-scalar-selections) and [`replace_with_mask(array, mask, replacements)`](https://arrow.apache.org/docs/cpp/compute.html#replace-functions) both of which can implement our `mask` by passing a `NullArray` as the third argument. --- encodings/alp/src/alp/compute/mask.rs | 72 ++++++ encodings/alp/src/alp/compute/mod.rs | 10 +- encodings/alp/src/alp_rd/compute/mask.rs | 57 +++++ encodings/alp/src/alp_rd/compute/mod.rs | 7 +- encodings/bytebool/src/compute.rs | 24 +- encodings/datetime-parts/src/compute/cast.rs | 25 ++ encodings/datetime-parts/src/compute/mask.rs | 58 +++++ encodings/datetime-parts/src/compute/mod.rs | 13 +- encodings/dict/Cargo.toml | 4 + encodings/dict/benches/dict_mask.rs | 50 ++++ encodings/dict/src/array.rs | 5 + encodings/dict/src/compute/mask.rs | 48 ++++ encodings/dict/src/compute/mod.rs | 21 +- vortex-array/src/array/bool/compute/cast.rs | 20 ++ vortex-array/src/array/bool/compute/mask.rs | 12 + vortex-array/src/array/bool/compute/mod.rs | 14 +- vortex-array/src/array/bool/mod.rs | 6 + .../src/array/chunked/compute/filter.rs | 126 +++++----- .../src/array/chunked/compute/mask.rs | 152 ++++++++++++ vortex-array/src/array/chunked/compute/mod.rs | 7 +- .../src/array/constant/compute/mod.rs | 18 ++ .../src/array/extension/compute/cast.rs | 27 ++ .../src/array/extension/compute/mask.rs | 48 ++++ .../src/array/extension/compute/mod.rs | 18 +- vortex-array/src/array/list/compute/mod.rs | 37 ++- vortex-array/src/array/null/compute.rs | 12 +- .../src/array/primitive/compute/mask.rs | 17 ++ .../src/array/primitive/compute/mod.rs | 7 +- vortex-array/src/array/primitive/mod.rs | 26 ++ vortex-array/src/array/sparse/compute/mod.rs | 36 ++- vortex-array/src/array/struct_/compute.rs | 209 +++++++++++++++- vortex-array/src/array/varbin/compute/cast.rs | 25 ++ vortex-array/src/array/varbin/compute/mask.rs | 44 ++++ vortex-array/src/array/varbin/compute/mod.rs | 14 +- .../src/array/varbinview/compute/mod.rs | 65 ++++- vortex-array/src/compute/filter.rs | 43 +++- vortex-array/src/compute/mask.rs | 232 ++++++++++++++++++ vortex-array/src/compute/mod.rs | 16 +- vortex-array/src/validity.rs | 51 +++- 39 files changed, 1583 insertions(+), 93 deletions(-) create mode 100644 encodings/alp/src/alp/compute/mask.rs create mode 100644 encodings/alp/src/alp_rd/compute/mask.rs create mode 100644 encodings/datetime-parts/src/compute/cast.rs create mode 100644 encodings/datetime-parts/src/compute/mask.rs create mode 100644 encodings/dict/benches/dict_mask.rs create mode 100644 encodings/dict/src/compute/mask.rs create mode 100644 vortex-array/src/array/bool/compute/cast.rs create mode 100644 vortex-array/src/array/bool/compute/mask.rs create mode 100644 vortex-array/src/array/chunked/compute/mask.rs create mode 100644 vortex-array/src/array/extension/compute/cast.rs create mode 100644 vortex-array/src/array/extension/compute/mask.rs create mode 100644 vortex-array/src/array/primitive/compute/mask.rs create mode 100644 vortex-array/src/array/varbin/compute/cast.rs create mode 100644 vortex-array/src/array/varbin/compute/mask.rs create mode 100644 vortex-array/src/compute/mask.rs diff --git a/encodings/alp/src/alp/compute/mask.rs b/encodings/alp/src/alp/compute/mask.rs new file mode 100644 index 00000000000..5737378f626 --- /dev/null +++ b/encodings/alp/src/alp/compute/mask.rs @@ -0,0 +1,72 @@ +use vortex_array::compute::{mask, try_cast, FilterMask, MaskFn}; +use vortex_array::{ArrayDType as _, ArrayData, IntoArrayData}; +use vortex_error::VortexResult; + +use crate::{ALPArray, ALPEncoding}; + +impl MaskFn for ALPEncoding { + fn mask(&self, array: &ALPArray, filter_mask: FilterMask) -> VortexResult { + ALPArray::try_new( + mask(&array.encoded(), filter_mask)?, + array.exponents(), + array + .patches() + .map(|patches| { + patches.map_values(|values| try_cast(&values, &values.dtype().as_nullable())) + }) + .transpose()?, + ) + .map(IntoArrayData::into_array) + } +} + +#[cfg(test)] +mod tests { + use vortex_array::array::PrimitiveArray; + use vortex_array::compute::test_harness::test_mask; + use vortex_array::validity::Validity; + use vortex_array::IntoArrayData as _; + use vortex_buffer::buffer; + + use crate::alp_encode; + + #[test] + fn test_mask_no_patches_alp_array() { + test_mask( + alp_encode(&PrimitiveArray::new( + buffer![1.0f32, 2.0, 3.0, 4.0, 5.0], + Validity::AllValid, + )) + .unwrap() + .into_array(), + ); + + test_mask( + alp_encode(&PrimitiveArray::new( + buffer![1.0f32, 2.0, 3.0, 4.0, 5.0], + Validity::NonNullable, + )) + .unwrap() + .into_array(), + ); + } + + #[test] + fn test_mask_patched_alp_array() { + let alp_array = alp_encode(&PrimitiveArray::new( + buffer![1.0f32, 2.0, 3.0, 4.0, 1e10], + Validity::AllValid, + )) + .unwrap(); + assert!(alp_array.patches().is_some()); + test_mask(alp_array.into_array()); + + let alp_array = alp_encode(&PrimitiveArray::new( + buffer![1.0f32, 2.0, 3.0, 4.0, 1e10], + Validity::NonNullable, + )) + .unwrap(); + assert!(alp_array.patches().is_some()); + test_mask(alp_array.into_array()); + } +} diff --git a/encodings/alp/src/alp/compute/mod.rs b/encodings/alp/src/alp/compute/mod.rs index 74c0546ff6f..895dbece9d1 100644 --- a/encodings/alp/src/alp/compute/mod.rs +++ b/encodings/alp/src/alp/compute/mod.rs @@ -1,6 +1,8 @@ +mod mask; + use vortex_array::compute::{ - filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, SliceFn, - TakeFn, + filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, MaskFn, ScalarAtFn, + SliceFn, TakeFn, }; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; @@ -14,6 +16,10 @@ impl ComputeVTable for ALPEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/encodings/alp/src/alp_rd/compute/mask.rs b/encodings/alp/src/alp_rd/compute/mask.rs new file mode 100644 index 00000000000..0f9da516509 --- /dev/null +++ b/encodings/alp/src/alp_rd/compute/mask.rs @@ -0,0 +1,57 @@ +use vortex_array::compute::{mask, FilterMask, MaskFn}; +use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; +use vortex_error::VortexResult; + +use crate::{ALPRDArray, ALPRDEncoding}; + +impl MaskFn for ALPRDEncoding { + fn mask(&self, array: &ALPRDArray, filter_mask: FilterMask) -> VortexResult { + Ok(ALPRDArray::try_new( + array.dtype().as_nullable(), + mask(&array.left_parts(), filter_mask)?, + array.left_parts_dict(), + array.right_parts(), + array.right_bit_width(), + array.left_parts_patches(), + )? + .into_array()) + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_array::array::PrimitiveArray; + use vortex_array::compute::test_harness::test_mask; + use vortex_array::IntoArrayData as _; + + use crate::{ALPRDFloat, RDEncoder}; + + #[rstest] + #[case(0.1f32, 0.2f32, 3e25f32)] + #[case(0.1f64, 0.2f64, 3e100f64)] + fn test_mask_simple(#[case] a: T, #[case] b: T, #[case] outlier: T) { + test_mask( + RDEncoder::new(&[a, b]) + .encode(&PrimitiveArray::from_iter([a, b, outlier, b, outlier])) + .into_array(), + ); + } + + #[rstest] + #[case(0.1f32, 3e25f32)] + #[case(0.5f64, 1e100f64)] + fn test_mask_with_nulls(#[case] a: T, #[case] outlier: T) { + test_mask( + RDEncoder::new(&[a]) + .encode(&PrimitiveArray::from_option_iter([ + Some(a), + None, + Some(outlier), + Some(a), + None, + ])) + .into_array(), + ); + } +} diff --git a/encodings/alp/src/alp_rd/compute/mod.rs b/encodings/alp/src/alp_rd/compute/mod.rs index c696d1c51a0..c87a80b9b9a 100644 --- a/encodings/alp/src/alp_rd/compute/mod.rs +++ b/encodings/alp/src/alp_rd/compute/mod.rs @@ -1,9 +1,10 @@ -use vortex_array::compute::{ComputeVTable, FilterFn, ScalarAtFn, SliceFn, TakeFn}; +use vortex_array::compute::{ComputeVTable, FilterFn, MaskFn, ScalarAtFn, SliceFn, TakeFn}; use vortex_array::ArrayData; use crate::ALPRDEncoding; mod filter; +mod mask; mod scalar_at; mod slice; mod take; @@ -13,6 +14,10 @@ impl ComputeVTable for ALPRDEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index a21a37c3e12..fa6e04145f8 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -1,5 +1,7 @@ use num_traits::AsPrimitive; -use vortex_array::compute::{ComputeVTable, FillForwardFn, ScalarAtFn, SliceFn, TakeFn}; +use vortex_array::compute::{ + ComputeVTable, FillForwardFn, FilterMask, MaskFn, ScalarAtFn, SliceFn, TakeFn, +}; use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; @@ -14,6 +16,10 @@ impl ComputeVTable for ByteBoolEncoding { None } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -27,6 +33,13 @@ impl ComputeVTable for ByteBoolEncoding { } } +impl MaskFn for ByteBoolEncoding { + fn mask(&self, array: &ByteBoolArray, mask: FilterMask) -> VortexResult { + ByteBoolArray::try_new(array.buffer().clone(), array.validity().mask(&mask)?) + .map(IntoArrayData::into_array) + } +} + impl ScalarAtFn for ByteBoolEncoding { fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult { Ok(Scalar::bool( @@ -136,6 +149,7 @@ impl FillForwardFn for ByteBoolEncoding { #[cfg(test)] mod tests { + use vortex_array::compute::test_harness::test_mask; use vortex_array::compute::{compare, scalar_at, slice, Operator}; use super::*; @@ -208,4 +222,12 @@ mod tests { let s = scalar_at(&arr, 4).unwrap(); assert!(s.is_null()); } + + #[test] + fn test_mask_byte_bool() { + test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).into_array()); + test_mask( + ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).into_array(), + ); + } } diff --git a/encodings/datetime-parts/src/compute/cast.rs b/encodings/datetime-parts/src/compute/cast.rs new file mode 100644 index 00000000000..ce703d630fa --- /dev/null +++ b/encodings/datetime-parts/src/compute/cast.rs @@ -0,0 +1,25 @@ +use vortex_array::compute::{try_cast, CastFn}; +use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::{DateTimePartsArray, DateTimePartsEncoding}; + +impl CastFn for DateTimePartsEncoding { + fn cast(&self, array: &DateTimePartsArray, dtype: &DType) -> VortexResult { + if !array.dtype().eq_ignore_nullability(dtype) { + vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype); + }; + + Ok(DateTimePartsArray::try_new( + array.dtype().clone().as_nullable(), + try_cast( + array.days().as_ref(), + &array.days().dtype().with_nullability(dtype.nullability()), + )?, + array.seconds(), + array.subsecond(), + )? + .into_array()) + } +} diff --git a/encodings/datetime-parts/src/compute/mask.rs b/encodings/datetime-parts/src/compute/mask.rs new file mode 100644 index 00000000000..9c52974b464 --- /dev/null +++ b/encodings/datetime-parts/src/compute/mask.rs @@ -0,0 +1,58 @@ +use vortex_array::compute::{mask, FilterMask, MaskFn}; +use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; +use vortex_error::VortexResult; + +use crate::{DateTimePartsArray, DateTimePartsEncoding}; + +impl MaskFn for DateTimePartsEncoding { + fn mask(&self, array: &DateTimePartsArray, filter_mask: FilterMask) -> VortexResult { + Ok(DateTimePartsArray::try_new( + array.dtype().clone().as_nullable(), + mask(array.days().as_ref(), filter_mask)?, + array.seconds(), + array.subsecond(), + )? + .into_array()) + } +} + +#[cfg(test)] +mod tests { + use vortex_array::array::TemporalArray; + use vortex_array::compute::test_harness::test_mask; + use vortex_array::IntoArrayData as _; + use vortex_buffer::buffer; + use vortex_datetime_dtype::TimeUnit; + use vortex_dtype::DType; + + use crate::{split_temporal, DateTimePartsArray, TemporalParts}; + + #[test] + fn test_mask_datetime_parts_array() { + let raw_millis = buffer![ + 86_400i64, // element with only day component + 86_400i64 + 1000, // element with day + second components + 86_400i64 + 1000 + 1, // element with day + second + sub-second components + 86_400i64 + 1000 + 5, // element with day + second + sub-second components + 86_400i64 + 1000 + 55, // element with day + second + sub-second components + ] + .into_array(); + let temporal_array = + TemporalArray::new_timestamp(raw_millis, TimeUnit::Ms, Some("UTC".to_string())); + let TemporalParts { + days, + seconds, + subseconds, + } = split_temporal(temporal_array.clone()).unwrap(); + let date_times = DateTimePartsArray::try_new( + DType::Extension(temporal_array.ext_dtype()), + days, + seconds, + subseconds, + ) + .unwrap() + .into_array(); + + test_mask(date_times.clone()); + } +} diff --git a/encodings/datetime-parts/src/compute/mod.rs b/encodings/datetime-parts/src/compute/mod.rs index de3937c1e8a..d5ebef744f8 100644 --- a/encodings/datetime-parts/src/compute/mod.rs +++ b/encodings/datetime-parts/src/compute/mod.rs @@ -1,9 +1,12 @@ +mod cast; mod filter; +mod mask; mod take; use vortex_array::array::{PrimitiveArray, TemporalArray}; use vortex_array::compute::{ - scalar_at, slice, try_cast, ComputeVTable, FilterFn, ScalarAtFn, SliceFn, TakeFn, + scalar_at, slice, try_cast, CastFn, ComputeVTable, FilterFn, MaskFn, ScalarAtFn, SliceFn, + TakeFn, }; use vortex_array::validity::ArrayValidity; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; @@ -17,10 +20,18 @@ use vortex_scalar::{PrimitiveScalar, Scalar}; use crate::{DateTimePartsArray, DateTimePartsEncoding}; impl ComputeVTable for DateTimePartsEncoding { + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/encodings/dict/Cargo.toml b/encodings/dict/Cargo.toml index 89924fea653..69f81008db8 100644 --- a/encodings/dict/Cargo.toml +++ b/encodings/dict/Cargo.toml @@ -35,3 +35,7 @@ vortex-array = { workspace = true, features = ["test-harness"] } [[bench]] name = "dict_compress" harness = false + +[[bench]] +name = "dict_mask" +harness = false diff --git a/encodings/dict/benches/dict_mask.rs b/encodings/dict/benches/dict_mask.rs new file mode 100644 index 00000000000..4e0241d4b1f --- /dev/null +++ b/encodings/dict/benches/dict_mask.rs @@ -0,0 +1,50 @@ +#![allow(clippy::unwrap_used)] + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng as _}; +use vortex_array::array::PrimitiveArray; +use vortex_array::compute::{mask, FilterMask}; +use vortex_array::IntoArrayData as _; +use vortex_buffer::{buffer, Buffer}; +use vortex_dict::DictArray; + +fn filter_mask(len: usize, fraction_masked: f64, rng: &mut StdRng) -> FilterMask { + let indices = (0..len) + .filter(|_| rng.gen_bool(fraction_masked)) + .map(|x| x as u64) + .collect::>(); + FilterMask::from_indices(len, indices) +} + +#[allow(clippy::cast_possible_truncation)] +fn bench_dict_mask(c: &mut Criterion) { + let mut group = c.benchmark_group("bench_dict_mask"); + let mut rng = StdRng::seed_from_u64(0); + + let len = 65_535; + // for fraction_valid in [0.5, 0.1, 0.01, 0.001, 0.0001] { + for fraction_valid in [0.1] { + let codes = + PrimitiveArray::from_iter((0..len).map(|_| (!rng.gen_bool(fraction_valid)) as u64)) + .into_array(); + let values = buffer![1].into_array(); + let array = DictArray::try_new(codes, values).unwrap().into_array(); + // for fraction_masked in [0.1, 0.01, 0.001, 0.0001] { + for fraction_masked in [0.9, 0.5, 0.1, 0.0001] { + let filter_mask = filter_mask(len, fraction_masked, &mut rng); + group.bench_with_input( + BenchmarkId::from_parameter(format!( + "fraction_valid={}, fraction_masked={}", + fraction_valid, fraction_masked + )), + &(&array, filter_mask), + |b, (array, filter_mask)| b.iter(|| mask(array, filter_mask.clone()).unwrap()), + ); + } + } + group.finish() +} + +criterion_group!(benches, bench_dict_mask); +criterion_main!(benches); diff --git a/encodings/dict/src/array.rs b/encodings/dict/src/array.rs index 18584099521..2a9785af649 100644 --- a/encodings/dict/src/array.rs +++ b/encodings/dict/src/array.rs @@ -48,6 +48,11 @@ impl DictArray { ) } + #[inline] + pub fn codes_ptype(&self) -> PType { + self.metadata().codes_ptype + } + #[inline] pub fn codes(&self) -> ArrayData { self.as_ref() diff --git a/encodings/dict/src/compute/mask.rs b/encodings/dict/src/compute/mask.rs new file mode 100644 index 00000000000..10ef542b88f --- /dev/null +++ b/encodings/dict/src/compute/mask.rs @@ -0,0 +1,48 @@ +use vortex_array::compute::{FilterIter, FilterMask, MaskFn}; +use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant as _}; +use vortex_buffer::BufferMut; +use vortex_dtype::{match_each_integer_ptype, NativePType}; +use vortex_error::VortexResult; + +use crate::{DictArray, DictEncoding}; + +impl MaskFn for DictEncoding { + fn mask(&self, array: &DictArray, mask: FilterMask) -> VortexResult { + let new_codes = match_each_integer_ptype!(array.codes_ptype(), |$T| { + let mut codes = array.codes().into_primitive()?.into_buffer_mut(); + typed_mask::<$T>(&mut codes, mask)?; + codes.into_array() + }); + DictArray::try_new(new_codes, array.values()).map(IntoArrayData::into_array) + } +} + +fn typed_mask(codes: &mut BufferMut, mask: FilterMask) -> VortexResult<()> { + match mask.iter()? { + FilterIter::Indices(indices) => { + for index in indices { + codes[*index] = T::zero(); + } + } + FilterIter::IndicesIter(bit_index_iterator) => { + for index in bit_index_iterator { + codes[index] = T::zero(); + } + } + FilterIter::Slices(slices) => { + for slice in slices { + for index in slice.0..slice.1 { + codes[index] = T::zero(); + } + } + } + FilterIter::SlicesIter(bit_slice_iterator) => { + for slice in bit_slice_iterator { + for index in slice.0..slice.1 { + codes[index] = T::zero(); + } + } + } + } + Ok(()) +} diff --git a/encodings/dict/src/compute/mod.rs b/encodings/dict/src/compute/mod.rs index 79cfb45ab76..626db947a2d 100644 --- a/encodings/dict/src/compute/mod.rs +++ b/encodings/dict/src/compute/mod.rs @@ -1,10 +1,11 @@ mod binary_numeric; mod compare; mod like; +mod mask; use vortex_array::compute::{ filter, scalar_at, slice, take, BinaryNumericFn, CompareFn, ComputeVTable, FilterFn, - FilterMask, LikeFn, ScalarAtFn, SliceFn, TakeFn, + FilterMask, LikeFn, MaskFn, ScalarAtFn, SliceFn, TakeFn, }; use vortex_array::{ArrayData, IntoArrayData}; use vortex_error::VortexResult; @@ -29,6 +30,10 @@ impl ComputeVTable for DictEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -78,7 +83,7 @@ impl SliceFn for DictEncoding { mod test { use vortex_array::accessor::ArrayAccessor; use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinViewArray}; - use vortex_array::compute::test_harness::test_binary_numeric; + use vortex_array::compute::test_harness::{test_binary_numeric, test_mask}; use vortex_array::compute::{compare, scalar_at, slice, Operator}; use vortex_array::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::{DType, Nullability}; @@ -160,4 +165,16 @@ mod test { let array = sliced_dict_array(); test_binary_numeric::(array) } + + #[test] + fn test_mask_dict_array() { + let reference = + PrimitiveArray::from_option_iter([None, Some(42), Some(-9), Some(42), Some(5)]); + let (codes, values) = dict_encode_primitive(&reference); + test_mask( + DictArray::try_new(codes.into_array(), values.into_array()) + .unwrap() + .into_array(), + ) + } } diff --git a/vortex-array/src/array/bool/compute/cast.rs b/vortex-array/src/array/bool/compute/cast.rs new file mode 100644 index 00000000000..86e600d2100 --- /dev/null +++ b/vortex-array/src/array/bool/compute/cast.rs @@ -0,0 +1,20 @@ +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::array::{BoolArray, BoolEncoding}; +use crate::compute::CastFn; +use crate::{ArrayData, IntoArrayData}; + +impl CastFn for BoolEncoding { + fn cast(&self, array: &BoolArray, dtype: &DType) -> VortexResult { + let DType::Bool(new_nullability) = dtype else { + vortex_bail!(MismatchedTypes: "bool type", dtype); + }; + + BoolArray::try_new( + array.boolean_buffer(), + array.validity().with_nullability(*new_nullability)?, + ) + .map(IntoArrayData::into_array) + } +} diff --git a/vortex-array/src/array/bool/compute/mask.rs b/vortex-array/src/array/bool/compute/mask.rs new file mode 100644 index 00000000000..a1720220075 --- /dev/null +++ b/vortex-array/src/array/bool/compute/mask.rs @@ -0,0 +1,12 @@ +use vortex_error::VortexResult; + +use crate::array::{BoolArray, BoolEncoding}; +use crate::compute::{FilterMask, MaskFn}; +use crate::{ArrayData, IntoArrayData}; + +impl MaskFn for BoolEncoding { + fn mask(&self, array: &BoolArray, mask: FilterMask) -> VortexResult { + BoolArray::try_new(array.boolean_buffer(), array.validity().mask(&mask)?) + .map(IntoArrayData::into_array) + } +} diff --git a/vortex-array/src/array/bool/compute/mod.rs b/vortex-array/src/array/bool/compute/mod.rs index 5a9499590b2..dc0dbfdebc4 100644 --- a/vortex-array/src/array/bool/compute/mod.rs +++ b/vortex-array/src/array/bool/compute/mod.rs @@ -1,15 +1,17 @@ use crate::array::BoolEncoding; use crate::compute::{ - BinaryBooleanFn, ComputeVTable, FillForwardFn, FillNullFn, FilterFn, InvertFn, ScalarAtFn, - SliceFn, TakeFn, + BinaryBooleanFn, CastFn, ComputeVTable, FillForwardFn, FillNullFn, FilterFn, InvertFn, MaskFn, + ScalarAtFn, SliceFn, TakeFn, }; use crate::ArrayData; +mod cast; mod fill_forward; mod fill_null; pub mod filter; mod flatten; mod invert; +mod mask; mod scalar_at; mod slice; mod take; @@ -23,6 +25,10 @@ impl ComputeVTable for BoolEncoding { None } + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn> { Some(self) } @@ -39,6 +45,10 @@ impl ComputeVTable for BoolEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/array/bool/mod.rs b/vortex-array/src/array/bool/mod.rs index 4d9ec80fb59..a94b2f4b1f9 100644 --- a/vortex-array/src/array/bool/mod.rs +++ b/vortex-array/src/array/bool/mod.rs @@ -211,6 +211,7 @@ mod tests { use vortex_dtype::Nullability; use crate::array::{BoolArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; use crate::compute::{scalar_at, slice}; use crate::patches::Patches; use crate::validity::Validity; @@ -308,4 +309,9 @@ mod tests { assert_eq!(offset, 0); assert_eq!(values.as_slice(), &[254, 127]); } + + #[test] + fn test_mask_primitive_array() { + test_mask(BoolArray::from_iter([true, false, true, true, false]).into_array()); + } } diff --git a/vortex-array/src/array/chunked/compute/filter.rs b/vortex-array/src/array/chunked/compute/filter.rs index ea01a01fbf1..af17aba60da 100644 --- a/vortex-array/src/array/chunked/compute/filter.rs +++ b/vortex-array/src/array/chunked/compute/filter.rs @@ -32,7 +32,7 @@ impl FilterFn for ChunkedEncoding { /// When we rewrite a set of slices in a filter predicate into chunk addresses, we want to account /// for the fact that some chunks will be wholly skipped. #[derive(Clone)] -enum ChunkFilter { +pub(crate) enum ChunkFilter { All, None, Slices(Vec<(usize, usize)>), @@ -61,66 +61,12 @@ fn slices_to_mask(slices: &[(usize, usize)], len: usize) -> FilterMask { } /// Filter the chunks using slice ranges. -#[allow(deprecated)] fn filter_slices(array: &ChunkedArray, mask: FilterMask) -> VortexResult> { - let mut result = Vec::with_capacity(array.nchunks()); - - // Pre-materialize the chunk ends for performance. - // The chunk ends is nchunks+1, which is expected to be in the hundreds or at most thousands. - let chunk_ends = array.chunk_offsets().into_canonical()?.into_primitive()?; - let chunk_ends = chunk_ends.as_slice::(); - - let mut chunk_filters = vec![ChunkFilter::None; array.nchunks()]; - - for (slice_start, slice_end) in mask.iter_slices()? { - let (start_chunk, start_idx) = find_chunk_idx(slice_start, chunk_ends); - // NOTE: we adjust slice end back by one, in case it ends on a chunk boundary, we do not - // want to index into the unused chunk. - let (end_chunk, end_idx) = find_chunk_idx(slice_end - 1, chunk_ends); - // Adjust back to an exclusive range - let end_idx = end_idx + 1; - - if start_chunk == end_chunk { - // start == end means that the slice lies within a single chunk. - match &mut chunk_filters[start_chunk] { - f @ (ChunkFilter::All | ChunkFilter::None) => { - *f = ChunkFilter::Slices(vec![(start_idx, end_idx)]); - } - ChunkFilter::Slices(slices) => { - slices.push((start_idx, end_idx)); - } - } - } else { - // start != end means that the range is split over at least two chunks: - // start chunk: append a slice from (start_idx, start_chunk_end), i.e. whole chunk. - // end chunk: append a slice from (0, end_idx). - // chunks between start and end: append ChunkFilter::All. - let start_chunk_len: usize = - (chunk_ends[start_chunk + 1] - chunk_ends[start_chunk]).try_into()?; - let start_slice = (start_idx, start_chunk_len); - match &mut chunk_filters[start_chunk] { - f @ (ChunkFilter::All | ChunkFilter::None) => { - *f = ChunkFilter::Slices(vec![start_slice]) - } - ChunkFilter::Slices(slices) => slices.push(start_slice), - } - - let end_slice = (0, end_idx); - match &mut chunk_filters[end_chunk] { - f @ (ChunkFilter::All | ChunkFilter::None) => { - *f = ChunkFilter::Slices(vec![end_slice]); - } - ChunkFilter::Slices(slices) => slices.push(end_slice), - } - - for chunk in &mut chunk_filters[start_chunk + 1..end_chunk] { - *chunk = ChunkFilter::All; - } - } - } + let chunked_filters = chunk_filters(array, mask)?; // Now, apply the chunk filter to every slice. - for (chunk, chunk_filter) in array.chunks().zip(chunk_filters.iter()) { + let mut result = Vec::with_capacity(array.nchunks()); + for (chunk, chunk_filter) in array.chunks().zip(chunked_filters.iter()) { match chunk_filter { // All => preserve the entire chunk unfiltered. ChunkFilter::All => result.push(chunk), @@ -186,9 +132,71 @@ fn filter_indices(array: &ChunkedArray, mask: FilterMask) -> VortexResult VortexResult> { + // Pre-materialize the chunk ends for performance. + // The chunk ends is nchunks+1, which is expected to be in the hundreds or at most thousands. + let chunk_ends = array.chunk_offsets().into_canonical()?.into_primitive()?; + let chunk_ends = chunk_ends.as_slice::(); + + let mut chunk_filters = vec![ChunkFilter::None; array.nchunks()]; + + for (slice_start, slice_end) in mask.iter_slices()? { + let (start_chunk, start_idx) = find_chunk_idx(slice_start, chunk_ends); + // NOTE: we adjust slice end back by one, in case it ends on a chunk boundary, we do not + // want to index into the unused chunk. + let (end_chunk, end_idx) = find_chunk_idx(slice_end - 1, chunk_ends); + // Adjust back to an exclusive range + let end_idx = end_idx + 1; + + if start_chunk == end_chunk { + // start == end means that the slice lies within a single chunk. + match &mut chunk_filters[start_chunk] { + f @ (ChunkFilter::All | ChunkFilter::None) => { + *f = ChunkFilter::Slices(vec![(start_idx, end_idx)]); + } + ChunkFilter::Slices(slices) => { + slices.push((start_idx, end_idx)); + } + } + } else { + // start != end means that the range is split over at least two chunks: + // start chunk: append a slice from (start_idx, start_chunk_end), i.e. whole chunk. + // end chunk: append a slice from (0, end_idx). + // chunks between start and end: append ChunkFilter::All. + let start_chunk_len: usize = + (chunk_ends[start_chunk + 1] - chunk_ends[start_chunk]).try_into()?; + let start_slice = (start_idx, start_chunk_len); + match &mut chunk_filters[start_chunk] { + f @ (ChunkFilter::All | ChunkFilter::None) => { + *f = ChunkFilter::Slices(vec![start_slice]) + } + ChunkFilter::Slices(slices) => slices.push(start_slice), + } + + let end_slice = (0, end_idx); + match &mut chunk_filters[end_chunk] { + f @ (ChunkFilter::All | ChunkFilter::None) => { + *f = ChunkFilter::Slices(vec![end_slice]); + } + ChunkFilter::Slices(slices) => slices.push(end_slice), + } + + for chunk in &mut chunk_filters[start_chunk + 1..end_chunk] { + *chunk = ChunkFilter::All; + } + } + } + + Ok(chunk_filters) +} + // Mirrors the find_chunk_idx method on ChunkedArray, but avoids all of the overhead // from scalars, dtypes, and metadata cloning. -fn find_chunk_idx(idx: usize, chunk_ends: &[u64]) -> (usize, usize) { +pub(crate) fn find_chunk_idx(idx: usize, chunk_ends: &[u64]) -> (usize, usize) { let chunk_id = chunk_ends .search_sorted(&(idx as u64), SearchSortedSide::Right) .to_ends_index(chunk_ends.len()) diff --git a/vortex-array/src/array/chunked/compute/mask.rs b/vortex-array/src/array/chunked/compute/mask.rs new file mode 100644 index 00000000000..70da58d95da --- /dev/null +++ b/vortex-array/src/array/chunked/compute/mask.rs @@ -0,0 +1,152 @@ +use itertools::Itertools as _; +use vortex_buffer::BufferMut; +use vortex_dtype::DType; +use vortex_error::{VortexExpect as _, VortexResult}; +use vortex_scalar::Scalar; + +use super::filter::{chunk_filters, find_chunk_idx, ChunkFilter}; +use crate::array::{ChunkedArray, ChunkedEncoding, ConstantArray}; +use crate::compute::{mask, try_cast, FilterIter, FilterMask, MaskFn}; +use crate::{ArrayDType, ArrayData, ArrayLen as _, IntoArrayData, IntoCanonical as _}; + +impl MaskFn for ChunkedEncoding { + fn mask(&self, array: &ChunkedArray, mask: FilterMask) -> VortexResult { + let new_dtype = array.dtype().as_nullable(); + let new_chunks = match mask.iter()? { + FilterIter::Indices(_) => mask_indices(array, mask, &new_dtype), + FilterIter::IndicesIter(_) => mask_indices(array, mask, &new_dtype), + FilterIter::Slices(_) => mask_slices(array, mask, &new_dtype), + FilterIter::SlicesIter(_) => mask_slices(array, mask, &new_dtype), + }?; + debug_assert_eq!(new_chunks.len(), array.nchunks()); + debug_assert_eq!( + new_chunks.iter().map(|x| x.len()).sum::(), + array.len() + ); + ChunkedArray::try_new(new_chunks, new_dtype).map(IntoArrayData::into_array) + } +} + +#[allow(deprecated)] +pub fn mask_indices( + array: &ChunkedArray, + filter_mask: FilterMask, + new_dtype: &DType, +) -> VortexResult> { + let mut new_chunks = Vec::with_capacity(array.nchunks()); + let mut current_chunk_id = 0; + let mut chunk_indices = BufferMut::with_capacity(array.nchunks()); + + // Avoid find_chunk_idx and use our own to avoid the overhead. + // The array should only be some thousands of values in the general case. + let chunk_ends = array.chunk_offsets().into_canonical()?.into_primitive()?; + let chunk_ends = chunk_ends.as_slice::(); + + for set_index in filter_mask.iter_indices()? { + let (chunk_id, index) = find_chunk_idx(set_index, chunk_ends); + if chunk_id != current_chunk_id { + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + let masked_chunk = mask( + &chunk, + FilterMask::from_indices(chunk.len(), chunk_indices.clone().freeze()), + )?; + new_chunks.push(masked_chunk); + current_chunk_id += 1; + + // Advance the chunk forward, reset the chunk indices buffer. + while current_chunk_id < chunk_id { + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + new_chunks.push(try_cast(chunk, new_dtype)?); + current_chunk_id += 1; + } + + chunk_indices.clear(); + } + + chunk_indices.push(index as u64); + } + + if !chunk_indices.is_empty() { + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + let masked_chunk = mask( + &chunk, + FilterMask::from_indices(chunk.len(), chunk_indices.clone().freeze()), + )?; + new_chunks.push(masked_chunk); + current_chunk_id += 1; + } + + while current_chunk_id < array.nchunks() { + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + new_chunks.push(try_cast(chunk, new_dtype)?); + current_chunk_id += 1; + } + + Ok(new_chunks) +} + +pub fn mask_slices( + array: &ChunkedArray, + filter_mask: FilterMask, + new_dtype: &DType, +) -> VortexResult> { + let chunked_filters = chunk_filters(array, filter_mask)?; + + array + .chunks() + .zip_eq(chunked_filters.iter()) + .map(|(chunk, chunk_filter)| -> VortexResult { + Ok(match chunk_filter { + ChunkFilter::All => { + // All => entire chunk is masked out + ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()).into_array() + } + ChunkFilter::None => { + // None => preserve the entire chunk unmasked + chunk + } + // Slices => turn the slices into a boolean buffer. + ChunkFilter::Slices(slices) => mask( + &chunk, + FilterMask::from_slices(chunk.len(), slices.iter().cloned()), + )?, + }) + }) + .process_results(|iter| iter.collect::>()) +} + +#[cfg(test)] +mod test { + use vortex_buffer::buffer; + use vortex_dtype::{DType, Nullability, PType}; + + use crate::array::{ChunkedArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; + use crate::IntoArrayData; + + #[test] + fn test_mask_chunked_array() { + let dtype = DType::Primitive(PType::U64, Nullability::NonNullable); + let chunked = ChunkedArray::try_new( + vec![ + buffer![0u64, 1].into_array(), + buffer![2_u64].into_array(), + PrimitiveArray::empty::(dtype.nullability()).into_array(), + buffer![3_u64, 4].into_array(), + ], + dtype, + ) + .unwrap() + .into_array(); + + test_mask(chunked); + } +} diff --git a/vortex-array/src/array/chunked/compute/mod.rs b/vortex-array/src/array/chunked/compute/mod.rs index 7cb867ffb77..e4a5f83e2b6 100644 --- a/vortex-array/src/array/chunked/compute/mod.rs +++ b/vortex-array/src/array/chunked/compute/mod.rs @@ -5,7 +5,7 @@ use crate::array::chunked::ChunkedArray; use crate::array::ChunkedEncoding; use crate::compute::{ try_cast, BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, ComputeVTable, FillNullFn, - FilterFn, InvertFn, ScalarAtFn, SliceFn, TakeFn, + FilterFn, InvertFn, MaskFn, ScalarAtFn, SliceFn, TakeFn, }; use crate::{ArrayData, IntoArrayData}; @@ -15,6 +15,7 @@ mod compare; mod fill_null; mod filter; mod invert; +mod mask; mod scalar_at; mod slice; mod take; @@ -48,6 +49,10 @@ impl ComputeVTable for ChunkedEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/array/constant/compute/mod.rs b/vortex-array/src/array/constant/compute/mod.rs index dbb7dfbc7b5..291589d41b6 100644 --- a/vortex-array/src/array/constant/compute/mod.rs +++ b/vortex-array/src/array/constant/compute/mod.rs @@ -76,3 +76,21 @@ impl FilterFn for ConstantEncoding { Ok(ConstantArray::new(array.scalar(), mask.true_count()).into_array()) } } + +#[cfg(test)] +mod test { + use vortex_dtype::half::f16; + use vortex_scalar::Scalar; + + use super::ConstantArray; + use crate::compute::test_harness::test_mask; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_constant() { + test_mask(ConstantArray::new(Scalar::null_typed::(), 5).into_array()); + test_mask(ConstantArray::new(Scalar::from(3u16), 5).into_array()); + test_mask(ConstantArray::new(Scalar::from(1.0f32 / 0.0f32), 5).into_array()); + test_mask(ConstantArray::new(Scalar::from(f16::from_f32(3.0f32)), 5).into_array()); + } +} diff --git a/vortex-array/src/array/extension/compute/cast.rs b/vortex-array/src/array/extension/compute/cast.rs new file mode 100644 index 00000000000..c8b046126d5 --- /dev/null +++ b/vortex-array/src/array/extension/compute/cast.rs @@ -0,0 +1,27 @@ +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::array::extension::ExtensionArray; +use crate::array::ExtensionEncoding; +use crate::compute::{try_cast, CastFn}; +use crate::{ArrayDType as _, ArrayData, IntoArrayData as _}; + +impl CastFn for ExtensionEncoding { + fn cast(&self, array: &ExtensionArray, dtype: &DType) -> VortexResult { + if !array.dtype().eq_ignore_nullability(dtype) { + vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype); + } + let DType::Extension(ext_dtype) = dtype else { + vortex_bail!( + "dtype must have extension dtype {} {}", + array.dtype(), + dtype + ); + }; + Ok(ExtensionArray::new( + ext_dtype.clone(), + try_cast(array.storage(), ext_dtype.storage_dtype())?, + ) + .into_array()) + } +} diff --git a/vortex-array/src/array/extension/compute/mask.rs b/vortex-array/src/array/extension/compute/mask.rs new file mode 100644 index 00000000000..53f1a7d1443 --- /dev/null +++ b/vortex-array/src/array/extension/compute/mask.rs @@ -0,0 +1,48 @@ +use std::sync::Arc; + +use vortex_dtype::{DType, Nullability}; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::array::extension::ExtensionArray; +use crate::array::ExtensionEncoding; +use crate::compute::{mask, FilterMask, MaskFn}; +use crate::{ArrayDType as _, ArrayData, IntoArrayData}; + +impl MaskFn for ExtensionEncoding { + fn mask(&self, array: &ExtensionArray, filter_mask: FilterMask) -> VortexResult { + let DType::Extension(ext_dtype) = array.dtype() else { + vortex_bail!("extension array must have extension dtype"); + }; + Ok(ExtensionArray::new( + Arc::from(ext_dtype.with_nullability(Nullability::Nullable)), + mask(&array.storage(), filter_mask)?, + ) + .into_array()) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use vortex_buffer::buffer; + use vortex_dtype::{DType, ExtDType, ExtID, PType}; + + use crate::array::ExtensionArray; + use crate::compute::test_harness::test_mask; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_extension_array() { + let ext_dtype = Arc::new(ExtDType::new( + ExtID::new("timestamp".into()), + DType::from(PType::I64).into(), + None, + )); + + test_mask( + ExtensionArray::new(ext_dtype.clone(), buffer![1i64, 2, 3, 4, 5].into_array()) + .into_array(), + ); + } +} diff --git a/vortex-array/src/array/extension/compute/mod.rs b/vortex-array/src/array/extension/compute/mod.rs index 46c0ec9f834..7e1bd2b445a 100644 --- a/vortex-array/src/array/extension/compute/mod.rs +++ b/vortex-array/src/array/extension/compute/mod.rs @@ -1,4 +1,6 @@ +mod cast; mod compare; +mod mask; use vortex_error::VortexResult; use vortex_scalar::Scalar; @@ -6,23 +8,29 @@ use vortex_scalar::Scalar; use crate::array::extension::ExtensionArray; use crate::array::ExtensionEncoding; use crate::compute::{ - scalar_at, slice, take, CastFn, CompareFn, ComputeVTable, ScalarAtFn, SliceFn, TakeFn, + scalar_at, slice, take, CastFn, CompareFn, ComputeVTable, MaskFn, ScalarAtFn, SliceFn, TakeFn, }; use crate::variants::ExtensionArrayTrait; use crate::{ArrayData, IntoArrayData}; impl ComputeVTable for ExtensionEncoding { fn cast_fn(&self) -> Option<&dyn CastFn> { - // It's not possible to cast an extension array to another type. - // TODO(ngates): we should allow some extension arrays to implement a callback - // to support this - None + // It's not possible to cast an extension array to another type, but we can make it + // nullable. + // + // TODO(ngates): we should allow some extension arrays to implement a callback to support + // this + Some(self) } fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/array/list/compute/mod.rs b/vortex-array/src/array/list/compute/mod.rs index 317f574c75b..0b99c390135 100644 --- a/vortex-array/src/array/list/compute/mod.rs +++ b/vortex-array/src/array/list/compute/mod.rs @@ -5,7 +5,7 @@ use vortex_error::VortexResult; use vortex_scalar::Scalar; use crate::array::{ListArray, ListEncoding}; -use crate::compute::{scalar_at, slice, ComputeVTable, ScalarAtFn, SliceFn}; +use crate::compute::{scalar_at, slice, ComputeVTable, FilterMask, MaskFn, ScalarAtFn, SliceFn}; use crate::{ArrayDType, ArrayData, IntoArrayData}; impl ComputeVTable for ListEncoding { @@ -16,6 +16,10 @@ impl ComputeVTable for ListEncoding { fn slice_fn(&self) -> Option<&dyn SliceFn> { Some(self) } + + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } } impl ScalarAtFn for ListEncoding { @@ -41,3 +45,34 @@ impl SliceFn for ListEncoding { .into_array()) } } + +impl MaskFn for ListEncoding { + fn mask(&self, array: &ListArray, mask: FilterMask) -> VortexResult { + ListArray::try_new( + array.elements(), + array.offsets(), + array.validity().mask(&mask)?, + ) + .map(IntoArrayData::into_array) + } +} + +#[cfg(test)] +mod test { + use crate::array::{ListArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; + use crate::validity::Validity; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_list() { + let elements = PrimitiveArray::from_iter(0..35); + let offsets = PrimitiveArray::from_iter([0, 5, 11, 18, 26, 35]); + let validity = Validity::AllValid; + let array = ListArray::try_new(elements.into_array(), offsets.into_array(), validity) + .unwrap() + .into_array(); + + test_mask(array); + } +} diff --git a/vortex-array/src/array/null/compute.rs b/vortex-array/src/array/null/compute.rs index 7a224aaec6e..8938267aaba 100644 --- a/vortex-array/src/array/null/compute.rs +++ b/vortex-array/src/array/null/compute.rs @@ -4,11 +4,15 @@ use vortex_scalar::Scalar; use crate::array::null::NullArray; use crate::array::NullEncoding; -use crate::compute::{ComputeVTable, ScalarAtFn, SliceFn, TakeFn}; +use crate::compute::{ComputeVTable, FilterMask, MaskFn, ScalarAtFn, SliceFn, TakeFn}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; impl ComputeVTable for NullEncoding { + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -22,6 +26,12 @@ impl ComputeVTable for NullEncoding { } } +impl MaskFn for NullEncoding { + fn mask(&self, array: &NullArray, _mask: FilterMask) -> VortexResult { + Ok(array.clone().into_array()) + } +} + impl SliceFn for NullEncoding { fn slice(&self, _array: &NullArray, start: usize, stop: usize) -> VortexResult { Ok(NullArray::new(stop - start).into_array()) diff --git a/vortex-array/src/array/primitive/compute/mask.rs b/vortex-array/src/array/primitive/compute/mask.rs new file mode 100644 index 00000000000..c4bb4162724 --- /dev/null +++ b/vortex-array/src/array/primitive/compute/mask.rs @@ -0,0 +1,17 @@ +use vortex_error::VortexResult; + +use crate::array::primitive::PrimitiveArray; +use crate::array::PrimitiveEncoding; +use crate::compute::{FilterMask, MaskFn}; +use crate::variants::PrimitiveArrayTrait as _; +use crate::{ArrayData, IntoArrayData}; + +impl MaskFn for PrimitiveEncoding { + fn mask(&self, array: &PrimitiveArray, mask: FilterMask) -> VortexResult { + let validity = array.validity().mask(&mask)?; + Ok( + PrimitiveArray::from_byte_buffer(array.byte_buffer().clone(), array.ptype(), validity) + .into_array(), + ) + } +} diff --git a/vortex-array/src/array/primitive/compute/mod.rs b/vortex-array/src/array/primitive/compute/mod.rs index bac23e2d7f4..cfdde3eec32 100644 --- a/vortex-array/src/array/primitive/compute/mod.rs +++ b/vortex-array/src/array/primitive/compute/mod.rs @@ -1,6 +1,6 @@ use crate::array::PrimitiveEncoding; use crate::compute::{ - CastFn, ComputeVTable, FillForwardFn, FilterFn, ScalarAtFn, SearchSortedFn, + CastFn, ComputeVTable, FillForwardFn, FilterFn, MaskFn, ScalarAtFn, SearchSortedFn, SearchSortedUsizeFn, SliceFn, TakeFn, }; use crate::ArrayData; @@ -8,6 +8,7 @@ use crate::ArrayData; mod cast; mod fill; mod filter; +mod mask; mod scalar_at; mod search_sorted; mod slice; @@ -18,6 +19,10 @@ impl ComputeVTable for PrimitiveEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn> { Some(self) } diff --git a/vortex-array/src/array/primitive/mod.rs b/vortex-array/src/array/primitive/mod.rs index fa212cbda05..a1bdb2547c8 100644 --- a/vortex-array/src/array/primitive/mod.rs +++ b/vortex-array/src/array/primitive/mod.rs @@ -320,3 +320,29 @@ impl VisitorVTable for PrimitiveEncoding { visitor.visit_validity(&array.validity()) } } + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + + use crate::array::{BoolArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; + use crate::validity::Validity; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_primitive_array() { + test_mask(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable).into_array()); + test_mask(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid).into_array()); + test_mask(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllInvalid).into_array()); + test_mask( + PrimitiveArray::new( + buffer![0, 1, 2, 3, 4], + Validity::Array( + BoolArray::from_iter([true, false, true, false, true]).into_array(), + ), + ) + .into_array(), + ); + } +} diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index 1d323085b36..ae6f22d4a8a 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -105,13 +105,14 @@ impl FilterFn for SparseEncoding { mod test { use rstest::{fixture, rstest}; use vortex_buffer::buffer; + use vortex_dtype::{DType, Nullability, PType}; use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; use crate::array::sparse::SparseArray; - use crate::compute::test_harness::test_binary_numeric; + use crate::compute::test_harness::{test_binary_numeric, test_mask}; use crate::compute::{ - filter, search_sorted, slice, FilterMask, SearchResult, SearchSortedSide, + filter, search_sorted, slice, try_cast, FilterMask, SearchResult, SearchSortedSide, }; use crate::validity::Validity; use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; @@ -225,4 +226,35 @@ mod test { fn test_sparse_binary_numeric(array: ArrayData) { test_binary_numeric::(array) } + + #[test] + fn test_mask_sparse_array() { + let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)); + test_mask( + SparseArray::try_new( + buffer![1u64, 2, 4].into_array(), + try_cast( + buffer![100i32, 200, 300].into_array(), + null_fill_value.dtype(), + ) + .unwrap(), + 5, + null_fill_value, + ) + .unwrap() + .into_array(), + ); + + let ten_fill_value = Scalar::from(10i32); + test_mask( + SparseArray::try_new( + buffer![1u64, 2, 4].into_array(), + buffer![100i32, 200, 300].into_array(), + 5, + ten_fill_value, + ) + .unwrap() + .into_array(), + ) + } } diff --git a/vortex-array/src/array/struct_/compute.rs b/vortex-array/src/array/struct_/compute.rs index ba3b44dce7c..56064d4605a 100644 --- a/vortex-array/src/array/struct_/compute.rs +++ b/vortex-array/src/array/struct_/compute.rs @@ -1,21 +1,30 @@ use itertools::Itertools; -use vortex_error::VortexResult; +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexResult}; use vortex_scalar::Scalar; use crate::array::struct_::StructArray; use crate::array::StructEncoding; use crate::compute::{ - filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, SliceFn, - TakeFn, + filter, scalar_at, slice, take, try_cast, CastFn, ComputeVTable, FilterFn, FilterMask, MaskFn, + ScalarAtFn, SliceFn, TakeFn, }; use crate::variants::StructArrayTrait; -use crate::{ArrayDType, ArrayData, IntoArrayData}; +use crate::{ArrayDType, ArrayData, ArrayLen as _, IntoArrayData}; impl ComputeVTable for StructEncoding { + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -29,6 +38,28 @@ impl ComputeVTable for StructEncoding { } } +impl CastFn for StructEncoding { + fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult { + let Some(sdtype) = dtype.as_struct() else { + vortex_bail!("cannot cast {} to {}", array.dtype(), dtype); + }; + + let validity = array.validity().with_nullability(dtype.nullability())?; + + StructArray::try_new( + array.names().clone(), + array + .children() + .zip_eq(sdtype.dtypes()) + .map(|(field, dtype)| try_cast(&field, &dtype)) + .try_collect()?, + array.len(), + validity, + ) + .map(|a| a.into_array()) + } +} + impl ScalarAtFn for StructEncoding { fn scalar_at(&self, array: &StructArray, index: usize) -> VortexResult { Ok(Scalar::struct_( @@ -90,11 +121,31 @@ impl FilterFn for StructEncoding { } } +impl MaskFn for StructEncoding { + fn mask(&self, array: &StructArray, filter_mask: FilterMask) -> VortexResult { + let validity = array.validity().mask(&filter_mask)?; + + StructArray::try_new( + array.names().clone(), + array.children().collect(), + array.len(), + validity, + ) + .map(|a| a.into_array()) + } +} + #[cfg(test)] mod tests { - use crate::array::StructArray; - use crate::compute::{filter, FilterMask}; + use arrow_buffer::BooleanBuffer; + use vortex_buffer::buffer; + use vortex_dtype::{DType, Nullability, PType, StructDType}; + + use crate::array::{BoolArray, PrimitiveArray, StructArray, VarBinArray}; + use crate::compute::test_harness::test_mask; + use crate::compute::{filter, try_cast, FilterMask}; use crate::validity::Validity; + use crate::{ArrayDType as _, IntoArrayData as _}; #[test] fn filter_empty_struct() { @@ -114,4 +165,150 @@ mod tests { let filtered = filter(struct_arr.as_ref(), FilterMask::from_iter::<[bool; 0]>([])).unwrap(); assert_eq!(filtered.len(), 0); } + + #[test] + fn test_mask_empty_struct() { + test_mask( + StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable) + .unwrap() + .into_array(), + ); + } + + #[test] + fn test_mask_complex_struct() { + let xs = buffer![0i64, 1, 2, 3, 4].into_array(); + let ys = VarBinArray::from_iter( + [Some("a"), Some("b"), None, Some("d"), None], + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + let zs = + BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array(); + + test_mask( + StructArray::try_new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + StructArray::try_new( + ["left".into(), "right".into()].into(), + vec![xs.clone(), xs], + 5, + Validity::NonNullable, + ) + .unwrap() + .into_array(), + ys, + zs, + ], + 5, + Validity::NonNullable, + ) + .unwrap() + .into_array(), + ); + } + + #[test] + fn test_cast_empty_struct() { + let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable) + .unwrap() + .into_array(); + let non_nullable_dtype = DType::Struct( + StructDType::new([].into(), vec![]), + Nullability::NonNullable, + ); + let casted = try_cast(&array, &non_nullable_dtype).unwrap(); + assert_eq!(casted.dtype(), &non_nullable_dtype); + + let nullable_dtype = + DType::Struct(StructDType::new([].into(), vec![]), Nullability::Nullable); + let casted = try_cast(&array, &nullable_dtype).unwrap(); + assert_eq!(casted.dtype(), &nullable_dtype); + } + + #[test] + fn test_cast_complex_struct() { + let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]) + .into_array(); + let ys = VarBinArray::from_vec( + vec!["a", "b", "c", "d", "e"], + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + let zs = BoolArray::new( + BooleanBuffer::from_iter([true, true, false, false, true]), + Nullability::Nullable, + ) + .into_array(); + let fully_nullable_array = StructArray::try_new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + StructArray::try_new( + ["left".into(), "right".into()].into(), + vec![xs.clone(), xs], + 5, + Validity::AllValid, + ) + .unwrap() + .into_array(), + ys, + zs, + ], + 5, + Validity::AllValid, + ) + .unwrap() + .into_array(); + + let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable(); + let casted = try_cast(&fully_nullable_array, &top_level_non_nullable).unwrap(); + assert_eq!(casted.dtype(), &top_level_non_nullable); + + let non_null_xs_right = DType::Struct( + StructDType::new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + DType::Struct( + StructDType::new( + ["left".into(), "right".into()].into(), + vec![ + DType::Primitive(PType::I64, Nullability::NonNullable), + DType::Primitive(PType::I64, Nullability::Nullable), + ], + ), + Nullability::Nullable, + ), + DType::Utf8(Nullability::Nullable), + DType::Bool(Nullability::Nullable), + ], + ), + Nullability::Nullable, + ); + let casted = try_cast(&fully_nullable_array, &non_null_xs_right).unwrap(); + assert_eq!(casted.dtype(), &non_null_xs_right); + + let non_null_xs = DType::Struct( + StructDType::new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + DType::Struct( + StructDType::new( + ["left".into(), "right".into()].into(), + vec![ + DType::Primitive(PType::I64, Nullability::Nullable), + DType::Primitive(PType::I64, Nullability::Nullable), + ], + ), + Nullability::NonNullable, + ), + DType::Utf8(Nullability::Nullable), + DType::Bool(Nullability::Nullable), + ], + ), + Nullability::Nullable, + ); + let casted = try_cast(&fully_nullable_array, &non_null_xs).unwrap(); + assert_eq!(casted.dtype(), &non_null_xs); + } } diff --git a/vortex-array/src/array/varbin/compute/cast.rs b/vortex-array/src/array/varbin/compute/cast.rs new file mode 100644 index 00000000000..c1c2c04d157 --- /dev/null +++ b/vortex-array/src/array/varbin/compute/cast.rs @@ -0,0 +1,25 @@ +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::array::varbin::VarBinArray; +use crate::array::VarBinEncoding; +use crate::compute::CastFn; +use crate::{ArrayDType, ArrayData, IntoArrayData}; + +impl CastFn for VarBinEncoding { + fn cast(&self, array: &VarBinArray, dtype: &DType) -> VortexResult { + match dtype { + DType::Utf8(nullability) => { + let validity = array.validity().with_nullability(*nullability)?; + VarBinArray::try_new( + array.offsets(), + array.bytes(), + array.dtype().with_nullability(*nullability), + validity, + ) + .map(IntoArrayData::into_array) + } + _ => vortex_bail!("cannot cast {} to {}", array.dtype(), dtype), + } + } +} diff --git a/vortex-array/src/array/varbin/compute/mask.rs b/vortex-array/src/array/varbin/compute/mask.rs new file mode 100644 index 00000000000..a3481c03cf6 --- /dev/null +++ b/vortex-array/src/array/varbin/compute/mask.rs @@ -0,0 +1,44 @@ +use vortex_error::VortexResult; + +use crate::array::varbin::VarBinArray; +use crate::array::VarBinEncoding; +use crate::compute::{FilterMask, MaskFn}; +use crate::{ArrayDType, ArrayData, IntoArrayData}; + +impl MaskFn for VarBinEncoding { + fn mask(&self, array: &VarBinArray, mask: FilterMask) -> VortexResult { + VarBinArray::try_new( + array.offsets(), + array.bytes(), + array.dtype().as_nullable(), + array.validity().mask(&mask)?, + ) + .map(IntoArrayData::into_array) + } +} + +#[cfg(test)] +mod test { + use vortex_dtype::{DType, Nullability}; + + use crate::array::VarBinArray; + use crate::compute::test_harness::test_mask; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_var_bin_array() { + let array = VarBinArray::from_vec( + vec!["hello", "world", "filter", "good", "bye"], + DType::Utf8(Nullability::NonNullable), + ) + .into_array(); + test_mask(array); + + let array = VarBinArray::from_iter( + vec![Some("hello"), None, Some("filter"), Some("good"), None], + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + test_mask(array); + } +} diff --git a/vortex-array/src/array/varbin/compute/mod.rs b/vortex-array/src/array/varbin/compute/mod.rs index 2174280a013..f18d1f6be60 100644 --- a/vortex-array/src/array/varbin/compute/mod.rs +++ b/vortex-array/src/array/varbin/compute/mod.rs @@ -3,15 +3,23 @@ use vortex_scalar::Scalar; use crate::array::varbin::{varbin_scalar, VarBinArray}; use crate::array::VarBinEncoding; -use crate::compute::{CompareFn, ComputeVTable, FilterFn, ScalarAtFn, SliceFn, TakeFn}; +use crate::compute::{ + CastFn, CompareFn, ComputeVTable, FilterFn, MaskFn, ScalarAtFn, SliceFn, TakeFn, +}; use crate::{ArrayDType, ArrayData}; +mod cast; mod compare; mod filter; +mod mask; mod slice; mod take; impl ComputeVTable for VarBinEncoding { + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } @@ -20,6 +28,10 @@ impl ComputeVTable for VarBinEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/array/varbinview/compute/mod.rs b/vortex-array/src/array/varbinview/compute/mod.rs index 0e707b7c746..fd975788e09 100644 --- a/vortex-array/src/array/varbinview/compute/mod.rs +++ b/vortex-array/src/array/varbinview/compute/mod.rs @@ -3,19 +3,29 @@ use std::ops::Deref; use itertools::Itertools; use num_traits::AsPrimitive; use vortex_buffer::{Alignment, Buffer, ByteBuffer}; -use vortex_dtype::{match_each_integer_ptype, PType}; -use vortex_error::VortexResult; +use vortex_dtype::{match_each_integer_ptype, DType, PType}; +use vortex_error::{vortex_bail, VortexResult}; use vortex_scalar::Scalar; use crate::array::varbin::varbin_scalar; use crate::array::varbinview::{VarBinViewArray, VIEW_SIZE_BYTES}; use crate::array::{PrimitiveArray, VarBinViewEncoding}; -use crate::compute::{slice, ComputeVTable, ScalarAtFn, SliceFn, TakeFn}; +use crate::compute::{ + slice, CastFn, ComputeVTable, FilterMask, MaskFn, ScalarAtFn, SliceFn, TakeFn, +}; use crate::validity::Validity; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; impl ComputeVTable for VarBinViewEncoding { + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -138,6 +148,36 @@ fn take_views_unchecked>(views: Buffer, indices: &[I ) } +impl CastFn for VarBinViewEncoding { + fn cast(&self, array: &VarBinViewArray, dtype: &DType) -> VortexResult { + match dtype { + DType::Utf8(nullability) => { + let validity = array.validity().with_nullability(*nullability)?; + VarBinViewArray::try_new( + array.views(), + array.buffers().collect(), + array.dtype().with_nullability(*nullability), + validity, + ) + .map(IntoArrayData::into_array) + } + _ => vortex_bail!("cannot cast {} to {}", array.dtype(), dtype), + } + } +} + +impl MaskFn for VarBinViewEncoding { + fn mask(&self, array: &VarBinViewArray, mask: FilterMask) -> VortexResult { + VarBinViewArray::try_new( + array.views(), + array.buffers().collect(), + array.dtype().as_nullable(), + array.validity().mask(&mask)?, + ) + .map(IntoArrayData::into_array) + } +} + #[cfg(test)] mod tests { use vortex_buffer::buffer; @@ -145,6 +185,7 @@ mod tests { use crate::accessor::ArrayAccessor; use crate::array::VarBinViewArray; use crate::compute::take; + use crate::compute::test_harness::test_mask; use crate::{ArrayDType, IntoArrayData, IntoArrayVariant}; #[test] @@ -172,4 +213,22 @@ mod tests { [Some("one".to_string()), Some("four".to_string())] ); } + + #[test] + fn take_mask_var_bin_view_array() { + test_mask( + VarBinViewArray::from_iter_str(["one", "two", "three", "four", "five"]).into_array(), + ); + + test_mask( + VarBinViewArray::from_iter_nullable_str([ + Some("one"), + None, + Some("three"), + Some("four"), + Some("five"), + ]) + .into_array(), + ); + } } diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index 89fab5e6f87..8c82baa7c0f 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -36,7 +36,31 @@ where } } -/// Return a new array by applying a boolean predicate to select items from a base Array. +/// Keep only the elements for which the corresponding mask value is true. +/// +/// # Examples +/// +/// ``` +/// use vortex_array::IntoArrayData; +/// use vortex_array::array::{BoolArray, PrimitiveArray}; +/// use vortex_array::compute::{FilterMask, scalar_at}; +/// use vortex_array::compute::filter; +/// use vortex_array::validity::ArrayValidity; +/// use vortex_scalar::Scalar; +/// +/// let array = +/// PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]) +/// .into_array(); +/// let mask = FilterMask::try_from( +/// BoolArray::from_iter([true, false, false, false, true]).into_array(), +/// ) +/// .unwrap(); +/// +/// let filtered = filter(&array, mask).unwrap(); +/// assert_eq!(filtered.len(), 2); +/// assert_eq!(scalar_at(&filtered, 0).unwrap(), Scalar::from(Some(0_i32))); +/// assert_eq!(scalar_at(&filtered, 1).unwrap(), Scalar::from(Some(2_i32))); +/// ``` /// /// # Performance /// @@ -251,7 +275,7 @@ impl FilterMask { self.boolean_buffer().cloned() } - fn boolean_buffer(&self) -> VortexResult<&BooleanBuffer> { + pub fn boolean_buffer(&self) -> VortexResult<&BooleanBuffer> { self.buffer.get_or_try_init(|| { Ok(self .array @@ -351,6 +375,21 @@ impl FromIterator for FilterMask { } } +impl FilterMask { + pub fn from_slices>(length: usize, slices: T) -> Self { + let mut builder = BooleanBufferBuilder::new(length); + let mut cursor = 0; + for (start, end) in slices.into_iter() { + assert!(start <= cursor); + builder.append_n(start - cursor, false); + builder.append_n(end - start, true); + cursor = end; + } + + Self::from(builder.finish()) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/vortex-array/src/compute/mask.rs b/vortex-array/src/compute/mask.rs new file mode 100644 index 00000000000..1ee48cb837d --- /dev/null +++ b/vortex-array/src/compute/mask.rs @@ -0,0 +1,232 @@ +use arrow_array::BooleanArray; +use vortex_error::{vortex_bail, VortexError, VortexResult}; +use vortex_scalar::Scalar; + +use super::FilterMask; +use crate::array::ConstantArray; +use crate::arrow::FromArrowArray; +use crate::compute::try_cast; +use crate::encoding::Encoding; +use crate::{ArrayDType, ArrayData, IntoArrayData, IntoCanonical}; + +pub trait MaskFn { + /// Replace masked values with null in array. + fn mask(&self, array: &Array, mask: FilterMask) -> VortexResult; +} + +impl MaskFn for E +where + E: MaskFn, + for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>, +{ + fn mask(&self, array: &ArrayData, mask: FilterMask) -> VortexResult { + let (array_ref, encoding) = array.downcast_array_ref::()?; + MaskFn::mask(encoding, array_ref, mask) + } +} + +/// Replace values with null where the mask is true. +/// +/// The returned array is nullable but otherwise has the same dtype and length as `array`. +/// +/// # Examples +/// +/// ``` +/// use vortex_array::IntoArrayData; +/// use vortex_array::array::{BoolArray, PrimitiveArray}; +/// use vortex_array::compute::{FilterMask, scalar_at}; +/// use vortex_array::compute::mask; +/// use vortex_array::validity::ArrayValidity; +/// use vortex_scalar::Scalar; +/// +/// let array = +/// PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]) +/// .into_array(); +/// let mask_array = FilterMask::try_from( +/// BoolArray::from_iter([true, false, false, false, true]).into_array(), +/// ) +/// .unwrap(); +/// +/// let masked = mask(&array, mask_array).unwrap(); +/// assert_eq!(masked.len(), 5); +/// assert!(!masked.is_valid(0)); +/// assert!(!masked.is_valid(1)); +/// assert_eq!(scalar_at(&masked, 2).unwrap(), Scalar::from(Some(1))); +/// assert!(!masked.is_valid(3)); +/// assert!(!masked.is_valid(4)); +/// ``` +/// +pub fn mask(array: &ArrayData, mask: FilterMask) -> VortexResult { + if mask.len() != array.len() { + vortex_bail!( + "mask.len() is {}, does not equal array.len() of {}", + mask.len(), + array.len() + ); + } + + let true_count = mask.true_count(); + + let masked = if true_count == 0 { + // Fast-path for empty mask + try_cast(array, &array.dtype().as_nullable())? + } else if true_count == mask.len() { + // Fast-path for full mask. + ConstantArray::new( + Scalar::null(array.dtype().clone().as_nullable()), + array.len(), + ) + .into_array() + } else { + mask_impl(array, mask)? + }; + + debug_assert_eq!( + masked.len(), + array.len(), + "Mask should not change length {}\n\n{:?}\n\n{:?}", + array.encoding().id(), + array, + masked + ); + debug_assert_eq!( + masked.dtype(), + &array.dtype().as_nullable(), + "Mask dtype mismatch {} {} {} {}", + array.encoding().id(), + masked.dtype(), + array.dtype(), + array.dtype().as_nullable(), + ); + + Ok(masked) +} + +fn mask_impl(array: &ArrayData, mask: FilterMask) -> VortexResult { + if let Some(mask_fn) = array.encoding().mask_fn() { + return mask_fn.mask(array, mask); + } + + // Fallback: implement using Arrow kernels. + log::debug!("No mask implementation found for {}", array.encoding().id(),); + + let array_ref = array.clone().into_arrow()?; + let mask = BooleanArray::new(mask.to_boolean_buffer()?, None); + + let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?; + + Ok(ArrayData::from_arrow(masked, true)) +} + +#[cfg(feature = "test-harness")] +pub mod test_harness { + use crate::array::BoolArray; + use crate::compute::{mask, scalar_at, FilterMask}; + use crate::validity::ArrayValidity as _; + use crate::{ArrayData, IntoArrayData}; + + pub fn test_mask(array: ArrayData) { + assert_eq!(array.len(), 5); + test_heterogenous_mask(&array); + test_empty_mask(&array); + test_full_mask(&array); + } + + #[allow(clippy::unwrap_used)] + fn test_heterogenous_mask(array: &ArrayData) { + let mask_array = FilterMask::try_from( + BoolArray::from_iter([true, false, false, true, true]).into_array(), + ) + .unwrap(); + let masked = mask(array, mask_array).unwrap(); + assert_eq!(masked.len(), array.len()); + assert!(!masked.is_valid(0)); + assert_eq!( + scalar_at(&masked, 1).unwrap(), + scalar_at(array, 1).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 2).unwrap(), + scalar_at(array, 2).unwrap().into_nullable() + ); + assert!(!masked.is_valid(3)); + assert!(!masked.is_valid(4)); + } + + #[allow(clippy::unwrap_used)] + fn test_empty_mask(array: &ArrayData) { + let all_unmasked = FilterMask::try_from( + BoolArray::from_iter([false, false, false, false, false]).into_array(), + ) + .unwrap(); + let masked = mask(array, all_unmasked).unwrap(); + assert_eq!(masked.len(), array.len()); + assert_eq!( + scalar_at(&masked, 0).unwrap(), + scalar_at(array, 0).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 1).unwrap(), + scalar_at(array, 1).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 2).unwrap(), + scalar_at(array, 2).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 3).unwrap(), + scalar_at(array, 3).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 4).unwrap(), + scalar_at(array, 4).unwrap().into_nullable() + ); + } + + #[allow(clippy::unwrap_used)] + fn test_full_mask(array: &ArrayData) { + let all_masked = + FilterMask::try_from(BoolArray::from_iter([true, true, true, true, true]).into_array()) + .unwrap(); + let masked = mask(array, all_masked).unwrap(); + assert_eq!(masked.len(), array.len()); + assert!(!masked.is_valid(0)); + assert!(!masked.is_valid(1)); + assert!(!masked.is_valid(2)); + assert!(!masked.is_valid(3)); + assert!(!masked.is_valid(4)); + + let mask1 = FilterMask::try_from( + BoolArray::from_iter([true, false, false, true, true]).into_array(), + ) + .unwrap(); + let mask2 = FilterMask::try_from( + BoolArray::from_iter([false, true, false, false, true]).into_array(), + ) + .unwrap(); + let first = mask(array, mask1).unwrap(); + let double_masked = mask(&first, mask2).unwrap(); + assert_eq!(double_masked.len(), array.len()); + assert!(!double_masked.is_valid(0)); + assert!(!double_masked.is_valid(1)); + assert_eq!( + scalar_at(&double_masked, 2).unwrap(), + scalar_at(array, 2).unwrap().into_nullable() + ); + assert!(!double_masked.is_valid(3)); + assert!(!double_masked.is_valid(4)); + } +} + +#[cfg(test)] +mod test { + use super::test_harness::test_mask; + use crate::array::PrimitiveArray; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_non_nullable_array() { + let non_nullable_array = PrimitiveArray::from_iter([1, 2, 3, 4, 5]).into_array(); + test_mask(non_nullable_array); + } +} diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 9e7e9a135cd..d6b5396d110 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -21,6 +21,7 @@ pub use fill_null::{fill_null, FillNullFn}; pub use filter::{filter, FilterFn, FilterIter, FilterMask}; pub use invert::{invert, InvertFn}; pub use like::{like, LikeFn, LikeOptions}; +pub use mask::{mask, MaskFn}; pub use scalar_at::{scalar_at, ScalarAtFn}; pub use search_sorted::*; pub use slice::{slice, SliceFn}; @@ -37,6 +38,7 @@ mod fill_null; mod filter; mod invert; mod like; +mod mask; mod scalar_at; mod search_sorted; mod slice; @@ -87,13 +89,24 @@ pub trait ComputeVTable { None } - /// Filter an array with a given mask. + /// Remove masked values from the array. + /// + /// The length of the returned array equals the number of true mask values. /// /// See: [FilterFn]. fn filter_fn(&self) -> Option<&dyn FilterFn> { None } + /// Replace masked values with null. + /// + /// This operation does not change the length of the array. + /// + /// See: [MaskFn]. + fn mask_fn(&self) -> Option<&dyn MaskFn> { + None + } + /// Invert a boolean array. Converts true -> false, false -> true, null -> null. /// /// See [InvertFn] @@ -148,4 +161,5 @@ pub trait ComputeVTable { #[cfg(feature = "test-harness")] pub mod test_harness { pub use crate::compute::binary_numeric::test_harness::test_binary_numeric; + pub use crate::compute::mask::test_harness::test_mask; } diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 6128796a790..7a7893e61fe 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -1,7 +1,7 @@ //! Array validity and nullability behavior, used by arrays and compute functions. use std::fmt::{Debug, Display}; -use std::ops::BitAnd; +use std::ops::{BitAnd, Not}; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use serde::{Deserialize, Serialize}; @@ -222,6 +222,9 @@ impl Validity { } } + /// Keep only the entries for which the mask is true. + /// + /// The result has length equal to the number of true values in mask. pub fn filter(&self, mask: &FilterMask) -> VortexResult { // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily // if we happen to be NonNullable, AllValid, or AllInvalid. @@ -233,6 +236,23 @@ impl Validity { } } + /// Set to false any entries for which the mask is true. + /// + /// The result is always nullable. The result has the same length as self. + pub fn mask(self, mask: &FilterMask) -> VortexResult { + Ok(match self { + Validity::NonNullable | Validity::AllValid => { + Validity::Array(BoolArray::from(mask.boolean_buffer()?.not()).into_array()) + } + Validity::AllInvalid => Validity::AllInvalid, + Validity::Array(is_valid) => { + let bools = BoolArray::try_from(is_valid)?.boolean_buffer(); + let is_valid_in_mask = mask.boolean_buffer()?.not(); + Validity::from(bools.bitand(&is_valid_in_mask)) + } + }) + } + pub fn to_logical(&self, length: usize) -> LogicalValidity { match self { Self::NonNullable => LogicalValidity::AllValid(length), @@ -334,6 +354,35 @@ impl Validity { } } + /// Convert into a non-nullable variant + pub fn into_non_nullable(self) -> Option { + match self { + Self::NonNullable => Some(Self::NonNullable), + Self::AllValid => Some(Self::NonNullable), + Self::AllInvalid => None, + Self::Array(is_valid) => { + is_valid + .statistics() + .compute_min::() + .unwrap_or(false) + .then(|| { + // min true => all true + Self::NonNullable + }) + } + } + } + + /// Convert into a variant compatible with the given nullability, if possible. + pub fn with_nullability(self, nullability: Nullability) -> VortexResult { + match nullability { + Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| { + vortex_err!("cannot cast array with invalid values to non-nullable type") + }), + Nullability::Nullable => Ok(self.into_nullable()), + } + } + /// Create Validity from boolean array with given nullability of the array. /// /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself From c25b6687aae24411c5852b816f169d5891be560c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 13 Jan 2025 14:43:01 -0500 Subject: [PATCH 2/8] remove unnecssary clone --- vortex-array/src/array/extension/compute/mask.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vortex-array/src/array/extension/compute/mask.rs b/vortex-array/src/array/extension/compute/mask.rs index 53f1a7d1443..72a733ffde2 100644 --- a/vortex-array/src/array/extension/compute/mask.rs +++ b/vortex-array/src/array/extension/compute/mask.rs @@ -41,8 +41,7 @@ mod test { )); test_mask( - ExtensionArray::new(ext_dtype.clone(), buffer![1i64, 2, 3, 4, 5].into_array()) - .into_array(), + ExtensionArray::new(ext_dtype, buffer![1i64, 2, 3, 4, 5].into_array()).into_array(), ); } } From 26c461300112e7f1406aee12599593f91bac4e6c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 13 Jan 2025 15:00:15 -0500 Subject: [PATCH 3/8] clippy again --- encodings/datetime-parts/src/compute/mask.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/encodings/datetime-parts/src/compute/mask.rs b/encodings/datetime-parts/src/compute/mask.rs index 9c52974b464..1ddbfc66da8 100644 --- a/encodings/datetime-parts/src/compute/mask.rs +++ b/encodings/datetime-parts/src/compute/mask.rs @@ -53,6 +53,6 @@ mod tests { .unwrap() .into_array(); - test_mask(date_times.clone()); + test_mask(date_times); } } From 15c80b05a3b469bdd021d15568c0b39e791ea806 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 13 Jan 2025 14:38:02 -0500 Subject: [PATCH 4/8] feat: StructArrayTrait::field_by_idx --- vortex-array/src/variants.rs | 225 ++++++++++++++++++++++++++++++++++- 1 file changed, 224 insertions(+), 1 deletion(-) diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index 4ce126acaf2..e59c956d47d 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -7,9 +7,13 @@ use std::sync::Arc; use vortex_dtype::{DType, ExtDType, Field, FieldInfo, FieldNames, PType}; use vortex_error::{vortex_panic, VortexError, VortexExpect as _, VortexResult}; +use vortex_scalar::Scalar; +use crate::array::ConstantArray; +use crate::compute::{mask, try_cast, FilterMask}; use crate::encoding::Encoding; -use crate::{ArrayDType, ArrayData, ArrayTrait}; +use crate::validity::LogicalValidity; +use crate::{ArrayDType, ArrayData, ArrayTrait, IntoArrayData as _}; /// An Array encoding must declare which DTypes it can be downcast into. pub trait VariantsVTable { @@ -208,6 +212,158 @@ pub trait StructArrayTrait: ArrayTrait { self.names().len() } + /// Return a field's array by index, masking by the struct's validity. + /// + /// If either this array or the field array is invalid at a position, the result is invalid at + /// that position. Consequently, if either array is nullable, the result is nullable. + /// + /// # Examples + /// + /// The field of a non-nullable struct array is the same whether accessed by [field_by_idx] or + /// [maybe_null_field_by_idx]: + /// + /// ``` + /// use vortex_array::array::{BoolArray, PrimitiveArray, StructArray}; + /// use vortex_array::validity::{ArrayValidity, Validity}; + /// use vortex_array::variants::StructArrayTrait; + /// use vortex_array::{ArrayDType, IntoArrayData}; + /// use vortex_dtype::FieldNames; + /// + /// let original_field = PrimitiveArray::from_option_iter([ + /// Some(1), None, Some(3), None, Some(4), + /// ]).into_array(); + /// let array = StructArray::try_new( + /// FieldNames::from(["a".into()]), + /// vec![original_field], + /// 5, + /// Validity::NonNullable, + /// ).unwrap(); + /// let field = array.field_by_idx(0).unwrap().unwrap(); + /// let maybe_null_field = array.maybe_null_field_by_idx(0).unwrap(); + /// + /// assert_eq!(field.dtype(), maybe_null_field.dtype()); + /// assert!((0..field.len()).all(|i| { + /// field.is_valid(i) == maybe_null_field.is_valid(i) + /// })); + /// ``` + /// + /// When both a struct and its field are nullable, [field_by_idx] returns the intersection of + /// the validity, which is to say: a position is valid if and only if both the struct and the + /// field are valid at that position. + /// + /// ``` + /// use vortex_array::array::{BoolArray, PrimitiveArray, StructArray}; + /// use vortex_array::compute::scalar_at; + /// use vortex_array::validity::{ArrayValidity, Validity}; + /// use vortex_array::variants::StructArrayTrait; + /// use vortex_array::{ArrayDType, IntoArrayData}; + /// use vortex_dtype::FieldNames; + /// use vortex_scalar::Scalar; + /// + /// let original_field = PrimitiveArray::from_option_iter([ + /// Some(1), None, Some(3), None, Some(4), + /// ]).into_array(); + /// let struct_validity = Validity::Array(BoolArray::from_iter([ + /// true, true, false, false, true, + /// ]).into_array()); + /// let array = StructArray::try_new( + /// FieldNames::from(["a".into()]), + /// vec![original_field], + /// 5, + /// struct_validity, + /// ).unwrap(); + /// let field = array.field_by_idx(0).unwrap().unwrap(); + /// + /// assert!(field.dtype().is_nullable()); + /// assert!(!field.is_valid(0)); + /// assert!(!field.is_valid(1)); + /// assert_eq!(scalar_at(&field, 2).unwrap(), Scalar::from(Some(3))); + /// assert!(!field.is_valid(3)); + /// assert!(!field.is_valid(4)); + /// ``` + /// + /// When a field is non-nullable, but the struct is nullable, the field receives the struct's + /// validity. + /// + /// ``` + /// use vortex_array::array::{BoolArray, StructArray}; + /// use vortex_array::compute::scalar_at; + /// use vortex_array::validity::{ArrayValidity, Validity}; + /// use vortex_array::variants::StructArrayTrait; + /// use vortex_array::{ArrayDType, IntoArrayData}; + /// use vortex_buffer::buffer; + /// use vortex_dtype::FieldNames; + /// use vortex_scalar::Scalar; + /// + /// let original_field = buffer![1, 2, 3, 4, 5].into_array(); + /// let struct_validity = Validity::Array(BoolArray::from_iter([ + /// true, true, false, false, true, + /// ]).into_array()); + /// let array = StructArray::try_new( + /// FieldNames::from(["a".into()]), + /// vec![original_field], + /// 5, + /// struct_validity, + /// ).unwrap(); + /// let field = array.field_by_idx(0).unwrap().unwrap(); + /// + /// assert!(field.dtype().is_nullable()); + /// assert!(!field.is_valid(0)); + /// assert!(!field.is_valid(1)); + /// assert_eq!(scalar_at(&field, 2).unwrap(), Scalar::from(Some(3))); + /// assert_eq!(scalar_at(&field, 3).unwrap(), Scalar::from(Some(4))); + /// assert!(!field.is_valid(4)); + /// ``` + fn field_by_idx(&self, idx: usize) -> VortexResult> { + let Some(maybe_null_field) = self.maybe_null_field_by_idx(idx) else { + return Ok(None); + }; + + if !self.dtype().is_nullable() { + return Ok(Some(maybe_null_field)); + } + + match self.logical_validity() { + LogicalValidity::AllValid(_) => { + let nullable_dtype = maybe_null_field.dtype().as_nullable(); + try_cast(maybe_null_field, &nullable_dtype).map(Some) + } + LogicalValidity::AllInvalid(_) => { + let nullable_dtype = maybe_null_field.dtype().as_nullable(); + + Ok(Some( + ConstantArray::new(Scalar::null(nullable_dtype), maybe_null_field.len()) + .into_array(), + )) + } + LogicalValidity::Array(is_valid) => { + mask(&maybe_null_field, FilterMask::try_from(is_valid)?).map(Some) + } + } + } + + /// Return a field's array by name, masking by the struct's validity. + /// + /// See also [field_by_idx]. + fn field_by_name(&self, name: &str) -> Option { + let field_idx = self + .names() + .iter() + .position(|field_name| field_name.as_ref() == name); + + field_idx.and_then(|field_idx| self.maybe_null_field_by_idx(field_idx)) + } + + /// Return a field's array by name or index, masking by the struct's validity. + /// + /// See also [field_by_idx]. + fn field(&self, field: &Field) -> Option { + match field { + Field::Index(idx) => self.maybe_null_field_by_idx(*idx), + Field::Name(name) => self.maybe_null_field_by_name(name.as_ref()), + } + } + /// Return a field's array by index, ignoring struct nullability fn maybe_null_field_by_idx(&self, idx: usize) -> Option; @@ -221,6 +377,7 @@ pub trait StructArrayTrait: ArrayTrait { field_idx.and_then(|field_idx| self.maybe_null_field_by_idx(field_idx)) } + /// Return a field's array by name or index, ignoring struct nullability fn maybe_null_field(&self, field: &Field) -> Option { match field { Field::Index(idx) => self.maybe_null_field_by_idx(*idx), @@ -245,3 +402,69 @@ pub trait ExtensionArrayTrait: ArrayTrait { /// Returns the underlying [`ArrayData`], without the [`ExtDType`]. fn storage_data(&self) -> ArrayData; } + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_dtype::FieldNames; + use vortex_scalar::Scalar; + + use crate::array::{BoolArray, PrimitiveArray, StructArray}; + use crate::compute::scalar_at; + use crate::validity::{ArrayValidity, Validity}; + use crate::variants::StructArrayTrait; + use crate::{ArrayDType, IntoArrayData}; + + #[test] + fn test_field() { + let original_field = + PrimitiveArray::from_option_iter([Some(1), None, Some(3), None, Some(4)]).into_array(); + let array = StructArray::try_new( + FieldNames::from(["a".into()]), + vec![original_field.clone()], + 5, + Validity::NonNullable, + ) + .unwrap(); + let field = array.field_by_idx(0).unwrap().unwrap(); + let maybe_null_field = array.maybe_null_field_by_idx(0).unwrap(); + + assert_eq!(field.dtype(), maybe_null_field.dtype()); + assert!((0..field.len()).all(|i| { field.is_valid(i) == maybe_null_field.is_valid(i) })); + + let struct_validity = + Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()); + let array = StructArray::try_new( + FieldNames::from(["a".into()]), + vec![original_field], + 5, + struct_validity.clone(), + ) + .unwrap(); + let field = array.field_by_idx(0).unwrap().unwrap(); + + assert!(field.dtype().is_nullable()); + assert!(!field.is_valid(0)); + assert!(!field.is_valid(1)); + assert_eq!(scalar_at(&field, 2).unwrap(), Scalar::from(Some(3))); + assert!(!field.is_valid(3)); + assert!(!field.is_valid(4)); + + let original_field = buffer![1, 2, 3, 4, 5].into_array(); + let array = StructArray::try_new( + FieldNames::from(["a".into()]), + vec![original_field], + 5, + struct_validity, + ) + .unwrap(); + let field = array.field_by_idx(0).unwrap().unwrap(); + + assert!(field.dtype().is_nullable()); + assert!(!field.is_valid(0)); + assert!(!field.is_valid(1)); + assert_eq!(scalar_at(&field, 2).unwrap(), Scalar::from(Some(3))); + assert_eq!(scalar_at(&field, 3).unwrap(), Scalar::from(Some(4))); + assert!(!field.is_valid(4)); + } +} From 3a232f92c43d65ebb211cae649af3e883c00a96b Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 13 Jan 2025 15:13:17 -0500 Subject: [PATCH 5/8] fix --- vortex-array/src/variants.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index e59c956d47d..0d330c1336b 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -219,8 +219,8 @@ pub trait StructArrayTrait: ArrayTrait { /// /// # Examples /// - /// The field of a non-nullable struct array is the same whether accessed by [field_by_idx] or - /// [maybe_null_field_by_idx]: + /// The field of a non-nullable struct array is the same whether accessed by + /// [StructArrayTrait::field_by_idx] or [StructArrayTrait::maybe_null_field_by_idx]: /// /// ``` /// use vortex_array::array::{BoolArray, PrimitiveArray, StructArray}; @@ -247,9 +247,9 @@ pub trait StructArrayTrait: ArrayTrait { /// })); /// ``` /// - /// When both a struct and its field are nullable, [field_by_idx] returns the intersection of - /// the validity, which is to say: a position is valid if and only if both the struct and the - /// field are valid at that position. + /// When both a struct and its field are nullable, [StructArrayTrait::field_by_idx] returns the + /// intersection of the validity, which is to say: a position is valid if and only if both the + /// struct and the field are valid at that position. /// /// ``` /// use vortex_array::array::{BoolArray, PrimitiveArray, StructArray}; @@ -344,7 +344,7 @@ pub trait StructArrayTrait: ArrayTrait { /// Return a field's array by name, masking by the struct's validity. /// - /// See also [field_by_idx]. + /// See also [StructArrayTrait::field_by_idx]. fn field_by_name(&self, name: &str) -> Option { let field_idx = self .names() @@ -356,7 +356,7 @@ pub trait StructArrayTrait: ArrayTrait { /// Return a field's array by name or index, masking by the struct's validity. /// - /// See also [field_by_idx]. + /// See also [StructArrayTrait::field_by_idx]. fn field(&self, field: &Field) -> Option { match field { Field::Index(idx) => self.maybe_null_field_by_idx(*idx), From 88fcc35e84cd4580c67d82a5e62a7df7529eda08 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 13 Jan 2025 15:30:16 -0500 Subject: [PATCH 6/8] fix field_by_name and field to use field_by_idx --- vortex-array/src/variants.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index 0d330c1336b..02b3188b6f7 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -345,22 +345,25 @@ pub trait StructArrayTrait: ArrayTrait { /// Return a field's array by name, masking by the struct's validity. /// /// See also [StructArrayTrait::field_by_idx]. - fn field_by_name(&self, name: &str) -> Option { + fn field_by_name(&self, name: &str) -> VortexResult> { let field_idx = self .names() .iter() .position(|field_name| field_name.as_ref() == name); - field_idx.and_then(|field_idx| self.maybe_null_field_by_idx(field_idx)) + match field_idx { + None => Ok(None), + Some(field_idx) => self.field_by_idx(field_idx), + } } /// Return a field's array by name or index, masking by the struct's validity. /// /// See also [StructArrayTrait::field_by_idx]. - fn field(&self, field: &Field) -> Option { + fn field(&self, field: &Field) -> VortexResult> { match field { - Field::Index(idx) => self.maybe_null_field_by_idx(*idx), - Field::Name(name) => self.maybe_null_field_by_name(name.as_ref()), + Field::Index(idx) => self.field_by_idx(*idx), + Field::Name(name) => self.field_by_name(name.as_ref()), } } From 1ef27476bca7fc12cbaf3dc7c2190b8c7c93abe9 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 13 Jan 2025 16:46:52 -0500 Subject: [PATCH 7/8] validity=not(mask) --- vortex-array/src/variants.rs | 40 ++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index 02b3188b6f7..220f1e618f9 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -10,7 +10,7 @@ use vortex_error::{vortex_panic, VortexError, VortexExpect as _, VortexResult}; use vortex_scalar::Scalar; use crate::array::ConstantArray; -use crate::compute::{mask, try_cast, FilterMask}; +use crate::compute::{invert, mask, try_cast, FilterMask}; use crate::encoding::Encoding; use crate::validity::LogicalValidity; use crate::{ArrayDType, ArrayData, ArrayTrait, IntoArrayData as _}; @@ -261,7 +261,7 @@ pub trait StructArrayTrait: ArrayTrait { /// use vortex_scalar::Scalar; /// /// let original_field = PrimitiveArray::from_option_iter([ - /// Some(1), None, Some(3), None, Some(4), + /// Some(1), None, Some(3), None, Some(5), /// ]).into_array(); /// let struct_validity = Validity::Array(BoolArray::from_iter([ /// true, true, false, false, true, @@ -275,11 +275,11 @@ pub trait StructArrayTrait: ArrayTrait { /// let field = array.field_by_idx(0).unwrap().unwrap(); /// /// assert!(field.dtype().is_nullable()); - /// assert!(!field.is_valid(0)); + /// assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1))); /// assert!(!field.is_valid(1)); - /// assert_eq!(scalar_at(&field, 2).unwrap(), Scalar::from(Some(3))); + /// assert!(!field.is_valid(2)); /// assert!(!field.is_valid(3)); - /// assert!(!field.is_valid(4)); + /// assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5))); /// ``` /// /// When a field is non-nullable, but the struct is nullable, the field receives the struct's @@ -308,11 +308,11 @@ pub trait StructArrayTrait: ArrayTrait { /// let field = array.field_by_idx(0).unwrap().unwrap(); /// /// assert!(field.dtype().is_nullable()); - /// assert!(!field.is_valid(0)); - /// assert!(!field.is_valid(1)); - /// assert_eq!(scalar_at(&field, 2).unwrap(), Scalar::from(Some(3))); - /// assert_eq!(scalar_at(&field, 3).unwrap(), Scalar::from(Some(4))); - /// assert!(!field.is_valid(4)); + /// assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1))); + /// assert_eq!(scalar_at(&field, 1).unwrap(), Scalar::from(Some(2))); + /// assert!(!field.is_valid(2)); + /// assert!(!field.is_valid(3)); + /// assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5))); /// ``` fn field_by_idx(&self, idx: usize) -> VortexResult> { let Some(maybe_null_field) = self.maybe_null_field_by_idx(idx) else { @@ -337,7 +337,7 @@ pub trait StructArrayTrait: ArrayTrait { )) } LogicalValidity::Array(is_valid) => { - mask(&maybe_null_field, FilterMask::try_from(is_valid)?).map(Some) + mask(&maybe_null_field, FilterMask::try_from(invert(&is_valid)?)?).map(Some) } } } @@ -421,7 +421,7 @@ mod tests { #[test] fn test_field() { let original_field = - PrimitiveArray::from_option_iter([Some(1), None, Some(3), None, Some(4)]).into_array(); + PrimitiveArray::from_option_iter([Some(1), None, Some(3), None, Some(5)]).into_array(); let array = StructArray::try_new( FieldNames::from(["a".into()]), vec![original_field.clone()], @@ -447,11 +447,11 @@ mod tests { let field = array.field_by_idx(0).unwrap().unwrap(); assert!(field.dtype().is_nullable()); - assert!(!field.is_valid(0)); + assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1))); assert!(!field.is_valid(1)); - assert_eq!(scalar_at(&field, 2).unwrap(), Scalar::from(Some(3))); + assert!(!field.is_valid(2)); assert!(!field.is_valid(3)); - assert!(!field.is_valid(4)); + assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5))); let original_field = buffer![1, 2, 3, 4, 5].into_array(); let array = StructArray::try_new( @@ -464,10 +464,10 @@ mod tests { let field = array.field_by_idx(0).unwrap().unwrap(); assert!(field.dtype().is_nullable()); - assert!(!field.is_valid(0)); - assert!(!field.is_valid(1)); - assert_eq!(scalar_at(&field, 2).unwrap(), Scalar::from(Some(3))); - assert_eq!(scalar_at(&field, 3).unwrap(), Scalar::from(Some(4))); - assert!(!field.is_valid(4)); + assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1))); + assert_eq!(scalar_at(&field, 1).unwrap(), Scalar::from(Some(2))); + assert!(!field.is_valid(2)); + assert!(!field.is_valid(3)); + assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5))); } } From 567d53436af800d03029bfa3f0372d19650832f4 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 13 Jan 2025 16:21:45 -0500 Subject: [PATCH 8/8] fix: teach column and getitem to respect validity --- vortex-expr/src/column.rs | 39 ++++++++++++++++++++++++++++++++----- vortex-expr/src/get_item.rs | 36 ++++++++++++++++++++++++++++++---- 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/vortex-expr/src/column.rs b/vortex-expr/src/column.rs index e130ac64c73..af3671c6106 100644 --- a/vortex-expr/src/column.rs +++ b/vortex-expr/src/column.rs @@ -27,9 +27,7 @@ impl Column { } pub fn col(field: impl Into) -> ExprRef { - Arc::new(Column { - field: field.into(), - }) + Column::new_expr(field) } impl From for Column { @@ -69,7 +67,7 @@ impl VortexExpr for Column { batch.dtype() ) })? - .maybe_null_field(&self.field) + .field(&self.field)? .ok_or_else(|| vortex_err!("Array doesn't contain child array {}", self.field)) } @@ -85,7 +83,12 @@ impl VortexExpr for Column { #[cfg(test)] mod tests { - use vortex_dtype::{DType, Nullability, PType}; + use vortex_array::array::{BoolArray, PrimitiveArray, StructArray}; + use vortex_array::compute::scalar_at; + use vortex_array::validity::{ArrayValidity as _, Validity}; + use vortex_array::{ArrayDType as _, IntoArrayData as _}; + use vortex_dtype::{DType, FieldNames, Nullability, PType}; + use vortex_scalar::Scalar; use crate::{col, test_harness}; @@ -101,4 +104,30 @@ mod tests { DType::Primitive(PType::U16, Nullability::Nullable) ); } + + #[test] + fn evaluate_with_nulls() { + let a = PrimitiveArray::from_option_iter([Some(0_i32), None, None, Some(3), Some(4)]) + .into_array(); + let array = StructArray::try_new( + FieldNames::from(["a".into()]), + vec![a], + 5, + Validity::Array(BoolArray::from_iter([true, false, true, false, true]).into_array()), + ) + .unwrap() + .into_array(); + + let a_result = col("a").evaluate(&array).unwrap(); + + assert_eq!( + a_result.dtype(), + &DType::Primitive(PType::I32, Nullability::Nullable) + ); + assert_eq!(scalar_at(&a_result, 0).unwrap(), Scalar::from(Some(0_i32))); + assert!(!a_result.is_valid(1)); + assert!(!a_result.is_valid(2)); + assert!(!a_result.is_valid(3)); + assert_eq!(scalar_at(&a_result, 4).unwrap(), Scalar::from(Some(4_i32))); + } } diff --git a/vortex-expr/src/get_item.rs b/vortex-expr/src/get_item.rs index 513c3024ad9..3adea9b92c3 100644 --- a/vortex-expr/src/get_item.rs +++ b/vortex-expr/src/get_item.rs @@ -53,8 +53,7 @@ impl VortexExpr for GetItem { child .as_struct_array() .ok_or_else(|| vortex_err!("GetItem: child array into struct"))? - // TODO(joe): apply struct validity - .maybe_null_field(self.field()) + .field(self.field())? .ok_or_else(|| vortex_err!("Field {} not found", self.field)) } @@ -76,11 +75,14 @@ impl PartialEq for GetItem { #[cfg(test)] mod tests { - use vortex_array::array::StructArray; + use vortex_array::array::{BoolArray, PrimitiveArray, StructArray}; + use vortex_array::compute::scalar_at; + use vortex_array::validity::{ArrayValidity as _, Validity}; use vortex_array::{ArrayDType, IntoArrayData}; use vortex_buffer::buffer; - use vortex_dtype::DType; use vortex_dtype::PType::{I32, I64}; + use vortex_dtype::{DType, FieldNames, Nullability}; + use vortex_scalar::Scalar; use crate::get_item::get_item; use crate::ident; @@ -115,4 +117,30 @@ mod tests { let item = get_item.evaluate(st.as_ref()).unwrap(); assert_eq!(item.dtype(), &DType::from(I64)) } + + #[test] + fn evaluate_with_nulls() { + let a = PrimitiveArray::from_option_iter([Some(0_i32), None, None, Some(3), Some(4)]) + .into_array(); + let array = StructArray::try_new( + FieldNames::from(["a".into()]), + vec![a], + 5, + Validity::Array(BoolArray::from_iter([true, false, true, false, true]).into_array()), + ) + .unwrap() + .into_array(); + + let a_result = get_item("a", ident()).evaluate(&array).unwrap(); + + assert_eq!( + a_result.dtype(), + &DType::Primitive(I32, Nullability::Nullable) + ); + assert_eq!(scalar_at(&a_result, 0).unwrap(), Scalar::from(Some(0_i32))); + assert!(!a_result.is_valid(1)); + assert!(!a_result.is_valid(2)); + assert!(!a_result.is_valid(3)); + assert_eq!(scalar_at(&a_result, 4).unwrap(), Scalar::from(Some(4_i32))); + } }