diff --git a/vortex-array/src/arrays/primitive/compute/take/mod.rs b/vortex-array/src/arrays/primitive/compute/take/mod.rs index 4023991c65d..b8799a74369 100644 --- a/vortex-array/src/arrays/primitive/compute/take/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/take/mod.rs @@ -10,10 +10,12 @@ use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::array::ArrayView; +use crate::arrays::ConstantArray; use crate::arrays::Primitive; use crate::arrays::PrimitiveArray; use crate::arrays::dict::TakeExecute; @@ -24,6 +26,7 @@ use crate::dtype::NativePType; use crate::executor::ExecutionCtx; use crate::match_each_integer_ptype; use crate::match_each_native_ptype; +use crate::scalar::Scalar; use crate::validity::Validity; // Kernel selection happens on the first call to `take` and uses a combination of compile-time @@ -81,19 +84,36 @@ impl TakeExecute for Primitive { vortex_bail!("Invalid indices dtype: {}", indices.dtype()) }; + let indices_validity = indices.validity()?; + // Null index lanes are semantically ignored, but their physical values may be out of + // bounds. Redirect those lanes to zero for the cast/gather, then restore the original index + // validity below. + let indices_nulls_zeroed = match indices_validity.execute_mask(indices.len(), ctx)? { + Mask::AllTrue(_) => indices.clone(), + Mask::AllFalse(_) => { + return Ok(Some( + ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len()) + .into_array(), + )); + } + Mask::Values(_) => indices + .clone() + .fill_null(Scalar::from(0).cast(indices.dtype())?)?, + }; + let unsigned_indices = if ptype.is_unsigned_int() { - indices.clone().execute::(ctx)? + indices_nulls_zeroed.execute::(ctx)? } else { // This will fail if all values cannot be converted to unsigned - indices - .clone() + indices_nulls_zeroed .cast(DType::Primitive(ptype.to_unsigned(), *null))? .execute::(ctx)? }; let validity = array .validity()? - .take(&unsigned_indices.clone().into_array())?; + .take(&unsigned_indices.clone().into_array())? + .and(indices_validity)?; // Delegate to the best kernel based on the target CPU { let unsigned_indices = unsigned_indices.as_view(); @@ -200,3 +220,30 @@ mod test { test_take_conformance(&array.into_array()); } } + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + + use crate::IntoArray; + use crate::arrays::BoolArray; + use crate::arrays::PrimitiveArray; + use crate::assert_arrays_eq; + use crate::validity::Validity; + + #[test] + fn take_null_index_skips_out_of_bounds_value() { + let values = PrimitiveArray::from_iter([10i32, 20, 30]); + let indices = PrimitiveArray::new( + buffer![1u64, 3], + Validity::Array(BoolArray::from_iter([true, false]).into_array()), + ); + + let taken = values.take(indices.into_array()).unwrap(); + + assert_arrays_eq!( + taken, + PrimitiveArray::from_option_iter([Some(20i32), None]).into_array() + ); + } +}