Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions vortex-array/src/arrays/primitive/compute/take/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_mask::Mask;

use crate::ArrayRef;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::arrays::ConstantArray;
use crate::arrays::Primitive;
use crate::arrays::PrimitiveArray;
use crate::arrays::dict::TakeExecute;
Expand All @@ -24,6 +26,7 @@ use crate::dtype::NativePType;
use crate::executor::ExecutionCtx;
use crate::match_each_integer_ptype;
use crate::match_each_native_ptype;
use crate::scalar::Scalar;
use crate::validity::Validity;

// Kernel selection happens on the first call to `take` and uses a combination of compile-time
Expand Down Expand Up @@ -81,19 +84,36 @@ impl TakeExecute for Primitive {
vortex_bail!("Invalid indices dtype: {}", indices.dtype())
};

let indices_validity = indices.validity()?;
// Null index lanes are semantically ignored, but their physical values may be out of
// bounds. Redirect those lanes to zero for the cast/gather, then restore the original index
// validity below.
let indices_nulls_zeroed = match indices_validity.execute_mask(indices.len(), ctx)? {
Mask::AllTrue(_) => indices.clone(),
Mask::AllFalse(_) => {
return Ok(Some(
ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
.into_array(),
));
}
Mask::Values(_) => indices
.clone()
.fill_null(Scalar::from(0).cast(indices.dtype())?)?,
};

let unsigned_indices = if ptype.is_unsigned_int() {
indices.clone().execute::<PrimitiveArray>(ctx)?
indices_nulls_zeroed.execute::<PrimitiveArray>(ctx)?
} else {
// This will fail if all values cannot be converted to unsigned
indices
.clone()
indices_nulls_zeroed
.cast(DType::Primitive(ptype.to_unsigned(), *null))?
.execute::<PrimitiveArray>(ctx)?
};

let validity = array
.validity()?
.take(&unsigned_indices.clone().into_array())?;
.take(&unsigned_indices.clone().into_array())?
.and(indices_validity)?;
// Delegate to the best kernel based on the target CPU
{
let unsigned_indices = unsigned_indices.as_view();
Expand Down Expand Up @@ -200,3 +220,30 @@ mod test {
test_take_conformance(&array.into_array());
}
}

#[cfg(test)]
mod tests {
use vortex_buffer::buffer;

use crate::IntoArray;
use crate::arrays::BoolArray;
use crate::arrays::PrimitiveArray;
use crate::assert_arrays_eq;
use crate::validity::Validity;

#[test]
fn take_null_index_skips_out_of_bounds_value() {
let values = PrimitiveArray::from_iter([10i32, 20, 30]);
let indices = PrimitiveArray::new(
buffer![1u64, 3],
Validity::Array(BoolArray::from_iter([true, false]).into_array()),
);

let taken = values.take(indices.into_array()).unwrap();

assert_arrays_eq!(
taken,
PrimitiveArray::from_option_iter([Some(20i32), None]).into_array()
);
}
}
Loading