Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 74 additions & 9 deletions vortex-mask/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::OnceLock;
use itertools::Itertools;
use vortex_buffer::BitBuffer;
use vortex_buffer::BitBufferMut;
use vortex_buffer::BitIterator;
use vortex_buffer::BitChunkIterator;
use vortex_error::VortexResult;
use vortex_error::vortex_panic;

Expand Down Expand Up @@ -431,9 +431,10 @@ impl Mask {

/// Iterate the mask as one `bool` per element, in order.
///
/// Unlike repeatedly calling [`Mask::value`], this advances a single cursor rather than
/// recomputing the byte/bit offset for every element, and it does not allocate for the
/// all-true / all-false variants. Prefer this for sequential per-element scans.
/// For a [`Mask::Values`] mask this decodes the packed bitmap a `u64` word at a time and
/// drains each word with a register shift, avoiding the per-element byte load and offset
/// arithmetic of [`Mask::value`]. The all-true / all-false variants iterate a counter and
/// allocate nothing. Prefer this for sequential per-element scans.
#[inline]
pub fn iter(&self) -> MaskBoolIter<'_> {
match self {
Expand All @@ -445,7 +446,7 @@ impl Mask {
value: false,
remaining: *len,
},
Mask::Values(values) => MaskBoolIter::Bits(values.bit_buffer().iter()),
Mask::Values(values) => MaskBoolIter::Words(WordBoolIter::new(values.bit_buffer())),
}
}

Expand Down Expand Up @@ -846,8 +847,8 @@ pub enum MaskBoolIter<'a> {
/// The number of elements still to yield.
remaining: usize,
},
/// Per-element bits of a [`Mask::Values`] mask.
Bits(BitIterator<'a>),
/// Per-element bits of a [`Mask::Values`] mask, decoded a `u64` word at a time.
Words(WordBoolIter<'a>),
}

impl Iterator for MaskBoolIter<'_> {
Expand All @@ -861,22 +862,86 @@ impl Iterator for MaskBoolIter<'_> {
*remaining -= 1;
Some(*value)
}
Self::Bits(bits) => bits.next(),
Self::Words(words) => words.next(),
}
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = match self {
Self::Repeat { remaining, .. } => *remaining,
Self::Bits(bits) => bits.len(),
Self::Words(words) => words.len(),
};
(remaining, Some(remaining))
}
}

impl ExactSizeIterator for MaskBoolIter<'_> {}

/// Word-buffered `bool` iterator over a packed bitmap.
///
/// Decodes the bitmap a `u64` word at a time and drains each word bit by bit with a register
/// shift, avoiding the per-element byte load and offset arithmetic of
/// [`BitIterator`](vortex_buffer::BitIterator). Created for the [`Mask::Values`] variant by
/// [`Mask::iter`].
pub struct WordBoolIter<'a> {
/// Yields the full 64-bit chunks of the bitmap.
chunks: BitChunkIterator<'a>,
/// The final partial word (fewer than 64 bits), used once `chunks` is exhausted.
remainder: u64,
/// The word currently being drained; its least-significant bit is the next element.
word: u64,
/// Number of bits still buffered in `word`.
bits_left: u32,
/// Total number of elements still to yield.
remaining: usize,
}

impl<'a> WordBoolIter<'a> {
#[inline]
fn new(buffer: &'a BitBuffer) -> Self {
let chunks = buffer.chunks();
Self {
remainder: chunks.remainder_bits(),
chunks: chunks.iter(),
word: 0,
bits_left: 0,
remaining: buffer.len(),
}
}
}

impl Iterator for WordBoolIter<'_> {
type Item = bool;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
if self.bits_left == 0 {
// Refill from the next full chunk, falling back to the partial remainder word once the
// full chunks are exhausted. `remaining` bounds how many bits we actually read, so the
// unused high bits of the remainder word (and any bogus refill past the end) are never
// observed.
self.word = self.chunks.next().unwrap_or(self.remainder);
self.bits_left = 64;
}
let value = (self.word & 1) != 0;
self.word >>= 1;
self.bits_left -= 1;
self.remaining -= 1;
Some(value)
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}

impl ExactSizeIterator for WordBoolIter<'_> {}

impl From<BitBuffer> for Mask {
fn from(value: BitBuffer) -> Self {
Self::from_buffer(value)
Expand Down
19 changes: 19 additions & 0 deletions vortex-mask/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -751,3 +751,22 @@ fn mask_iter_matches_value(#[case] mask: Mask, #[case] expected: Vec<bool>) {
assert_eq!(it.len(), mask.len() - 1);
}
}

#[rstest]
// Spans multiple full 64-bit words plus a partial remainder word.
#[case::multi_word(Mask::from_buffer(BitBuffer::from_iter((0..150).map(|i| i % 3 == 0))), 0, 150)]
// Non-zero bit offset within the first byte (the `BitChunks` `bit_offset` path).
#[case::offset_small(Mask::from_buffer(BitBuffer::from_iter((0..40).map(|i| i % 5 < 2))), 3, 30)]
// Offset spanning past a word boundary so the remainder word shifts too.
#[case::offset_multi_word(
Mask::from_buffer(BitBuffer::from_iter((0..200).map(|i| (i * 7) % 11 < 4))), 5, 180
)]
fn mask_iter_word_path_matches_value(#[case] mask: Mask, #[case] start: usize, #[case] len: usize) {
// Slicing yields a `Mask::Values` with a non-zero bit offset, exercising the word-buffered
// decode against the scalar `value` path.
let sliced = mask.slice(start..start + len);
let collected: Vec<bool> = sliced.iter().collect();
let by_value: Vec<bool> = (0..sliced.len()).map(|i| sliced.value(i)).collect();
assert_eq!(collected, by_value);
assert_eq!(collected.len(), len);
}
Loading