From 6b3b101103bafa11bc67c90d3268dc6aefc0b916 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 6 Jan 2025 18:13:12 -0500 Subject: [PATCH 01/26] feat: mask Mask sets entries of an array to null. I like the analogy to light: the array is a sequence of lights (each value might be a different wavelength). Null is represented by the absence of light. Placing a mask (i.e. a piece of plastic with slits) over the array causes those values where the mask is present (i.e. "on", "true") to be dark. An example in pseudo-code: ```rust a = [1, 2, 3, 4, 5] a_mask = [t, f, f, t, f] mask(a, a_mask) == [null, 2, 3, null, 5] ``` Specializations --------------- I only fallback to Arrow for two of the core arrays: - Sparse. I was skeptical that I could do better than decompressing and applying it. - Constant. If the mask is sparse, SparseArray might be a good choice. I didn't investigate. For the non-core arrays, I'm missing the following. I'm not clear that I can beat decompression for run end. The others are easy enough but some amount of typing and testing. - fastlanes - fsst - roaring - runend - runend-bool - zigzag Naming ------ Pandas also calls this operation [`mask`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.mask.html) but accepts an optional second argument which is an array of values to use instead of null (which makes Pandas' mask more like an `if_else`). Arrow-rs calls this [`nullif`](https://arrow.apache.org/rust/arrow/compute/fn.nullif.html). Arrow-cpp has [`if_else(condition, consequent, alternate)`](https://arrow.apache.org/docs/cpp/compute.html#cpp-compute-scalar-selections) and [`replace_with_mask(array, mask, replacements)`](https://arrow.apache.org/docs/cpp/compute.html#replace-functions) both of which can implement our `mask` by passing a `NullArray` as the third argument. --- encodings/alp/src/alp/compute/mask.rs | 72 ++++++ encodings/alp/src/alp/compute/mod.rs | 10 +- encodings/alp/src/alp_rd/compute/mask.rs | 57 +++++ encodings/alp/src/alp_rd/compute/mod.rs | 7 +- encodings/bytebool/src/compute.rs | 24 +- encodings/datetime-parts/src/compute/cast.rs | 25 ++ encodings/datetime-parts/src/compute/mask.rs | 58 +++++ encodings/datetime-parts/src/compute/mod.rs | 13 +- encodings/dict/Cargo.toml | 4 + encodings/dict/benches/dict_mask.rs | 50 ++++ encodings/dict/src/array.rs | 5 + encodings/dict/src/compute/mask.rs | 48 ++++ encodings/dict/src/compute/mod.rs | 21 +- vortex-array/src/array/bool/compute/cast.rs | 20 ++ vortex-array/src/array/bool/compute/mask.rs | 12 + vortex-array/src/array/bool/compute/mod.rs | 14 +- vortex-array/src/array/bool/mod.rs | 6 + .../src/array/chunked/compute/filter.rs | 126 +++++----- .../src/array/chunked/compute/mask.rs | 152 ++++++++++++ vortex-array/src/array/chunked/compute/mod.rs | 7 +- .../src/array/constant/compute/mod.rs | 18 ++ .../src/array/extension/compute/cast.rs | 27 ++ .../src/array/extension/compute/mask.rs | 48 ++++ .../src/array/extension/compute/mod.rs | 18 +- vortex-array/src/array/list/compute/mod.rs | 37 ++- vortex-array/src/array/null/compute.rs | 12 +- .../src/array/primitive/compute/mask.rs | 17 ++ .../src/array/primitive/compute/mod.rs | 7 +- vortex-array/src/array/primitive/mod.rs | 26 ++ vortex-array/src/array/sparse/compute/mod.rs | 36 ++- vortex-array/src/array/struct_/compute.rs | 209 +++++++++++++++- vortex-array/src/array/varbin/compute/cast.rs | 25 ++ vortex-array/src/array/varbin/compute/mask.rs | 44 ++++ vortex-array/src/array/varbin/compute/mod.rs | 14 +- .../src/array/varbinview/compute/mod.rs | 65 ++++- vortex-array/src/compute/filter.rs | 43 +++- vortex-array/src/compute/mask.rs | 232 ++++++++++++++++++ vortex-array/src/compute/mod.rs | 16 +- vortex-array/src/validity.rs | 51 +++- 39 files changed, 1583 insertions(+), 93 deletions(-) create mode 100644 encodings/alp/src/alp/compute/mask.rs create mode 100644 encodings/alp/src/alp_rd/compute/mask.rs create mode 100644 encodings/datetime-parts/src/compute/cast.rs create mode 100644 encodings/datetime-parts/src/compute/mask.rs create mode 100644 encodings/dict/benches/dict_mask.rs create mode 100644 encodings/dict/src/compute/mask.rs create mode 100644 vortex-array/src/array/bool/compute/cast.rs create mode 100644 vortex-array/src/array/bool/compute/mask.rs create mode 100644 vortex-array/src/array/chunked/compute/mask.rs create mode 100644 vortex-array/src/array/extension/compute/cast.rs create mode 100644 vortex-array/src/array/extension/compute/mask.rs create mode 100644 vortex-array/src/array/primitive/compute/mask.rs create mode 100644 vortex-array/src/array/varbin/compute/cast.rs create mode 100644 vortex-array/src/array/varbin/compute/mask.rs create mode 100644 vortex-array/src/compute/mask.rs diff --git a/encodings/alp/src/alp/compute/mask.rs b/encodings/alp/src/alp/compute/mask.rs new file mode 100644 index 00000000000..5737378f626 --- /dev/null +++ b/encodings/alp/src/alp/compute/mask.rs @@ -0,0 +1,72 @@ +use vortex_array::compute::{mask, try_cast, FilterMask, MaskFn}; +use vortex_array::{ArrayDType as _, ArrayData, IntoArrayData}; +use vortex_error::VortexResult; + +use crate::{ALPArray, ALPEncoding}; + +impl MaskFn for ALPEncoding { + fn mask(&self, array: &ALPArray, filter_mask: FilterMask) -> VortexResult { + ALPArray::try_new( + mask(&array.encoded(), filter_mask)?, + array.exponents(), + array + .patches() + .map(|patches| { + patches.map_values(|values| try_cast(&values, &values.dtype().as_nullable())) + }) + .transpose()?, + ) + .map(IntoArrayData::into_array) + } +} + +#[cfg(test)] +mod tests { + use vortex_array::array::PrimitiveArray; + use vortex_array::compute::test_harness::test_mask; + use vortex_array::validity::Validity; + use vortex_array::IntoArrayData as _; + use vortex_buffer::buffer; + + use crate::alp_encode; + + #[test] + fn test_mask_no_patches_alp_array() { + test_mask( + alp_encode(&PrimitiveArray::new( + buffer![1.0f32, 2.0, 3.0, 4.0, 5.0], + Validity::AllValid, + )) + .unwrap() + .into_array(), + ); + + test_mask( + alp_encode(&PrimitiveArray::new( + buffer![1.0f32, 2.0, 3.0, 4.0, 5.0], + Validity::NonNullable, + )) + .unwrap() + .into_array(), + ); + } + + #[test] + fn test_mask_patched_alp_array() { + let alp_array = alp_encode(&PrimitiveArray::new( + buffer![1.0f32, 2.0, 3.0, 4.0, 1e10], + Validity::AllValid, + )) + .unwrap(); + assert!(alp_array.patches().is_some()); + test_mask(alp_array.into_array()); + + let alp_array = alp_encode(&PrimitiveArray::new( + buffer![1.0f32, 2.0, 3.0, 4.0, 1e10], + Validity::NonNullable, + )) + .unwrap(); + assert!(alp_array.patches().is_some()); + test_mask(alp_array.into_array()); + } +} diff --git a/encodings/alp/src/alp/compute/mod.rs b/encodings/alp/src/alp/compute/mod.rs index 74c0546ff6f..895dbece9d1 100644 --- a/encodings/alp/src/alp/compute/mod.rs +++ b/encodings/alp/src/alp/compute/mod.rs @@ -1,6 +1,8 @@ +mod mask; + use vortex_array::compute::{ - filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, SliceFn, - TakeFn, + filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, MaskFn, ScalarAtFn, + SliceFn, TakeFn, }; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; @@ -14,6 +16,10 @@ impl ComputeVTable for ALPEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/encodings/alp/src/alp_rd/compute/mask.rs b/encodings/alp/src/alp_rd/compute/mask.rs new file mode 100644 index 00000000000..0f9da516509 --- /dev/null +++ b/encodings/alp/src/alp_rd/compute/mask.rs @@ -0,0 +1,57 @@ +use vortex_array::compute::{mask, FilterMask, MaskFn}; +use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; +use vortex_error::VortexResult; + +use crate::{ALPRDArray, ALPRDEncoding}; + +impl MaskFn for ALPRDEncoding { + fn mask(&self, array: &ALPRDArray, filter_mask: FilterMask) -> VortexResult { + Ok(ALPRDArray::try_new( + array.dtype().as_nullable(), + mask(&array.left_parts(), filter_mask)?, + array.left_parts_dict(), + array.right_parts(), + array.right_bit_width(), + array.left_parts_patches(), + )? + .into_array()) + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_array::array::PrimitiveArray; + use vortex_array::compute::test_harness::test_mask; + use vortex_array::IntoArrayData as _; + + use crate::{ALPRDFloat, RDEncoder}; + + #[rstest] + #[case(0.1f32, 0.2f32, 3e25f32)] + #[case(0.1f64, 0.2f64, 3e100f64)] + fn test_mask_simple(#[case] a: T, #[case] b: T, #[case] outlier: T) { + test_mask( + RDEncoder::new(&[a, b]) + .encode(&PrimitiveArray::from_iter([a, b, outlier, b, outlier])) + .into_array(), + ); + } + + #[rstest] + #[case(0.1f32, 3e25f32)] + #[case(0.5f64, 1e100f64)] + fn test_mask_with_nulls(#[case] a: T, #[case] outlier: T) { + test_mask( + RDEncoder::new(&[a]) + .encode(&PrimitiveArray::from_option_iter([ + Some(a), + None, + Some(outlier), + Some(a), + None, + ])) + .into_array(), + ); + } +} diff --git a/encodings/alp/src/alp_rd/compute/mod.rs b/encodings/alp/src/alp_rd/compute/mod.rs index c696d1c51a0..c87a80b9b9a 100644 --- a/encodings/alp/src/alp_rd/compute/mod.rs +++ b/encodings/alp/src/alp_rd/compute/mod.rs @@ -1,9 +1,10 @@ -use vortex_array::compute::{ComputeVTable, FilterFn, ScalarAtFn, SliceFn, TakeFn}; +use vortex_array::compute::{ComputeVTable, FilterFn, MaskFn, ScalarAtFn, SliceFn, TakeFn}; use vortex_array::ArrayData; use crate::ALPRDEncoding; mod filter; +mod mask; mod scalar_at; mod slice; mod take; @@ -13,6 +14,10 @@ impl ComputeVTable for ALPRDEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index a21a37c3e12..fa6e04145f8 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -1,5 +1,7 @@ use num_traits::AsPrimitive; -use vortex_array::compute::{ComputeVTable, FillForwardFn, ScalarAtFn, SliceFn, TakeFn}; +use vortex_array::compute::{ + ComputeVTable, FillForwardFn, FilterMask, MaskFn, ScalarAtFn, SliceFn, TakeFn, +}; use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; @@ -14,6 +16,10 @@ impl ComputeVTable for ByteBoolEncoding { None } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -27,6 +33,13 @@ impl ComputeVTable for ByteBoolEncoding { } } +impl MaskFn for ByteBoolEncoding { + fn mask(&self, array: &ByteBoolArray, mask: FilterMask) -> VortexResult { + ByteBoolArray::try_new(array.buffer().clone(), array.validity().mask(&mask)?) + .map(IntoArrayData::into_array) + } +} + impl ScalarAtFn for ByteBoolEncoding { fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult { Ok(Scalar::bool( @@ -136,6 +149,7 @@ impl FillForwardFn for ByteBoolEncoding { #[cfg(test)] mod tests { + use vortex_array::compute::test_harness::test_mask; use vortex_array::compute::{compare, scalar_at, slice, Operator}; use super::*; @@ -208,4 +222,12 @@ mod tests { let s = scalar_at(&arr, 4).unwrap(); assert!(s.is_null()); } + + #[test] + fn test_mask_byte_bool() { + test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).into_array()); + test_mask( + ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).into_array(), + ); + } } diff --git a/encodings/datetime-parts/src/compute/cast.rs b/encodings/datetime-parts/src/compute/cast.rs new file mode 100644 index 00000000000..ce703d630fa --- /dev/null +++ b/encodings/datetime-parts/src/compute/cast.rs @@ -0,0 +1,25 @@ +use vortex_array::compute::{try_cast, CastFn}; +use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::{DateTimePartsArray, DateTimePartsEncoding}; + +impl CastFn for DateTimePartsEncoding { + fn cast(&self, array: &DateTimePartsArray, dtype: &DType) -> VortexResult { + if !array.dtype().eq_ignore_nullability(dtype) { + vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype); + }; + + Ok(DateTimePartsArray::try_new( + array.dtype().clone().as_nullable(), + try_cast( + array.days().as_ref(), + &array.days().dtype().with_nullability(dtype.nullability()), + )?, + array.seconds(), + array.subsecond(), + )? + .into_array()) + } +} diff --git a/encodings/datetime-parts/src/compute/mask.rs b/encodings/datetime-parts/src/compute/mask.rs new file mode 100644 index 00000000000..9c52974b464 --- /dev/null +++ b/encodings/datetime-parts/src/compute/mask.rs @@ -0,0 +1,58 @@ +use vortex_array::compute::{mask, FilterMask, MaskFn}; +use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; +use vortex_error::VortexResult; + +use crate::{DateTimePartsArray, DateTimePartsEncoding}; + +impl MaskFn for DateTimePartsEncoding { + fn mask(&self, array: &DateTimePartsArray, filter_mask: FilterMask) -> VortexResult { + Ok(DateTimePartsArray::try_new( + array.dtype().clone().as_nullable(), + mask(array.days().as_ref(), filter_mask)?, + array.seconds(), + array.subsecond(), + )? + .into_array()) + } +} + +#[cfg(test)] +mod tests { + use vortex_array::array::TemporalArray; + use vortex_array::compute::test_harness::test_mask; + use vortex_array::IntoArrayData as _; + use vortex_buffer::buffer; + use vortex_datetime_dtype::TimeUnit; + use vortex_dtype::DType; + + use crate::{split_temporal, DateTimePartsArray, TemporalParts}; + + #[test] + fn test_mask_datetime_parts_array() { + let raw_millis = buffer![ + 86_400i64, // element with only day component + 86_400i64 + 1000, // element with day + second components + 86_400i64 + 1000 + 1, // element with day + second + sub-second components + 86_400i64 + 1000 + 5, // element with day + second + sub-second components + 86_400i64 + 1000 + 55, // element with day + second + sub-second components + ] + .into_array(); + let temporal_array = + TemporalArray::new_timestamp(raw_millis, TimeUnit::Ms, Some("UTC".to_string())); + let TemporalParts { + days, + seconds, + subseconds, + } = split_temporal(temporal_array.clone()).unwrap(); + let date_times = DateTimePartsArray::try_new( + DType::Extension(temporal_array.ext_dtype()), + days, + seconds, + subseconds, + ) + .unwrap() + .into_array(); + + test_mask(date_times.clone()); + } +} diff --git a/encodings/datetime-parts/src/compute/mod.rs b/encodings/datetime-parts/src/compute/mod.rs index de3937c1e8a..d5ebef744f8 100644 --- a/encodings/datetime-parts/src/compute/mod.rs +++ b/encodings/datetime-parts/src/compute/mod.rs @@ -1,9 +1,12 @@ +mod cast; mod filter; +mod mask; mod take; use vortex_array::array::{PrimitiveArray, TemporalArray}; use vortex_array::compute::{ - scalar_at, slice, try_cast, ComputeVTable, FilterFn, ScalarAtFn, SliceFn, TakeFn, + scalar_at, slice, try_cast, CastFn, ComputeVTable, FilterFn, MaskFn, ScalarAtFn, SliceFn, + TakeFn, }; use vortex_array::validity::ArrayValidity; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; @@ -17,10 +20,18 @@ use vortex_scalar::{PrimitiveScalar, Scalar}; use crate::{DateTimePartsArray, DateTimePartsEncoding}; impl ComputeVTable for DateTimePartsEncoding { + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/encodings/dict/Cargo.toml b/encodings/dict/Cargo.toml index 89924fea653..69f81008db8 100644 --- a/encodings/dict/Cargo.toml +++ b/encodings/dict/Cargo.toml @@ -35,3 +35,7 @@ vortex-array = { workspace = true, features = ["test-harness"] } [[bench]] name = "dict_compress" harness = false + +[[bench]] +name = "dict_mask" +harness = false diff --git a/encodings/dict/benches/dict_mask.rs b/encodings/dict/benches/dict_mask.rs new file mode 100644 index 00000000000..4e0241d4b1f --- /dev/null +++ b/encodings/dict/benches/dict_mask.rs @@ -0,0 +1,50 @@ +#![allow(clippy::unwrap_used)] + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng as _}; +use vortex_array::array::PrimitiveArray; +use vortex_array::compute::{mask, FilterMask}; +use vortex_array::IntoArrayData as _; +use vortex_buffer::{buffer, Buffer}; +use vortex_dict::DictArray; + +fn filter_mask(len: usize, fraction_masked: f64, rng: &mut StdRng) -> FilterMask { + let indices = (0..len) + .filter(|_| rng.gen_bool(fraction_masked)) + .map(|x| x as u64) + .collect::>(); + FilterMask::from_indices(len, indices) +} + +#[allow(clippy::cast_possible_truncation)] +fn bench_dict_mask(c: &mut Criterion) { + let mut group = c.benchmark_group("bench_dict_mask"); + let mut rng = StdRng::seed_from_u64(0); + + let len = 65_535; + // for fraction_valid in [0.5, 0.1, 0.01, 0.001, 0.0001] { + for fraction_valid in [0.1] { + let codes = + PrimitiveArray::from_iter((0..len).map(|_| (!rng.gen_bool(fraction_valid)) as u64)) + .into_array(); + let values = buffer![1].into_array(); + let array = DictArray::try_new(codes, values).unwrap().into_array(); + // for fraction_masked in [0.1, 0.01, 0.001, 0.0001] { + for fraction_masked in [0.9, 0.5, 0.1, 0.0001] { + let filter_mask = filter_mask(len, fraction_masked, &mut rng); + group.bench_with_input( + BenchmarkId::from_parameter(format!( + "fraction_valid={}, fraction_masked={}", + fraction_valid, fraction_masked + )), + &(&array, filter_mask), + |b, (array, filter_mask)| b.iter(|| mask(array, filter_mask.clone()).unwrap()), + ); + } + } + group.finish() +} + +criterion_group!(benches, bench_dict_mask); +criterion_main!(benches); diff --git a/encodings/dict/src/array.rs b/encodings/dict/src/array.rs index 18584099521..2a9785af649 100644 --- a/encodings/dict/src/array.rs +++ b/encodings/dict/src/array.rs @@ -48,6 +48,11 @@ impl DictArray { ) } + #[inline] + pub fn codes_ptype(&self) -> PType { + self.metadata().codes_ptype + } + #[inline] pub fn codes(&self) -> ArrayData { self.as_ref() diff --git a/encodings/dict/src/compute/mask.rs b/encodings/dict/src/compute/mask.rs new file mode 100644 index 00000000000..10ef542b88f --- /dev/null +++ b/encodings/dict/src/compute/mask.rs @@ -0,0 +1,48 @@ +use vortex_array::compute::{FilterIter, FilterMask, MaskFn}; +use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant as _}; +use vortex_buffer::BufferMut; +use vortex_dtype::{match_each_integer_ptype, NativePType}; +use vortex_error::VortexResult; + +use crate::{DictArray, DictEncoding}; + +impl MaskFn for DictEncoding { + fn mask(&self, array: &DictArray, mask: FilterMask) -> VortexResult { + let new_codes = match_each_integer_ptype!(array.codes_ptype(), |$T| { + let mut codes = array.codes().into_primitive()?.into_buffer_mut(); + typed_mask::<$T>(&mut codes, mask)?; + codes.into_array() + }); + DictArray::try_new(new_codes, array.values()).map(IntoArrayData::into_array) + } +} + +fn typed_mask(codes: &mut BufferMut, mask: FilterMask) -> VortexResult<()> { + match mask.iter()? { + FilterIter::Indices(indices) => { + for index in indices { + codes[*index] = T::zero(); + } + } + FilterIter::IndicesIter(bit_index_iterator) => { + for index in bit_index_iterator { + codes[index] = T::zero(); + } + } + FilterIter::Slices(slices) => { + for slice in slices { + for index in slice.0..slice.1 { + codes[index] = T::zero(); + } + } + } + FilterIter::SlicesIter(bit_slice_iterator) => { + for slice in bit_slice_iterator { + for index in slice.0..slice.1 { + codes[index] = T::zero(); + } + } + } + } + Ok(()) +} diff --git a/encodings/dict/src/compute/mod.rs b/encodings/dict/src/compute/mod.rs index 79cfb45ab76..626db947a2d 100644 --- a/encodings/dict/src/compute/mod.rs +++ b/encodings/dict/src/compute/mod.rs @@ -1,10 +1,11 @@ mod binary_numeric; mod compare; mod like; +mod mask; use vortex_array::compute::{ filter, scalar_at, slice, take, BinaryNumericFn, CompareFn, ComputeVTable, FilterFn, - FilterMask, LikeFn, ScalarAtFn, SliceFn, TakeFn, + FilterMask, LikeFn, MaskFn, ScalarAtFn, SliceFn, TakeFn, }; use vortex_array::{ArrayData, IntoArrayData}; use vortex_error::VortexResult; @@ -29,6 +30,10 @@ impl ComputeVTable for DictEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -78,7 +83,7 @@ impl SliceFn for DictEncoding { mod test { use vortex_array::accessor::ArrayAccessor; use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinViewArray}; - use vortex_array::compute::test_harness::test_binary_numeric; + use vortex_array::compute::test_harness::{test_binary_numeric, test_mask}; use vortex_array::compute::{compare, scalar_at, slice, Operator}; use vortex_array::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::{DType, Nullability}; @@ -160,4 +165,16 @@ mod test { let array = sliced_dict_array(); test_binary_numeric::(array) } + + #[test] + fn test_mask_dict_array() { + let reference = + PrimitiveArray::from_option_iter([None, Some(42), Some(-9), Some(42), Some(5)]); + let (codes, values) = dict_encode_primitive(&reference); + test_mask( + DictArray::try_new(codes.into_array(), values.into_array()) + .unwrap() + .into_array(), + ) + } } diff --git a/vortex-array/src/array/bool/compute/cast.rs b/vortex-array/src/array/bool/compute/cast.rs new file mode 100644 index 00000000000..86e600d2100 --- /dev/null +++ b/vortex-array/src/array/bool/compute/cast.rs @@ -0,0 +1,20 @@ +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::array::{BoolArray, BoolEncoding}; +use crate::compute::CastFn; +use crate::{ArrayData, IntoArrayData}; + +impl CastFn for BoolEncoding { + fn cast(&self, array: &BoolArray, dtype: &DType) -> VortexResult { + let DType::Bool(new_nullability) = dtype else { + vortex_bail!(MismatchedTypes: "bool type", dtype); + }; + + BoolArray::try_new( + array.boolean_buffer(), + array.validity().with_nullability(*new_nullability)?, + ) + .map(IntoArrayData::into_array) + } +} diff --git a/vortex-array/src/array/bool/compute/mask.rs b/vortex-array/src/array/bool/compute/mask.rs new file mode 100644 index 00000000000..a1720220075 --- /dev/null +++ b/vortex-array/src/array/bool/compute/mask.rs @@ -0,0 +1,12 @@ +use vortex_error::VortexResult; + +use crate::array::{BoolArray, BoolEncoding}; +use crate::compute::{FilterMask, MaskFn}; +use crate::{ArrayData, IntoArrayData}; + +impl MaskFn for BoolEncoding { + fn mask(&self, array: &BoolArray, mask: FilterMask) -> VortexResult { + BoolArray::try_new(array.boolean_buffer(), array.validity().mask(&mask)?) + .map(IntoArrayData::into_array) + } +} diff --git a/vortex-array/src/array/bool/compute/mod.rs b/vortex-array/src/array/bool/compute/mod.rs index 5a9499590b2..dc0dbfdebc4 100644 --- a/vortex-array/src/array/bool/compute/mod.rs +++ b/vortex-array/src/array/bool/compute/mod.rs @@ -1,15 +1,17 @@ use crate::array::BoolEncoding; use crate::compute::{ - BinaryBooleanFn, ComputeVTable, FillForwardFn, FillNullFn, FilterFn, InvertFn, ScalarAtFn, - SliceFn, TakeFn, + BinaryBooleanFn, CastFn, ComputeVTable, FillForwardFn, FillNullFn, FilterFn, InvertFn, MaskFn, + ScalarAtFn, SliceFn, TakeFn, }; use crate::ArrayData; +mod cast; mod fill_forward; mod fill_null; pub mod filter; mod flatten; mod invert; +mod mask; mod scalar_at; mod slice; mod take; @@ -23,6 +25,10 @@ impl ComputeVTable for BoolEncoding { None } + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn> { Some(self) } @@ -39,6 +45,10 @@ impl ComputeVTable for BoolEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/array/bool/mod.rs b/vortex-array/src/array/bool/mod.rs index 4d9ec80fb59..a94b2f4b1f9 100644 --- a/vortex-array/src/array/bool/mod.rs +++ b/vortex-array/src/array/bool/mod.rs @@ -211,6 +211,7 @@ mod tests { use vortex_dtype::Nullability; use crate::array::{BoolArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; use crate::compute::{scalar_at, slice}; use crate::patches::Patches; use crate::validity::Validity; @@ -308,4 +309,9 @@ mod tests { assert_eq!(offset, 0); assert_eq!(values.as_slice(), &[254, 127]); } + + #[test] + fn test_mask_primitive_array() { + test_mask(BoolArray::from_iter([true, false, true, true, false]).into_array()); + } } diff --git a/vortex-array/src/array/chunked/compute/filter.rs b/vortex-array/src/array/chunked/compute/filter.rs index ea01a01fbf1..af17aba60da 100644 --- a/vortex-array/src/array/chunked/compute/filter.rs +++ b/vortex-array/src/array/chunked/compute/filter.rs @@ -32,7 +32,7 @@ impl FilterFn for ChunkedEncoding { /// When we rewrite a set of slices in a filter predicate into chunk addresses, we want to account /// for the fact that some chunks will be wholly skipped. #[derive(Clone)] -enum ChunkFilter { +pub(crate) enum ChunkFilter { All, None, Slices(Vec<(usize, usize)>), @@ -61,66 +61,12 @@ fn slices_to_mask(slices: &[(usize, usize)], len: usize) -> FilterMask { } /// Filter the chunks using slice ranges. -#[allow(deprecated)] fn filter_slices(array: &ChunkedArray, mask: FilterMask) -> VortexResult> { - let mut result = Vec::with_capacity(array.nchunks()); - - // Pre-materialize the chunk ends for performance. - // The chunk ends is nchunks+1, which is expected to be in the hundreds or at most thousands. - let chunk_ends = array.chunk_offsets().into_canonical()?.into_primitive()?; - let chunk_ends = chunk_ends.as_slice::(); - - let mut chunk_filters = vec![ChunkFilter::None; array.nchunks()]; - - for (slice_start, slice_end) in mask.iter_slices()? { - let (start_chunk, start_idx) = find_chunk_idx(slice_start, chunk_ends); - // NOTE: we adjust slice end back by one, in case it ends on a chunk boundary, we do not - // want to index into the unused chunk. - let (end_chunk, end_idx) = find_chunk_idx(slice_end - 1, chunk_ends); - // Adjust back to an exclusive range - let end_idx = end_idx + 1; - - if start_chunk == end_chunk { - // start == end means that the slice lies within a single chunk. - match &mut chunk_filters[start_chunk] { - f @ (ChunkFilter::All | ChunkFilter::None) => { - *f = ChunkFilter::Slices(vec![(start_idx, end_idx)]); - } - ChunkFilter::Slices(slices) => { - slices.push((start_idx, end_idx)); - } - } - } else { - // start != end means that the range is split over at least two chunks: - // start chunk: append a slice from (start_idx, start_chunk_end), i.e. whole chunk. - // end chunk: append a slice from (0, end_idx). - // chunks between start and end: append ChunkFilter::All. - let start_chunk_len: usize = - (chunk_ends[start_chunk + 1] - chunk_ends[start_chunk]).try_into()?; - let start_slice = (start_idx, start_chunk_len); - match &mut chunk_filters[start_chunk] { - f @ (ChunkFilter::All | ChunkFilter::None) => { - *f = ChunkFilter::Slices(vec![start_slice]) - } - ChunkFilter::Slices(slices) => slices.push(start_slice), - } - - let end_slice = (0, end_idx); - match &mut chunk_filters[end_chunk] { - f @ (ChunkFilter::All | ChunkFilter::None) => { - *f = ChunkFilter::Slices(vec![end_slice]); - } - ChunkFilter::Slices(slices) => slices.push(end_slice), - } - - for chunk in &mut chunk_filters[start_chunk + 1..end_chunk] { - *chunk = ChunkFilter::All; - } - } - } + let chunked_filters = chunk_filters(array, mask)?; // Now, apply the chunk filter to every slice. - for (chunk, chunk_filter) in array.chunks().zip(chunk_filters.iter()) { + let mut result = Vec::with_capacity(array.nchunks()); + for (chunk, chunk_filter) in array.chunks().zip(chunked_filters.iter()) { match chunk_filter { // All => preserve the entire chunk unfiltered. ChunkFilter::All => result.push(chunk), @@ -186,9 +132,71 @@ fn filter_indices(array: &ChunkedArray, mask: FilterMask) -> VortexResult VortexResult> { + // Pre-materialize the chunk ends for performance. + // The chunk ends is nchunks+1, which is expected to be in the hundreds or at most thousands. + let chunk_ends = array.chunk_offsets().into_canonical()?.into_primitive()?; + let chunk_ends = chunk_ends.as_slice::(); + + let mut chunk_filters = vec![ChunkFilter::None; array.nchunks()]; + + for (slice_start, slice_end) in mask.iter_slices()? { + let (start_chunk, start_idx) = find_chunk_idx(slice_start, chunk_ends); + // NOTE: we adjust slice end back by one, in case it ends on a chunk boundary, we do not + // want to index into the unused chunk. + let (end_chunk, end_idx) = find_chunk_idx(slice_end - 1, chunk_ends); + // Adjust back to an exclusive range + let end_idx = end_idx + 1; + + if start_chunk == end_chunk { + // start == end means that the slice lies within a single chunk. + match &mut chunk_filters[start_chunk] { + f @ (ChunkFilter::All | ChunkFilter::None) => { + *f = ChunkFilter::Slices(vec![(start_idx, end_idx)]); + } + ChunkFilter::Slices(slices) => { + slices.push((start_idx, end_idx)); + } + } + } else { + // start != end means that the range is split over at least two chunks: + // start chunk: append a slice from (start_idx, start_chunk_end), i.e. whole chunk. + // end chunk: append a slice from (0, end_idx). + // chunks between start and end: append ChunkFilter::All. + let start_chunk_len: usize = + (chunk_ends[start_chunk + 1] - chunk_ends[start_chunk]).try_into()?; + let start_slice = (start_idx, start_chunk_len); + match &mut chunk_filters[start_chunk] { + f @ (ChunkFilter::All | ChunkFilter::None) => { + *f = ChunkFilter::Slices(vec![start_slice]) + } + ChunkFilter::Slices(slices) => slices.push(start_slice), + } + + let end_slice = (0, end_idx); + match &mut chunk_filters[end_chunk] { + f @ (ChunkFilter::All | ChunkFilter::None) => { + *f = ChunkFilter::Slices(vec![end_slice]); + } + ChunkFilter::Slices(slices) => slices.push(end_slice), + } + + for chunk in &mut chunk_filters[start_chunk + 1..end_chunk] { + *chunk = ChunkFilter::All; + } + } + } + + Ok(chunk_filters) +} + // Mirrors the find_chunk_idx method on ChunkedArray, but avoids all of the overhead // from scalars, dtypes, and metadata cloning. -fn find_chunk_idx(idx: usize, chunk_ends: &[u64]) -> (usize, usize) { +pub(crate) fn find_chunk_idx(idx: usize, chunk_ends: &[u64]) -> (usize, usize) { let chunk_id = chunk_ends .search_sorted(&(idx as u64), SearchSortedSide::Right) .to_ends_index(chunk_ends.len()) diff --git a/vortex-array/src/array/chunked/compute/mask.rs b/vortex-array/src/array/chunked/compute/mask.rs new file mode 100644 index 00000000000..70da58d95da --- /dev/null +++ b/vortex-array/src/array/chunked/compute/mask.rs @@ -0,0 +1,152 @@ +use itertools::Itertools as _; +use vortex_buffer::BufferMut; +use vortex_dtype::DType; +use vortex_error::{VortexExpect as _, VortexResult}; +use vortex_scalar::Scalar; + +use super::filter::{chunk_filters, find_chunk_idx, ChunkFilter}; +use crate::array::{ChunkedArray, ChunkedEncoding, ConstantArray}; +use crate::compute::{mask, try_cast, FilterIter, FilterMask, MaskFn}; +use crate::{ArrayDType, ArrayData, ArrayLen as _, IntoArrayData, IntoCanonical as _}; + +impl MaskFn for ChunkedEncoding { + fn mask(&self, array: &ChunkedArray, mask: FilterMask) -> VortexResult { + let new_dtype = array.dtype().as_nullable(); + let new_chunks = match mask.iter()? { + FilterIter::Indices(_) => mask_indices(array, mask, &new_dtype), + FilterIter::IndicesIter(_) => mask_indices(array, mask, &new_dtype), + FilterIter::Slices(_) => mask_slices(array, mask, &new_dtype), + FilterIter::SlicesIter(_) => mask_slices(array, mask, &new_dtype), + }?; + debug_assert_eq!(new_chunks.len(), array.nchunks()); + debug_assert_eq!( + new_chunks.iter().map(|x| x.len()).sum::(), + array.len() + ); + ChunkedArray::try_new(new_chunks, new_dtype).map(IntoArrayData::into_array) + } +} + +#[allow(deprecated)] +pub fn mask_indices( + array: &ChunkedArray, + filter_mask: FilterMask, + new_dtype: &DType, +) -> VortexResult> { + let mut new_chunks = Vec::with_capacity(array.nchunks()); + let mut current_chunk_id = 0; + let mut chunk_indices = BufferMut::with_capacity(array.nchunks()); + + // Avoid find_chunk_idx and use our own to avoid the overhead. + // The array should only be some thousands of values in the general case. + let chunk_ends = array.chunk_offsets().into_canonical()?.into_primitive()?; + let chunk_ends = chunk_ends.as_slice::(); + + for set_index in filter_mask.iter_indices()? { + let (chunk_id, index) = find_chunk_idx(set_index, chunk_ends); + if chunk_id != current_chunk_id { + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + let masked_chunk = mask( + &chunk, + FilterMask::from_indices(chunk.len(), chunk_indices.clone().freeze()), + )?; + new_chunks.push(masked_chunk); + current_chunk_id += 1; + + // Advance the chunk forward, reset the chunk indices buffer. + while current_chunk_id < chunk_id { + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + new_chunks.push(try_cast(chunk, new_dtype)?); + current_chunk_id += 1; + } + + chunk_indices.clear(); + } + + chunk_indices.push(index as u64); + } + + if !chunk_indices.is_empty() { + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + let masked_chunk = mask( + &chunk, + FilterMask::from_indices(chunk.len(), chunk_indices.clone().freeze()), + )?; + new_chunks.push(masked_chunk); + current_chunk_id += 1; + } + + while current_chunk_id < array.nchunks() { + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + new_chunks.push(try_cast(chunk, new_dtype)?); + current_chunk_id += 1; + } + + Ok(new_chunks) +} + +pub fn mask_slices( + array: &ChunkedArray, + filter_mask: FilterMask, + new_dtype: &DType, +) -> VortexResult> { + let chunked_filters = chunk_filters(array, filter_mask)?; + + array + .chunks() + .zip_eq(chunked_filters.iter()) + .map(|(chunk, chunk_filter)| -> VortexResult { + Ok(match chunk_filter { + ChunkFilter::All => { + // All => entire chunk is masked out + ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()).into_array() + } + ChunkFilter::None => { + // None => preserve the entire chunk unmasked + chunk + } + // Slices => turn the slices into a boolean buffer. + ChunkFilter::Slices(slices) => mask( + &chunk, + FilterMask::from_slices(chunk.len(), slices.iter().cloned()), + )?, + }) + }) + .process_results(|iter| iter.collect::>()) +} + +#[cfg(test)] +mod test { + use vortex_buffer::buffer; + use vortex_dtype::{DType, Nullability, PType}; + + use crate::array::{ChunkedArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; + use crate::IntoArrayData; + + #[test] + fn test_mask_chunked_array() { + let dtype = DType::Primitive(PType::U64, Nullability::NonNullable); + let chunked = ChunkedArray::try_new( + vec![ + buffer![0u64, 1].into_array(), + buffer![2_u64].into_array(), + PrimitiveArray::empty::(dtype.nullability()).into_array(), + buffer![3_u64, 4].into_array(), + ], + dtype, + ) + .unwrap() + .into_array(); + + test_mask(chunked); + } +} diff --git a/vortex-array/src/array/chunked/compute/mod.rs b/vortex-array/src/array/chunked/compute/mod.rs index 7cb867ffb77..e4a5f83e2b6 100644 --- a/vortex-array/src/array/chunked/compute/mod.rs +++ b/vortex-array/src/array/chunked/compute/mod.rs @@ -5,7 +5,7 @@ use crate::array::chunked::ChunkedArray; use crate::array::ChunkedEncoding; use crate::compute::{ try_cast, BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, ComputeVTable, FillNullFn, - FilterFn, InvertFn, ScalarAtFn, SliceFn, TakeFn, + FilterFn, InvertFn, MaskFn, ScalarAtFn, SliceFn, TakeFn, }; use crate::{ArrayData, IntoArrayData}; @@ -15,6 +15,7 @@ mod compare; mod fill_null; mod filter; mod invert; +mod mask; mod scalar_at; mod slice; mod take; @@ -48,6 +49,10 @@ impl ComputeVTable for ChunkedEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/array/constant/compute/mod.rs b/vortex-array/src/array/constant/compute/mod.rs index dbb7dfbc7b5..291589d41b6 100644 --- a/vortex-array/src/array/constant/compute/mod.rs +++ b/vortex-array/src/array/constant/compute/mod.rs @@ -76,3 +76,21 @@ impl FilterFn for ConstantEncoding { Ok(ConstantArray::new(array.scalar(), mask.true_count()).into_array()) } } + +#[cfg(test)] +mod test { + use vortex_dtype::half::f16; + use vortex_scalar::Scalar; + + use super::ConstantArray; + use crate::compute::test_harness::test_mask; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_constant() { + test_mask(ConstantArray::new(Scalar::null_typed::(), 5).into_array()); + test_mask(ConstantArray::new(Scalar::from(3u16), 5).into_array()); + test_mask(ConstantArray::new(Scalar::from(1.0f32 / 0.0f32), 5).into_array()); + test_mask(ConstantArray::new(Scalar::from(f16::from_f32(3.0f32)), 5).into_array()); + } +} diff --git a/vortex-array/src/array/extension/compute/cast.rs b/vortex-array/src/array/extension/compute/cast.rs new file mode 100644 index 00000000000..c8b046126d5 --- /dev/null +++ b/vortex-array/src/array/extension/compute/cast.rs @@ -0,0 +1,27 @@ +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::array::extension::ExtensionArray; +use crate::array::ExtensionEncoding; +use crate::compute::{try_cast, CastFn}; +use crate::{ArrayDType as _, ArrayData, IntoArrayData as _}; + +impl CastFn for ExtensionEncoding { + fn cast(&self, array: &ExtensionArray, dtype: &DType) -> VortexResult { + if !array.dtype().eq_ignore_nullability(dtype) { + vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype); + } + let DType::Extension(ext_dtype) = dtype else { + vortex_bail!( + "dtype must have extension dtype {} {}", + array.dtype(), + dtype + ); + }; + Ok(ExtensionArray::new( + ext_dtype.clone(), + try_cast(array.storage(), ext_dtype.storage_dtype())?, + ) + .into_array()) + } +} diff --git a/vortex-array/src/array/extension/compute/mask.rs b/vortex-array/src/array/extension/compute/mask.rs new file mode 100644 index 00000000000..53f1a7d1443 --- /dev/null +++ b/vortex-array/src/array/extension/compute/mask.rs @@ -0,0 +1,48 @@ +use std::sync::Arc; + +use vortex_dtype::{DType, Nullability}; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::array::extension::ExtensionArray; +use crate::array::ExtensionEncoding; +use crate::compute::{mask, FilterMask, MaskFn}; +use crate::{ArrayDType as _, ArrayData, IntoArrayData}; + +impl MaskFn for ExtensionEncoding { + fn mask(&self, array: &ExtensionArray, filter_mask: FilterMask) -> VortexResult { + let DType::Extension(ext_dtype) = array.dtype() else { + vortex_bail!("extension array must have extension dtype"); + }; + Ok(ExtensionArray::new( + Arc::from(ext_dtype.with_nullability(Nullability::Nullable)), + mask(&array.storage(), filter_mask)?, + ) + .into_array()) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use vortex_buffer::buffer; + use vortex_dtype::{DType, ExtDType, ExtID, PType}; + + use crate::array::ExtensionArray; + use crate::compute::test_harness::test_mask; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_extension_array() { + let ext_dtype = Arc::new(ExtDType::new( + ExtID::new("timestamp".into()), + DType::from(PType::I64).into(), + None, + )); + + test_mask( + ExtensionArray::new(ext_dtype.clone(), buffer![1i64, 2, 3, 4, 5].into_array()) + .into_array(), + ); + } +} diff --git a/vortex-array/src/array/extension/compute/mod.rs b/vortex-array/src/array/extension/compute/mod.rs index 46c0ec9f834..7e1bd2b445a 100644 --- a/vortex-array/src/array/extension/compute/mod.rs +++ b/vortex-array/src/array/extension/compute/mod.rs @@ -1,4 +1,6 @@ +mod cast; mod compare; +mod mask; use vortex_error::VortexResult; use vortex_scalar::Scalar; @@ -6,23 +8,29 @@ use vortex_scalar::Scalar; use crate::array::extension::ExtensionArray; use crate::array::ExtensionEncoding; use crate::compute::{ - scalar_at, slice, take, CastFn, CompareFn, ComputeVTable, ScalarAtFn, SliceFn, TakeFn, + scalar_at, slice, take, CastFn, CompareFn, ComputeVTable, MaskFn, ScalarAtFn, SliceFn, TakeFn, }; use crate::variants::ExtensionArrayTrait; use crate::{ArrayData, IntoArrayData}; impl ComputeVTable for ExtensionEncoding { fn cast_fn(&self) -> Option<&dyn CastFn> { - // It's not possible to cast an extension array to another type. - // TODO(ngates): we should allow some extension arrays to implement a callback - // to support this - None + // It's not possible to cast an extension array to another type, but we can make it + // nullable. + // + // TODO(ngates): we should allow some extension arrays to implement a callback to support + // this + Some(self) } fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/array/list/compute/mod.rs b/vortex-array/src/array/list/compute/mod.rs index 317f574c75b..0b99c390135 100644 --- a/vortex-array/src/array/list/compute/mod.rs +++ b/vortex-array/src/array/list/compute/mod.rs @@ -5,7 +5,7 @@ use vortex_error::VortexResult; use vortex_scalar::Scalar; use crate::array::{ListArray, ListEncoding}; -use crate::compute::{scalar_at, slice, ComputeVTable, ScalarAtFn, SliceFn}; +use crate::compute::{scalar_at, slice, ComputeVTable, FilterMask, MaskFn, ScalarAtFn, SliceFn}; use crate::{ArrayDType, ArrayData, IntoArrayData}; impl ComputeVTable for ListEncoding { @@ -16,6 +16,10 @@ impl ComputeVTable for ListEncoding { fn slice_fn(&self) -> Option<&dyn SliceFn> { Some(self) } + + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } } impl ScalarAtFn for ListEncoding { @@ -41,3 +45,34 @@ impl SliceFn for ListEncoding { .into_array()) } } + +impl MaskFn for ListEncoding { + fn mask(&self, array: &ListArray, mask: FilterMask) -> VortexResult { + ListArray::try_new( + array.elements(), + array.offsets(), + array.validity().mask(&mask)?, + ) + .map(IntoArrayData::into_array) + } +} + +#[cfg(test)] +mod test { + use crate::array::{ListArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; + use crate::validity::Validity; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_list() { + let elements = PrimitiveArray::from_iter(0..35); + let offsets = PrimitiveArray::from_iter([0, 5, 11, 18, 26, 35]); + let validity = Validity::AllValid; + let array = ListArray::try_new(elements.into_array(), offsets.into_array(), validity) + .unwrap() + .into_array(); + + test_mask(array); + } +} diff --git a/vortex-array/src/array/null/compute.rs b/vortex-array/src/array/null/compute.rs index 7a224aaec6e..8938267aaba 100644 --- a/vortex-array/src/array/null/compute.rs +++ b/vortex-array/src/array/null/compute.rs @@ -4,11 +4,15 @@ use vortex_scalar::Scalar; use crate::array::null::NullArray; use crate::array::NullEncoding; -use crate::compute::{ComputeVTable, ScalarAtFn, SliceFn, TakeFn}; +use crate::compute::{ComputeVTable, FilterMask, MaskFn, ScalarAtFn, SliceFn, TakeFn}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; impl ComputeVTable for NullEncoding { + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -22,6 +26,12 @@ impl ComputeVTable for NullEncoding { } } +impl MaskFn for NullEncoding { + fn mask(&self, array: &NullArray, _mask: FilterMask) -> VortexResult { + Ok(array.clone().into_array()) + } +} + impl SliceFn for NullEncoding { fn slice(&self, _array: &NullArray, start: usize, stop: usize) -> VortexResult { Ok(NullArray::new(stop - start).into_array()) diff --git a/vortex-array/src/array/primitive/compute/mask.rs b/vortex-array/src/array/primitive/compute/mask.rs new file mode 100644 index 00000000000..c4bb4162724 --- /dev/null +++ b/vortex-array/src/array/primitive/compute/mask.rs @@ -0,0 +1,17 @@ +use vortex_error::VortexResult; + +use crate::array::primitive::PrimitiveArray; +use crate::array::PrimitiveEncoding; +use crate::compute::{FilterMask, MaskFn}; +use crate::variants::PrimitiveArrayTrait as _; +use crate::{ArrayData, IntoArrayData}; + +impl MaskFn for PrimitiveEncoding { + fn mask(&self, array: &PrimitiveArray, mask: FilterMask) -> VortexResult { + let validity = array.validity().mask(&mask)?; + Ok( + PrimitiveArray::from_byte_buffer(array.byte_buffer().clone(), array.ptype(), validity) + .into_array(), + ) + } +} diff --git a/vortex-array/src/array/primitive/compute/mod.rs b/vortex-array/src/array/primitive/compute/mod.rs index bac23e2d7f4..cfdde3eec32 100644 --- a/vortex-array/src/array/primitive/compute/mod.rs +++ b/vortex-array/src/array/primitive/compute/mod.rs @@ -1,6 +1,6 @@ use crate::array::PrimitiveEncoding; use crate::compute::{ - CastFn, ComputeVTable, FillForwardFn, FilterFn, ScalarAtFn, SearchSortedFn, + CastFn, ComputeVTable, FillForwardFn, FilterFn, MaskFn, ScalarAtFn, SearchSortedFn, SearchSortedUsizeFn, SliceFn, TakeFn, }; use crate::ArrayData; @@ -8,6 +8,7 @@ use crate::ArrayData; mod cast; mod fill; mod filter; +mod mask; mod scalar_at; mod search_sorted; mod slice; @@ -18,6 +19,10 @@ impl ComputeVTable for PrimitiveEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn> { Some(self) } diff --git a/vortex-array/src/array/primitive/mod.rs b/vortex-array/src/array/primitive/mod.rs index fa212cbda05..a1bdb2547c8 100644 --- a/vortex-array/src/array/primitive/mod.rs +++ b/vortex-array/src/array/primitive/mod.rs @@ -320,3 +320,29 @@ impl VisitorVTable for PrimitiveEncoding { visitor.visit_validity(&array.validity()) } } + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + + use crate::array::{BoolArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; + use crate::validity::Validity; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_primitive_array() { + test_mask(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable).into_array()); + test_mask(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid).into_array()); + test_mask(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllInvalid).into_array()); + test_mask( + PrimitiveArray::new( + buffer![0, 1, 2, 3, 4], + Validity::Array( + BoolArray::from_iter([true, false, true, false, true]).into_array(), + ), + ) + .into_array(), + ); + } +} diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index 1d323085b36..ae6f22d4a8a 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -105,13 +105,14 @@ impl FilterFn for SparseEncoding { mod test { use rstest::{fixture, rstest}; use vortex_buffer::buffer; + use vortex_dtype::{DType, Nullability, PType}; use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; use crate::array::sparse::SparseArray; - use crate::compute::test_harness::test_binary_numeric; + use crate::compute::test_harness::{test_binary_numeric, test_mask}; use crate::compute::{ - filter, search_sorted, slice, FilterMask, SearchResult, SearchSortedSide, + filter, search_sorted, slice, try_cast, FilterMask, SearchResult, SearchSortedSide, }; use crate::validity::Validity; use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; @@ -225,4 +226,35 @@ mod test { fn test_sparse_binary_numeric(array: ArrayData) { test_binary_numeric::(array) } + + #[test] + fn test_mask_sparse_array() { + let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)); + test_mask( + SparseArray::try_new( + buffer![1u64, 2, 4].into_array(), + try_cast( + buffer![100i32, 200, 300].into_array(), + null_fill_value.dtype(), + ) + .unwrap(), + 5, + null_fill_value, + ) + .unwrap() + .into_array(), + ); + + let ten_fill_value = Scalar::from(10i32); + test_mask( + SparseArray::try_new( + buffer![1u64, 2, 4].into_array(), + buffer![100i32, 200, 300].into_array(), + 5, + ten_fill_value, + ) + .unwrap() + .into_array(), + ) + } } diff --git a/vortex-array/src/array/struct_/compute.rs b/vortex-array/src/array/struct_/compute.rs index ba3b44dce7c..56064d4605a 100644 --- a/vortex-array/src/array/struct_/compute.rs +++ b/vortex-array/src/array/struct_/compute.rs @@ -1,21 +1,30 @@ use itertools::Itertools; -use vortex_error::VortexResult; +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexResult}; use vortex_scalar::Scalar; use crate::array::struct_::StructArray; use crate::array::StructEncoding; use crate::compute::{ - filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, SliceFn, - TakeFn, + filter, scalar_at, slice, take, try_cast, CastFn, ComputeVTable, FilterFn, FilterMask, MaskFn, + ScalarAtFn, SliceFn, TakeFn, }; use crate::variants::StructArrayTrait; -use crate::{ArrayDType, ArrayData, IntoArrayData}; +use crate::{ArrayDType, ArrayData, ArrayLen as _, IntoArrayData}; impl ComputeVTable for StructEncoding { + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -29,6 +38,28 @@ impl ComputeVTable for StructEncoding { } } +impl CastFn for StructEncoding { + fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult { + let Some(sdtype) = dtype.as_struct() else { + vortex_bail!("cannot cast {} to {}", array.dtype(), dtype); + }; + + let validity = array.validity().with_nullability(dtype.nullability())?; + + StructArray::try_new( + array.names().clone(), + array + .children() + .zip_eq(sdtype.dtypes()) + .map(|(field, dtype)| try_cast(&field, &dtype)) + .try_collect()?, + array.len(), + validity, + ) + .map(|a| a.into_array()) + } +} + impl ScalarAtFn for StructEncoding { fn scalar_at(&self, array: &StructArray, index: usize) -> VortexResult { Ok(Scalar::struct_( @@ -90,11 +121,31 @@ impl FilterFn for StructEncoding { } } +impl MaskFn for StructEncoding { + fn mask(&self, array: &StructArray, filter_mask: FilterMask) -> VortexResult { + let validity = array.validity().mask(&filter_mask)?; + + StructArray::try_new( + array.names().clone(), + array.children().collect(), + array.len(), + validity, + ) + .map(|a| a.into_array()) + } +} + #[cfg(test)] mod tests { - use crate::array::StructArray; - use crate::compute::{filter, FilterMask}; + use arrow_buffer::BooleanBuffer; + use vortex_buffer::buffer; + use vortex_dtype::{DType, Nullability, PType, StructDType}; + + use crate::array::{BoolArray, PrimitiveArray, StructArray, VarBinArray}; + use crate::compute::test_harness::test_mask; + use crate::compute::{filter, try_cast, FilterMask}; use crate::validity::Validity; + use crate::{ArrayDType as _, IntoArrayData as _}; #[test] fn filter_empty_struct() { @@ -114,4 +165,150 @@ mod tests { let filtered = filter(struct_arr.as_ref(), FilterMask::from_iter::<[bool; 0]>([])).unwrap(); assert_eq!(filtered.len(), 0); } + + #[test] + fn test_mask_empty_struct() { + test_mask( + StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable) + .unwrap() + .into_array(), + ); + } + + #[test] + fn test_mask_complex_struct() { + let xs = buffer![0i64, 1, 2, 3, 4].into_array(); + let ys = VarBinArray::from_iter( + [Some("a"), Some("b"), None, Some("d"), None], + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + let zs = + BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array(); + + test_mask( + StructArray::try_new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + StructArray::try_new( + ["left".into(), "right".into()].into(), + vec![xs.clone(), xs], + 5, + Validity::NonNullable, + ) + .unwrap() + .into_array(), + ys, + zs, + ], + 5, + Validity::NonNullable, + ) + .unwrap() + .into_array(), + ); + } + + #[test] + fn test_cast_empty_struct() { + let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable) + .unwrap() + .into_array(); + let non_nullable_dtype = DType::Struct( + StructDType::new([].into(), vec![]), + Nullability::NonNullable, + ); + let casted = try_cast(&array, &non_nullable_dtype).unwrap(); + assert_eq!(casted.dtype(), &non_nullable_dtype); + + let nullable_dtype = + DType::Struct(StructDType::new([].into(), vec![]), Nullability::Nullable); + let casted = try_cast(&array, &nullable_dtype).unwrap(); + assert_eq!(casted.dtype(), &nullable_dtype); + } + + #[test] + fn test_cast_complex_struct() { + let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]) + .into_array(); + let ys = VarBinArray::from_vec( + vec!["a", "b", "c", "d", "e"], + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + let zs = BoolArray::new( + BooleanBuffer::from_iter([true, true, false, false, true]), + Nullability::Nullable, + ) + .into_array(); + let fully_nullable_array = StructArray::try_new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + StructArray::try_new( + ["left".into(), "right".into()].into(), + vec![xs.clone(), xs], + 5, + Validity::AllValid, + ) + .unwrap() + .into_array(), + ys, + zs, + ], + 5, + Validity::AllValid, + ) + .unwrap() + .into_array(); + + let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable(); + let casted = try_cast(&fully_nullable_array, &top_level_non_nullable).unwrap(); + assert_eq!(casted.dtype(), &top_level_non_nullable); + + let non_null_xs_right = DType::Struct( + StructDType::new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + DType::Struct( + StructDType::new( + ["left".into(), "right".into()].into(), + vec![ + DType::Primitive(PType::I64, Nullability::NonNullable), + DType::Primitive(PType::I64, Nullability::Nullable), + ], + ), + Nullability::Nullable, + ), + DType::Utf8(Nullability::Nullable), + DType::Bool(Nullability::Nullable), + ], + ), + Nullability::Nullable, + ); + let casted = try_cast(&fully_nullable_array, &non_null_xs_right).unwrap(); + assert_eq!(casted.dtype(), &non_null_xs_right); + + let non_null_xs = DType::Struct( + StructDType::new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + DType::Struct( + StructDType::new( + ["left".into(), "right".into()].into(), + vec![ + DType::Primitive(PType::I64, Nullability::Nullable), + DType::Primitive(PType::I64, Nullability::Nullable), + ], + ), + Nullability::NonNullable, + ), + DType::Utf8(Nullability::Nullable), + DType::Bool(Nullability::Nullable), + ], + ), + Nullability::Nullable, + ); + let casted = try_cast(&fully_nullable_array, &non_null_xs).unwrap(); + assert_eq!(casted.dtype(), &non_null_xs); + } } diff --git a/vortex-array/src/array/varbin/compute/cast.rs b/vortex-array/src/array/varbin/compute/cast.rs new file mode 100644 index 00000000000..c1c2c04d157 --- /dev/null +++ b/vortex-array/src/array/varbin/compute/cast.rs @@ -0,0 +1,25 @@ +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexResult}; + +use crate::array::varbin::VarBinArray; +use crate::array::VarBinEncoding; +use crate::compute::CastFn; +use crate::{ArrayDType, ArrayData, IntoArrayData}; + +impl CastFn for VarBinEncoding { + fn cast(&self, array: &VarBinArray, dtype: &DType) -> VortexResult { + match dtype { + DType::Utf8(nullability) => { + let validity = array.validity().with_nullability(*nullability)?; + VarBinArray::try_new( + array.offsets(), + array.bytes(), + array.dtype().with_nullability(*nullability), + validity, + ) + .map(IntoArrayData::into_array) + } + _ => vortex_bail!("cannot cast {} to {}", array.dtype(), dtype), + } + } +} diff --git a/vortex-array/src/array/varbin/compute/mask.rs b/vortex-array/src/array/varbin/compute/mask.rs new file mode 100644 index 00000000000..a3481c03cf6 --- /dev/null +++ b/vortex-array/src/array/varbin/compute/mask.rs @@ -0,0 +1,44 @@ +use vortex_error::VortexResult; + +use crate::array::varbin::VarBinArray; +use crate::array::VarBinEncoding; +use crate::compute::{FilterMask, MaskFn}; +use crate::{ArrayDType, ArrayData, IntoArrayData}; + +impl MaskFn for VarBinEncoding { + fn mask(&self, array: &VarBinArray, mask: FilterMask) -> VortexResult { + VarBinArray::try_new( + array.offsets(), + array.bytes(), + array.dtype().as_nullable(), + array.validity().mask(&mask)?, + ) + .map(IntoArrayData::into_array) + } +} + +#[cfg(test)] +mod test { + use vortex_dtype::{DType, Nullability}; + + use crate::array::VarBinArray; + use crate::compute::test_harness::test_mask; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_var_bin_array() { + let array = VarBinArray::from_vec( + vec!["hello", "world", "filter", "good", "bye"], + DType::Utf8(Nullability::NonNullable), + ) + .into_array(); + test_mask(array); + + let array = VarBinArray::from_iter( + vec![Some("hello"), None, Some("filter"), Some("good"), None], + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + test_mask(array); + } +} diff --git a/vortex-array/src/array/varbin/compute/mod.rs b/vortex-array/src/array/varbin/compute/mod.rs index 2174280a013..f18d1f6be60 100644 --- a/vortex-array/src/array/varbin/compute/mod.rs +++ b/vortex-array/src/array/varbin/compute/mod.rs @@ -3,15 +3,23 @@ use vortex_scalar::Scalar; use crate::array::varbin::{varbin_scalar, VarBinArray}; use crate::array::VarBinEncoding; -use crate::compute::{CompareFn, ComputeVTable, FilterFn, ScalarAtFn, SliceFn, TakeFn}; +use crate::compute::{ + CastFn, CompareFn, ComputeVTable, FilterFn, MaskFn, ScalarAtFn, SliceFn, TakeFn, +}; use crate::{ArrayDType, ArrayData}; +mod cast; mod compare; mod filter; +mod mask; mod slice; mod take; impl ComputeVTable for VarBinEncoding { + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } @@ -20,6 +28,10 @@ impl ComputeVTable for VarBinEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/array/varbinview/compute/mod.rs b/vortex-array/src/array/varbinview/compute/mod.rs index 0e707b7c746..fd975788e09 100644 --- a/vortex-array/src/array/varbinview/compute/mod.rs +++ b/vortex-array/src/array/varbinview/compute/mod.rs @@ -3,19 +3,29 @@ use std::ops::Deref; use itertools::Itertools; use num_traits::AsPrimitive; use vortex_buffer::{Alignment, Buffer, ByteBuffer}; -use vortex_dtype::{match_each_integer_ptype, PType}; -use vortex_error::VortexResult; +use vortex_dtype::{match_each_integer_ptype, DType, PType}; +use vortex_error::{vortex_bail, VortexResult}; use vortex_scalar::Scalar; use crate::array::varbin::varbin_scalar; use crate::array::varbinview::{VarBinViewArray, VIEW_SIZE_BYTES}; use crate::array::{PrimitiveArray, VarBinViewEncoding}; -use crate::compute::{slice, ComputeVTable, ScalarAtFn, SliceFn, TakeFn}; +use crate::compute::{ + slice, CastFn, ComputeVTable, FilterMask, MaskFn, ScalarAtFn, SliceFn, TakeFn, +}; use crate::validity::Validity; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; impl ComputeVTable for VarBinViewEncoding { + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -138,6 +148,36 @@ fn take_views_unchecked>(views: Buffer, indices: &[I ) } +impl CastFn for VarBinViewEncoding { + fn cast(&self, array: &VarBinViewArray, dtype: &DType) -> VortexResult { + match dtype { + DType::Utf8(nullability) => { + let validity = array.validity().with_nullability(*nullability)?; + VarBinViewArray::try_new( + array.views(), + array.buffers().collect(), + array.dtype().with_nullability(*nullability), + validity, + ) + .map(IntoArrayData::into_array) + } + _ => vortex_bail!("cannot cast {} to {}", array.dtype(), dtype), + } + } +} + +impl MaskFn for VarBinViewEncoding { + fn mask(&self, array: &VarBinViewArray, mask: FilterMask) -> VortexResult { + VarBinViewArray::try_new( + array.views(), + array.buffers().collect(), + array.dtype().as_nullable(), + array.validity().mask(&mask)?, + ) + .map(IntoArrayData::into_array) + } +} + #[cfg(test)] mod tests { use vortex_buffer::buffer; @@ -145,6 +185,7 @@ mod tests { use crate::accessor::ArrayAccessor; use crate::array::VarBinViewArray; use crate::compute::take; + use crate::compute::test_harness::test_mask; use crate::{ArrayDType, IntoArrayData, IntoArrayVariant}; #[test] @@ -172,4 +213,22 @@ mod tests { [Some("one".to_string()), Some("four".to_string())] ); } + + #[test] + fn take_mask_var_bin_view_array() { + test_mask( + VarBinViewArray::from_iter_str(["one", "two", "three", "four", "five"]).into_array(), + ); + + test_mask( + VarBinViewArray::from_iter_nullable_str([ + Some("one"), + None, + Some("three"), + Some("four"), + Some("five"), + ]) + .into_array(), + ); + } } diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index 89fab5e6f87..8c82baa7c0f 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -36,7 +36,31 @@ where } } -/// Return a new array by applying a boolean predicate to select items from a base Array. +/// Keep only the elements for which the corresponding mask value is true. +/// +/// # Examples +/// +/// ``` +/// use vortex_array::IntoArrayData; +/// use vortex_array::array::{BoolArray, PrimitiveArray}; +/// use vortex_array::compute::{FilterMask, scalar_at}; +/// use vortex_array::compute::filter; +/// use vortex_array::validity::ArrayValidity; +/// use vortex_scalar::Scalar; +/// +/// let array = +/// PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]) +/// .into_array(); +/// let mask = FilterMask::try_from( +/// BoolArray::from_iter([true, false, false, false, true]).into_array(), +/// ) +/// .unwrap(); +/// +/// let filtered = filter(&array, mask).unwrap(); +/// assert_eq!(filtered.len(), 2); +/// assert_eq!(scalar_at(&filtered, 0).unwrap(), Scalar::from(Some(0_i32))); +/// assert_eq!(scalar_at(&filtered, 1).unwrap(), Scalar::from(Some(2_i32))); +/// ``` /// /// # Performance /// @@ -251,7 +275,7 @@ impl FilterMask { self.boolean_buffer().cloned() } - fn boolean_buffer(&self) -> VortexResult<&BooleanBuffer> { + pub fn boolean_buffer(&self) -> VortexResult<&BooleanBuffer> { self.buffer.get_or_try_init(|| { Ok(self .array @@ -351,6 +375,21 @@ impl FromIterator for FilterMask { } } +impl FilterMask { + pub fn from_slices>(length: usize, slices: T) -> Self { + let mut builder = BooleanBufferBuilder::new(length); + let mut cursor = 0; + for (start, end) in slices.into_iter() { + assert!(start <= cursor); + builder.append_n(start - cursor, false); + builder.append_n(end - start, true); + cursor = end; + } + + Self::from(builder.finish()) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/vortex-array/src/compute/mask.rs b/vortex-array/src/compute/mask.rs new file mode 100644 index 00000000000..1ee48cb837d --- /dev/null +++ b/vortex-array/src/compute/mask.rs @@ -0,0 +1,232 @@ +use arrow_array::BooleanArray; +use vortex_error::{vortex_bail, VortexError, VortexResult}; +use vortex_scalar::Scalar; + +use super::FilterMask; +use crate::array::ConstantArray; +use crate::arrow::FromArrowArray; +use crate::compute::try_cast; +use crate::encoding::Encoding; +use crate::{ArrayDType, ArrayData, IntoArrayData, IntoCanonical}; + +pub trait MaskFn { + /// Replace masked values with null in array. + fn mask(&self, array: &Array, mask: FilterMask) -> VortexResult; +} + +impl MaskFn for E +where + E: MaskFn, + for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>, +{ + fn mask(&self, array: &ArrayData, mask: FilterMask) -> VortexResult { + let (array_ref, encoding) = array.downcast_array_ref::()?; + MaskFn::mask(encoding, array_ref, mask) + } +} + +/// Replace values with null where the mask is true. +/// +/// The returned array is nullable but otherwise has the same dtype and length as `array`. +/// +/// # Examples +/// +/// ``` +/// use vortex_array::IntoArrayData; +/// use vortex_array::array::{BoolArray, PrimitiveArray}; +/// use vortex_array::compute::{FilterMask, scalar_at}; +/// use vortex_array::compute::mask; +/// use vortex_array::validity::ArrayValidity; +/// use vortex_scalar::Scalar; +/// +/// let array = +/// PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]) +/// .into_array(); +/// let mask_array = FilterMask::try_from( +/// BoolArray::from_iter([true, false, false, false, true]).into_array(), +/// ) +/// .unwrap(); +/// +/// let masked = mask(&array, mask_array).unwrap(); +/// assert_eq!(masked.len(), 5); +/// assert!(!masked.is_valid(0)); +/// assert!(!masked.is_valid(1)); +/// assert_eq!(scalar_at(&masked, 2).unwrap(), Scalar::from(Some(1))); +/// assert!(!masked.is_valid(3)); +/// assert!(!masked.is_valid(4)); +/// ``` +/// +pub fn mask(array: &ArrayData, mask: FilterMask) -> VortexResult { + if mask.len() != array.len() { + vortex_bail!( + "mask.len() is {}, does not equal array.len() of {}", + mask.len(), + array.len() + ); + } + + let true_count = mask.true_count(); + + let masked = if true_count == 0 { + // Fast-path for empty mask + try_cast(array, &array.dtype().as_nullable())? + } else if true_count == mask.len() { + // Fast-path for full mask. + ConstantArray::new( + Scalar::null(array.dtype().clone().as_nullable()), + array.len(), + ) + .into_array() + } else { + mask_impl(array, mask)? + }; + + debug_assert_eq!( + masked.len(), + array.len(), + "Mask should not change length {}\n\n{:?}\n\n{:?}", + array.encoding().id(), + array, + masked + ); + debug_assert_eq!( + masked.dtype(), + &array.dtype().as_nullable(), + "Mask dtype mismatch {} {} {} {}", + array.encoding().id(), + masked.dtype(), + array.dtype(), + array.dtype().as_nullable(), + ); + + Ok(masked) +} + +fn mask_impl(array: &ArrayData, mask: FilterMask) -> VortexResult { + if let Some(mask_fn) = array.encoding().mask_fn() { + return mask_fn.mask(array, mask); + } + + // Fallback: implement using Arrow kernels. + log::debug!("No mask implementation found for {}", array.encoding().id(),); + + let array_ref = array.clone().into_arrow()?; + let mask = BooleanArray::new(mask.to_boolean_buffer()?, None); + + let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?; + + Ok(ArrayData::from_arrow(masked, true)) +} + +#[cfg(feature = "test-harness")] +pub mod test_harness { + use crate::array::BoolArray; + use crate::compute::{mask, scalar_at, FilterMask}; + use crate::validity::ArrayValidity as _; + use crate::{ArrayData, IntoArrayData}; + + pub fn test_mask(array: ArrayData) { + assert_eq!(array.len(), 5); + test_heterogenous_mask(&array); + test_empty_mask(&array); + test_full_mask(&array); + } + + #[allow(clippy::unwrap_used)] + fn test_heterogenous_mask(array: &ArrayData) { + let mask_array = FilterMask::try_from( + BoolArray::from_iter([true, false, false, true, true]).into_array(), + ) + .unwrap(); + let masked = mask(array, mask_array).unwrap(); + assert_eq!(masked.len(), array.len()); + assert!(!masked.is_valid(0)); + assert_eq!( + scalar_at(&masked, 1).unwrap(), + scalar_at(array, 1).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 2).unwrap(), + scalar_at(array, 2).unwrap().into_nullable() + ); + assert!(!masked.is_valid(3)); + assert!(!masked.is_valid(4)); + } + + #[allow(clippy::unwrap_used)] + fn test_empty_mask(array: &ArrayData) { + let all_unmasked = FilterMask::try_from( + BoolArray::from_iter([false, false, false, false, false]).into_array(), + ) + .unwrap(); + let masked = mask(array, all_unmasked).unwrap(); + assert_eq!(masked.len(), array.len()); + assert_eq!( + scalar_at(&masked, 0).unwrap(), + scalar_at(array, 0).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 1).unwrap(), + scalar_at(array, 1).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 2).unwrap(), + scalar_at(array, 2).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 3).unwrap(), + scalar_at(array, 3).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 4).unwrap(), + scalar_at(array, 4).unwrap().into_nullable() + ); + } + + #[allow(clippy::unwrap_used)] + fn test_full_mask(array: &ArrayData) { + let all_masked = + FilterMask::try_from(BoolArray::from_iter([true, true, true, true, true]).into_array()) + .unwrap(); + let masked = mask(array, all_masked).unwrap(); + assert_eq!(masked.len(), array.len()); + assert!(!masked.is_valid(0)); + assert!(!masked.is_valid(1)); + assert!(!masked.is_valid(2)); + assert!(!masked.is_valid(3)); + assert!(!masked.is_valid(4)); + + let mask1 = FilterMask::try_from( + BoolArray::from_iter([true, false, false, true, true]).into_array(), + ) + .unwrap(); + let mask2 = FilterMask::try_from( + BoolArray::from_iter([false, true, false, false, true]).into_array(), + ) + .unwrap(); + let first = mask(array, mask1).unwrap(); + let double_masked = mask(&first, mask2).unwrap(); + assert_eq!(double_masked.len(), array.len()); + assert!(!double_masked.is_valid(0)); + assert!(!double_masked.is_valid(1)); + assert_eq!( + scalar_at(&double_masked, 2).unwrap(), + scalar_at(array, 2).unwrap().into_nullable() + ); + assert!(!double_masked.is_valid(3)); + assert!(!double_masked.is_valid(4)); + } +} + +#[cfg(test)] +mod test { + use super::test_harness::test_mask; + use crate::array::PrimitiveArray; + use crate::IntoArrayData as _; + + #[test] + fn test_mask_non_nullable_array() { + let non_nullable_array = PrimitiveArray::from_iter([1, 2, 3, 4, 5]).into_array(); + test_mask(non_nullable_array); + } +} diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 9e7e9a135cd..d6b5396d110 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -21,6 +21,7 @@ pub use fill_null::{fill_null, FillNullFn}; pub use filter::{filter, FilterFn, FilterIter, FilterMask}; pub use invert::{invert, InvertFn}; pub use like::{like, LikeFn, LikeOptions}; +pub use mask::{mask, MaskFn}; pub use scalar_at::{scalar_at, ScalarAtFn}; pub use search_sorted::*; pub use slice::{slice, SliceFn}; @@ -37,6 +38,7 @@ mod fill_null; mod filter; mod invert; mod like; +mod mask; mod scalar_at; mod search_sorted; mod slice; @@ -87,13 +89,24 @@ pub trait ComputeVTable { None } - /// Filter an array with a given mask. + /// Remove masked values from the array. + /// + /// The length of the returned array equals the number of true mask values. /// /// See: [FilterFn]. fn filter_fn(&self) -> Option<&dyn FilterFn> { None } + /// Replace masked values with null. + /// + /// This operation does not change the length of the array. + /// + /// See: [MaskFn]. + fn mask_fn(&self) -> Option<&dyn MaskFn> { + None + } + /// Invert a boolean array. Converts true -> false, false -> true, null -> null. /// /// See [InvertFn] @@ -148,4 +161,5 @@ pub trait ComputeVTable { #[cfg(feature = "test-harness")] pub mod test_harness { pub use crate::compute::binary_numeric::test_harness::test_binary_numeric; + pub use crate::compute::mask::test_harness::test_mask; } diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 6128796a790..7a7893e61fe 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -1,7 +1,7 @@ //! Array validity and nullability behavior, used by arrays and compute functions. use std::fmt::{Debug, Display}; -use std::ops::BitAnd; +use std::ops::{BitAnd, Not}; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use serde::{Deserialize, Serialize}; @@ -222,6 +222,9 @@ impl Validity { } } + /// Keep only the entries for which the mask is true. + /// + /// The result has length equal to the number of true values in mask. pub fn filter(&self, mask: &FilterMask) -> VortexResult { // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily // if we happen to be NonNullable, AllValid, or AllInvalid. @@ -233,6 +236,23 @@ impl Validity { } } + /// Set to false any entries for which the mask is true. + /// + /// The result is always nullable. The result has the same length as self. + pub fn mask(self, mask: &FilterMask) -> VortexResult { + Ok(match self { + Validity::NonNullable | Validity::AllValid => { + Validity::Array(BoolArray::from(mask.boolean_buffer()?.not()).into_array()) + } + Validity::AllInvalid => Validity::AllInvalid, + Validity::Array(is_valid) => { + let bools = BoolArray::try_from(is_valid)?.boolean_buffer(); + let is_valid_in_mask = mask.boolean_buffer()?.not(); + Validity::from(bools.bitand(&is_valid_in_mask)) + } + }) + } + pub fn to_logical(&self, length: usize) -> LogicalValidity { match self { Self::NonNullable => LogicalValidity::AllValid(length), @@ -334,6 +354,35 @@ impl Validity { } } + /// Convert into a non-nullable variant + pub fn into_non_nullable(self) -> Option { + match self { + Self::NonNullable => Some(Self::NonNullable), + Self::AllValid => Some(Self::NonNullable), + Self::AllInvalid => None, + Self::Array(is_valid) => { + is_valid + .statistics() + .compute_min::() + .unwrap_or(false) + .then(|| { + // min true => all true + Self::NonNullable + }) + } + } + } + + /// Convert into a variant compatible with the given nullability, if possible. + pub fn with_nullability(self, nullability: Nullability) -> VortexResult { + match nullability { + Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| { + vortex_err!("cannot cast array with invalid values to non-nullable type") + }), + Nullability::Nullable => Ok(self.into_nullable()), + } + } + /// Create Validity from boolean array with given nullability of the array. /// /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself From c25b6687aae24411c5852b816f169d5891be560c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 13 Jan 2025 14:43:01 -0500 Subject: [PATCH 02/26] remove unnecssary clone --- vortex-array/src/array/extension/compute/mask.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vortex-array/src/array/extension/compute/mask.rs b/vortex-array/src/array/extension/compute/mask.rs index 53f1a7d1443..72a733ffde2 100644 --- a/vortex-array/src/array/extension/compute/mask.rs +++ b/vortex-array/src/array/extension/compute/mask.rs @@ -41,8 +41,7 @@ mod test { )); test_mask( - ExtensionArray::new(ext_dtype.clone(), buffer![1i64, 2, 3, 4, 5].into_array()) - .into_array(), + ExtensionArray::new(ext_dtype, buffer![1i64, 2, 3, 4, 5].into_array()).into_array(), ); } } From 26c461300112e7f1406aee12599593f91bac4e6c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 13 Jan 2025 15:00:15 -0500 Subject: [PATCH 03/26] clippy again --- encodings/datetime-parts/src/compute/mask.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/encodings/datetime-parts/src/compute/mask.rs b/encodings/datetime-parts/src/compute/mask.rs index 9c52974b464..1ddbfc66da8 100644 --- a/encodings/datetime-parts/src/compute/mask.rs +++ b/encodings/datetime-parts/src/compute/mask.rs @@ -53,6 +53,6 @@ mod tests { .unwrap() .into_array(); - test_mask(date_times.clone()); + test_mask(date_times); } } From 469f401045d5ae19a3727fee85133155e85bfa66 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 14 Jan 2025 13:05:02 -0500 Subject: [PATCH 04/26] revert datetime-parts changes --- encodings/datetime-parts/src/compute/cast.rs | 25 --------- encodings/datetime-parts/src/compute/mask.rs | 58 -------------------- encodings/datetime-parts/src/compute/mod.rs | 13 +---- 3 files changed, 1 insertion(+), 95 deletions(-) delete mode 100644 encodings/datetime-parts/src/compute/cast.rs delete mode 100644 encodings/datetime-parts/src/compute/mask.rs diff --git a/encodings/datetime-parts/src/compute/cast.rs b/encodings/datetime-parts/src/compute/cast.rs deleted file mode 100644 index ce703d630fa..00000000000 --- a/encodings/datetime-parts/src/compute/cast.rs +++ /dev/null @@ -1,25 +0,0 @@ -use vortex_array::compute::{try_cast, CastFn}; -use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; -use vortex_dtype::DType; -use vortex_error::{vortex_bail, VortexResult}; - -use crate::{DateTimePartsArray, DateTimePartsEncoding}; - -impl CastFn for DateTimePartsEncoding { - fn cast(&self, array: &DateTimePartsArray, dtype: &DType) -> VortexResult { - if !array.dtype().eq_ignore_nullability(dtype) { - vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype); - }; - - Ok(DateTimePartsArray::try_new( - array.dtype().clone().as_nullable(), - try_cast( - array.days().as_ref(), - &array.days().dtype().with_nullability(dtype.nullability()), - )?, - array.seconds(), - array.subsecond(), - )? - .into_array()) - } -} diff --git a/encodings/datetime-parts/src/compute/mask.rs b/encodings/datetime-parts/src/compute/mask.rs deleted file mode 100644 index 1ddbfc66da8..00000000000 --- a/encodings/datetime-parts/src/compute/mask.rs +++ /dev/null @@ -1,58 +0,0 @@ -use vortex_array::compute::{mask, FilterMask, MaskFn}; -use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; -use vortex_error::VortexResult; - -use crate::{DateTimePartsArray, DateTimePartsEncoding}; - -impl MaskFn for DateTimePartsEncoding { - fn mask(&self, array: &DateTimePartsArray, filter_mask: FilterMask) -> VortexResult { - Ok(DateTimePartsArray::try_new( - array.dtype().clone().as_nullable(), - mask(array.days().as_ref(), filter_mask)?, - array.seconds(), - array.subsecond(), - )? - .into_array()) - } -} - -#[cfg(test)] -mod tests { - use vortex_array::array::TemporalArray; - use vortex_array::compute::test_harness::test_mask; - use vortex_array::IntoArrayData as _; - use vortex_buffer::buffer; - use vortex_datetime_dtype::TimeUnit; - use vortex_dtype::DType; - - use crate::{split_temporal, DateTimePartsArray, TemporalParts}; - - #[test] - fn test_mask_datetime_parts_array() { - let raw_millis = buffer![ - 86_400i64, // element with only day component - 86_400i64 + 1000, // element with day + second components - 86_400i64 + 1000 + 1, // element with day + second + sub-second components - 86_400i64 + 1000 + 5, // element with day + second + sub-second components - 86_400i64 + 1000 + 55, // element with day + second + sub-second components - ] - .into_array(); - let temporal_array = - TemporalArray::new_timestamp(raw_millis, TimeUnit::Ms, Some("UTC".to_string())); - let TemporalParts { - days, - seconds, - subseconds, - } = split_temporal(temporal_array.clone()).unwrap(); - let date_times = DateTimePartsArray::try_new( - DType::Extension(temporal_array.ext_dtype()), - days, - seconds, - subseconds, - ) - .unwrap() - .into_array(); - - test_mask(date_times); - } -} diff --git a/encodings/datetime-parts/src/compute/mod.rs b/encodings/datetime-parts/src/compute/mod.rs index d5ebef744f8..de3937c1e8a 100644 --- a/encodings/datetime-parts/src/compute/mod.rs +++ b/encodings/datetime-parts/src/compute/mod.rs @@ -1,12 +1,9 @@ -mod cast; mod filter; -mod mask; mod take; use vortex_array::array::{PrimitiveArray, TemporalArray}; use vortex_array::compute::{ - scalar_at, slice, try_cast, CastFn, ComputeVTable, FilterFn, MaskFn, ScalarAtFn, SliceFn, - TakeFn, + scalar_at, slice, try_cast, ComputeVTable, FilterFn, ScalarAtFn, SliceFn, TakeFn, }; use vortex_array::validity::ArrayValidity; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; @@ -20,18 +17,10 @@ use vortex_scalar::{PrimitiveScalar, Scalar}; use crate::{DateTimePartsArray, DateTimePartsEncoding}; impl ComputeVTable for DateTimePartsEncoding { - fn cast_fn(&self) -> Option<&dyn CastFn> { - Some(self) - } - fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } - fn mask_fn(&self) -> Option<&dyn MaskFn> { - Some(self) - } - fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } From 757d1b823af039a7cd4be725f5e701b6f2cc6ba7 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 15 Jan 2025 10:40:57 -0500 Subject: [PATCH 05/26] extension arrays are neither castable nor maskable --- .../src/array/extension/compute/cast.rs | 27 ----------- .../src/array/extension/compute/mask.rs | 47 ------------------- .../src/array/extension/compute/mod.rs | 18 ++----- 3 files changed, 5 insertions(+), 87 deletions(-) delete mode 100644 vortex-array/src/array/extension/compute/cast.rs delete mode 100644 vortex-array/src/array/extension/compute/mask.rs diff --git a/vortex-array/src/array/extension/compute/cast.rs b/vortex-array/src/array/extension/compute/cast.rs deleted file mode 100644 index c8b046126d5..00000000000 --- a/vortex-array/src/array/extension/compute/cast.rs +++ /dev/null @@ -1,27 +0,0 @@ -use vortex_dtype::DType; -use vortex_error::{vortex_bail, VortexResult}; - -use crate::array::extension::ExtensionArray; -use crate::array::ExtensionEncoding; -use crate::compute::{try_cast, CastFn}; -use crate::{ArrayDType as _, ArrayData, IntoArrayData as _}; - -impl CastFn for ExtensionEncoding { - fn cast(&self, array: &ExtensionArray, dtype: &DType) -> VortexResult { - if !array.dtype().eq_ignore_nullability(dtype) { - vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype); - } - let DType::Extension(ext_dtype) = dtype else { - vortex_bail!( - "dtype must have extension dtype {} {}", - array.dtype(), - dtype - ); - }; - Ok(ExtensionArray::new( - ext_dtype.clone(), - try_cast(array.storage(), ext_dtype.storage_dtype())?, - ) - .into_array()) - } -} diff --git a/vortex-array/src/array/extension/compute/mask.rs b/vortex-array/src/array/extension/compute/mask.rs deleted file mode 100644 index 72a733ffde2..00000000000 --- a/vortex-array/src/array/extension/compute/mask.rs +++ /dev/null @@ -1,47 +0,0 @@ -use std::sync::Arc; - -use vortex_dtype::{DType, Nullability}; -use vortex_error::{vortex_bail, VortexResult}; - -use crate::array::extension::ExtensionArray; -use crate::array::ExtensionEncoding; -use crate::compute::{mask, FilterMask, MaskFn}; -use crate::{ArrayDType as _, ArrayData, IntoArrayData}; - -impl MaskFn for ExtensionEncoding { - fn mask(&self, array: &ExtensionArray, filter_mask: FilterMask) -> VortexResult { - let DType::Extension(ext_dtype) = array.dtype() else { - vortex_bail!("extension array must have extension dtype"); - }; - Ok(ExtensionArray::new( - Arc::from(ext_dtype.with_nullability(Nullability::Nullable)), - mask(&array.storage(), filter_mask)?, - ) - .into_array()) - } -} - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use vortex_buffer::buffer; - use vortex_dtype::{DType, ExtDType, ExtID, PType}; - - use crate::array::ExtensionArray; - use crate::compute::test_harness::test_mask; - use crate::IntoArrayData as _; - - #[test] - fn test_mask_extension_array() { - let ext_dtype = Arc::new(ExtDType::new( - ExtID::new("timestamp".into()), - DType::from(PType::I64).into(), - None, - )); - - test_mask( - ExtensionArray::new(ext_dtype, buffer![1i64, 2, 3, 4, 5].into_array()).into_array(), - ); - } -} diff --git a/vortex-array/src/array/extension/compute/mod.rs b/vortex-array/src/array/extension/compute/mod.rs index 7e1bd2b445a..46c0ec9f834 100644 --- a/vortex-array/src/array/extension/compute/mod.rs +++ b/vortex-array/src/array/extension/compute/mod.rs @@ -1,6 +1,4 @@ -mod cast; mod compare; -mod mask; use vortex_error::VortexResult; use vortex_scalar::Scalar; @@ -8,29 +6,23 @@ use vortex_scalar::Scalar; use crate::array::extension::ExtensionArray; use crate::array::ExtensionEncoding; use crate::compute::{ - scalar_at, slice, take, CastFn, CompareFn, ComputeVTable, MaskFn, ScalarAtFn, SliceFn, TakeFn, + scalar_at, slice, take, CastFn, CompareFn, ComputeVTable, ScalarAtFn, SliceFn, TakeFn, }; use crate::variants::ExtensionArrayTrait; use crate::{ArrayData, IntoArrayData}; impl ComputeVTable for ExtensionEncoding { fn cast_fn(&self) -> Option<&dyn CastFn> { - // It's not possible to cast an extension array to another type, but we can make it - // nullable. - // - // TODO(ngates): we should allow some extension arrays to implement a callback to support - // this - Some(self) + // It's not possible to cast an extension array to another type. + // TODO(ngates): we should allow some extension arrays to implement a callback + // to support this + None } fn compare_fn(&self) -> Option<&dyn CompareFn> { Some(self) } - fn mask_fn(&self) -> Option<&dyn MaskFn> { - Some(self) - } - fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } From 4bf1e0c14c99d46be3619362754f7d051965850d Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 16 Jan 2025 16:45:38 -0500 Subject: [PATCH 06/26] revert ALP mask for now --- encodings/alp/src/alp/compute/mask.rs | 72 --------------------------- encodings/alp/src/alp/compute/mod.rs | 10 +--- 2 files changed, 2 insertions(+), 80 deletions(-) delete mode 100644 encodings/alp/src/alp/compute/mask.rs diff --git a/encodings/alp/src/alp/compute/mask.rs b/encodings/alp/src/alp/compute/mask.rs deleted file mode 100644 index 5737378f626..00000000000 --- a/encodings/alp/src/alp/compute/mask.rs +++ /dev/null @@ -1,72 +0,0 @@ -use vortex_array::compute::{mask, try_cast, FilterMask, MaskFn}; -use vortex_array::{ArrayDType as _, ArrayData, IntoArrayData}; -use vortex_error::VortexResult; - -use crate::{ALPArray, ALPEncoding}; - -impl MaskFn for ALPEncoding { - fn mask(&self, array: &ALPArray, filter_mask: FilterMask) -> VortexResult { - ALPArray::try_new( - mask(&array.encoded(), filter_mask)?, - array.exponents(), - array - .patches() - .map(|patches| { - patches.map_values(|values| try_cast(&values, &values.dtype().as_nullable())) - }) - .transpose()?, - ) - .map(IntoArrayData::into_array) - } -} - -#[cfg(test)] -mod tests { - use vortex_array::array::PrimitiveArray; - use vortex_array::compute::test_harness::test_mask; - use vortex_array::validity::Validity; - use vortex_array::IntoArrayData as _; - use vortex_buffer::buffer; - - use crate::alp_encode; - - #[test] - fn test_mask_no_patches_alp_array() { - test_mask( - alp_encode(&PrimitiveArray::new( - buffer![1.0f32, 2.0, 3.0, 4.0, 5.0], - Validity::AllValid, - )) - .unwrap() - .into_array(), - ); - - test_mask( - alp_encode(&PrimitiveArray::new( - buffer![1.0f32, 2.0, 3.0, 4.0, 5.0], - Validity::NonNullable, - )) - .unwrap() - .into_array(), - ); - } - - #[test] - fn test_mask_patched_alp_array() { - let alp_array = alp_encode(&PrimitiveArray::new( - buffer![1.0f32, 2.0, 3.0, 4.0, 1e10], - Validity::AllValid, - )) - .unwrap(); - assert!(alp_array.patches().is_some()); - test_mask(alp_array.into_array()); - - let alp_array = alp_encode(&PrimitiveArray::new( - buffer![1.0f32, 2.0, 3.0, 4.0, 1e10], - Validity::NonNullable, - )) - .unwrap(); - assert!(alp_array.patches().is_some()); - test_mask(alp_array.into_array()); - } -} diff --git a/encodings/alp/src/alp/compute/mod.rs b/encodings/alp/src/alp/compute/mod.rs index d442ce5172c..1a8b147c095 100644 --- a/encodings/alp/src/alp/compute/mod.rs +++ b/encodings/alp/src/alp/compute/mod.rs @@ -1,8 +1,6 @@ -mod mask; - use vortex_array::compute::{ - filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, MaskFn, ScalarAtFn, - SliceFn, TakeFn, + filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, SliceFn, + TakeFn, }; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; @@ -16,10 +14,6 @@ impl ComputeVTable for ALPEncoding { Some(self) } - fn mask_fn(&self) -> Option<&dyn MaskFn> { - Some(self) - } - fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } From 2110ee2d750bc749995954184ea0f09b1be309c2 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 16 Jan 2025 16:55:15 -0500 Subject: [PATCH 07/26] use slice::fill instead of a loop --- encodings/dict/src/compute/mask.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/encodings/dict/src/compute/mask.rs b/encodings/dict/src/compute/mask.rs index 3a7be8ac3b3..a2ed6e213f1 100644 --- a/encodings/dict/src/compute/mask.rs +++ b/encodings/dict/src/compute/mask.rs @@ -26,9 +26,7 @@ fn typed_mask(codes: &mut BufferMut, mask: FilterMask) -> Vor } FilterIter::Slices(slices) => { for slice in slices { - for index in slice.0..slice.1 { - codes[index] = T::zero(); - } + codes[slice.0..slice.1].fill(T::zero()); } } } From 373ae0259bb8ee32dde7ba5bf9f3118bc3e8f29e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 16 Jan 2025 17:45:04 -0500 Subject: [PATCH 08/26] fix: non-nullable dict array needs a leading empty element --- encodings/dict/src/compress.rs | 2 +- encodings/dict/src/compute/mask.rs | 72 +++++++++++++++++++++++++++--- encodings/dict/src/compute/mod.rs | 14 +----- vortex-array/src/compute/filter.rs | 2 +- vortex-buffer/src/buffer.rs | 2 +- vortex-buffer/src/buffer_mut.rs | 50 +++++++++++++++++++++ 6 files changed, 120 insertions(+), 22 deletions(-) diff --git a/encodings/dict/src/compress.rs b/encodings/dict/src/compress.rs index 5b9029ff8a0..4a741c317cf 100644 --- a/encodings/dict/src/compress.rs +++ b/encodings/dict/src/compress.rs @@ -158,7 +158,7 @@ fn dict_encode_varbin_bytes<'a, I: Iterator>>( ) } -fn dict_values_validity(nullable: bool, len: usize) -> Validity { +pub(crate) fn dict_values_validity(nullable: bool, len: usize) -> Validity { if nullable { Validity::Array( SparseArray::try_new( diff --git a/encodings/dict/src/compute/mask.rs b/encodings/dict/src/compute/mask.rs index a2ed6e213f1..f5b4ab4effa 100644 --- a/encodings/dict/src/compute/mask.rs +++ b/encodings/dict/src/compute/mask.rs @@ -1,22 +1,48 @@ -use vortex_array::compute::{FilterIter, FilterMask, MaskFn}; -use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant as _}; +use vortex_array::array::PrimitiveArray; +use vortex_array::compute::{add_scalar, FilterIter, FilterMask, MaskFn}; +use vortex_array::variants::PrimitiveArrayTrait as _; +use vortex_array::{ArrayDType as _, ArrayData, IntoArrayData, IntoArrayVariant as _}; use vortex_buffer::BufferMut; use vortex_dtype::{match_each_integer_ptype, NativePType}; use vortex_error::VortexResult; +use vortex_scalar::Scalar; -use crate::{DictArray, DictEncoding}; +use crate::{dict_values_validity, DictArray, DictEncoding}; impl MaskFn for DictEncoding { fn mask(&self, array: &DictArray, mask: FilterMask) -> VortexResult { - let new_codes = match_each_integer_ptype!(array.codes_ptype(), |$T| { - let mut codes = array.codes().into_primitive()?.into_buffer_mut(); + let (codes, new_values) = if array.dtype().is_nullable() { + (array.codes().into_primitive()?, array.values().clone()) + } else { + let values = array.values().into_primitive()?; + let values_with_null = + match_each_integer_ptype!(values.ptype(), |$T| add_null_value::<$T>(values))?; + let codes_with_null = add_scalar( + array.codes(), + Scalar::from(1u8).cast(array.codes().dtype())?, + )? + .into_primitive()?; + (codes_with_null, values_with_null) + }; + + let new_codes = match_each_integer_ptype!(codes.ptype(), |$T| { + let mut codes = codes.into_buffer_mut(); typed_mask::<$T>(&mut codes, mask)?; codes.into_array() }); - DictArray::try_new(new_codes, array.values()).map(IntoArrayData::into_array) + DictArray::try_new(new_codes, new_values).map(IntoArrayData::into_array) } } +fn add_null_value(values: PrimitiveArray) -> VortexResult { + let buf: BufferMut = values.into_buffer_mut::(); + let mut new_buf: BufferMut = BufferMut::::with_capacity(buf.len() + 1); + new_buf.push(T::zero()); + new_buf.extend(buf); + let len = new_buf.len(); + Ok(PrimitiveArray::new(new_buf, dict_values_validity(true, len)).into_array()) +} + fn typed_mask(codes: &mut BufferMut, mask: FilterMask) -> VortexResult<()> { match mask.iter() { FilterIter::Indices(indices) => { @@ -32,3 +58,37 @@ fn typed_mask(codes: &mut BufferMut, mask: FilterMask) -> Vor } Ok(()) } + +#[cfg(test)] +mod tests { + use vortex_array::array::PrimitiveArray; + use vortex_array::compute::test_harness::test_mask; + use vortex_array::validity::Validity; + use vortex_array::IntoArrayData as _; + use vortex_buffer::buffer; + + use crate::{dict_encode_primitive, DictArray}; + + #[test] + fn test_mask_nullable_dict_array() { + let reference = + PrimitiveArray::from_option_iter([None, Some(42), Some(-9), Some(42), Some(5)]); + let (codes, values) = dict_encode_primitive(&reference); + test_mask( + DictArray::try_new(codes.into_array(), values.into_array()) + .unwrap() + .into_array(), + ) + } + + #[test] + fn test_mask_non_nullable_dict_array() { + let reference = PrimitiveArray::new(buffer![5, 42, -9, 42, 5], Validity::NonNullable); + let (codes, values) = dict_encode_primitive(&reference); + test_mask( + DictArray::try_new(codes.into_array(), values.into_array()) + .unwrap() + .into_array(), + ) + } +} diff --git a/encodings/dict/src/compute/mod.rs b/encodings/dict/src/compute/mod.rs index 5b03dfc70dc..d520fcafd8d 100644 --- a/encodings/dict/src/compute/mod.rs +++ b/encodings/dict/src/compute/mod.rs @@ -83,7 +83,7 @@ impl SliceFn for DictEncoding { mod test { use vortex_array::accessor::ArrayAccessor; use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinViewArray}; - use vortex_array::compute::test_harness::{test_binary_numeric, test_mask}; + use vortex_array::compute::test_harness::test_binary_numeric; use vortex_array::compute::{compare, scalar_at, slice, Operator}; use vortex_array::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData}; use vortex_dtype::{DType, Nullability}; @@ -165,16 +165,4 @@ mod test { let array = sliced_dict_array(); test_binary_numeric::(array) } - - #[test] - fn test_mask_dict_array() { - let reference = - PrimitiveArray::from_option_iter([None, Some(42), Some(-9), Some(42), Some(5)]); - let (codes, values) = dict_encode_primitive(&reference); - test_mask( - DictArray::try_new(codes.into_array(), values.into_array()) - .unwrap() - .into_array(), - ) - } } diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index 825bdbe582d..972a985f2b4 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -58,7 +58,7 @@ where /// ) /// .unwrap(); /// -/// let filtered = filter(&array, mask).unwrap(); +/// let filtered = filter(&array, &mask).unwrap(); /// assert_eq!(filtered.len(), 2); /// assert_eq!(scalar_at(&filtered, 0).unwrap(), Scalar::from(Some(0_i32))); /// assert_eq!(scalar_at(&filtered, 1).unwrap(), Scalar::from(Some(2_i32))); diff --git a/vortex-buffer/src/buffer.rs b/vortex-buffer/src/buffer.rs index 3f8b129bc96..1e7b13d7bc5 100644 --- a/vortex-buffer/src/buffer.rs +++ b/vortex-buffer/src/buffer.rs @@ -365,7 +365,7 @@ impl Buf for ByteBuffer { } } -/// Owned iterator over a `Buffer`. +/// Owned iterator over a [`Buffer`]. pub struct BufferIterator { buffer: Buffer, index: usize, diff --git a/vortex-buffer/src/buffer_mut.rs b/vortex-buffer/src/buffer_mut.rs index 03719f6b33b..a5d08b7ebaf 100644 --- a/vortex-buffer/src/buffer_mut.rs +++ b/vortex-buffer/src/buffer_mut.rs @@ -450,6 +450,56 @@ impl FromIterator for BufferMut { } } +/// Owned iterator over a [`BufferMut`]. +/// +/// Examples +/// -------- +/// +/// ``` +/// use vortex_buffer::buffer_mut; +/// +/// let mut a = buffer_mut![1u16, 2, 3]; +/// let b = buffer_mut![4u16, 5, 6]; +/// a.extend(b); +/// assert_eq!(a.len(), 6); +/// assert_eq!(a[3..6], [4, 5, 6]); +/// ``` +pub struct BufferMutIterator { + buffer: BufferMut, + index: usize, +} + +impl Iterator for BufferMutIterator { + type Item = T; + + fn next(&mut self) -> Option { + if self.index == self.buffer.len() { + None + } else { + let value = self.buffer[self.index]; + self.index += 1; + Some(value) + } + } + + fn size_hint(&self) -> (usize, Option) { + let remaining = self.buffer.len() - self.index; + (remaining, Some(remaining)) + } +} + +impl IntoIterator for BufferMut { + type Item = T; + type IntoIter = BufferMutIterator; + + fn into_iter(self) -> Self::IntoIter { + BufferMutIterator { + buffer: self, + index: 0, + } + } +} + impl Buf for ByteBufferMut { fn remaining(&self) -> usize { self.len() From e6bf8f04f78d6a2fb9dde3c0cf3b3314ce9b138b Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 16 Jan 2025 17:57:38 -0500 Subject: [PATCH 09/26] better error message when casting BoolArray to non-bool dtype. --- vortex-array/src/array/bool/compute/cast.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vortex-array/src/array/bool/compute/cast.rs b/vortex-array/src/array/bool/compute/cast.rs index 86e600d2100..f729c06394d 100644 --- a/vortex-array/src/array/bool/compute/cast.rs +++ b/vortex-array/src/array/bool/compute/cast.rs @@ -3,12 +3,12 @@ use vortex_error::{vortex_bail, VortexResult}; use crate::array::{BoolArray, BoolEncoding}; use crate::compute::CastFn; -use crate::{ArrayData, IntoArrayData}; +use crate::{ArrayDType as _, ArrayData, IntoArrayData}; impl CastFn for BoolEncoding { fn cast(&self, array: &BoolArray, dtype: &DType) -> VortexResult { let DType::Bool(new_nullability) = dtype else { - vortex_bail!(MismatchedTypes: "bool type", dtype); + vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype); }; BoolArray::try_new( From bf2bbd2dfb683f91a6c2a14ff204d3165c67ed32 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 17 Jan 2025 09:44:56 -0500 Subject: [PATCH 10/26] test that changing name order is not allowed --- vortex-array/src/array/struct_/compute.rs | 56 +++++++++++++++++++++-- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/vortex-array/src/array/struct_/compute.rs b/vortex-array/src/array/struct_/compute.rs index 4a3d0133643..8ed30cfda75 100644 --- a/vortex-array/src/array/struct_/compute.rs +++ b/vortex-array/src/array/struct_/compute.rs @@ -1,6 +1,6 @@ use itertools::Itertools; use vortex_dtype::DType; -use vortex_error::{vortex_bail, VortexResult}; +use vortex_error::{vortex_bail, VortexExpect as _, VortexResult}; use vortex_scalar::Scalar; use crate::array::struct_::StructArray; @@ -40,17 +40,26 @@ impl ComputeVTable for StructEncoding { impl CastFn for StructEncoding { fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult { - let Some(sdtype) = dtype.as_struct() else { + let Some(target_sdtype) = dtype.as_struct() else { vortex_bail!("cannot cast {} to {}", array.dtype(), dtype); }; + let source_sdtype = array + .dtype() + .as_struct() + .vortex_expect("struct array must have struct dtype"); + + if target_sdtype.names() != source_sdtype.names() { + vortex_bail!("cannot cast {} to {}", array.dtype(), dtype); + } + let validity = array.validity().with_nullability(dtype.nullability())?; StructArray::try_new( - array.names().clone(), + target_sdtype.names().clone(), array .children() - .zip_eq(sdtype.dtypes()) + .zip_eq(target_sdtype.dtypes()) .map(|(field, dtype)| try_cast(&field, &dtype)) .try_collect()?, array.len(), @@ -139,7 +148,7 @@ impl MaskFn for StructEncoding { mod tests { use arrow_buffer::BooleanBuffer; use vortex_buffer::buffer; - use vortex_dtype::{DType, Nullability, PType, StructDType}; + use vortex_dtype::{DType, FieldNames, Nullability, PType, StructDType}; use crate::array::{BoolArray, PrimitiveArray, StructArray, VarBinArray}; use crate::compute::test_harness::test_mask; @@ -228,6 +237,43 @@ mod tests { assert_eq!(casted.dtype(), &nullable_dtype); } + #[test] + fn test_cast_cannot_change_name_order() { + let array = StructArray::try_new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + buffer![1u8].into_array(), + buffer![1u8].into_array(), + buffer![1u8].into_array(), + ], + 1, + Validity::NonNullable, + ) + .unwrap() + .into_array(); + + let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable); + + let result = try_cast( + array, + &DType::Struct( + StructDType::new( + FieldNames::from(["ys".into(), "xs".into(), "zs".into()]), + vec![tu8.clone(), tu8.clone(), tu8.clone()], + ), + Nullability::NonNullable, + ), + ); + assert!( + result.as_ref().is_err_and(|err| { + err.to_string() + .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}") + }), + "{:?}", + result + ); + } + #[test] fn test_cast_complex_struct() { let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]) From 5b988931f57454a59cf12d38dea212c953891975 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 17 Jan 2025 13:43:46 -0500 Subject: [PATCH 11/26] remove DictArray mask for now --- encodings/dict/src/array.rs | 5 -- encodings/dict/src/compress.rs | 2 +- encodings/dict/src/compute/mask.rs | 94 ------------------------------ encodings/dict/src/compute/mod.rs | 7 +-- vortex-buffer/src/buffer_mut.rs | 50 ---------------- 5 files changed, 2 insertions(+), 156 deletions(-) delete mode 100644 encodings/dict/src/compute/mask.rs diff --git a/encodings/dict/src/array.rs b/encodings/dict/src/array.rs index fae21317c3b..2459307d120 100644 --- a/encodings/dict/src/array.rs +++ b/encodings/dict/src/array.rs @@ -49,11 +49,6 @@ impl DictArray { ) } - #[inline] - pub fn codes_ptype(&self) -> PType { - self.metadata().codes_ptype - } - #[inline] pub fn codes(&self) -> ArrayData { self.as_ref() diff --git a/encodings/dict/src/compress.rs b/encodings/dict/src/compress.rs index 4a741c317cf..5b9029ff8a0 100644 --- a/encodings/dict/src/compress.rs +++ b/encodings/dict/src/compress.rs @@ -158,7 +158,7 @@ fn dict_encode_varbin_bytes<'a, I: Iterator>>( ) } -pub(crate) fn dict_values_validity(nullable: bool, len: usize) -> Validity { +fn dict_values_validity(nullable: bool, len: usize) -> Validity { if nullable { Validity::Array( SparseArray::try_new( diff --git a/encodings/dict/src/compute/mask.rs b/encodings/dict/src/compute/mask.rs deleted file mode 100644 index f5b4ab4effa..00000000000 --- a/encodings/dict/src/compute/mask.rs +++ /dev/null @@ -1,94 +0,0 @@ -use vortex_array::array::PrimitiveArray; -use vortex_array::compute::{add_scalar, FilterIter, FilterMask, MaskFn}; -use vortex_array::variants::PrimitiveArrayTrait as _; -use vortex_array::{ArrayDType as _, ArrayData, IntoArrayData, IntoArrayVariant as _}; -use vortex_buffer::BufferMut; -use vortex_dtype::{match_each_integer_ptype, NativePType}; -use vortex_error::VortexResult; -use vortex_scalar::Scalar; - -use crate::{dict_values_validity, DictArray, DictEncoding}; - -impl MaskFn for DictEncoding { - fn mask(&self, array: &DictArray, mask: FilterMask) -> VortexResult { - let (codes, new_values) = if array.dtype().is_nullable() { - (array.codes().into_primitive()?, array.values().clone()) - } else { - let values = array.values().into_primitive()?; - let values_with_null = - match_each_integer_ptype!(values.ptype(), |$T| add_null_value::<$T>(values))?; - let codes_with_null = add_scalar( - array.codes(), - Scalar::from(1u8).cast(array.codes().dtype())?, - )? - .into_primitive()?; - (codes_with_null, values_with_null) - }; - - let new_codes = match_each_integer_ptype!(codes.ptype(), |$T| { - let mut codes = codes.into_buffer_mut(); - typed_mask::<$T>(&mut codes, mask)?; - codes.into_array() - }); - DictArray::try_new(new_codes, new_values).map(IntoArrayData::into_array) - } -} - -fn add_null_value(values: PrimitiveArray) -> VortexResult { - let buf: BufferMut = values.into_buffer_mut::(); - let mut new_buf: BufferMut = BufferMut::::with_capacity(buf.len() + 1); - new_buf.push(T::zero()); - new_buf.extend(buf); - let len = new_buf.len(); - Ok(PrimitiveArray::new(new_buf, dict_values_validity(true, len)).into_array()) -} - -fn typed_mask(codes: &mut BufferMut, mask: FilterMask) -> VortexResult<()> { - match mask.iter() { - FilterIter::Indices(indices) => { - for index in indices { - codes[*index] = T::zero(); - } - } - FilterIter::Slices(slices) => { - for slice in slices { - codes[slice.0..slice.1].fill(T::zero()); - } - } - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use vortex_array::array::PrimitiveArray; - use vortex_array::compute::test_harness::test_mask; - use vortex_array::validity::Validity; - use vortex_array::IntoArrayData as _; - use vortex_buffer::buffer; - - use crate::{dict_encode_primitive, DictArray}; - - #[test] - fn test_mask_nullable_dict_array() { - let reference = - PrimitiveArray::from_option_iter([None, Some(42), Some(-9), Some(42), Some(5)]); - let (codes, values) = dict_encode_primitive(&reference); - test_mask( - DictArray::try_new(codes.into_array(), values.into_array()) - .unwrap() - .into_array(), - ) - } - - #[test] - fn test_mask_non_nullable_dict_array() { - let reference = PrimitiveArray::new(buffer![5, 42, -9, 42, 5], Validity::NonNullable); - let (codes, values) = dict_encode_primitive(&reference); - test_mask( - DictArray::try_new(codes.into_array(), values.into_array()) - .unwrap() - .into_array(), - ) - } -} diff --git a/encodings/dict/src/compute/mod.rs b/encodings/dict/src/compute/mod.rs index d520fcafd8d..be979c7bf41 100644 --- a/encodings/dict/src/compute/mod.rs +++ b/encodings/dict/src/compute/mod.rs @@ -1,11 +1,10 @@ mod binary_numeric; mod compare; mod like; -mod mask; use vortex_array::compute::{ filter, scalar_at, slice, take, BinaryNumericFn, CompareFn, ComputeVTable, FilterFn, - FilterMask, LikeFn, MaskFn, ScalarAtFn, SliceFn, TakeFn, + FilterMask, LikeFn, ScalarAtFn, SliceFn, TakeFn, }; use vortex_array::{ArrayData, IntoArrayData}; use vortex_error::VortexResult; @@ -30,10 +29,6 @@ impl ComputeVTable for DictEncoding { Some(self) } - fn mask_fn(&self) -> Option<&dyn MaskFn> { - Some(self) - } - fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-buffer/src/buffer_mut.rs b/vortex-buffer/src/buffer_mut.rs index a5d08b7ebaf..03719f6b33b 100644 --- a/vortex-buffer/src/buffer_mut.rs +++ b/vortex-buffer/src/buffer_mut.rs @@ -450,56 +450,6 @@ impl FromIterator for BufferMut { } } -/// Owned iterator over a [`BufferMut`]. -/// -/// Examples -/// -------- -/// -/// ``` -/// use vortex_buffer::buffer_mut; -/// -/// let mut a = buffer_mut![1u16, 2, 3]; -/// let b = buffer_mut![4u16, 5, 6]; -/// a.extend(b); -/// assert_eq!(a.len(), 6); -/// assert_eq!(a[3..6], [4, 5, 6]); -/// ``` -pub struct BufferMutIterator { - buffer: BufferMut, - index: usize, -} - -impl Iterator for BufferMutIterator { - type Item = T; - - fn next(&mut self) -> Option { - if self.index == self.buffer.len() { - None - } else { - let value = self.buffer[self.index]; - self.index += 1; - Some(value) - } - } - - fn size_hint(&self) -> (usize, Option) { - let remaining = self.buffer.len() - self.index; - (remaining, Some(remaining)) - } -} - -impl IntoIterator for BufferMut { - type Item = T; - type IntoIter = BufferMutIterator; - - fn into_iter(self) -> Self::IntoIter { - BufferMutIterator { - buffer: self, - index: 0, - } - } -} - impl Buf for ByteBufferMut { fn remaining(&self) -> usize { self.len() From 6d5af142c9baee4ed1489f24151ad02b463c363e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 21 Jan 2025 10:52:33 +0000 Subject: [PATCH 12/26] fix varbin and varbinview casts --- vortex-array/src/array/varbin/compute/cast.rs | 24 +++++++++---------- .../src/array/varbinview/compute/mod.rs | 24 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/vortex-array/src/array/varbin/compute/cast.rs b/vortex-array/src/array/varbin/compute/cast.rs index c1c2c04d157..7c9734a0a5b 100644 --- a/vortex-array/src/array/varbin/compute/cast.rs +++ b/vortex-array/src/array/varbin/compute/cast.rs @@ -8,18 +8,18 @@ use crate::{ArrayDType, ArrayData, IntoArrayData}; impl CastFn for VarBinEncoding { fn cast(&self, array: &VarBinArray, dtype: &DType) -> VortexResult { - match dtype { - DType::Utf8(nullability) => { - let validity = array.validity().with_nullability(*nullability)?; - VarBinArray::try_new( - array.offsets(), - array.bytes(), - array.dtype().with_nullability(*nullability), - validity, - ) - .map(IntoArrayData::into_array) - } - _ => vortex_bail!("cannot cast {} to {}", array.dtype(), dtype), + if !array.dtype().eq_ignore_nullability(dtype) { + vortex_bail!("cannot cast {} to {}", array.dtype(), dtype); } + + let new_nullability = dtype.nullability(); + let validity = array.validity().with_nullability(new_nullability)?; + VarBinArray::try_new( + array.offsets(), + array.bytes(), + array.dtype().with_nullability(new_nullability), + validity, + ) + .map(IntoArrayData::into_array) } } diff --git a/vortex-array/src/array/varbinview/compute/mod.rs b/vortex-array/src/array/varbinview/compute/mod.rs index c8aea3a9b6e..4c0a6f3a820 100644 --- a/vortex-array/src/array/varbinview/compute/mod.rs +++ b/vortex-array/src/array/varbinview/compute/mod.rs @@ -127,19 +127,19 @@ fn take_views_unchecked>( impl CastFn for VarBinViewEncoding { fn cast(&self, array: &VarBinViewArray, dtype: &DType) -> VortexResult { - match dtype { - DType::Utf8(nullability) => { - let validity = array.validity().with_nullability(*nullability)?; - VarBinViewArray::try_new( - array.views(), - array.buffers().collect(), - array.dtype().with_nullability(*nullability), - validity, - ) - .map(IntoArrayData::into_array) - } - _ => vortex_bail!("cannot cast {} to {}", array.dtype(), dtype), + if !array.dtype().eq_ignore_nullability(dtype) { + vortex_bail!("cannot cast {} to {}", array.dtype(), dtype); } + + let new_nullability = dtype.nullability(); + let validity = array.validity().with_nullability(new_nullability)?; + VarBinViewArray::try_new( + array.views(), + array.buffers().collect(), + array.dtype().with_nullability(new_nullability), + validity, + ) + .map(IntoArrayData::into_array) } } From af00e5cdeb780638a9a93b696b6944bde366628a Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 21 Jan 2025 12:18:32 +0000 Subject: [PATCH 13/26] clippy --- vortex-array/src/array/chunked/compute/mask.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vortex-array/src/array/chunked/compute/mask.rs b/vortex-array/src/array/chunked/compute/mask.rs index 473e2ff5dc6..f4286c22dac 100644 --- a/vortex-array/src/array/chunked/compute/mask.rs +++ b/vortex-array/src/array/chunked/compute/mask.rs @@ -91,7 +91,7 @@ fn mask_slices( array .chunks() - .zip_eq(chunked_filters.into_iter()) + .zip_eq(chunked_filters) .map(|(chunk, chunk_filter)| -> VortexResult { Ok(match chunk_filter { ChunkFilter::All => { From 5c4357453d2cb2144fc3a18616adefc91a436642 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 21 Jan 2025 12:20:28 +0000 Subject: [PATCH 14/26] clippy --- encodings/dict/benches/dict_mask.rs | 4 +--- vortex-array/src/array/struct_/compute.rs | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/encodings/dict/benches/dict_mask.rs b/encodings/dict/benches/dict_mask.rs index 4e0241d4b1f..afc39bd15f0 100644 --- a/encodings/dict/benches/dict_mask.rs +++ b/encodings/dict/benches/dict_mask.rs @@ -6,14 +6,12 @@ use rand::{Rng, SeedableRng as _}; use vortex_array::array::PrimitiveArray; use vortex_array::compute::{mask, FilterMask}; use vortex_array::IntoArrayData as _; -use vortex_buffer::{buffer, Buffer}; use vortex_dict::DictArray; fn filter_mask(len: usize, fraction_masked: f64, rng: &mut StdRng) -> FilterMask { let indices = (0..len) .filter(|_| rng.gen_bool(fraction_masked)) - .map(|x| x as u64) - .collect::>(); + .collect::>(); FilterMask::from_indices(len, indices) } diff --git a/vortex-array/src/array/struct_/compute.rs b/vortex-array/src/array/struct_/compute.rs index 8ed30cfda75..c3c74c6768e 100644 --- a/vortex-array/src/array/struct_/compute.rs +++ b/vortex-array/src/array/struct_/compute.rs @@ -259,7 +259,7 @@ mod tests { &DType::Struct( StructDType::new( FieldNames::from(["ys".into(), "xs".into(), "zs".into()]), - vec![tu8.clone(), tu8.clone(), tu8.clone()], + vec![tu8.clone(), tu8.clone(), tu8], ), Nullability::NonNullable, ), From fe393c048d992dd01c66ca91e37abdf1907d616f Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 21 Jan 2025 12:21:11 +0000 Subject: [PATCH 15/26] clippy --- encodings/dict/benches/dict_mask.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/encodings/dict/benches/dict_mask.rs b/encodings/dict/benches/dict_mask.rs index afc39bd15f0..b4c1164089b 100644 --- a/encodings/dict/benches/dict_mask.rs +++ b/encodings/dict/benches/dict_mask.rs @@ -6,6 +6,7 @@ use rand::{Rng, SeedableRng as _}; use vortex_array::array::PrimitiveArray; use vortex_array::compute::{mask, FilterMask}; use vortex_array::IntoArrayData as _; +use vortex_buffer::buffer; use vortex_dict::DictArray; fn filter_mask(len: usize, fraction_masked: f64, rng: &mut StdRng) -> FilterMask { From 64b714d15b4b25ddc924bda95838d2fcd93d9bb5 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 3 Feb 2025 15:15:43 -0500 Subject: [PATCH 16/26] remove cruft --- vortex-array/src/array/bool/compute/mod.rs | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/vortex-array/src/array/bool/compute/mod.rs b/vortex-array/src/array/bool/compute/mod.rs index 5515a6d83f6..7df64560fd4 100644 --- a/vortex-array/src/array/bool/compute/mod.rs +++ b/vortex-array/src/array/bool/compute/mod.rs @@ -66,24 +66,4 @@ impl ComputeVTable for BoolEncoding { fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn> { Some(self) } - - fn binary_numeric_fn(&self) -> Option<&dyn crate::compute::BinaryNumericFn> { - None - } - - fn compare_fn(&self) -> Option<&dyn crate::compute::CompareFn> { - None - } - - fn like_fn(&self) -> Option<&dyn crate::compute::LikeFn> { - None - } - - fn search_sorted_fn(&self) -> Option<&dyn crate::compute::SearchSortedFn> { - None - } - - fn search_sorted_usize_fn(&self) -> Option<&dyn crate::compute::SearchSortedUsizeFn> { - None - } } From cb0fabd9187f43911d1f0941c6127a1aba1aa16f Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 3 Feb 2025 15:19:59 -0500 Subject: [PATCH 17/26] revert unused changes --- vortex-array/src/array/varbin/compute/mod.rs | 32 -------------------- 1 file changed, 32 deletions(-) diff --git a/vortex-array/src/array/varbin/compute/mod.rs b/vortex-array/src/array/varbin/compute/mod.rs index 38ebc46c4b5..01e725f992c 100644 --- a/vortex-array/src/array/varbin/compute/mod.rs +++ b/vortex-array/src/array/varbin/compute/mod.rs @@ -47,38 +47,6 @@ impl ComputeVTable for VarBinEncoding { fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn> { Some(self) } - - fn binary_boolean_fn(&self) -> Option<&dyn crate::compute::BinaryBooleanFn> { - None - } - - fn binary_numeric_fn(&self) -> Option<&dyn crate::compute::BinaryNumericFn> { - None - } - - fn fill_forward_fn(&self) -> Option<&dyn crate::compute::FillForwardFn> { - None - } - - fn fill_null_fn(&self) -> Option<&dyn crate::compute::FillNullFn> { - None - } - - fn invert_fn(&self) -> Option<&dyn crate::compute::InvertFn> { - None - } - - fn like_fn(&self) -> Option<&dyn crate::compute::LikeFn> { - None - } - - fn search_sorted_fn(&self) -> Option<&dyn crate::compute::SearchSortedFn> { - None - } - - fn search_sorted_usize_fn(&self) -> Option<&dyn crate::compute::SearchSortedUsizeFn> { - None - } } impl ScalarAtFn for VarBinEncoding { From 77740e850a73ce46b3115949edd218363401e544 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 3 Feb 2025 15:21:56 -0500 Subject: [PATCH 18/26] cleanups --- vortex-array/src/array/varbin/compute/cast.rs | 12 ++++-------- vortex-array/src/array/varbinview/compute/mod.rs | 7 ++++--- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/vortex-array/src/array/varbin/compute/cast.rs b/vortex-array/src/array/varbin/compute/cast.rs index 4e2922f8b1e..31df18910ce 100644 --- a/vortex-array/src/array/varbin/compute/cast.rs +++ b/vortex-array/src/array/varbin/compute/cast.rs @@ -13,13 +13,9 @@ impl CastFn for VarBinEncoding { } let new_nullability = dtype.nullability(); - let validity = array.validity().with_nullability(new_nullability)?; - VarBinArray::try_new( - array.offsets(), - array.bytes(), - array.dtype().with_nullability(new_nullability), - validity, - ) - .map(IntoArray::into_array) + let new_validity = array.validity().with_nullability(new_nullability)?; + let new_dtype = array.dtype().with_nullability(new_nullability); + VarBinArray::try_new(array.offsets(), array.bytes(), new_dtype, new_validity) + .map(IntoArray::into_array) } } diff --git a/vortex-array/src/array/varbinview/compute/mod.rs b/vortex-array/src/array/varbinview/compute/mod.rs index d5f866efebd..14b7fd15de2 100644 --- a/vortex-array/src/array/varbinview/compute/mod.rs +++ b/vortex-array/src/array/varbinview/compute/mod.rs @@ -138,12 +138,13 @@ impl CastFn for VarBinViewEncoding { } let new_nullability = dtype.nullability(); - let validity = array.validity().with_nullability(new_nullability)?; + let new_validity = array.validity().with_nullability(new_nullability)?; + let new_dtype = array.dtype().with_nullability(new_nullability); VarBinViewArray::try_new( array.views(), array.buffers().collect(), - array.dtype().with_nullability(new_nullability), - validity, + new_dtype, + new_validity, ) .map(IntoArray::into_array) } From 4d909ebaec12f47aa0f486ac2618b666c881ae9c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 3 Feb 2025 15:30:57 -0500 Subject: [PATCH 19/26] validity array not supporting min is an error --- vortex-array/src/validity.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 15021e3da27..05d010fec74 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -374,7 +374,7 @@ impl Validity { is_valid .statistics() .compute_min::() - .unwrap_or(false) + .vortex_expect("validity array must support min") .then(|| { // min true => all true Self::NonNullable From ac41d69c9df8499cf1e541407c99487089d20196 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 3 Feb 2025 15:31:58 -0500 Subject: [PATCH 20/26] lift validity failure before boolean buffer creation --- vortex-array/src/array/bool/compute/cast.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vortex-array/src/array/bool/compute/cast.rs b/vortex-array/src/array/bool/compute/cast.rs index 46a9b60bf24..80bcc082073 100644 --- a/vortex-array/src/array/bool/compute/cast.rs +++ b/vortex-array/src/array/bool/compute/cast.rs @@ -11,10 +11,7 @@ impl CastFn for BoolEncoding { vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype); }; - BoolArray::try_new( - array.boolean_buffer(), - array.validity().with_nullability(*new_nullability)?, - ) - .map(IntoArray::into_array) + let new_validity = array.validity().with_nullability(*new_nullability)?; + BoolArray::try_new(array.boolean_buffer(), new_validity).map(IntoArray::into_array) } } From efb9881dd9766b14a11760add600aeb748a0fade Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 3 Feb 2025 15:50:08 -0500 Subject: [PATCH 21/26] final fixups --- vortex-array/src/array/bool/compute/cast.rs | 2 +- vortex-array/src/array/chunked/compute/mask.rs | 9 +++++---- vortex-array/src/array/struct_/compute/mod.rs | 2 +- vortex-array/src/array/varbin/compute/cast.rs | 2 +- vortex-array/src/array/varbinview/compute/mod.rs | 2 +- vortex-array/src/validity.rs | 6 +++--- 6 files changed, 12 insertions(+), 11 deletions(-) diff --git a/vortex-array/src/array/bool/compute/cast.rs b/vortex-array/src/array/bool/compute/cast.rs index 80bcc082073..0e82b2a21dc 100644 --- a/vortex-array/src/array/bool/compute/cast.rs +++ b/vortex-array/src/array/bool/compute/cast.rs @@ -11,7 +11,7 @@ impl CastFn for BoolEncoding { vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype); }; - let new_validity = array.validity().with_nullability(*new_nullability)?; + let new_validity = array.validity().cast_nullability(*new_nullability)?; BoolArray::try_new(array.boolean_buffer(), new_validity).map(IntoArray::into_array) } } diff --git a/vortex-array/src/array/chunked/compute/mask.rs b/vortex-array/src/array/chunked/compute/mask.rs index 02fcf65e009..37e67046209 100644 --- a/vortex-array/src/array/chunked/compute/mask.rs +++ b/vortex-array/src/array/chunked/compute/mask.rs @@ -51,12 +51,13 @@ fn mask_indices( .chunk(current_chunk_id) .vortex_expect("find_chunk_idx must return valid chunk ID"); let masked_chunk = mask(&chunk, Mask::from_indices(chunk.len(), chunk_indices))?; + // Advance the chunk forward, reset the chunk indices buffer. chunk_indices = Vec::new(); new_chunks.push(masked_chunk); current_chunk_id += 1; - // Advance the chunk forward, reset the chunk indices buffer. while current_chunk_id < chunk_id { + // Chunks that are not affected by the mask, must still be casted to the correct dtype. let chunk = array .chunk(current_chunk_id) .vortex_expect("find_chunk_idx must return valid chunk ID"); @@ -101,15 +102,15 @@ fn mask_slices( .map(|(chunk, chunk_filter)| -> VortexResult { Ok(match chunk_filter { ChunkFilter::All => { - // All => entire chunk is masked out + // entire chunk is masked out ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()).into_array() } ChunkFilter::None => { - // None => preserve the entire chunk unmasked + // entire chunk is not affected by mask chunk } - // Slices => turn the slices into a boolean buffer. ChunkFilter::Slices(slices) => { + // Slices of indices that must be set to null mask(&chunk, Mask::from_slices(chunk.len(), slices))? } }) diff --git a/vortex-array/src/array/struct_/compute/mod.rs b/vortex-array/src/array/struct_/compute/mod.rs index 0e7e3794123..e367d33f385 100644 --- a/vortex-array/src/array/struct_/compute/mod.rs +++ b/vortex-array/src/array/struct_/compute/mod.rs @@ -61,7 +61,7 @@ impl CastFn for StructEncoding { vortex_bail!("cannot cast {} to {}", array.dtype(), dtype); } - let validity = array.validity().with_nullability(dtype.nullability())?; + let validity = array.validity().cast_nullability(dtype.nullability())?; StructArray::try_new( target_sdtype.names().clone(), diff --git a/vortex-array/src/array/varbin/compute/cast.rs b/vortex-array/src/array/varbin/compute/cast.rs index 31df18910ce..d6c4ebbbaa6 100644 --- a/vortex-array/src/array/varbin/compute/cast.rs +++ b/vortex-array/src/array/varbin/compute/cast.rs @@ -13,7 +13,7 @@ impl CastFn for VarBinEncoding { } let new_nullability = dtype.nullability(); - let new_validity = array.validity().with_nullability(new_nullability)?; + let new_validity = array.validity().cast_nullability(new_nullability)?; let new_dtype = array.dtype().with_nullability(new_nullability); VarBinArray::try_new(array.offsets(), array.bytes(), new_dtype, new_validity) .map(IntoArray::into_array) diff --git a/vortex-array/src/array/varbinview/compute/mod.rs b/vortex-array/src/array/varbinview/compute/mod.rs index 14b7fd15de2..f98c19c3c4b 100644 --- a/vortex-array/src/array/varbinview/compute/mod.rs +++ b/vortex-array/src/array/varbinview/compute/mod.rs @@ -138,7 +138,7 @@ impl CastFn for VarBinViewEncoding { } let new_nullability = dtype.nullability(); - let new_validity = array.validity().with_nullability(new_nullability)?; + let new_validity = array.validity().cast_nullability(new_nullability)?; let new_dtype = array.dtype().with_nullability(new_nullability); VarBinViewArray::try_new( array.views(), diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 05d010fec74..e74862d0018 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -267,9 +267,9 @@ impl Validity { } Validity::AllInvalid => Validity::AllInvalid, Validity::Array(is_valid) => { - let bools = BoolArray::try_from(is_valid.clone())?.boolean_buffer(); + let is_valid = BoolArray::try_from(is_valid.clone())?.boolean_buffer(); let keep_valid = make_invalid.not(); - Validity::from(bools.bitand(&keep_valid)) + Validity::from(is_valid.bitand(&keep_valid)) } }), } @@ -384,7 +384,7 @@ impl Validity { } /// Convert into a variant compatible with the given nullability, if possible. - pub fn with_nullability(self, nullability: Nullability) -> VortexResult { + pub fn cast_nullability(self, nullability: Nullability) -> VortexResult { match nullability { Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| { vortex_err!("cannot cast array with invalid values to non-nullable type") From 256f0134b56351869dc5fe0ac1f1f28a969dd98b Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 6 Feb 2025 15:33:20 -0500 Subject: [PATCH 22/26] use new idioms for true count --- vortex-array/src/compute/mask.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vortex-array/src/compute/mask.rs b/vortex-array/src/compute/mask.rs index 87359632199..f400a2beb02 100644 --- a/vortex-array/src/compute/mask.rs +++ b/vortex-array/src/compute/mask.rs @@ -64,12 +64,10 @@ pub fn mask(array: &Array, mask: Mask) -> VortexResult { ); } - let true_count = mask.true_count(); - - let masked = if true_count == 0 { + let masked = if matches!(mask, Mask::AllFalse(_)) { // Fast-path for empty mask try_cast(array, &array.dtype().as_nullable())? - } else if true_count == mask.len() { + } else if matches!(mask, Mask::AllTrue(_)) { // Fast-path for full mask. ConstantArray::new( Scalar::null(array.dtype().clone().as_nullable()), From 3ec13c0a101890f0c74052066c391eb2ea699907 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 18 Feb 2025 10:56:02 -0500 Subject: [PATCH 23/26] fixes --- encodings/dict/benches/dict_mask.rs | 65 +++++++++++-------- vortex-array/src/array/bool/compute/cast.rs | 2 +- vortex-array/src/array/struct_/compute/mod.rs | 5 +- vortex-array/src/array/varbin/compute/cast.rs | 1 - .../src/array/varbinview/compute/cast.rs | 1 - .../src/array/varbinview/compute/mod.rs | 4 +- 6 files changed, 44 insertions(+), 34 deletions(-) diff --git a/encodings/dict/benches/dict_mask.rs b/encodings/dict/benches/dict_mask.rs index 7f67984e94f..e02fa57619a 100644 --- a/encodings/dict/benches/dict_mask.rs +++ b/encodings/dict/benches/dict_mask.rs @@ -1,15 +1,18 @@ #![allow(clippy::unwrap_used)] -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use divan::Bencher; use rand::rngs::StdRng; use rand::{Rng, SeedableRng as _}; use vortex_array::array::PrimitiveArray; use vortex_array::compute::mask; use vortex_array::IntoArray as _; -use vortex_buffer::buffer; use vortex_dict::DictArray; use vortex_mask::Mask; +fn main() { + divan::main(); +} + fn filter_mask(len: usize, fraction_masked: f64, rng: &mut StdRng) -> Mask { let indices = (0..len) .filter(|_| rng.gen_bool(fraction_masked)) @@ -17,34 +20,40 @@ fn filter_mask(len: usize, fraction_masked: f64, rng: &mut StdRng) -> Mask { Mask::from_indices(len, indices) } -#[allow(clippy::cast_possible_truncation)] -fn bench_dict_mask(c: &mut Criterion) { - let mut group = c.benchmark_group("bench_dict_mask"); +#[divan::bench(args = [ + (0.9, 0.9), + (0.9, 0.5), + (0.9, 0.1), + (0.9, 0.01), + (0.5, 0.9), + (0.5, 0.5), + (0.5, 0.1), + (0.5, 0.01), + (0.1, 0.9), + (0.1, 0.5), + (0.1, 0.1), + (0.1, 0.01), + (0.01, 0.9), + (0.01, 0.5), + (0.01, 0.1), + (0.01, 0.01), +])] +fn bench_dict_mask(bencher: Bencher, (fraction_valid, fraction_masked): (f64, f64)) { let mut rng = StdRng::seed_from_u64(0); let len = 65_535; - // for fraction_valid in [0.5, 0.1, 0.01, 0.001, 0.0001] { - for fraction_valid in [0.1] { - let codes = - PrimitiveArray::from_iter((0..len).map(|_| (!rng.gen_bool(fraction_valid)) as u64)) - .into_array(); - let values = buffer![1].into_array(); - let array = DictArray::try_new(codes, values).unwrap().into_array(); - // for fraction_masked in [0.1, 0.01, 0.001, 0.0001] { - for fraction_masked in [0.9, 0.5, 0.1, 0.0001] { - let filter_mask = filter_mask(len, fraction_masked, &mut rng); - group.bench_with_input( - BenchmarkId::from_parameter(format!( - "fraction_valid={}, fraction_masked={}", - fraction_valid, fraction_masked - )), - &(&array, filter_mask), - |b, (array, filter_mask)| b.iter(|| mask(array, filter_mask.clone()).unwrap()), - ); + let codes = PrimitiveArray::from_iter((0..len).map(|_| { + if rng.gen_bool(fraction_valid) { + 1u64 + } else { + 0u64 } - } - group.finish() + })) + .into_array(); + let values = PrimitiveArray::from_option_iter([None, Some(42i32)]).into_array(); + let array = DictArray::try_new(codes, values).unwrap().into_array(); + let filter_mask = filter_mask(len, fraction_masked, &mut rng); + bencher + .with_inputs(|| (&array, filter_mask.clone())) + .bench_values(|(array, filter_mask)| mask(array, filter_mask).unwrap()); } - -criterion_group!(benches, bench_dict_mask); -criterion_main!(benches); diff --git a/vortex-array/src/array/bool/compute/cast.rs b/vortex-array/src/array/bool/compute/cast.rs index 2ed5a636fd6..c13c92b2b6a 100644 --- a/vortex-array/src/array/bool/compute/cast.rs +++ b/vortex-array/src/array/bool/compute/cast.rs @@ -13,7 +13,7 @@ impl CastFn for BoolEncoding { let new_nullability = dtype.nullability(); let new_validity = array.validity().cast_nullability(new_nullability)?; - Ok(BoolArray::try_new(array.boolean_buffer(), new_validity).into_array()) + BoolArray::try_new(array.boolean_buffer(), new_validity).map(IntoArray::into_array) } } diff --git a/vortex-array/src/array/struct_/compute/mod.rs b/vortex-array/src/array/struct_/compute/mod.rs index 6f887831197..3196c1f82d0 100644 --- a/vortex-array/src/array/struct_/compute/mod.rs +++ b/vortex-array/src/array/struct_/compute/mod.rs @@ -71,7 +71,8 @@ impl CastFn for StructEncoding { target_sdtype.names().clone(), array .children() - .zip_eq(target_sdtype.dtypes()) + .into_iter() + .zip_eq(target_sdtype.fields()) .map(|(field, dtype)| try_cast(&field, &dtype)) .try_collect()?, array.len(), @@ -148,7 +149,7 @@ impl MaskFn for StructEncoding { StructArray::try_new( array.names().clone(), - array.children().collect(), + array.children(), array.len(), validity, ) diff --git a/vortex-array/src/array/varbin/compute/cast.rs b/vortex-array/src/array/varbin/compute/cast.rs index 7434f7c624a..eeb107926fa 100644 --- a/vortex-array/src/array/varbin/compute/cast.rs +++ b/vortex-array/src/array/varbin/compute/cast.rs @@ -3,7 +3,6 @@ use vortex_error::{vortex_bail, VortexResult}; use crate::array::{VarBinArray, VarBinEncoding}; use crate::compute::CastFn; -use crate::validity::Validity; use crate::{Array, IntoArray}; impl CastFn for VarBinEncoding { diff --git a/vortex-array/src/array/varbinview/compute/cast.rs b/vortex-array/src/array/varbinview/compute/cast.rs index 7e0ccf4fe49..c5e43c55a4c 100644 --- a/vortex-array/src/array/varbinview/compute/cast.rs +++ b/vortex-array/src/array/varbinview/compute/cast.rs @@ -3,7 +3,6 @@ use vortex_error::{vortex_bail, VortexResult}; use crate::array::{VarBinViewArray, VarBinViewEncoding}; use crate::compute::CastFn; -use crate::validity::Validity; use crate::{Array, IntoArray}; impl CastFn for VarBinViewEncoding { diff --git a/vortex-array/src/array/varbinview/compute/mod.rs b/vortex-array/src/array/varbinview/compute/mod.rs index d9d8848c04d..0d1df61bd02 100644 --- a/vortex-array/src/array/varbinview/compute/mod.rs +++ b/vortex-array/src/array/varbinview/compute/mod.rs @@ -4,12 +4,13 @@ mod take; mod to_arrow; use vortex_error::VortexResult; +use vortex_mask::Mask; use vortex_scalar::Scalar; use crate::array::varbin::varbin_scalar; use crate::array::varbinview::VarBinViewArray; use crate::array::VarBinViewEncoding; -use crate::compute::{CastFn, MinMaxFn, ScalarAtFn, SliceFn, TakeFn, ToArrowFn}; +use crate::compute::{CastFn, MaskFn, MinMaxFn, ScalarAtFn, SliceFn, TakeFn, ToArrowFn}; use crate::vtable::ComputeVTable; use crate::{Array, IntoArray}; @@ -132,6 +133,7 @@ mod tests { ); } + #[test] fn take_into_nullable() { let arr = VarBinViewArray::from_iter_nullable_str([ Some("one"), From 554f492ad0f65832794962ebb879684bb4c61024 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 18 Feb 2025 11:13:31 -0500 Subject: [PATCH 24/26] test mask on dict array --- encodings/dict/src/compute/mod.rs | 36 ++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/encodings/dict/src/compute/mod.rs b/encodings/dict/src/compute/mod.rs index 5921624eb32..4fd9a2aa9e8 100644 --- a/encodings/dict/src/compute/mod.rs +++ b/encodings/dict/src/compute/mod.rs @@ -79,7 +79,8 @@ impl SliceFn for DictEncoding { #[cfg(test)] mod test { use vortex_array::accessor::ArrayAccessor; - use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinViewArray}; + use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray}; + use vortex_array::compute::test_harness::test_mask; use vortex_array::compute::{compare, scalar_at, slice, Operator}; use vortex_array::{Array, IntoArray, IntoArrayVariant}; use vortex_dtype::{DType, Nullability}; @@ -198,4 +199,37 @@ mod test { Scalar::bool(true, Nullability::Nullable) ); } + + #[test] + fn test_mask_dict_array() { + let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()) + .unwrap() + .into_array(); + test_mask(array); + + let array = dict_encode( + &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]) + .into_array(), + ) + .unwrap() + .into_array(); + test_mask(array); + + let array = dict_encode( + &VarBinArray::from_iter( + [ + Some("hello"), + None, + Some("hello"), + Some("good"), + Some("good"), + ], + DType::Utf8(Nullability::Nullable), + ) + .into_array(), + ) + .unwrap() + .into_array(); + test_mask(array); + } } From 344548898916c35b93c5cb12e3d08e4ca093b4e4 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 18 Feb 2025 11:23:28 -0500 Subject: [PATCH 25/26] revert change --- vortex-array/src/array/varbinview/compute/cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vortex-array/src/array/varbinview/compute/cast.rs b/vortex-array/src/array/varbinview/compute/cast.rs index c5e43c55a4c..9c63889785b 100644 --- a/vortex-array/src/array/varbinview/compute/cast.rs +++ b/vortex-array/src/array/varbinview/compute/cast.rs @@ -8,7 +8,7 @@ use crate::{Array, IntoArray}; impl CastFn for VarBinViewEncoding { fn cast(&self, array: &VarBinViewArray, dtype: &DType) -> VortexResult { if !array.dtype().eq_ignore_nullability(dtype) { - vortex_bail!("cannot cast {} to {}", array.dtype(), dtype); + vortex_bail!("Cannot cast {} to {}", array.dtype(), dtype); } let new_nullability = dtype.nullability(); From 5e2f41533c86046bdea557ebc3dfe6b224683abe Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 18 Feb 2025 11:39:29 -0500 Subject: [PATCH 26/26] StructArray children includes validity: must use fields() --- vortex-array/src/array/struct_/compute/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vortex-array/src/array/struct_/compute/mod.rs b/vortex-array/src/array/struct_/compute/mod.rs index 3196c1f82d0..59d2d06249a 100644 --- a/vortex-array/src/array/struct_/compute/mod.rs +++ b/vortex-array/src/array/struct_/compute/mod.rs @@ -149,7 +149,7 @@ impl MaskFn for StructEncoding { StructArray::try_new( array.names().clone(), - array.children(), + array.fields().collect(), array.len(), validity, )