Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 86 additions & 33 deletions vortex-buffer/src/buffer_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,59 +477,113 @@ impl<T> AsMut<[T]> for BufferMut<T> {
}

impl<T> BufferMut<T> {
/// A helper method for the two [`Extend`] implementations.
///
/// We use the lower bound hint on the iterator to manually write data, and then we continue to
/// push items normally past the lower bound.
fn extend_iter(&mut self, mut iter: impl Iterator<Item = T>) {
// Attempt to reserve enough memory up-front, although this is only a lower bound.
let (lower, _) = iter.size_hint();
self.reserve(lower);
// Since we do not know the length of the iterator, we can only guess how much memory we
// need to reserve. Note that these hints may be inaccurate.
let (lower_bound, upper_bound_opt) = iter.size_hint();

// In the case that the upper bound is adversarial, we put a hard limit on the amount of
// memory we reserve (and the OS should handle the rest with zero pages).
let reserve_amount = upper_bound_opt
.unwrap_or(lower_bound)
.min(i32::MAX as usize);
self.reserve(reserve_amount);

let remaining = self.capacity() - self.len();
let unwritten = self.capacity() - self.len();

// We store `begin` in the case that the lower bound hint is incorrect.
let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
let mut dst: *mut T = begin.cast_mut();
for _ in 0..remaining {
if let Some(item) = iter.next() {
unsafe {
// SAFETY: We know we have enough capacity to write the item.
dst.write(item);
// Note. we used to have dst.add(iteration).write(item), here.
// however this was much slower than just incrementing dst.
dst = dst.add(1);
}
} else {

// As a first step, we manually iterate the iterator up to the known capacity.
for _ in 0..unwritten {
let Some(item) = iter.next() else {
// The lower bound hint may be incorrect.
break;
}
};

// SAFETY: We have reserved enough capacity to hold this item, and `dst` is a pointer
// derived from a valid reference to byte data.
unsafe { dst.write(item) };

// Note: We used to have `dst.add(iteration).write(item)`, here. However this was much
// slower than just incrementing `dst`.
// SAFETY: The offsets fits in `isize`, and because we were able to reserve the memory
// we know that `add` will not overflow.
unsafe { dst = dst.add(1) };
}

// TODO(joe): replace with ptr_sub when stable
let length = self.len() + unsafe { dst.byte_offset_from(begin) as usize / size_of::<T>() };
// SAFETY: `dst` was derived from `begin`, which were both valid references to byte data,
// and since the only operation that `dst` has is `add`, we know that `dst >= begin`.
let items_written = unsafe { dst.offset_from_unsigned(begin) };
let length = self.len() + items_written;

// SAFETY: We have written valid items between the old length and the new length.
unsafe { self.set_len(length) };

// Append remaining elements
// Finally, since the iterator will have arbitrarily more items to yield, we push the
// remaining items normally.
iter.for_each(|item| self.push(item));
}

/// An unsafe variant of the `Extend` trait and its `extend` method that receives what the
/// caller guarantees to be an iterator with a trusted upper bound.
/// Extends the `BufferMut` with an iterator with `TrustedLen`.
///
/// The caller guarantees that the iterator will have a trusted upper bound, which allows the
/// implementation to reserve all of the memory needed up front.
pub fn extend_trusted<I: TrustedLen<Item = T>>(&mut self, iter: I) {
// Reserve all memory upfront since it's an exact upper bound
let (_, high) = iter.size_hint();
self.reserve(high.vortex_expect("TrustedLen iterator didn't have valid upper bound"));
// Since we know the exact upper bound (from `TrustedLen`), we can reserve all of the memory
// for this operation up front.
let (_, upper_bound) = iter.size_hint();
self.reserve(
upper_bound
.vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
);

// We store `begin` in the case that the upper bound hint is incorrect.
let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
let mut dst: *mut T = begin.cast_mut();

iter.for_each(|item| {
unsafe {
// SAFETY: We know we have enough capacity to write the item.
dst.write(item);
// Note. we used to have dst.add(iteration).write(item), here.
// however this was much slower than just incrementing dst.
dst = dst.add(1);
}
// SAFETY: We have reserved enough capacity to hold this item, and `dst` is a pointer
// derived from a valid reference to byte data.
unsafe { dst.write(item) };

// Note: We used to have `dst.add(iteration).write(item)`, here. However this was much
// slower than just incrementing `dst`.
// SAFETY: The offsets fits in `isize`, and because we were able to reserve the memory
// we know that `add` will not overflow.
unsafe { dst = dst.add(1) };
});
// TODO(joe): replace with ptr_sub when stable
let length = self.len() + unsafe { dst.byte_offset_from(begin) as usize / size_of::<T>() };

// SAFETY: `dst` was derived from `begin`, which were both valid references to byte data,
// and since the only operation that `dst` has is `add`, we know that `dst >= begin`.
let items_written = unsafe { dst.offset_from_unsigned(begin) };
let length = self.len() + items_written;

// SAFETY: We have written valid items between the old length and the new length.
unsafe { self.set_len(length) };
}

/// Creates a `BufferMut` from an iterator with a trusted length.
///
/// Internally, this calls [`extend_trusted()`](Self::extend_trusted).
pub fn from_trusted_len_iter<I>(iter: I) -> Self
where
I: TrustedLen<Item = T>,
{
let (_, upper_bound) = iter.size_hint();
let mut buffer = Self::with_capacity(
upper_bound
.vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
);

buffer.extend_trusted(iter);
buffer
}
}

impl<T> Extend<T> for BufferMut<T> {
Expand All @@ -554,7 +608,6 @@ impl<T> FromIterator<T> for BufferMut<T> {
// We don't infer the capacity here and just let the first call to `extend` do it for us.
let mut buffer = Self::with_capacity(0);
buffer.extend(iter);
debug_assert_eq!(buffer.alignment(), Alignment::of::<T>());
buffer
}
}
Expand Down
Loading