diff --git a/src/bounded.rs b/src/bounded.rs index 8e0b9f4..dab3a29 100644 --- a/src/bounded.rs +++ b/src/bounded.rs @@ -7,7 +7,7 @@ use crate::sync::atomic::{AtomicUsize, Ordering}; use crate::sync::cell::UnsafeCell; #[allow(unused_imports)] use crate::sync::prelude::*; -use crate::{busy_wait, PopError, PushError}; +use crate::{busy_wait, ForcePushError, PopError, PushError}; /// A slot in a queue. struct Slot { @@ -83,6 +83,74 @@ impl Bounded { /// Attempts to push an item into the queue. pub fn push(&self, value: T) -> Result<(), PushError> { + self.push_or_else(value, |value, tail, _, _| { + let head = self.head.load(Ordering::Relaxed); + + // If the head lags one lap behind the tail as well... + if head.wrapping_add(self.one_lap) == tail { + // ...then the queue is full. + Err(PushError::Full(value)) + } else { + Ok(value) + } + }) + } + + /// Pushes an item into the queue, displacing another item if needed. + pub fn force_push(&self, value: T) -> Result, ForcePushError> { + let result = self.push_or_else(value, |value, tail, new_tail, slot| { + let head = tail.wrapping_sub(self.one_lap); + let new_head = new_tail.wrapping_sub(self.one_lap); + + // Try to move the head. + if self + .head + .compare_exchange_weak(head, new_head, Ordering::SeqCst, Ordering::Relaxed) + .is_ok() + { + // Move the tail. + self.tail.store(new_tail, Ordering::SeqCst); + + // Swap out the old value. + // SAFETY: We know this is initialized, since it's covered by the current queue. + let old = unsafe { + slot.value + .with_mut(|slot| slot.replace(MaybeUninit::new(value)).assume_init()) + }; + + // Update the stamp. + slot.stamp.store(tail + 1, Ordering::Release); + + // Return a PushError. + Err(PushError::Full(old)) + } else { + Ok(value) + } + }); + + match result { + Ok(()) => Ok(None), + Err(PushError::Full(old_value)) => Ok(Some(old_value)), + Err(PushError::Closed(value)) => Err(ForcePushError(value)), + } + } + + /// Attempts to push an item into the queue, running a closure on failure. + /// + /// `fail` is run when there is no more room left in the tail of the queue. The parameters of + /// this function are as follows: + /// + /// - The item that failed to push. + /// - The value of `self.tail` before the new value would be inserted. + /// - The value of `self.tail` after the new value would be inserted. + /// - The slot that we attempted to push into. + /// + /// If `fail` returns `Ok(val)`, we will try pushing `val` to the head of the queue. Otherwise, + /// this function will return the error. + fn push_or_else(&self, mut value: T, mut fail: F) -> Result<(), PushError> + where + F: FnMut(T, usize, usize, &Slot) -> Result>, + { let mut tail = self.tail.load(Ordering::Relaxed); loop { @@ -95,22 +163,23 @@ impl Bounded { let index = tail & (self.mark_bit - 1); let lap = tail & !(self.one_lap - 1); + // Calculate the new location of the tail. + let new_tail = if index + 1 < self.buffer.len() { + // Same lap, incremented index. + // Set to `{ lap: lap, mark: 0, index: index + 1 }`. + tail + 1 + } else { + // One lap forward, index wraps around to zero. + // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`. + lap.wrapping_add(self.one_lap) + }; + // Inspect the corresponding slot. let slot = &self.buffer[index]; let stamp = slot.stamp.load(Ordering::Acquire); // If the tail and the stamp match, we may attempt to push. if tail == stamp { - let new_tail = if index + 1 < self.buffer.len() { - // Same lap, incremented index. - // Set to `{ lap: lap, mark: 0, index: index + 1 }`. - tail + 1 - } else { - // One lap forward, index wraps around to zero. - // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`. - lap.wrapping_add(self.one_lap) - }; - // Try moving the tail. match self.tail.compare_exchange_weak( tail, @@ -132,13 +201,9 @@ impl Bounded { } } else if stamp.wrapping_add(self.one_lap) == tail + 1 { crate::full_fence(); - let head = self.head.load(Ordering::Relaxed); - // If the head lags one lap behind the tail as well... - if head.wrapping_add(self.one_lap) == tail { - // ...then the queue is full. - return Err(PushError::Full(value)); - } + // We've failed to push; run our failure closure. + value = fail(value, tail, new_tail, slot)?; // Loom complains if there isn't an explicit busy wait here. #[cfg(loom)] diff --git a/src/lib.rs b/src/lib.rs index f836d17..8b495ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -181,6 +181,54 @@ impl ConcurrentQueue { } } + /// Push an element into the queue, potentially displacing another element. + /// + /// Attempts to push an element into the queue. If the queue is full, one item from the + /// queue is replaced with the provided item. The displaced item is returned as `Some(T)`. + /// If the queue is closed, an error is returned. + /// + /// # Examples + /// + /// ``` + /// use concurrent_queue::{ConcurrentQueue, ForcePushError, PushError}; + /// + /// let q = ConcurrentQueue::bounded(3); + /// + /// // We can push to the queue. + /// for i in 1..=3 { + /// assert_eq!(q.force_push(i), Ok(None)); + /// } + /// + /// // Push errors because the queue is now full. + /// assert_eq!(q.push(4), Err(PushError::Full(4))); + /// + /// // Pushing a new value replaces the old ones. + /// assert_eq!(q.force_push(5), Ok(Some(1))); + /// assert_eq!(q.force_push(6), Ok(Some(2))); + /// + /// // Close the queue to stop further pushes. + /// q.close(); + /// + /// // Pushing will return an error. + /// assert_eq!(q.force_push(7), Err(ForcePushError(7))); + /// + /// // Popping items will return the force-pushed ones. + /// assert_eq!(q.pop(), Ok(3)); + /// assert_eq!(q.pop(), Ok(5)); + /// assert_eq!(q.pop(), Ok(6)); + /// ``` + pub fn force_push(&self, value: T) -> Result, ForcePushError> { + match &self.0 { + Inner::Single(q) => q.force_push(value), + Inner::Bounded(q) => q.force_push(value), + Inner::Unbounded(q) => match q.push(value) { + Ok(()) => Ok(None), + Err(PushError::Closed(value)) => Err(ForcePushError(value)), + Err(PushError::Full(_)) => unreachable!(), + }, + } + } + /// Attempts to pop an item from the queue. /// /// If the queue is empty, an error is returned. @@ -532,6 +580,32 @@ impl fmt::Display for PushError { } } +/// Error that occurs when force-pushing into a full queue. +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct ForcePushError(pub T); + +impl ForcePushError { + /// Return the inner value that failed to be force-pushed. + pub fn into_inner(self) -> T { + self.0 + } +} + +impl fmt::Debug for ForcePushError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("ForcePushError").field(&self.0).finish() + } +} + +impl fmt::Display for ForcePushError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Closed") + } +} + +#[cfg(feature = "std")] +impl error::Error for ForcePushError {} + /// Equivalent to `atomic::fence(Ordering::SeqCst)`, but in some cases faster. #[inline] fn full_fence() { diff --git a/src/single.rs b/src/single.rs index 0efa0b7..048f486 100644 --- a/src/single.rs +++ b/src/single.rs @@ -1,10 +1,11 @@ use core::mem::MaybeUninit; +use core::ptr; use crate::sync::atomic::{AtomicUsize, Ordering}; use crate::sync::cell::UnsafeCell; #[allow(unused_imports)] use crate::sync::prelude::*; -use crate::{busy_wait, PopError, PushError}; +use crate::{busy_wait, ForcePushError, PopError, PushError}; const LOCKED: usize = 1 << 0; const PUSHED: usize = 1 << 1; @@ -47,6 +48,55 @@ impl Single { } } + /// Attempts to push an item into the queue, displacing another if necessary. + pub fn force_push(&self, value: T) -> Result, ForcePushError> { + // Attempt to lock the slot. + let mut state = 0; + + loop { + // Lock the slot. + let prev = self + .state + .compare_exchange(state, LOCKED | PUSHED, Ordering::SeqCst, Ordering::SeqCst) + .unwrap_or_else(|x| x); + + if prev & CLOSED != 0 { + return Err(ForcePushError(value)); + } + + if prev == state { + // Swap out the value. + // SAFETY: We have locked the state. + let prev_value = unsafe { + self.slot + .with_mut(move |slot| ptr::replace(slot, MaybeUninit::new(value))) + }; + + // We can unlock the slot now. + self.state.fetch_and(!LOCKED, Ordering::Release); + + // If the value was pushed, initialize it and return it. + let prev_value = if prev & PUSHED == 0 { + None + } else { + Some(unsafe { prev_value.assume_init() }) + }; + + // Return the old value. + return Ok(prev_value); + } + + // Try to go for the current (pushed) state. + if prev & LOCKED == 0 { + state = prev; + } else { + // State is locked. + busy_wait(); + state = prev & !LOCKED; + } + } + } + /// Attempts to pop an item from the queue. pub fn pop(&self) -> Result { let mut state = PUSHED; diff --git a/src/sync.rs b/src/sync.rs index 53238d0..d6c4f0a 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -62,10 +62,6 @@ pub(crate) mod prelude { pub(crate) trait UnsafeCellExt { type Value; - fn with(&self, f: F) -> R - where - F: FnOnce(*const Self::Value) -> R; - fn with_mut(&self, f: F) -> R where F: FnOnce(*mut Self::Value) -> R; @@ -74,13 +70,6 @@ pub(crate) mod prelude { impl UnsafeCellExt for cell::UnsafeCell { type Value = T; - fn with(&self, f: F) -> R - where - F: FnOnce(*const Self::Value) -> R, - { - f(self.get()) - } - fn with_mut(&self, f: F) -> R where F: FnOnce(*mut Self::Value) -> R, diff --git a/tests/single.rs b/tests/single.rs index 4dc1182..8d2a0d6 100644 --- a/tests/single.rs +++ b/tests/single.rs @@ -1,6 +1,6 @@ #![allow(clippy::bool_assert_comparison)] -use concurrent_queue::{ConcurrentQueue, PopError, PushError}; +use concurrent_queue::{ConcurrentQueue, ForcePushError, PopError, PushError}; #[cfg(not(target_family = "wasm"))] use easy_parallel::Parallel; @@ -65,6 +65,21 @@ fn close() { assert_eq!(q.pop(), Err(PopError::Closed)); } +#[test] +fn force_push() { + let q = ConcurrentQueue::::bounded(1); + assert_eq!(q.force_push(10), Ok(None)); + + assert!(!q.is_closed()); + assert_eq!(q.force_push(20), Ok(Some(10))); + assert_eq!(q.force_push(30), Ok(Some(20))); + + assert!(q.close()); + assert_eq!(q.force_push(40), Err(ForcePushError(40))); + assert_eq!(q.pop(), Ok(30)); + assert_eq!(q.pop(), Err(PopError::Closed)); +} + #[cfg(not(target_family = "wasm"))] #[test] fn spsc() { diff --git a/tests/unbounded.rs b/tests/unbounded.rs index e95dc8c..53ced9a 100644 --- a/tests/unbounded.rs +++ b/tests/unbounded.rs @@ -1,6 +1,6 @@ #![allow(clippy::bool_assert_comparison)] -use concurrent_queue::{ConcurrentQueue, PopError, PushError}; +use concurrent_queue::{ConcurrentQueue, ForcePushError, PopError, PushError}; #[cfg(not(target_family = "wasm"))] use easy_parallel::Parallel; @@ -74,6 +74,32 @@ fn close() { assert_eq!(q.pop(), Err(PopError::Closed)); } +#[test] +fn force_push() { + let q = ConcurrentQueue::::bounded(5); + + for i in 1..=5 { + assert_eq!(q.force_push(i), Ok(None)); + } + + assert!(!q.is_closed()); + for i in 6..=10 { + assert_eq!(q.force_push(i), Ok(Some(i - 5))); + } + assert_eq!(q.pop(), Ok(6)); + assert_eq!(q.force_push(11), Ok(None)); + for i in 12..=15 { + assert_eq!(q.force_push(i), Ok(Some(i - 5))); + } + + assert!(q.close()); + assert_eq!(q.force_push(40), Err(ForcePushError(40))); + for i in 11..=15 { + assert_eq!(q.pop(), Ok(i)); + } + assert_eq!(q.pop(), Err(PopError::Closed)); +} + #[cfg(not(target_family = "wasm"))] #[test] fn spsc() {