diff --git a/encodings/alp/src/alp_rd/compute/mask.rs b/encodings/alp/src/alp_rd/compute/mask.rs new file mode 100644 index 00000000000..49919940c91 --- /dev/null +++ b/encodings/alp/src/alp_rd/compute/mask.rs @@ -0,0 +1,58 @@ +use vortex_array::compute::{mask, MaskFn}; +use vortex_array::{Array, IntoArray}; +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::{ALPRDArray, ALPRDEncoding}; + +impl MaskFn for ALPRDEncoding { + fn mask(&self, array: &ALPRDArray, filter_mask: Mask) -> VortexResult { + Ok(ALPRDArray::try_new( + array.dtype().as_nullable(), + mask(&array.left_parts(), filter_mask)?, + array.left_parts_dict(), + array.right_parts(), + array.right_bit_width(), + array.left_parts_patches(), + )? + .into_array()) + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_array::array::PrimitiveArray; + use vortex_array::compute::test_harness::test_mask; + use vortex_array::IntoArray as _; + + use crate::{ALPRDFloat, RDEncoder}; + + #[rstest] + #[case(0.1f32, 0.2f32, 3e25f32)] + #[case(0.1f64, 0.2f64, 3e100f64)] + fn test_mask_simple(#[case] a: T, #[case] b: T, #[case] outlier: T) { + test_mask( + RDEncoder::new(&[a, b]) + .encode(&PrimitiveArray::from_iter([a, b, outlier, b, outlier])) + .into_array(), + ); + } + + #[rstest] + #[case(0.1f32, 3e25f32)] + #[case(0.5f64, 1e100f64)] + fn test_mask_with_nulls(#[case] a: T, #[case] outlier: T) { + test_mask( + RDEncoder::new(&[a]) + .encode(&PrimitiveArray::from_option_iter([ + Some(a), + None, + Some(outlier), + Some(a), + None, + ])) + .into_array(), + ); + } +} diff --git a/encodings/alp/src/alp_rd/compute/mod.rs b/encodings/alp/src/alp_rd/compute/mod.rs index 53b0bb1a06b..9b9453a6c3f 100644 --- a/encodings/alp/src/alp_rd/compute/mod.rs +++ b/encodings/alp/src/alp_rd/compute/mod.rs @@ -1,10 +1,11 @@ -use vortex_array::compute::{FilterFn, ScalarAtFn, SliceFn, TakeFn}; +use vortex_array::compute::{FilterFn, MaskFn, ScalarAtFn, SliceFn, TakeFn}; use vortex_array::vtable::ComputeVTable; use vortex_array::Array; use crate::ALPRDEncoding; mod filter; +mod mask; mod scalar_at; mod slice; mod take; @@ -14,6 +15,10 @@ impl ComputeVTable for ALPRDEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index fe7048b26dd..e40c31ad6ca 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -1,5 +1,5 @@ use num_traits::AsPrimitive; -use vortex_array::compute::{FillForwardFn, ScalarAtFn, SliceFn, TakeFn}; +use vortex_array::compute::{FillForwardFn, MaskFn, ScalarAtFn, SliceFn, TakeFn}; use vortex_array::validity::Validity; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::vtable::ComputeVTable; @@ -16,6 +16,10 @@ impl ComputeVTable for ByteBoolEncoding { None } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -29,6 +33,13 @@ impl ComputeVTable for ByteBoolEncoding { } } +impl MaskFn for ByteBoolEncoding { + fn mask(&self, array: &ByteBoolArray, mask: Mask) -> VortexResult { + ByteBoolArray::try_new(array.buffer().clone(), array.validity().mask(&mask)?) + .map(IntoArray::into_array) + } +} + impl ScalarAtFn for ByteBoolEncoding { fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult { Ok(Scalar::bool( @@ -139,6 +150,7 @@ impl FillForwardFn for ByteBoolEncoding { #[cfg(test)] mod tests { + use vortex_array::compute::test_harness::test_mask; use vortex_array::compute::{compare, scalar_at, slice, Operator}; use super::*; @@ -211,4 +223,12 @@ mod tests { let s = scalar_at(&arr, 4).unwrap(); assert!(s.is_null()); } + + #[test] + fn test_mask_byte_bool() { + test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).into_array()); + test_mask( + ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).into_array(), + ); + } } diff --git a/encodings/dict/Cargo.toml b/encodings/dict/Cargo.toml index a44299a836b..310cdd39321 100644 --- a/encodings/dict/Cargo.toml +++ b/encodings/dict/Cargo.toml @@ -51,7 +51,11 @@ name = "dict_compare" harness = false required-features = ["test-harness"] +[[bench]] +name = "dict_mask" +harness = false + [[bench]] name = "chunked_dict_array_builder" harness = false -required-features = ["test-harness"] \ No newline at end of file +required-features = ["test-harness"] diff --git a/encodings/dict/benches/dict_mask.rs b/encodings/dict/benches/dict_mask.rs new file mode 100644 index 00000000000..e02fa57619a --- /dev/null +++ b/encodings/dict/benches/dict_mask.rs @@ -0,0 +1,59 @@ +#![allow(clippy::unwrap_used)] + +use divan::Bencher; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng as _}; +use vortex_array::array::PrimitiveArray; +use vortex_array::compute::mask; +use vortex_array::IntoArray as _; +use vortex_dict::DictArray; +use vortex_mask::Mask; + +fn main() { + divan::main(); +} + +fn filter_mask(len: usize, fraction_masked: f64, rng: &mut StdRng) -> Mask { + let indices = (0..len) + .filter(|_| rng.gen_bool(fraction_masked)) + .collect::>(); + Mask::from_indices(len, indices) +} + +#[divan::bench(args = [ + (0.9, 0.9), + (0.9, 0.5), + (0.9, 0.1), + (0.9, 0.01), + (0.5, 0.9), + (0.5, 0.5), + (0.5, 0.1), + (0.5, 0.01), + (0.1, 0.9), + (0.1, 0.5), + (0.1, 0.1), + (0.1, 0.01), + (0.01, 0.9), + (0.01, 0.5), + (0.01, 0.1), + (0.01, 0.01), +])] +fn bench_dict_mask(bencher: Bencher, (fraction_valid, fraction_masked): (f64, f64)) { + let mut rng = StdRng::seed_from_u64(0); + + let len = 65_535; + let codes = PrimitiveArray::from_iter((0..len).map(|_| { + if rng.gen_bool(fraction_valid) { + 1u64 + } else { + 0u64 + } + })) + .into_array(); + let values = PrimitiveArray::from_option_iter([None, Some(42i32)]).into_array(); + let array = DictArray::try_new(codes, values).unwrap().into_array(); + let filter_mask = filter_mask(len, fraction_masked, &mut rng); + bencher + .with_inputs(|| (&array, filter_mask.clone())) + .bench_values(|(array, filter_mask)| mask(array, filter_mask).unwrap()); +} diff --git a/encodings/dict/src/compute/mod.rs b/encodings/dict/src/compute/mod.rs index 5921624eb32..4fd9a2aa9e8 100644 --- a/encodings/dict/src/compute/mod.rs +++ b/encodings/dict/src/compute/mod.rs @@ -79,7 +79,8 @@ impl SliceFn for DictEncoding { #[cfg(test)] mod test { use vortex_array::accessor::ArrayAccessor; - use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinViewArray}; + use vortex_array::array::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray}; + use vortex_array::compute::test_harness::test_mask; use vortex_array::compute::{compare, scalar_at, slice, Operator}; use vortex_array::{Array, IntoArray, IntoArrayVariant}; use vortex_dtype::{DType, Nullability}; @@ -198,4 +199,37 @@ mod test { Scalar::bool(true, Nullability::Nullable) ); } + + #[test] + fn test_mask_dict_array() { + let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()) + .unwrap() + .into_array(); + test_mask(array); + + let array = dict_encode( + &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]) + .into_array(), + ) + .unwrap() + .into_array(); + test_mask(array); + + let array = dict_encode( + &VarBinArray::from_iter( + [ + Some("hello"), + None, + Some("hello"), + Some("good"), + Some("good"), + ], + DType::Utf8(Nullability::Nullable), + ) + .into_array(), + ) + .unwrap() + .into_array(); + test_mask(array); + } } diff --git a/encodings/sparse/src/compute/mod.rs b/encodings/sparse/src/compute/mod.rs index 467fb78ec3c..d3615e064a7 100644 --- a/encodings/sparse/src/compute/mod.rs +++ b/encodings/sparse/src/compute/mod.rs @@ -104,11 +104,14 @@ impl FilterFn for SparseEncoding { mod test { use rstest::{fixture, rstest}; use vortex_array::array::PrimitiveArray; - use vortex_array::compute::test_harness::test_binary_numeric; - use vortex_array::compute::{filter, search_sorted, slice, SearchResult, SearchSortedSide}; + use vortex_array::compute::test_harness::{test_binary_numeric, test_mask}; + use vortex_array::compute::{ + filter, search_sorted, slice, try_cast, SearchResult, SearchSortedSide, + }; use vortex_array::validity::Validity; use vortex_array::{Array, IntoArray, IntoArrayVariant}; use vortex_buffer::buffer; + use vortex_dtype::{DType, Nullability, PType}; use vortex_mask::Mask; use vortex_scalar::Scalar; @@ -223,4 +226,35 @@ mod test { fn test_sparse_binary_numeric(array: Array) { test_binary_numeric::(array) } + + #[test] + fn test_mask_sparse_array() { + let null_fill_value = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)); + test_mask( + SparseArray::try_new( + buffer![1u64, 2, 4].into_array(), + try_cast( + buffer![100i32, 200, 300].into_array(), + null_fill_value.dtype(), + ) + .unwrap(), + 5, + null_fill_value, + ) + .unwrap() + .into_array(), + ); + + let ten_fill_value = Scalar::from(10i32); + test_mask( + SparseArray::try_new( + buffer![1u64, 2, 4].into_array(), + buffer![100i32, 200, 300].into_array(), + 5, + ten_fill_value, + ) + .unwrap() + .into_array(), + ) + } } diff --git a/vortex-array/src/array/bool/compute/cast.rs b/vortex-array/src/array/bool/compute/cast.rs index 71c8822ff9a..c13c92b2b6a 100644 --- a/vortex-array/src/array/bool/compute/cast.rs +++ b/vortex-array/src/array/bool/compute/cast.rs @@ -11,13 +11,9 @@ impl CastFn for BoolEncoding { vortex_bail!("Cannot cast {} to {}", array.dtype(), dtype); } - // If the types are the same, return the array, - // otherwise set the array nullability as the dtype nullability. - if dtype.is_nullable() || array.all_valid()? { - Ok(BoolArray::new(array.boolean_buffer(), dtype.nullability()).into_array()) - } else { - vortex_bail!("Cannot cast null array to non-nullable type"); - } + let new_nullability = dtype.nullability(); + let new_validity = array.validity().cast_nullability(new_nullability)?; + BoolArray::try_new(array.boolean_buffer(), new_validity).map(IntoArray::into_array) } } diff --git a/vortex-array/src/array/bool/compute/mask.rs b/vortex-array/src/array/bool/compute/mask.rs new file mode 100644 index 00000000000..2e2f60cf5c6 --- /dev/null +++ b/vortex-array/src/array/bool/compute/mask.rs @@ -0,0 +1,13 @@ +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::array::{BoolArray, BoolEncoding}; +use crate::compute::MaskFn; +use crate::{Array, IntoArray}; + +impl MaskFn for BoolEncoding { + fn mask(&self, array: &BoolArray, mask: Mask) -> VortexResult { + BoolArray::try_new(array.boolean_buffer(), array.validity().mask(&mask)?) + .map(IntoArray::into_array) + } +} diff --git a/vortex-array/src/array/bool/compute/mod.rs b/vortex-array/src/array/bool/compute/mod.rs index 6314809ae5b..81d09978796 100644 --- a/vortex-array/src/array/bool/compute/mod.rs +++ b/vortex-array/src/array/bool/compute/mod.rs @@ -1,7 +1,7 @@ use crate::array::BoolEncoding; use crate::compute::{ - BinaryBooleanFn, CastFn, FillForwardFn, FillNullFn, FilterFn, InvertFn, MinMaxFn, ScalarAtFn, - SliceFn, TakeFn, ToArrowFn, + BinaryBooleanFn, CastFn, FillForwardFn, FillNullFn, FilterFn, InvertFn, MaskFn, MinMaxFn, + ScalarAtFn, SliceFn, TakeFn, ToArrowFn, }; use crate::vtable::ComputeVTable; use crate::Array; @@ -12,6 +12,7 @@ mod fill_null; pub mod filter; mod flatten; mod invert; +mod mask; mod min_max; mod scalar_at; mod slice; @@ -43,6 +44,10 @@ impl ComputeVTable for BoolEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/array/bool/mod.rs b/vortex-array/src/array/bool/mod.rs index b7146271ced..a8c1c7a5806 100644 --- a/vortex-array/src/array/bool/mod.rs +++ b/vortex-array/src/array/bool/mod.rs @@ -258,6 +258,7 @@ mod tests { use vortex_dtype::Nullability; use crate::array::{BoolArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; use crate::compute::{scalar_at, slice}; use crate::patches::Patches; use crate::validity::Validity; @@ -371,4 +372,9 @@ mod tests { let (values, _byte_bit_offset) = arr.into_bool().unwrap().into_boolean_builder(); assert_eq!(values.as_slice(), &[254, 127]); } + + #[test] + fn test_mask_primitive_array() { + test_mask(BoolArray::from_iter([true, false, true, true, false]).into_array()); + } } diff --git a/vortex-array/src/array/chunked/compute/filter.rs b/vortex-array/src/array/chunked/compute/filter.rs index cabb6b8a5c5..fa188738d19 100644 --- a/vortex-array/src/array/chunked/compute/filter.rs +++ b/vortex-array/src/array/chunked/compute/filter.rs @@ -8,7 +8,7 @@ use crate::validity::Validity; use crate::{Array, IntoArray, IntoArrayVariant}; // This is modeled after the constant with the equivalent name in arrow-rs. -const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; +pub(crate) const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; impl FilterFn for ChunkedEncoding { fn filter(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult { @@ -32,7 +32,7 @@ impl FilterFn for ChunkedEncoding { /// When we rewrite a set of slices in a filter predicate into chunk addresses, we want to account /// for the fact that some chunks will be wholly skipped. #[derive(Clone)] -enum ChunkFilter { +pub(crate) enum ChunkFilter { All, None, Slices(Vec<(usize, usize)>), @@ -45,6 +45,29 @@ fn filter_slices( ) -> VortexResult> { let mut result = Vec::with_capacity(array.nchunks()); + let chunk_filters = chunk_filters(array, slices)?; + + // Now, apply the chunk filter to every slice. + for (chunk, chunk_filter) in array.chunks().zip(chunk_filters.into_iter()) { + match chunk_filter { + // All => preserve the entire chunk unfiltered. + ChunkFilter::All => result.push(chunk), + // None => whole chunk is filtered out, skip + ChunkFilter::None => {} + // Slices => turn the slices into a boolean buffer. + ChunkFilter::Slices(slices) => { + result.push(filter(&chunk, &Mask::from_slices(chunk.len(), slices))?); + } + } + } + + Ok(result) +} + +pub(crate) fn chunk_filters( + array: &ChunkedArray, + slices: impl Iterator, +) -> VortexResult> { // Pre-materialize the chunk ends for performance. // The chunk ends is nchunks+1, which is expected to be in the hundreds or at most thousands. let chunk_ends = array.chunk_offsets().into_primitive()?; @@ -99,21 +122,7 @@ fn filter_slices( } } - // Now, apply the chunk filter to every slice. - for (chunk, chunk_filter) in array.chunks().zip(chunk_filters.into_iter()) { - match chunk_filter { - // All => preserve the entire chunk unfiltered. - ChunkFilter::All => result.push(chunk), - // None => whole chunk is filtered out, skip - ChunkFilter::None => {} - // Slices => turn the slices into a boolean buffer. - ChunkFilter::Slices(slices) => { - result.push(filter(&chunk, &Mask::from_slices(chunk.len(), slices))?); - } - } - } - - Ok(result) + Ok(chunk_filters) } /// Filter the chunks using indices. @@ -171,7 +180,7 @@ fn filter_indices( // Mirrors the find_chunk_idx method on ChunkedArray, but avoids all of the overhead // from scalars, dtypes, and metadata cloning. -fn find_chunk_idx(idx: usize, chunk_ends: &[u64]) -> (usize, usize) { +pub(crate) fn find_chunk_idx(idx: usize, chunk_ends: &[u64]) -> (usize, usize) { let chunk_id = chunk_ends .search_sorted(&(idx as u64), SearchSortedSide::Right) .to_ends_index(chunk_ends.len()) diff --git a/vortex-array/src/array/chunked/compute/mask.rs b/vortex-array/src/array/chunked/compute/mask.rs new file mode 100644 index 00000000000..37e67046209 --- /dev/null +++ b/vortex-array/src/array/chunked/compute/mask.rs @@ -0,0 +1,147 @@ +use itertools::Itertools as _; +use vortex_dtype::DType; +use vortex_error::{VortexExpect as _, VortexResult}; +use vortex_mask::{AllOr, Mask, MaskIter}; +use vortex_scalar::Scalar; + +use super::filter::{chunk_filters, find_chunk_idx, ChunkFilter}; +use crate::array::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD; +use crate::array::{ChunkedArray, ChunkedEncoding, ConstantArray}; +use crate::compute::{mask, try_cast, MaskFn}; +use crate::{Array, IntoArray, IntoCanonical as _}; + +impl MaskFn for ChunkedEncoding { + fn mask(&self, array: &ChunkedArray, mask: Mask) -> VortexResult { + let new_dtype = array.dtype().as_nullable(); + let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { + AllOr::All => unreachable!("handled in top-level mask"), + AllOr::None => unreachable!("handled in top-level mask"), + AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype), + AllOr::Some(MaskIter::Slices(slices)) => { + mask_slices(array, slices.iter().cloned(), &new_dtype) + } + }?; + debug_assert_eq!(new_chunks.len(), array.nchunks()); + debug_assert_eq!( + new_chunks.iter().map(|x| x.len()).sum::(), + array.len() + ); + ChunkedArray::try_new(new_chunks, new_dtype).map(IntoArray::into_array) + } +} + +fn mask_indices( + array: &ChunkedArray, + indices: &[usize], + new_dtype: &DType, +) -> VortexResult> { + let mut new_chunks = Vec::with_capacity(array.nchunks()); + let mut current_chunk_id = 0; + let mut chunk_indices = Vec::new(); + + // Avoid find_chunk_idx and use our own to avoid the overhead. + // The array should only be some thousands of values in the general case. + let chunk_ends = array.chunk_offsets().into_canonical()?.into_primitive()?; + let chunk_ends = chunk_ends.as_slice::(); + + for &set_index in indices { + let (chunk_id, index) = find_chunk_idx(set_index, chunk_ends); + if chunk_id != current_chunk_id { + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + let masked_chunk = mask(&chunk, Mask::from_indices(chunk.len(), chunk_indices))?; + // Advance the chunk forward, reset the chunk indices buffer. + chunk_indices = Vec::new(); + new_chunks.push(masked_chunk); + current_chunk_id += 1; + + while current_chunk_id < chunk_id { + // Chunks that are not affected by the mask, must still be casted to the correct dtype. + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + new_chunks.push(try_cast(chunk, new_dtype)?); + current_chunk_id += 1; + } + } + + chunk_indices.push(index); + } + + if !chunk_indices.is_empty() { + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + let masked_chunk = mask(&chunk, Mask::from_indices(chunk.len(), chunk_indices))?; + new_chunks.push(masked_chunk); + current_chunk_id += 1; + } + + while current_chunk_id < array.nchunks() { + let chunk = array + .chunk(current_chunk_id) + .vortex_expect("find_chunk_idx must return valid chunk ID"); + new_chunks.push(try_cast(chunk, new_dtype)?); + current_chunk_id += 1; + } + + Ok(new_chunks) +} + +fn mask_slices( + array: &ChunkedArray, + slices: impl Iterator, + new_dtype: &DType, +) -> VortexResult> { + let chunked_filters = chunk_filters(array, slices)?; + + array + .chunks() + .zip_eq(chunked_filters) + .map(|(chunk, chunk_filter)| -> VortexResult { + Ok(match chunk_filter { + ChunkFilter::All => { + // entire chunk is masked out + ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()).into_array() + } + ChunkFilter::None => { + // entire chunk is not affected by mask + chunk + } + ChunkFilter::Slices(slices) => { + // Slices of indices that must be set to null + mask(&chunk, Mask::from_slices(chunk.len(), slices))? + } + }) + }) + .process_results(|iter| iter.collect::>()) +} + +#[cfg(test)] +mod test { + use vortex_buffer::buffer; + use vortex_dtype::{DType, Nullability, PType}; + + use crate::array::{ChunkedArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; + use crate::IntoArray; + + #[test] + fn test_mask_chunked_array() { + let dtype = DType::Primitive(PType::U64, Nullability::NonNullable); + let chunked = ChunkedArray::try_new( + vec![ + buffer![0u64, 1].into_array(), + buffer![2_u64].into_array(), + PrimitiveArray::empty::(dtype.nullability()).into_array(), + buffer![3_u64, 4].into_array(), + ], + dtype, + ) + .unwrap() + .into_array(); + + test_mask(chunked); + } +} diff --git a/vortex-array/src/array/chunked/compute/mod.rs b/vortex-array/src/array/chunked/compute/mod.rs index aa0d712b50b..5013b144cb2 100644 --- a/vortex-array/src/array/chunked/compute/mod.rs +++ b/vortex-array/src/array/chunked/compute/mod.rs @@ -5,7 +5,7 @@ use crate::array::chunked::ChunkedArray; use crate::array::ChunkedEncoding; use crate::compute::{ try_cast, BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, FillNullFn, FilterFn, InvertFn, - MinMaxFn, ScalarAtFn, SliceFn, TakeFn, + MaskFn, MinMaxFn, ScalarAtFn, SliceFn, TakeFn, }; use crate::vtable::ComputeVTable; use crate::{Array, IntoArray}; @@ -16,6 +16,7 @@ mod compare; mod fill_null; mod filter; mod invert; +mod mask; mod min_max; mod scalar_at; mod slice; @@ -50,6 +51,10 @@ impl ComputeVTable for ChunkedEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/array/constant/compute/mod.rs b/vortex-array/src/array/constant/compute/mod.rs index 42864ba5583..da660de20b6 100644 --- a/vortex-array/src/array/constant/compute/mod.rs +++ b/vortex-array/src/array/constant/compute/mod.rs @@ -83,3 +83,21 @@ impl FilterFn for ConstantEncoding { Ok(ConstantArray::new(array.scalar(), mask.true_count()).into_array()) } } + +#[cfg(test)] +mod test { + use vortex_dtype::half::f16; + use vortex_scalar::Scalar; + + use super::ConstantArray; + use crate::compute::test_harness::test_mask; + use crate::IntoArray as _; + + #[test] + fn test_mask_constant() { + test_mask(ConstantArray::new(Scalar::null_typed::(), 5).into_array()); + test_mask(ConstantArray::new(Scalar::from(3u16), 5).into_array()); + test_mask(ConstantArray::new(Scalar::from(1.0f32 / 0.0f32), 5).into_array()); + test_mask(ConstantArray::new(Scalar::from(f16::from_f32(3.0f32)), 5).into_array()); + } +} diff --git a/vortex-array/src/array/list/compute/mod.rs b/vortex-array/src/array/list/compute/mod.rs index 3d34099e262..b8dccbae965 100644 --- a/vortex-array/src/array/list/compute/mod.rs +++ b/vortex-array/src/array/list/compute/mod.rs @@ -4,10 +4,13 @@ use std::sync::Arc; use itertools::Itertools; use vortex_error::VortexResult; +use vortex_mask::Mask; use vortex_scalar::Scalar; use crate::array::{ListArray, ListEncoding}; -use crate::compute::{scalar_at, slice, MinMaxFn, MinMaxResult, ScalarAtFn, SliceFn, ToArrowFn}; +use crate::compute::{ + scalar_at, slice, MaskFn, MinMaxFn, MinMaxResult, ScalarAtFn, SliceFn, ToArrowFn, +}; use crate::vtable::ComputeVTable; use crate::{Array, IntoArray}; @@ -24,6 +27,10 @@ impl ComputeVTable for ListEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn min_max_fn(&self) -> Option<&dyn MinMaxFn> { Some(self) } @@ -53,9 +60,40 @@ impl SliceFn for ListEncoding { } } +impl MaskFn for ListEncoding { + fn mask(&self, array: &ListArray, mask: Mask) -> VortexResult { + ListArray::try_new( + array.elements(), + array.offsets(), + array.validity().mask(&mask)?, + ) + .map(IntoArray::into_array) + } +} + impl MinMaxFn for ListEncoding { fn min_max(&self, _array: &ListArray) -> VortexResult> { // TODO(joe): Implement list min max Ok(None) } } + +#[cfg(test)] +mod test { + use crate::array::{ListArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; + use crate::validity::Validity; + use crate::IntoArray as _; + + #[test] + fn test_mask_list() { + let elements = PrimitiveArray::from_iter(0..35); + let offsets = PrimitiveArray::from_iter([0, 5, 11, 18, 26, 35]); + let validity = Validity::AllValid; + let array = ListArray::try_new(elements.into_array(), offsets.into_array(), validity) + .unwrap() + .into_array(); + + test_mask(array); + } +} diff --git a/vortex-array/src/array/null/compute.rs b/vortex-array/src/array/null/compute.rs index 029d2bb1d10..64a4fed1d41 100644 --- a/vortex-array/src/array/null/compute.rs +++ b/vortex-array/src/array/null/compute.rs @@ -2,16 +2,21 @@ use arrow_array::{new_null_array, ArrayRef}; use arrow_schema::DataType; use vortex_dtype::{match_each_integer_ptype, DType}; use vortex_error::{vortex_bail, VortexResult}; +use vortex_mask::Mask; use vortex_scalar::Scalar; use crate::array::null::NullArray; use crate::array::NullEncoding; -use crate::compute::{MinMaxFn, MinMaxResult, ScalarAtFn, SliceFn, TakeFn, ToArrowFn}; +use crate::compute::{MaskFn, MinMaxFn, MinMaxResult, ScalarAtFn, SliceFn, TakeFn, ToArrowFn}; use crate::variants::PrimitiveArrayTrait; use crate::vtable::ComputeVTable; use crate::{Array, IntoArray, IntoArrayVariant}; impl ComputeVTable for NullEncoding { + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -33,6 +38,12 @@ impl ComputeVTable for NullEncoding { } } +impl MaskFn for NullEncoding { + fn mask(&self, array: &NullArray, _mask: Mask) -> VortexResult { + Ok(array.clone().into_array()) + } +} + impl SliceFn for NullEncoding { fn slice(&self, _array: &NullArray, start: usize, stop: usize) -> VortexResult { Ok(NullArray::new(stop - start).into_array()) diff --git a/vortex-array/src/array/primitive/compute/mask.rs b/vortex-array/src/array/primitive/compute/mask.rs new file mode 100644 index 00000000000..101a0a95db5 --- /dev/null +++ b/vortex-array/src/array/primitive/compute/mask.rs @@ -0,0 +1,18 @@ +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::array::primitive::PrimitiveArray; +use crate::array::PrimitiveEncoding; +use crate::compute::MaskFn; +use crate::variants::PrimitiveArrayTrait as _; +use crate::{Array, IntoArray}; + +impl MaskFn for PrimitiveEncoding { + fn mask(&self, array: &PrimitiveArray, mask: Mask) -> VortexResult { + let validity = array.validity().mask(&mask)?; + Ok( + PrimitiveArray::from_byte_buffer(array.byte_buffer().clone(), array.ptype(), validity) + .into_array(), + ) + } +} diff --git a/vortex-array/src/array/primitive/compute/mod.rs b/vortex-array/src/array/primitive/compute/mod.rs index 129e3f54ebc..acb63f2fe53 100644 --- a/vortex-array/src/array/primitive/compute/mod.rs +++ b/vortex-array/src/array/primitive/compute/mod.rs @@ -1,6 +1,6 @@ use crate::array::PrimitiveEncoding; use crate::compute::{ - CastFn, FillForwardFn, FillNullFn, FilterFn, MinMaxFn, ScalarAtFn, SearchSortedFn, + CastFn, FillForwardFn, FillNullFn, FilterFn, MaskFn, MinMaxFn, ScalarAtFn, SearchSortedFn, SearchSortedUsizeFn, SliceFn, TakeFn, ToArrowFn, }; use crate::vtable::ComputeVTable; @@ -10,6 +10,7 @@ mod cast; mod fill; mod fill_null; mod filter; +mod mask; mod min_max; mod scalar_at; mod search_sorted; @@ -22,6 +23,10 @@ impl ComputeVTable for PrimitiveEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn> { Some(self) } diff --git a/vortex-array/src/array/primitive/mod.rs b/vortex-array/src/array/primitive/mod.rs index 55eb852e646..7f1355e11cb 100644 --- a/vortex-array/src/array/primitive/mod.rs +++ b/vortex-array/src/array/primitive/mod.rs @@ -383,3 +383,29 @@ impl VisitorVTable for PrimitiveEncoding { visitor.visit_validity(&array.validity()) } } + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + + use crate::array::{BoolArray, PrimitiveArray}; + use crate::compute::test_harness::test_mask; + use crate::validity::Validity; + use crate::IntoArray as _; + + #[test] + fn test_mask_primitive_array() { + test_mask(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable).into_array()); + test_mask(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid).into_array()); + test_mask(PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllInvalid).into_array()); + test_mask( + PrimitiveArray::new( + buffer![0, 1, 2, 3, 4], + Validity::Array( + BoolArray::from_iter([true, false, true, false, true]).into_array(), + ), + ) + .into_array(), + ); + } +} diff --git a/vortex-array/src/array/struct_/compute/mod.rs b/vortex-array/src/array/struct_/compute/mod.rs index 46281c20b2e..59d2d06249a 100644 --- a/vortex-array/src/array/struct_/compute/mod.rs +++ b/vortex-array/src/array/struct_/compute/mod.rs @@ -1,25 +1,34 @@ mod to_arrow; use itertools::Itertools; -use vortex_error::VortexResult; +use vortex_dtype::DType; +use vortex_error::{vortex_bail, VortexExpect, VortexResult}; use vortex_mask::Mask; use vortex_scalar::Scalar; use crate::array::struct_::StructArray; use crate::array::StructEncoding; use crate::compute::{ - filter, scalar_at, slice, take, FilterFn, MinMaxFn, MinMaxResult, ScalarAtFn, SliceFn, TakeFn, - ToArrowFn, + filter, scalar_at, slice, take, try_cast, CastFn, FilterFn, MaskFn, MinMaxFn, MinMaxResult, + ScalarAtFn, SliceFn, TakeFn, ToArrowFn, }; use crate::variants::StructArrayTrait; use crate::vtable::ComputeVTable; use crate::{Array, IntoArray}; impl ComputeVTable for StructEncoding { + fn cast_fn(&self) -> Option<&dyn CastFn> { + Some(self) + } + fn filter_fn(&self) -> Option<&dyn FilterFn> { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -41,6 +50,38 @@ impl ComputeVTable for StructEncoding { } } +impl CastFn for StructEncoding { + fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult { + let Some(target_sdtype) = dtype.as_struct() else { + vortex_bail!("cannot cast {} to {}", array.dtype(), dtype); + }; + + let source_sdtype = array + .dtype() + .as_struct() + .vortex_expect("struct array must have struct dtype"); + + if target_sdtype.names() != source_sdtype.names() { + vortex_bail!("cannot cast {} to {}", array.dtype(), dtype); + } + + let validity = array.validity().cast_nullability(dtype.nullability())?; + + StructArray::try_new( + target_sdtype.names().clone(), + array + .children() + .into_iter() + .zip_eq(target_sdtype.fields()) + .map(|(field, dtype)| try_cast(&field, &dtype)) + .try_collect()?, + array.len(), + validity, + ) + .map(|a| a.into_array()) + } +} + impl ScalarAtFn for StructEncoding { fn scalar_at(&self, array: &StructArray, index: usize) -> VortexResult { Ok(Scalar::struct_( @@ -102,6 +143,20 @@ impl FilterFn for StructEncoding { } } +impl MaskFn for StructEncoding { + fn mask(&self, array: &StructArray, filter_mask: Mask) -> VortexResult { + let validity = array.validity().mask(&filter_mask)?; + + StructArray::try_new( + array.names().clone(), + array.fields().collect(), + array.len(), + validity, + ) + .map(|a| a.into_array()) + } +} + impl MinMaxFn for StructEncoding { fn min_max(&self, _array: &StructArray) -> VortexResult> { // TODO(joe): Implement struct min max @@ -111,11 +166,17 @@ impl MinMaxFn for StructEncoding { #[cfg(test)] mod tests { + use std::sync::Arc; + + use vortex_buffer::buffer; + use vortex_dtype::{DType, FieldNames, Nullability, PType, StructDType}; use vortex_mask::Mask; - use crate::array::StructArray; - use crate::compute::filter; + use crate::array::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray}; + use crate::compute::test_harness::test_mask; + use crate::compute::{filter, try_cast}; use crate::validity::Validity; + use crate::IntoArray as _; #[test] fn filter_empty_struct() { @@ -135,4 +196,189 @@ mod tests { let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap(); assert_eq!(filtered.len(), 0); } + + #[test] + fn test_mask_empty_struct() { + test_mask( + StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable) + .unwrap() + .into_array(), + ); + } + + #[test] + fn test_mask_complex_struct() { + let xs = buffer![0i64, 1, 2, 3, 4].into_array(); + let ys = VarBinArray::from_iter( + [Some("a"), Some("b"), None, Some("d"), None], + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + let zs = + BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array(); + + test_mask( + StructArray::try_new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + StructArray::try_new( + ["left".into(), "right".into()].into(), + vec![xs.clone(), xs], + 5, + Validity::NonNullable, + ) + .unwrap() + .into_array(), + ys, + zs, + ], + 5, + Validity::NonNullable, + ) + .unwrap() + .into_array(), + ); + } + + #[test] + fn test_cast_empty_struct() { + let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable) + .unwrap() + .into_array(); + let non_nullable_dtype = DType::Struct( + Arc::from(StructDType::new([].into(), vec![])), + Nullability::NonNullable, + ); + let casted = try_cast(&array, &non_nullable_dtype).unwrap(); + assert_eq!(casted.dtype(), &non_nullable_dtype); + + let nullable_dtype = DType::Struct( + Arc::from(StructDType::new([].into(), vec![])), + Nullability::Nullable, + ); + let casted = try_cast(&array, &nullable_dtype).unwrap(); + assert_eq!(casted.dtype(), &nullable_dtype); + } + + #[test] + fn test_cast_cannot_change_name_order() { + let array = StructArray::try_new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + buffer![1u8].into_array(), + buffer![1u8].into_array(), + buffer![1u8].into_array(), + ], + 1, + Validity::NonNullable, + ) + .unwrap() + .into_array(); + + let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable); + + let result = try_cast( + array, + &DType::Struct( + Arc::from(StructDType::new( + FieldNames::from(["ys".into(), "xs".into(), "zs".into()]), + vec![tu8.clone(), tu8.clone(), tu8], + )), + Nullability::NonNullable, + ), + ); + assert!( + result.as_ref().is_err_and(|err| { + err.to_string() + .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}") + }), + "{:?}", + result + ); + } + + #[test] + fn test_cast_complex_struct() { + let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]) + .into_array(); + let ys = VarBinArray::from_vec( + vec!["a", "b", "c", "d", "e"], + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + let zs = BoolArray::new( + BooleanBuffer::from_iter([true, true, false, false, true]), + Nullability::Nullable, + ) + .into_array(); + let fully_nullable_array = StructArray::try_new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + StructArray::try_new( + ["left".into(), "right".into()].into(), + vec![xs.clone(), xs], + 5, + Validity::AllValid, + ) + .unwrap() + .into_array(), + ys, + zs, + ], + 5, + Validity::AllValid, + ) + .unwrap() + .into_array(); + + let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable(); + let casted = try_cast(&fully_nullable_array, &top_level_non_nullable).unwrap(); + assert_eq!(casted.dtype(), &top_level_non_nullable); + + let non_null_xs_right = DType::Struct( + Arc::from(StructDType::new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + DType::Struct( + Arc::from(StructDType::new( + ["left".into(), "right".into()].into(), + vec![ + DType::Primitive(PType::I64, Nullability::NonNullable), + DType::Primitive(PType::I64, Nullability::Nullable), + ], + )), + Nullability::Nullable, + ), + DType::Utf8(Nullability::Nullable), + DType::Bool(Nullability::Nullable), + ], + )), + Nullability::Nullable, + ); + let casted = try_cast(&fully_nullable_array, &non_null_xs_right).unwrap(); + assert_eq!(casted.dtype(), &non_null_xs_right); + + let non_null_xs = DType::Struct( + Arc::from(StructDType::new( + ["xs".into(), "ys".into(), "zs".into()].into(), + vec![ + DType::Struct( + Arc::from(StructDType::new( + ["left".into(), "right".into()].into(), + vec![ + DType::Primitive(PType::I64, Nullability::Nullable), + DType::Primitive(PType::I64, Nullability::Nullable), + ], + )), + Nullability::NonNullable, + ), + DType::Utf8(Nullability::Nullable), + DType::Bool(Nullability::Nullable), + ], + )), + Nullability::Nullable, + ); + let casted = try_cast(&fully_nullable_array, &non_null_xs).unwrap(); + assert_eq!(casted.dtype(), &non_null_xs); + } } diff --git a/vortex-array/src/array/varbin/compute/cast.rs b/vortex-array/src/array/varbin/compute/cast.rs index 11ead4be79a..eeb107926fa 100644 --- a/vortex-array/src/array/varbin/compute/cast.rs +++ b/vortex-array/src/array/varbin/compute/cast.rs @@ -3,7 +3,6 @@ use vortex_error::{vortex_bail, VortexResult}; use crate::array::{VarBinArray, VarBinEncoding}; use crate::compute::CastFn; -use crate::validity::Validity; use crate::{Array, IntoArray}; impl CastFn for VarBinEncoding { @@ -12,23 +11,11 @@ impl CastFn for VarBinEncoding { vortex_bail!("Cannot cast {} to {}", array.dtype(), dtype); } - // If the types are the same, return the array, - // otherwise set the array nullability as the dtype nullability. - if dtype.is_nullable() || array.all_valid()? { - VarBinArray::try_new( - array.offsets(), - array.bytes(), - dtype.clone(), - if dtype.is_nullable() { - Validity::AllValid - } else { - Validity::NonNullable - }, - ) - .map(|a| a.into_array()) - } else { - vortex_bail!("Cannot cast null array to non-nullable type"); - } + let new_nullability = dtype.nullability(); + let new_validity = array.validity().cast_nullability(new_nullability)?; + let new_dtype = array.dtype().with_nullability(new_nullability); + VarBinArray::try_new(array.offsets(), array.bytes(), new_dtype, new_validity) + .map(IntoArray::into_array) } } diff --git a/vortex-array/src/array/varbin/compute/mask.rs b/vortex-array/src/array/varbin/compute/mask.rs new file mode 100644 index 00000000000..e81cfb55206 --- /dev/null +++ b/vortex-array/src/array/varbin/compute/mask.rs @@ -0,0 +1,45 @@ +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::array::varbin::VarBinArray; +use crate::array::VarBinEncoding; +use crate::compute::MaskFn; +use crate::{Array, IntoArray}; + +impl MaskFn for VarBinEncoding { + fn mask(&self, array: &VarBinArray, mask: Mask) -> VortexResult { + VarBinArray::try_new( + array.offsets(), + array.bytes(), + array.dtype().as_nullable(), + array.validity().mask(&mask)?, + ) + .map(IntoArray::into_array) + } +} + +#[cfg(test)] +mod test { + use vortex_dtype::{DType, Nullability}; + + use crate::array::VarBinArray; + use crate::compute::test_harness::test_mask; + use crate::IntoArray as _; + + #[test] + fn test_mask_var_bin_array() { + let array = VarBinArray::from_vec( + vec!["hello", "world", "filter", "good", "bye"], + DType::Utf8(Nullability::NonNullable), + ) + .into_array(); + test_mask(array); + + let array = VarBinArray::from_iter( + vec![Some("hello"), None, Some("filter"), Some("good"), None], + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + test_mask(array); + } +} diff --git a/vortex-array/src/array/varbin/compute/mod.rs b/vortex-array/src/array/varbin/compute/mod.rs index 73835c4751a..4516f12cf11 100644 --- a/vortex-array/src/array/varbin/compute/mod.rs +++ b/vortex-array/src/array/varbin/compute/mod.rs @@ -5,7 +5,7 @@ use vortex_scalar::Scalar; use crate::array::varbin::{varbin_scalar, VarBinArray}; use crate::array::VarBinEncoding; use crate::compute::{ - CastFn, CompareFn, FilterFn, MinMaxFn, ScalarAtFn, SliceFn, TakeFn, ToArrowFn, + CastFn, CompareFn, FilterFn, MaskFn, MinMaxFn, ScalarAtFn, SliceFn, TakeFn, ToArrowFn, }; use crate::vtable::ComputeVTable; use crate::Array; @@ -13,6 +13,7 @@ use crate::Array; mod cast; mod compare; mod filter; +mod mask; mod min_max; mod slice; mod take; @@ -31,6 +32,10 @@ impl ComputeVTable for VarBinEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } diff --git a/vortex-array/src/array/varbinview/compute/cast.rs b/vortex-array/src/array/varbinview/compute/cast.rs index 056ae82905b..9c63889785b 100644 --- a/vortex-array/src/array/varbinview/compute/cast.rs +++ b/vortex-array/src/array/varbinview/compute/cast.rs @@ -3,7 +3,6 @@ use vortex_error::{vortex_bail, VortexResult}; use crate::array::{VarBinViewArray, VarBinViewEncoding}; use crate::compute::CastFn; -use crate::validity::Validity; use crate::{Array, IntoArray}; impl CastFn for VarBinViewEncoding { @@ -12,23 +11,16 @@ impl CastFn for VarBinViewEncoding { vortex_bail!("Cannot cast {} to {}", array.dtype(), dtype); } - // If the types are the same, return the array, - // otherwise set the array nullability as the dtype nullability. - if dtype.is_nullable() || array.all_valid()? { - VarBinViewArray::try_new( - array.views(), - array.buffers().collect(), - dtype.clone(), - if dtype.is_nullable() { - Validity::AllValid - } else { - Validity::NonNullable - }, - ) - .map(|a| a.into_array()) - } else { - vortex_bail!("Cannot cast null array to non-nullable type"); - } + let new_nullability = dtype.nullability(); + let new_validity = array.validity().cast_nullability(new_nullability)?; + let new_dtype = array.dtype().with_nullability(new_nullability); + VarBinViewArray::try_new( + array.views(), + array.buffers().collect(), + new_dtype, + new_validity, + ) + .map(IntoArray::into_array) } } diff --git a/vortex-array/src/array/varbinview/compute/mod.rs b/vortex-array/src/array/varbinview/compute/mod.rs index 149e16dfbae..0d1df61bd02 100644 --- a/vortex-array/src/array/varbinview/compute/mod.rs +++ b/vortex-array/src/array/varbinview/compute/mod.rs @@ -4,12 +4,13 @@ mod take; mod to_arrow; use vortex_error::VortexResult; +use vortex_mask::Mask; use vortex_scalar::Scalar; use crate::array::varbin::varbin_scalar; use crate::array::varbinview::VarBinViewArray; use crate::array::VarBinViewEncoding; -use crate::compute::{CastFn, MinMaxFn, ScalarAtFn, SliceFn, TakeFn, ToArrowFn}; +use crate::compute::{CastFn, MaskFn, MinMaxFn, ScalarAtFn, SliceFn, TakeFn, ToArrowFn}; use crate::vtable::ComputeVTable; use crate::{Array, IntoArray}; @@ -18,6 +19,10 @@ impl ComputeVTable for VarBinViewEncoding { Some(self) } + fn mask_fn(&self) -> Option<&dyn MaskFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -61,6 +66,18 @@ impl SliceFn for VarBinViewEncoding { } } +impl MaskFn for VarBinViewEncoding { + fn mask(&self, array: &VarBinViewArray, mask: Mask) -> VortexResult { + VarBinViewArray::try_new( + array.views(), + array.buffers().collect(), + array.dtype().as_nullable(), + array.validity().mask(&mask)?, + ) + .map(IntoArray::into_array) + } +} + #[cfg(test)] mod tests { use vortex_buffer::buffer; @@ -68,6 +85,7 @@ mod tests { use crate::accessor::ArrayAccessor; use crate::array::VarBinViewArray; use crate::builders::{ArrayBuilder, VarBinViewBuilder}; + use crate::compute::test_harness::test_mask; use crate::compute::{take, take_into}; use crate::{IntoArray, IntoArrayVariant}; @@ -97,6 +115,24 @@ mod tests { ); } + #[test] + fn take_mask_var_bin_view_array() { + test_mask( + VarBinViewArray::from_iter_str(["one", "two", "three", "four", "five"]).into_array(), + ); + + test_mask( + VarBinViewArray::from_iter_nullable_str([ + Some("one"), + None, + Some("three"), + Some("four"), + Some("five"), + ]) + .into_array(), + ); + } + #[test] fn take_into_nullable() { let arr = VarBinViewArray::from_iter_nullable_str([ diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index 618d82caa46..391ae2f0ce5 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -30,7 +30,30 @@ where } } -/// Return a new array by applying a boolean predicate to select items from a base Array. +/// Keep only the elements for which the corresponding mask value is true. +/// +/// # Examples +/// +/// ``` +/// use vortex_array::IntoArray; +/// use vortex_array::array::{BoolArray, PrimitiveArray}; +/// use vortex_array::compute::{scalar_at, filter, mask}; +/// use vortex_mask::Mask; +/// use vortex_scalar::Scalar; +/// +/// let array = +/// PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]) +/// .into_array(); +/// let mask = Mask::try_from( +/// BoolArray::from_iter([true, false, false, false, true]).into_array(), +/// ) +/// .unwrap(); +/// +/// let filtered = filter(&array, &mask).unwrap(); +/// assert_eq!(filtered.len(), 2); +/// assert_eq!(scalar_at(&filtered, 0).unwrap(), Scalar::from(Some(0_i32))); +/// assert_eq!(scalar_at(&filtered, 1).unwrap(), Scalar::from(Some(2_i32))); +/// ``` /// /// # Performance /// diff --git a/vortex-array/src/compute/mask.rs b/vortex-array/src/compute/mask.rs new file mode 100644 index 00000000000..f400a2beb02 --- /dev/null +++ b/vortex-array/src/compute/mask.rs @@ -0,0 +1,226 @@ +use arrow_array::BooleanArray; +use vortex_error::{vortex_bail, VortexError, VortexResult}; +use vortex_mask::Mask; +use vortex_scalar::Scalar; + +use crate::array::ConstantArray; +use crate::arrow::{FromArrowArray, IntoArrowArray}; +use crate::compute::try_cast; +use crate::encoding::Encoding; +use crate::{Array, IntoArray}; + +pub trait MaskFn { + /// Replace masked values with null in array. + fn mask(&self, array: &A, mask: Mask) -> VortexResult; +} + +impl MaskFn for E +where + E: MaskFn, + for<'a> &'a E::Array: TryFrom<&'a Array, Error = VortexError>, +{ + fn mask(&self, array: &Array, mask: Mask) -> VortexResult { + let (array_ref, encoding) = array.try_downcast_ref::()?; + MaskFn::mask(encoding, array_ref, mask) + } +} + +/// Replace values with null where the mask is true. +/// +/// The returned array is nullable but otherwise has the same dtype and length as `array`. +/// +/// # Examples +/// +/// ``` +/// use vortex_array::IntoArray; +/// use vortex_array::array::{BoolArray, PrimitiveArray}; +/// use vortex_array::compute::{scalar_at, mask}; +/// use vortex_mask::Mask; +/// use vortex_scalar::Scalar; +/// +/// let array = +/// PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)]) +/// .into_array(); +/// let mask_array = Mask::try_from( +/// BoolArray::from_iter([true, false, false, false, true]).into_array(), +/// ) +/// .unwrap(); +/// +/// let masked = mask(&array, mask_array).unwrap(); +/// assert_eq!(masked.len(), 5); +/// assert!(!masked.is_valid(0).unwrap()); +/// assert!(!masked.is_valid(1).unwrap()); +/// assert_eq!(scalar_at(&masked, 2).unwrap(), Scalar::from(Some(1))); +/// assert!(!masked.is_valid(3).unwrap()); +/// assert!(!masked.is_valid(4).unwrap()); +/// ``` +/// +pub fn mask(array: &Array, mask: Mask) -> VortexResult { + if mask.len() != array.len() { + vortex_bail!( + "mask.len() is {}, does not equal array.len() of {}", + mask.len(), + array.len() + ); + } + + let masked = if matches!(mask, Mask::AllFalse(_)) { + // Fast-path for empty mask + try_cast(array, &array.dtype().as_nullable())? + } else if matches!(mask, Mask::AllTrue(_)) { + // Fast-path for full mask. + ConstantArray::new( + Scalar::null(array.dtype().clone().as_nullable()), + array.len(), + ) + .into_array() + } else { + mask_impl(array, mask)? + }; + + debug_assert_eq!( + masked.len(), + array.len(), + "Mask should not change length {}\n\n{:?}\n\n{:?}", + array.encoding(), + array, + masked + ); + debug_assert_eq!( + masked.dtype(), + &array.dtype().as_nullable(), + "Mask dtype mismatch {} {} {} {}", + array.encoding(), + masked.dtype(), + array.dtype(), + array.dtype().as_nullable(), + ); + + Ok(masked) +} + +fn mask_impl(array: &Array, mask: Mask) -> VortexResult { + if let Some(mask_fn) = array.vtable().mask_fn() { + return mask_fn.mask(array, mask); + } + + // Fallback: implement using Arrow kernels. + log::debug!("No mask implementation found for {}", array.encoding()); + + let array_ref = array.clone().into_arrow_preferred()?; + let mask = BooleanArray::new(mask.to_boolean_buffer(), None); + + let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?; + + Ok(Array::from_arrow(masked, true)) +} + +#[cfg(feature = "test-harness")] +pub mod test_harness { + use vortex_mask::Mask; + + use crate::array::BoolArray; + use crate::compute::{mask, scalar_at}; + use crate::{Array, IntoArray}; + + pub fn test_mask(array: Array) { + assert_eq!(array.len(), 5); + test_heterogenous_mask(&array); + test_empty_mask(&array); + test_full_mask(&array); + } + + #[allow(clippy::unwrap_used)] + fn test_heterogenous_mask(array: &Array) { + let mask_array = + Mask::try_from(BoolArray::from_iter([true, false, false, true, true]).into_array()) + .unwrap(); + let masked = mask(array, mask_array).unwrap(); + assert_eq!(masked.len(), array.len()); + assert!(!masked.is_valid(0).unwrap()); + assert_eq!( + scalar_at(&masked, 1).unwrap(), + scalar_at(array, 1).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 2).unwrap(), + scalar_at(array, 2).unwrap().into_nullable() + ); + assert!(!masked.is_valid(3).unwrap()); + assert!(!masked.is_valid(4).unwrap()); + } + + #[allow(clippy::unwrap_used)] + fn test_empty_mask(array: &Array) { + let all_unmasked = + Mask::try_from(BoolArray::from_iter([false, false, false, false, false]).into_array()) + .unwrap(); + let masked = mask(array, all_unmasked).unwrap(); + assert_eq!(masked.len(), array.len()); + assert_eq!( + scalar_at(&masked, 0).unwrap(), + scalar_at(array, 0).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 1).unwrap(), + scalar_at(array, 1).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 2).unwrap(), + scalar_at(array, 2).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 3).unwrap(), + scalar_at(array, 3).unwrap().into_nullable() + ); + assert_eq!( + scalar_at(&masked, 4).unwrap(), + scalar_at(array, 4).unwrap().into_nullable() + ); + } + + #[allow(clippy::unwrap_used)] + fn test_full_mask(array: &Array) { + let all_masked = + Mask::try_from(BoolArray::from_iter([true, true, true, true, true]).into_array()) + .unwrap(); + let masked = mask(array, all_masked).unwrap(); + assert_eq!(masked.len(), array.len()); + assert!(!masked.is_valid(0).unwrap()); + assert!(!masked.is_valid(1).unwrap()); + assert!(!masked.is_valid(2).unwrap()); + assert!(!masked.is_valid(3).unwrap()); + assert!(!masked.is_valid(4).unwrap()); + + let mask1 = + Mask::try_from(BoolArray::from_iter([true, false, false, true, true]).into_array()) + .unwrap(); + let mask2 = + Mask::try_from(BoolArray::from_iter([false, true, false, false, true]).into_array()) + .unwrap(); + let first = mask(array, mask1).unwrap(); + let double_masked = mask(&first, mask2).unwrap(); + assert_eq!(double_masked.len(), array.len()); + assert!(!double_masked.is_valid(0).unwrap()); + assert!(!double_masked.is_valid(1).unwrap()); + assert_eq!( + scalar_at(&double_masked, 2).unwrap(), + scalar_at(array, 2).unwrap().into_nullable() + ); + assert!(!double_masked.is_valid(3).unwrap()); + assert!(!double_masked.is_valid(4).unwrap()); + } +} + +#[cfg(test)] +mod test { + use super::test_harness::test_mask; + use crate::array::PrimitiveArray; + use crate::IntoArray as _; + + #[test] + fn test_mask_non_nullable_array() { + let non_nullable_array = PrimitiveArray::from_iter([1, 2, 3, 4, 5]).into_array(); + test_mask(non_nullable_array); + } +} diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 21e7131d376..df182ef6e09 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -20,6 +20,7 @@ pub use fill_null::{fill_null, FillNullFn}; pub use filter::{filter, FilterFn}; pub use invert::{invert, InvertFn}; pub use like::{like, LikeFn, LikeOptions}; +pub use mask::{mask, MaskFn}; pub use min_max::{min_max, MinMaxFn, MinMaxResult}; pub use scalar_at::{scalar_at, ScalarAtFn}; pub use search_sorted::*; @@ -36,6 +37,7 @@ mod fill_null; mod filter; mod invert; mod like; +mod mask; mod min_max; mod scalar_at; mod search_sorted; @@ -46,4 +48,5 @@ mod to_arrow; #[cfg(feature = "test-harness")] pub mod test_harness { pub use crate::compute::binary_numeric::test_harness::test_binary_numeric; + pub use crate::compute::mask::test_harness::test_mask; } diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 9becc74bec1..71494c50215 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -1,7 +1,7 @@ //! Array validity and nullability behavior, used by arrays and compute functions. use std::fmt::{Debug, Display}; -use std::ops::BitAnd; +use std::ops::{BitAnd, Not}; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use serde::{Deserialize, Serialize}; @@ -285,6 +285,9 @@ impl Validity { } } + /// Keep only the entries for which the mask is true. + /// + /// The result has length equal to the number of true values in mask. pub fn filter(&self, mask: &Mask) -> VortexResult { // NOTE(ngates): we take the mask as a reference to avoid the caller cloning unnecessarily // if we happen to be NonNullable, AllValid, or AllInvalid. @@ -296,6 +299,27 @@ impl Validity { } } + /// Set to false any entries for which the mask is true. + /// + /// The result is always nullable. The result has the same length as self. + pub fn mask(&self, mask: &Mask) -> VortexResult { + match mask.boolean_buffer() { + AllOr::All => Ok(Validity::AllInvalid), + AllOr::None => Ok(self.clone()), + AllOr::Some(make_invalid) => Ok(match self { + Validity::NonNullable | Validity::AllValid => { + Validity::Array(BoolArray::from(make_invalid.not()).into_array()) + } + Validity::AllInvalid => Validity::AllInvalid, + Validity::Array(is_valid) => { + let is_valid = BoolArray::try_from(is_valid.clone())?.boolean_buffer(); + let keep_valid = make_invalid.not(); + Validity::from(is_valid.bitand(&keep_valid)) + } + }), + } + } + pub fn to_logical(&self, length: usize) -> VortexResult { Ok(match self { Self::NonNullable | Self::AllValid => Mask::AllTrue(length), @@ -405,6 +429,35 @@ impl Validity { } } + /// Convert into a non-nullable variant + pub fn into_non_nullable(self) -> Option { + match self { + Self::NonNullable => Some(Self::NonNullable), + Self::AllValid => Some(Self::NonNullable), + Self::AllInvalid => None, + Self::Array(is_valid) => { + is_valid + .statistics() + .compute_min::() + .vortex_expect("validity array must support min") + .then(|| { + // min true => all true + Self::NonNullable + }) + } + } + } + + /// Convert into a variant compatible with the given nullability, if possible. + pub fn cast_nullability(self, nullability: Nullability) -> VortexResult { + match nullability { + Nullability::NonNullable => self.into_non_nullable().ok_or_else(|| { + vortex_err!("Cannot cast array with invalid values to non-nullable type.") + }), + Nullability::Nullable => Ok(self.into_nullable()), + } + } + /// Create Validity from boolean array with given nullability of the array. /// /// Note: You want to pass the nullability of parent array and not the nullability of the validity array itself diff --git a/vortex-array/src/vtable/compute.rs b/vortex-array/src/vtable/compute.rs index c50418f1884..7edb1c88dd6 100644 --- a/vortex-array/src/vtable/compute.rs +++ b/vortex-array/src/vtable/compute.rs @@ -1,7 +1,7 @@ use crate::compute::{ BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, FillForwardFn, FillNullFn, FilterFn, - InvertFn, LikeFn, MinMaxFn, ScalarAtFn, SearchSortedFn, SearchSortedUsizeFn, SliceFn, TakeFn, - ToArrowFn, + InvertFn, LikeFn, MaskFn, MinMaxFn, ScalarAtFn, SearchSortedFn, SearchSortedUsizeFn, SliceFn, + TakeFn, ToArrowFn, }; use crate::Array; @@ -70,6 +70,15 @@ pub trait ComputeVTable { None } + /// Replace masked values with null. + /// + /// This operation does not change the length of the array. + /// + /// See: [MaskFn]. + fn mask_fn(&self) -> Option<&dyn MaskFn> { + None + } + /// Single item indexing on Vortex arrays. /// /// See: [ScalarAtFn]. diff --git a/vortex-buffer/src/buffer.rs b/vortex-buffer/src/buffer.rs index adacc63f400..b4398b8cf4e 100644 --- a/vortex-buffer/src/buffer.rs +++ b/vortex-buffer/src/buffer.rs @@ -469,7 +469,7 @@ impl Buf for ByteBuffer { } } -/// Owned iterator over a `Buffer`. +/// Owned iterator over a [`Buffer`]. pub struct BufferIterator { buffer: Buffer, index: usize,