Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions encodings/fastlanes/src/bitpacking/array/bitpack_compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,13 @@ where
let total_chunks = data.len().div_ceil(1024);
let mut chunk_offsets: BufferMut<u64> = BufferMut::with_capacity(total_chunks);

for (idx, value) in data.iter().enumerate() {
for ((idx, value), valid) in data.iter().enumerate().zip(validity_mask.iter()) {
if (idx % 1024) == 0 {
// Record the patch index offset for each chunk.
chunk_offsets.push(values.len() as u64);
}

if (value.leading_zeros() as usize) < T::PTYPE.bit_width() - bit_width as usize
&& validity_mask.value(idx)
{
if (value.leading_zeros() as usize) < T::PTYPE.bit_width() - bit_width as usize && valid {
indices.push(P::from(idx).vortex_expect("cast index from usize"));
values.push(*value);
}
Expand Down
12 changes: 8 additions & 4 deletions encodings/sparse/src/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,11 @@ fn execute_sparse_lists_inner<I: IntegerPType, O: IntegerPType>(

let mut next_index = 0;

for (patch_idx, sparse_idx) in patch_indices.iter().enumerate() {
for ((patch_idx, sparse_idx), patch_valid) in patch_indices
.iter()
.enumerate()
.zip(patch_values_validity.iter())
{
let sparse_idx = sparse_idx
.to_usize()
.vortex_expect("patch index must fit in usize");
Expand All @@ -237,7 +241,7 @@ fn execute_sparse_lists_inner<I: IntegerPType, O: IntegerPType>(
sparse_idx - next_index,
);

if patch_values_validity.value(patch_idx) {
if patch_valid {
let patch_list = patch_values
.list_elements_at(patch_idx)
.vortex_expect("list_elements_at");
Expand Down Expand Up @@ -318,7 +322,7 @@ fn execute_sparse_fixed_size_list_inner<I: IntegerPType>(
.iter()
.map(|x| (*x).to_usize().vortex_expect("index must fit in usize"));

for (patch_idx, sparse_idx) in indices.enumerate() {
for ((patch_idx, sparse_idx), patch_valid) in indices.enumerate().zip(values_validity.iter()) {
// Fill gap before this patch with fill values.
append_fixed_size_list_fill(
&mut builder,
Expand All @@ -327,7 +331,7 @@ fn execute_sparse_fixed_size_list_inner<I: IntegerPType>(
);

// Append the patch value, handling null patches by appending defaults.
if values_validity.value(patch_idx) {
if patch_valid {
let patch_list = values
.fixed_size_list_elements_at(patch_idx)
.vortex_expect("fixed_size_list_elements_at");
Expand Down
11 changes: 6 additions & 5 deletions vortex-array/src/aggregate_fn/accumulator_grouped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,13 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
)?;
let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());

for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() {
// `validity` is the per-group list-view validity, so it is zipped element-wise with the
// offsets and sizes (one entry per group).
for ((offset, size), valid) in offsets.iter().zip(sizes.iter()).zip(validity.iter()) {
let offset = offset.to_usize().vortex_expect("Offset value is not usize");
let size = size.to_usize().vortex_expect("Size value is not usize");

// validity is for the outer list view, so it must be indexed with `i`
if validity.value(i) {
if valid {
let group = elements.slice(offset..offset + size)?;
accumulator.accumulate(&group, ctx)?;
states.append_scalar(&accumulator.flush()?)?;
Expand Down Expand Up @@ -304,8 +305,8 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
.to_usize()
.vortex_expect("List size is not usize");

for i in 0..groups.len() {
if validity.value(i) {
for valid in validity.iter() {
if valid {
let group = elements.slice(offset..offset + size)?;
accumulator.accumulate(&group, ctx)?;
states.append_scalar(&accumulator.flush()?)?;
Expand Down
4 changes: 1 addition & 3 deletions vortex-array/src/arrays/fixed_size_list/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,11 @@ fn take_nullable_fsl<I: IntegerPType, E: IntegerPType>(
let mut new_validity_builder = BitBufferMut::with_capacity(new_len);

// Build the element indices while tracking which lists are null.
for (i, data_idx) in indices.iter().enumerate() {
for (data_idx, is_index_valid) in indices.iter().zip(indices_validity.iter()) {
let data_idx = data_idx
.to_usize()
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));

let is_index_valid = indices_validity.value(i);

// The list is null if the index is null or the indexed element is null.
if !is_index_valid || !array_validity.value(data_idx) {
// Append placeholder zeros for null lists. These will be masked by the validity array.
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/arrays/list/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPTy
let mut current_offset = OutputOffsetType::zero();
new_offsets.append_zero();

for (idx, data_idx) in indices.iter().enumerate() {
if !indices_validity.value(idx) {
for (data_idx, index_valid) in indices.iter().zip(indices_validity.iter()) {
if !index_valid {
new_offsets.append_value(current_offset);
continue;
}
Expand Down
5 changes: 3 additions & 2 deletions vortex-array/src/arrays/listview/compute/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,15 @@ impl ZipKernel for ListView {

let mut offsets = BufferMut::<u64>::with_capacity(len);
let mut sizes = BufferMut::<u64>::with_capacity(len);
for (idx, (out_offsets, out_sizes)) in offsets
for ((idx, (out_offsets, out_sizes)), selected) in offsets
.spare_capacity_mut()
.iter_mut()
.zip(sizes.spare_capacity_mut().iter_mut())
.take(len)
.enumerate()
.zip(mask.iter())
{
if mask.value(idx) {
if selected {
out_offsets.write(true_offsets[idx]);
out_sizes.write(true_sizes[idx]);
} else {
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/arrays/varbin/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ fn take_nullable<Index: IntegerPType, Offset: IntegerPType, NewOffset: IntegerPT
let mut valid_indices = Vec::with_capacity(indices.len());

// First pass: calculate offsets and validity
for (idx, data_idx) in indices.iter().enumerate() {
if !indices_validity.value(idx) {
for (data_idx, index_valid) in indices.iter().zip(indices_validity.iter()) {
if !index_valid {
validity_buffer.append(false);
new_offsets.push(current_offset);
continue;
Expand Down
63 changes: 63 additions & 0 deletions vortex-mask/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::sync::OnceLock;
use itertools::Itertools;
use vortex_buffer::BitBuffer;
use vortex_buffer::BitBufferMut;
use vortex_buffer::BitIterator;
use vortex_error::VortexResult;
use vortex_error::vortex_panic;

Expand Down Expand Up @@ -428,6 +429,26 @@ impl Mask {
}
}

/// Iterate the mask as one `bool` per element, in order.
///
/// Unlike repeatedly calling [`Mask::value`], this advances a single cursor rather than
/// recomputing the byte/bit offset for every element, and it does not allocate for the
/// all-true / all-false variants. Prefer this for sequential per-element scans.
#[inline]
pub fn iter(&self) -> MaskBoolIter<'_> {
match self {
Mask::AllTrue(len) => MaskBoolIter::Repeat {
value: true,
remaining: *len,
},
Mask::AllFalse(len) => MaskBoolIter::Repeat {
value: false,
remaining: *len,
},
Mask::Values(values) => MaskBoolIter::Bits(values.bit_buffer().iter()),
}
}

/// Returns the first true index in the mask.
pub fn first(&self) -> Option<usize> {
match &self {
Expand Down Expand Up @@ -814,6 +835,48 @@ pub enum MaskIter<'a> {
Slices(&'a [(usize, usize)]),
}

/// Iterator yielding one `bool` per element of a [`Mask`], in order.
///
/// Created by [`Mask::iter`].
pub enum MaskBoolIter<'a> {
/// An all-true or all-false run.
Repeat {
/// The constant value yielded by every element of the run.
value: bool,
/// The number of elements still to yield.
remaining: usize,
},
/// Per-element bits of a [`Mask::Values`] mask.
Bits(BitIterator<'a>),
}

impl Iterator for MaskBoolIter<'_> {
type Item = bool;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Repeat { remaining: 0, .. } => None,
Self::Repeat { value, remaining } => {
*remaining -= 1;
Some(*value)
}
Self::Bits(bits) => bits.next(),
}
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = match self {
Self::Repeat { remaining, .. } => *remaining,
Self::Bits(bits) => bits.len(),
};
(remaining, Some(remaining))
}
}

impl ExactSizeIterator for MaskBoolIter<'_> {}

impl From<BitBuffer> for Mask {
fn from(value: BitBuffer) -> Self {
Self::from_buffer(value)
Expand Down
26 changes: 26 additions & 0 deletions vortex-mask/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,29 @@ fn test_mask_concat_mixed_types() {
assert!(!result.value(6)); // from all_false
assert!(!result.value(7)); // from all_false
}

// `Mask::iter` per-element bool iteration

#[rstest]
#[case::all_true(Mask::new_true(4), vec![true, true, true, true])]
#[case::all_false(Mask::new_false(3), vec![false, false, false])]
#[case::values(
Mask::from_buffer(BitBuffer::from_iter([true, false, true, true, false])),
vec![true, false, true, true, false]
)]
#[case::empty(Mask::new_true(0), vec![])]
fn mask_iter_matches_value(#[case] mask: Mask, #[case] expected: Vec<bool>) {
// Iterator yields exactly one bool per element, matching `value`.
let collected: Vec<bool> = mask.iter().collect();
assert_eq!(collected, expected);

let by_value: Vec<bool> = (0..mask.len()).map(|i| mask.value(i)).collect();
assert_eq!(collected, by_value);

// ExactSizeIterator reports the right length, including after partial consumption.
let mut it = mask.iter();
assert_eq!(it.len(), mask.len());
if it.next().is_some() {
assert_eq!(it.len(), mask.len() - 1);
}
}
Loading