From 65c78096e6144899e6c00649e722df5b90531bed Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Tue, 2 Jun 2026 15:10:48 +0100 Subject: [PATCH 1/3] Implement ZipKernel for ListViewArray Signed-off-by: Robert Kruszewski --- .../src/arrays/listview/compute/mod.rs | 1 + .../src/arrays/listview/compute/zip.rs | 288 ++++++++++++++++++ .../src/arrays/listview/vtable/kernel.rs | 7 +- 3 files changed, 294 insertions(+), 2 deletions(-) create mode 100644 vortex-array/src/arrays/listview/compute/zip.rs diff --git a/vortex-array/src/arrays/listview/compute/mod.rs b/vortex-array/src/arrays/listview/compute/mod.rs index 9a43503c4b5..87587495f7a 100644 --- a/vortex-array/src/arrays/listview/compute/mod.rs +++ b/vortex-array/src/arrays/listview/compute/mod.rs @@ -6,3 +6,4 @@ mod mask; pub(crate) mod rules; mod slice; mod take; +mod zip; diff --git a/vortex-array/src/arrays/listview/compute/zip.rs b/vortex-array/src/arrays/listview/compute/zip.rs new file mode 100644 index 00000000000..0c7bb41ae87 --- /dev/null +++ b/vortex-array/src/arrays/listview/compute/zip.rs @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ops::BitAnd; +use std::ops::BitOr; +use std::ops::Not; + +use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::array::ArrayView; +use crate::arrays::ChunkedArray; +use crate::arrays::ListView; +use crate::arrays::ListViewArray; +use crate::arrays::listview::ListViewArrayExt; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::scalar_fn::fns::zip::ZipKernel; +use crate::validity::Validity; + +/// Zip two [`ListViewArray`]s by selecting whole list views per row. +/// +/// A [`ListViewArray`] addresses each list by an `(offset, size)` pair into a shared `elements` +/// array, and unlike [`ListArray`](crate::arrays::ListArray) it does not require lists to be stored +/// contiguously or in order. Zipping two list views is therefore a metadata-only operation over the +/// `offsets`, `sizes` and `validity` child arrays: we concatenate the two `elements` arrays +/// (without rewriting them) and, for each row, select the `(offset, size)` pair from `if_true` or +/// `if_false` per the mask. `if_false` views are shifted past the end of `if_true`'s elements so +/// they continue to address the correct half of the concatenated elements array. +impl ZipKernel for ListView { + fn zip( + if_true: ArrayView<'_, ListView>, + if_false: &ArrayRef, + mask: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let Some(if_false) = if_false.as_opt::() else { + return Ok(None); + }; + + // Null mask entries select `if_false`, matching `Zip`'s SQL ELSE semantics. + let mask = mask.try_to_mask_fill_null_false(ctx)?; + match &mask { + // Defer the trivial masks to the generic zip, which just casts one side. + Mask::AllTrue(_) | Mask::AllFalse(_) => return Ok(None), + Mask::Values(_) => {} + } + + let len = if_true.len(); + + // `if_false`'s elements share the element dtype up to nullability; normalize so both chunks + // of the concatenated elements array have an identical dtype. + let true_elements = if_true.elements().clone(); + let element_dtype = true_elements.dtype().clone(); + let false_elements = if if_false.elements().dtype() == &element_dtype { + if_false.elements().clone() + } else { + if_false.elements().cast(element_dtype.clone())? + }; + + // `if_false` views index into the second half of the concatenated elements. + let false_shift = true_elements.len() as u64; + let elements = + ChunkedArray::try_new(vec![true_elements, false_elements], element_dtype)?.into_array(); + + let true_offsets = to_u64(if_true.offsets(), ctx)?; + let true_sizes = to_u64(if_true.sizes(), ctx)?; + let false_offsets = to_u64(if_false.offsets(), ctx)?; + let false_sizes = to_u64(if_false.sizes(), ctx)?; + + let mut offsets = BufferMut::::with_capacity(len); + let mut sizes = BufferMut::::with_capacity(len); + for i in 0..len { + // SAFETY: the loop runs exactly `len` times and both buffers reserved `len`. + unsafe { + if mask.value(i) { + offsets.push_unchecked(true_offsets[i]); + sizes.push_unchecked(true_sizes[i]); + } else { + offsets.push_unchecked(false_offsets[i] + false_shift); + sizes.push_unchecked(false_sizes[i]); + } + } + } + + let validity = zip_validity( + if_true.validity()?, + if_false.validity()?, + &mask, + if_true.nullability() | if_false.nullability(), + len, + ctx, + )?; + + Ok(Some( + ListViewArray::try_new( + elements, + offsets.freeze().into_array(), + sizes.freeze().into_array(), + validity, + )? + .into_array(), + )) + } +} + +/// Read a non-nullable integer array into a `u64` buffer. +fn to_u64(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult> { + array + .clone() + .cast(DType::Primitive(PType::U64, Nullability::NonNullable))? + .execute::>(ctx) +} + +/// Combine the two list-level validities, taking `if_true`'s validity where `mask` is set and +/// `if_false`'s where it is not. +fn zip_validity( + if_true: Validity, + if_false: Validity, + mask: &Mask, + nullability: Nullability, + len: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult { + Ok(match (&if_true, &if_false) { + (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable, + (Validity::AllValid, Validity::AllValid) => Validity::AllValid, + (Validity::AllInvalid, Validity::AllInvalid) => Validity::AllInvalid, + _ => { + let true_mask = if_true.execute_mask(len, ctx)?; + let false_mask = if_false.execute_mask(len, ctx)?; + let combined = true_mask + .bitand(mask) + .bitor(&false_mask.bitand(&mask.clone().not())); + Validity::from_mask(combined, nullability) + } + }) +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + use vortex_mask::Mask; + + use crate::ArrayRef; + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::arrays::BoolArray; + use crate::arrays::ListView; + use crate::arrays::ListViewArray; + use crate::assert_arrays_eq; + use crate::builtins::ArrayBuiltins; + use crate::validity::Validity; + + fn list_view( + elements: ArrayRef, + offsets: ArrayRef, + sizes: ArrayRef, + validity: Validity, + ) -> ArrayRef { + ListViewArray::try_new(elements, offsets, sizes, validity) + .unwrap() + .into_array() + } + + /// `zip` of two list views selects whole lists per the mask and keeps the list encoding. + #[test] + fn zip_selects_lists() -> VortexResult<()> { + // [[1, 2], [3], [4, 5, 6]] + let if_true = list_view( + buffer![1i32, 2, 3, 4, 5, 6].into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 3].into_array(), + Validity::NonNullable, + ); + // [[10], [20, 21], [30]] + let if_false = list_view( + buffer![10i32, 20, 21, 30].into_array(), + buffer![0u32, 1, 3].into_array(), + buffer![1u32, 2, 1].into_array(), + Validity::NonNullable, + ); + let mask = Mask::from_iter([true, false, true]); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut ctx)?; + + // The kernel should keep the list-view encoding rather than canonicalizing. + assert!(result.is::()); + + // Expected: [[1, 2], [20, 21], [4, 5, 6]] + let expected = list_view( + buffer![1i32, 2, 20, 21, 4, 5, 6].into_array(), + buffer![0u32, 2, 4].into_array(), + buffer![2u32, 2, 3].into_array(), + Validity::NonNullable, + ); + assert_arrays_eq!(result, expected); + Ok(()) + } + + /// `zip` selects list-level validity from the chosen side and widens nullability. + #[test] + fn zip_selects_validity() -> VortexResult<()> { + // [[1], null, [2]] (list-level nulls) + let if_true = list_view( + buffer![1i32, 2].into_array(), + buffer![0u32, 1, 1].into_array(), + buffer![1u32, 0, 1].into_array(), + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ); + // [[10], [20], null] + let if_false = list_view( + buffer![10i32, 20].into_array(), + buffer![0u32, 1, 2].into_array(), + buffer![1u32, 1, 0].into_array(), + Validity::Array(BoolArray::from_iter([true, true, false]).into_array()), + ); + // true -> if_true, false -> if_false + let mask = Mask::from_iter([false, true, true]); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut ctx)?; + + // Row 0 -> if_false[0] = [10]; row 1 -> if_true[1] = null; row 2 -> if_true[2] = [2] + let expected = list_view( + buffer![10i32, 2].into_array(), + buffer![0u32, 1, 1].into_array(), + buffer![1u32, 0, 1].into_array(), + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ); + assert_arrays_eq!(result, expected); + Ok(()) + } + + /// `zip` handles out-of-order/non-contiguous offsets and widens nullability when only one side + /// is nullable. + #[test] + fn zip_out_of_order_offsets_and_widening() -> VortexResult<()> { + // [[5, 6], [7], [8, 9]] expressed with out-of-order offsets. + let if_true = list_view( + buffer![7i32, 8, 9, 5, 6].into_array(), + buffer![3u32, 0, 1].into_array(), + buffer![2u32, 1, 2].into_array(), + Validity::NonNullable, + ); + // [[100], null, [200, 201]] + let if_false = list_view( + buffer![100i32, 200, 201].into_array(), + buffer![0u32, 1, 1].into_array(), + buffer![1u32, 0, 2].into_array(), + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ); + let mask = Mask::from_iter([true, true, false]); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut ctx)?; + assert!(result.is::()); + + // [[5, 6], [7], [200, 201]], all valid but nullable (widened by if_false). + let expected = list_view( + buffer![5i32, 6, 7, 200, 201].into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 2].into_array(), + Validity::AllValid, + ); + assert_arrays_eq!(result, expected); + Ok(()) + } +} diff --git a/vortex-array/src/arrays/listview/vtable/kernel.rs b/vortex-array/src/arrays/listview/vtable/kernel.rs index f6ceca284bf..1ad98f62a33 100644 --- a/vortex-array/src/arrays/listview/vtable/kernel.rs +++ b/vortex-array/src/arrays/listview/vtable/kernel.rs @@ -4,6 +4,9 @@ use crate::arrays::ListView; use crate::kernel::ParentKernelSet; use crate::scalar_fn::fns::cast::CastExecuteAdaptor; +use crate::scalar_fn::fns::zip::ZipExecuteAdaptor; -pub(super) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&CastExecuteAdaptor(ListView))]); +pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&CastExecuteAdaptor(ListView)), + ParentKernelSet::lift(&ZipExecuteAdaptor(ListView)), +]); From dbbbad9082424cc01694367fb2bf2a6144ee3d71 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Tue, 2 Jun 2026 15:24:20 +0100 Subject: [PATCH 2/3] more Signed-off-by: Robert Kruszewski --- .../src/arrays/listview/compute/zip.rs | 83 ++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/vortex-array/src/arrays/listview/compute/zip.rs b/vortex-array/src/arrays/listview/compute/zip.rs index 0c7bb41ae87..b023e01097d 100644 --- a/vortex-array/src/arrays/listview/compute/zip.rs +++ b/vortex-array/src/arrays/listview/compute/zip.rs @@ -14,9 +14,11 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::IntoArray; use crate::array::ArrayView; +use crate::arrays::Chunked; use crate::arrays::ChunkedArray; use crate::arrays::ListView; use crate::arrays::ListViewArray; +use crate::arrays::chunked::ChunkedArrayExt; use crate::arrays::listview::ListViewArrayExt; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; @@ -67,8 +69,14 @@ impl ZipKernel for ListView { // `if_false` views index into the second half of the concatenated elements. let false_shift = true_elements.len() as u64; - let elements = - ChunkedArray::try_new(vec![true_elements, false_elements], element_dtype)?.into_array(); + + // Concatenate the two `elements` arrays without copying. If either side is already a + // `ChunkedArray` (e.g. the result of a previous list-view zip), splice its chunks in + // directly rather than nesting chunked arrays. + let mut chunks = Vec::new(); + push_element_chunks(true_elements, &mut chunks); + push_element_chunks(false_elements, &mut chunks); + let elements = ChunkedArray::try_new(chunks, element_dtype)?.into_array(); let true_offsets = to_u64(if_true.offsets(), ctx)?; let true_sizes = to_u64(if_true.sizes(), ctx)?; @@ -111,6 +119,15 @@ impl ZipKernel for ListView { } } +/// Appends `array`'s element chunks to `chunks`, flattening a top-level [`ChunkedArray`] so the +/// concatenated elements never nest chunked arrays. +fn push_element_chunks(array: ArrayRef, chunks: &mut Vec) { + match array.as_opt::() { + Some(chunked) => chunks.extend(chunked.iter_chunks().cloned()), + None => chunks.push(array), + } +} + /// Read a non-nullable integer array into a `u64` buffer. fn to_u64(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult> { array @@ -155,10 +172,17 @@ mod tests { use crate::LEGACY_SESSION; use crate::VortexSessionExecute; use crate::arrays::BoolArray; + use crate::arrays::Chunked; + use crate::arrays::ChunkedArray; use crate::arrays::ListView; use crate::arrays::ListViewArray; + use crate::arrays::chunked::ChunkedArrayExt; + use crate::arrays::listview::ListViewArrayExt; use crate::assert_arrays_eq; use crate::builtins::ArrayBuiltins; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; use crate::validity::Validity; fn list_view( @@ -285,4 +309,59 @@ mod tests { assert_arrays_eq!(result, expected); Ok(()) } + + /// When an input's `elements` is already a [`ChunkedArray`], its chunks are spliced in rather + /// than nesting a chunked array inside the concatenated elements. + #[test] + fn zip_flattens_chunked_elements() -> VortexResult<()> { + // elements [1, 2, 3] stored as two chunks; lists [[1, 2], [3]]. + let chunked_elements = ChunkedArray::try_new( + vec![buffer![1i32, 2].into_array(), buffer![3i32].into_array()], + DType::Primitive(PType::I32, Nullability::NonNullable), + )? + .into_array(); + let if_true = list_view( + chunked_elements, + buffer![0u32, 2].into_array(), + buffer![2u32, 1].into_array(), + Validity::NonNullable, + ); + // [[10], [20]] + let if_false = list_view( + buffer![10i32, 20].into_array(), + buffer![0u32, 1].into_array(), + buffer![1u32, 1].into_array(), + Validity::NonNullable, + ); + let mask = Mask::from_iter([true, false]); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut ctx)?; + + // The concatenated elements are chunked, but no chunk is itself a `ChunkedArray`. + let result_lv = result + .as_opt::() + .expect("zip keeps the list-view encoding"); + let chunked = result_lv + .elements() + .as_opt::() + .expect("zip concatenates elements into a chunked array"); + assert!( + chunked.iter_chunks().all(|chunk| !chunk.is::()), + "chunked elements must be flattened, not nested", + ); + + // [[1, 2], [20]] + let expected = list_view( + buffer![1i32, 2, 20].into_array(), + buffer![0u32, 2].into_array(), + buffer![2u32, 1].into_array(), + Validity::NonNullable, + ); + assert_arrays_eq!(result, expected); + Ok(()) + } } From 0ef75690a1add53823c51bfc390b631bc52fd6f5 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Tue, 2 Jun 2026 17:45:45 +0100 Subject: [PATCH 3/3] fixes Signed-off-by: Robert Kruszewski --- .../src/arrays/listview/compute/zip.rs | 66 ++++++------- .../src/arrays/listview/tests/operations.rs | 96 +++++++++++++++++++ vortex-array/src/scalar_fn/fns/zip/mod.rs | 32 +++---- 3 files changed, 144 insertions(+), 50 deletions(-) diff --git a/vortex-array/src/arrays/listview/compute/zip.rs b/vortex-array/src/arrays/listview/compute/zip.rs index b023e01097d..7093c3e0daa 100644 --- a/vortex-array/src/arrays/listview/compute/zip.rs +++ b/vortex-array/src/arrays/listview/compute/zip.rs @@ -57,15 +57,15 @@ impl ZipKernel for ListView { let len = if_true.len(); + let result_elements_dtype = if_true + .elements() + .dtype() + .union_nullability(if_false.elements().dtype().nullability()); + // `if_false`'s elements share the element dtype up to nullability; normalize so both chunks // of the concatenated elements array have an identical dtype. - let true_elements = if_true.elements().clone(); - let element_dtype = true_elements.dtype().clone(); - let false_elements = if if_false.elements().dtype() == &element_dtype { - if_false.elements().clone() - } else { - if_false.elements().cast(element_dtype.clone())? - }; + let true_elements = if_true.elements().cast(result_elements_dtype.clone())?; + let false_elements = if_false.elements().cast(result_elements_dtype.clone())?; // `if_false` views index into the second half of the concatenated elements. let false_shift = true_elements.len() as u64; @@ -73,10 +73,10 @@ impl ZipKernel for ListView { // Concatenate the two `elements` arrays without copying. If either side is already a // `ChunkedArray` (e.g. the result of a previous list-view zip), splice its chunks in // directly rather than nesting chunked arrays. - let mut chunks = Vec::new(); + let mut chunks = Vec::with_capacity(2); push_element_chunks(true_elements, &mut chunks); push_element_chunks(false_elements, &mut chunks); - let elements = ChunkedArray::try_new(chunks, element_dtype)?.into_array(); + let elements = ChunkedArray::try_new(chunks, result_elements_dtype)?.into_array(); let true_offsets = to_u64(if_true.offsets(), ctx)?; let true_sizes = to_u64(if_true.sizes(), ctx)?; @@ -85,27 +85,29 @@ impl ZipKernel for ListView { let mut offsets = BufferMut::::with_capacity(len); let mut sizes = BufferMut::::with_capacity(len); - for i in 0..len { - // SAFETY: the loop runs exactly `len` times and both buffers reserved `len`. - unsafe { - if mask.value(i) { - offsets.push_unchecked(true_offsets[i]); - sizes.push_unchecked(true_sizes[i]); - } else { - offsets.push_unchecked(false_offsets[i] + false_shift); - sizes.push_unchecked(false_sizes[i]); - } + for (idx, (out_offsets, out_sizes)) in offsets + .spare_capacity_mut() + .iter_mut() + .zip(sizes.spare_capacity_mut().iter_mut()) + .take(len) + .enumerate() + { + if mask.value(idx) { + out_offsets.write(true_offsets[idx]); + out_sizes.write(true_sizes[idx]); + } else { + out_offsets.write(false_offsets[idx] + false_shift); + out_sizes.write(false_sizes[idx]); } } - let validity = zip_validity( - if_true.validity()?, - if_false.validity()?, - &mask, - if_true.nullability() | if_false.nullability(), - len, - ctx, - )?; + // SAFETY: the loop above initialized exactly `len` slots in both buffers. + unsafe { + offsets.set_len(len); + sizes.set_len(len); + } + + let validity = zip_validity(if_true.validity()?, if_false.validity()?, &mask, ctx)?; Ok(Some( ListViewArray::try_new( @@ -142,8 +144,6 @@ fn zip_validity( if_true: Validity, if_false: Validity, mask: &Mask, - nullability: Nullability, - len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { Ok(match (&if_true, &if_false) { @@ -151,12 +151,12 @@ fn zip_validity( (Validity::AllValid, Validity::AllValid) => Validity::AllValid, (Validity::AllInvalid, Validity::AllInvalid) => Validity::AllInvalid, _ => { - let true_mask = if_true.execute_mask(len, ctx)?; - let false_mask = if_false.execute_mask(len, ctx)?; + let true_mask = if_true.execute_mask(mask.len(), ctx)?; + let false_mask = if_false.execute_mask(mask.len(), ctx)?; let combined = true_mask .bitand(mask) - .bitor(&false_mask.bitand(&mask.clone().not())); - Validity::from_mask(combined, nullability) + .bitor(&false_mask.bitand(&mask.not())); + Validity::from_mask(combined, if_true.nullability() | if_false.nullability()) } }) } diff --git a/vortex-array/src/arrays/listview/tests/operations.rs b/vortex-array/src/arrays/listview/tests/operations.rs index 235caf53caa..c911b9ba7a5 100644 --- a/vortex-array/src/arrays/listview/tests/operations.rs +++ b/vortex-array/src/arrays/listview/tests/operations.rs @@ -5,11 +5,13 @@ use std::sync::Arc; use rstest::rstest; use vortex_buffer::buffer; +use vortex_error::VortexResult; use vortex_mask::Mask; use super::common::create_basic_listview; use super::common::create_large_listview; use super::common::create_nullable_listview; +use crate::ArrayRef; use crate::IntoArray; use crate::LEGACY_SESSION; #[expect(deprecated)] @@ -382,6 +384,100 @@ fn test_cast_large_dataset() { } } +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Zip tests +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn test_zip_widens_false_element_nullability() -> VortexResult<()> { + // [[1, 2], [3], [4]] + let if_true = ListViewArray::new( + buffer![1i32, 2, 3, 4].into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + // [[10, null], [30], [40]] + let if_false = ListViewArray::new( + PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), Some(40)]).into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + let mask = Mask::from_iter([false, true, false]); + + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut LEGACY_SESSION.create_execution_ctx())?; + assert!(result.is::()); + assert_eq!( + result.dtype(), + &DType::List( + Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)), + Nullability::NonNullable, + ) + ); + + // [[10, null], [3], [40]] + let expected = ListViewArray::new( + PrimitiveArray::from_option_iter([Some(10i32), None, Some(3), Some(40)]).into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + assert_arrays_eq!(result, expected); + Ok(()) +} + +#[test] +fn test_zip_widens_true_element_nullability() -> VortexResult<()> { + // [[1, null], [3], [4]] + let if_true = ListViewArray::new( + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4)]).into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + // [[10], [20], [30]] + let if_false = ListViewArray::new( + buffer![10i32, 20, 30].into_array(), + buffer![0u32, 1, 2].into_array(), + buffer![1u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + let mask = Mask::from_iter([true, false, true]); + + let result = mask + .into_array() + .zip(if_true, if_false)? + .execute::(&mut LEGACY_SESSION.create_execution_ctx())?; + assert!(result.is::()); + assert_eq!( + result.dtype(), + &DType::List( + Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)), + Nullability::NonNullable, + ) + ); + + // [[1, null], [20], [4]] + let expected = ListViewArray::new( + PrimitiveArray::from_option_iter([Some(1i32), None, Some(20), Some(4)]).into_array(), + buffer![0u32, 2, 3].into_array(), + buffer![2u32, 1, 1].into_array(), + Validity::NonNullable, + ) + .into_array(); + assert_arrays_eq!(result, expected); + Ok(()) +} + //////////////////////////////////////////////////////////////////////////////////////////////////// // Constant tests //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 86b5c4f7dc1..b1b17e40bb6 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -90,20 +90,12 @@ impl ScalarFnVTable for Zip { } fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { - vortex_ensure!( - arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]), - "zip requires if_true and if_false to have the same base type, got {} and {}", - arg_dtypes[0], - arg_dtypes[1] - ); vortex_ensure!( matches!(arg_dtypes[2], DType::Bool(_)), "zip requires mask to be a boolean type, got {}", arg_dtypes[2] ); - Ok(arg_dtypes[0] - .clone() - .union_nullability(arg_dtypes[1].nullability())) + zip_return_dtype(&arg_dtypes[0], &arg_dtypes[1]) } fn execute( @@ -120,10 +112,7 @@ impl ScalarFnVTable for Zip { .execute::(ctx)? .to_mask_fill_null_false(ctx); - let return_dtype = if_true - .dtype() - .clone() - .union_nullability(if_false.dtype().nullability()); + let return_dtype = zip_return_dtype(if_true.dtype(), if_false.dtype())?; if mask.all_true() { return if_true.cast(return_dtype)?.execute(ctx); @@ -184,10 +173,7 @@ pub(crate) fn zip_impl( "zip requires arrays to have the same size" ); - let return_type = if_true - .dtype() - .clone() - .union_nullability(if_false.dtype().nullability()); + let return_type = zip_return_dtype(if_true.dtype(), if_false.dtype())?; if mask.all_true() { return if_true.cast(return_type); @@ -211,6 +197,18 @@ pub(crate) fn zip_impl( ) } +fn zip_return_dtype(if_true: &DType, if_false: &DType) -> VortexResult { + vortex_ensure!( + if_true.eq_ignore_nullability(if_false), + "zip requires if_true and if_false to have the same base type, got {} and {}", + if_true, + if_false + ); + Ok(if_true + .least_supertype(if_false) + .vortex_expect("zip inputs with the same base type must have a common dtype")) +} + fn zip_impl_with_builder( if_true: &ArrayRef, if_false: &ArrayRef,