Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 90 additions & 158 deletions vortex-array/src/arrays/primitive/compute/take/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ use crate::arrays::primitive::compute::take::TakeImpl;
use crate::arrays::primitive::compute::take::take_primitive_scalar;
use crate::arrays::primitive::vtable::Primitive;
use crate::dtype::NativePType;
use crate::dtype::PType;
use crate::dtype::UnsignedPType;
use crate::match_each_native_ptype;
use crate::match_each_unsigned_integer_ptype;
Expand Down Expand Up @@ -112,6 +111,16 @@ where

/// Takes the specified indices into a new [`Buffer`] using AVX2 SIMD.
///
/// An AVX2 gather only moves raw bytes, so signedness and float-ness are irrelevant — only the
/// byte width of `V` matters. Any 4-byte value rides the gather through the `u32` lane and any
/// 8-byte value through the `u64` lane, regardless of its actual type. Values 1 or 2 bytes wide
/// (AVX2 has no sub-32-bit gather) and wider than 8 bytes (`i128`, decimals) fall back to the
/// scalar kernel.
///
/// This treats `V` as plain-old-data: reinterpreting the gathered bytes as `V` is only sound
/// because every bit pattern is a valid `V`. All primitive and decimal-backing types satisfy
/// this, as does any `Copy` POD type the caller supplies.
///
/// # Panics
///
/// This function panics if any of the provided `indices` are out of bounds for `values`.
Expand All @@ -121,60 +130,29 @@ where
/// The caller must ensure the `avx2` feature is enabled.
#[target_feature(enable = "avx2")]
#[doc(hidden)]
unsafe fn take_avx2<V: NativePType, I: UnsignedPType>(buffer: &[V], indices: &[I]) -> Buffer<V> {
macro_rules! dispatch_avx2 {
($indices:ty, $values:ty) => {
{ let result = dispatch_avx2!($indices, $values, cast: $values); result }
};
($indices:ty, $values:ty, cast: $cast:ty) => {{
let indices = unsafe { std::mem::transmute::<&[I], &[$indices]>(indices) };
let values = unsafe { std::mem::transmute::<&[V], &[$cast]>(buffer) };

let result = exec_take::<$cast, $indices, AVX2Gather>(values, indices);
unsafe { result.transmute::<V>() }
}};
}

unsafe fn take_avx2<V: Copy, I: UnsignedPType>(buffer: &[V], indices: &[I]) -> Buffer<V> {
if buffer.is_empty() {
return Buffer::zeroed(indices.len());
}

match (I::PTYPE, V::PTYPE) {
// Int value types. Only 32 and 64 bit types are supported.
(PType::U8, PType::I32) => dispatch_avx2!(u8, i32),
(PType::U8, PType::U32) => dispatch_avx2!(u8, u32),
(PType::U8, PType::I64) => dispatch_avx2!(u8, i64),
(PType::U8, PType::U64) => dispatch_avx2!(u8, u64),
(PType::U16, PType::I32) => dispatch_avx2!(u16, i32),
(PType::U16, PType::U32) => dispatch_avx2!(u16, u32),
(PType::U16, PType::I64) => dispatch_avx2!(u16, i64),
(PType::U16, PType::U64) => dispatch_avx2!(u16, u64),
(PType::U32, PType::I32) => dispatch_avx2!(u32, i32),
(PType::U32, PType::U32) => dispatch_avx2!(u32, u32),
(PType::U32, PType::I64) => dispatch_avx2!(u32, i64),
(PType::U32, PType::U64) => dispatch_avx2!(u32, u64),

// Float value types, treat them as if they were corresponding int types.
(PType::U8, PType::F32) => dispatch_avx2!(u8, f32, cast: u32),
(PType::U16, PType::F32) => dispatch_avx2!(u16, f32, cast: u32),
(PType::U32, PType::F32) => dispatch_avx2!(u32, f32, cast: u32),
(PType::U64, PType::F32) => dispatch_avx2!(u64, f32, cast: u32),

(PType::U8, PType::F64) => dispatch_avx2!(u8, f64, cast: u64),
(PType::U16, PType::F64) => dispatch_avx2!(u16, f64, cast: u64),
(PType::U32, PType::F64) => dispatch_avx2!(u32, f64, cast: u64),
(PType::U64, PType::F64) => dispatch_avx2!(u64, f64, cast: u64),

// Scalar fallback for unsupported value types.
_ => {
tracing::trace!(
"take AVX2 kernel missing for indices {} values {}, falling back to scalar",
I::PTYPE,
V::PTYPE
);

take_primitive_scalar(buffer, indices)
}
// Dispatch on the gather lane width. The index type must still be concretized to select the
// right `GatherFn` impl, so re-dispatch it with `match_each_unsigned_integer_ptype!`.
macro_rules! dispatch {
($lane:ty) => {{
match_each_unsigned_integer_ptype!(I::PTYPE, |Idx| {
// SAFETY: `Idx` has the same `PTYPE` as `I`, so this is a no-op reinterpret of the
// index slice into the concrete type the gather impl is keyed on.
let indices = unsafe { std::mem::transmute::<&[I], &[Idx]>(indices) };
exec_take::<V, $lane, Idx, AVX2Gather>(buffer, indices)
})
}};
}

match size_of::<V>() {
4 => dispatch!(u32),
8 => dispatch!(u64),
// 1/2-byte and >8-byte values have no AVX2 gather lane, so fall back to scalar.
_ => take_primitive_scalar(buffer, indices),
}
}

Expand Down Expand Up @@ -259,17 +237,6 @@ impl_gather!(u8,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 16
},
{ i32 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu8_epi32,
splat: _mm256_set1_epi32,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi32,
mask_cvt: |x| { x },
gather: _mm256_mask_i32gather_epi32,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 16
},

// 64-bit values, loaded 4 at a time
{ u64 =>
Expand All @@ -282,17 +249,6 @@ impl_gather!(u8,
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 16
},
{ i64 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu8_epi64,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 16
}
);

Expand All @@ -310,17 +266,6 @@ impl_gather!(u16,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 8
},
{ i32 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu16_epi32,
splat: _mm256_set1_epi32,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi32,
mask_cvt: |x| { x },
gather: _mm256_mask_i32gather_epi32,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 8
},

// 64-bit values. 8x indices loaded at a time and 4x values loaded at a time.
{ u64 =>
Expand All @@ -333,17 +278,6 @@ impl_gather!(u16,
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 8
},
{ i64 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu16_epi64,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 8
}
);

Expand All @@ -361,17 +295,6 @@ impl_gather!(u32,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 8
},
{ i32 =>
load: _mm256_loadu_si256,
extend: identity,
splat: _mm256_set1_epi32,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi32,
mask_cvt: |x| { x },
gather: _mm256_mask_i32gather_epi32,
store: _mm256_storeu_si256,
WIDTH = 8, STRIDE = 8
},

// 64-bit values.
{ u64 =>
Expand All @@ -384,17 +307,6 @@ impl_gather!(u32,
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 4
},
{ i64 =>
load: _mm_loadu_si128,
extend: _mm256_cvtepu32_epi64,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 4
}
);

Expand All @@ -419,25 +331,6 @@ impl_gather!(u64,
store: _mm_storeu_si128,
WIDTH = 4, STRIDE = 4
},
{ i32 =>
load: _mm256_loadu_si256,
extend: identity,
splat: _mm256_set1_epi64x,
zero_vec: _mm_setzero_si128,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |m| {
unsafe {
let lo_bits = _mm256_extracti128_si256::<0>(m); // lower half
let hi_bits = _mm256_extracti128_si256::<1>(m); // upper half
let lo_packed = _mm_shuffle_epi32::<0b01_01_01_01>(lo_bits);
let hi_packed = _mm_shuffle_epi32::<0b01_01_01_01>(hi_bits);
_mm_unpacklo_epi64(lo_packed, hi_packed)
}
},
gather: _mm256_mask_i64gather_epi32,
store: _mm_storeu_si128,
WIDTH = 4, STRIDE = 4
},

// 64-bit values.
{ u64 =>
Expand All @@ -450,49 +343,55 @@ impl_gather!(u64,
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 4
},
{ i64 =>
load: _mm256_loadu_si256,
extend: identity,
splat: _mm256_set1_epi64x,
zero_vec: _mm256_setzero_si256,
mask_indices: _mm256_cmpgt_epi64,
mask_cvt: |x| { x },
gather: _mm256_mask_i64gather_epi64,
store: _mm256_storeu_si256,
WIDTH = 4, STRIDE = 4
}
);

/// AVX2 core inner loop for certain `Idx` and `Value` type.
/// AVX2 core inner loop for a given index type `Idx`, output element type `Out`, and gather
/// `Lane` type.
///
/// `Out` is the element type written to the output buffer; `Lane` (`u32` or `u64`) is the
/// integer type the gather intrinsics operate on. The caller must pair them so that
/// `size_of::<Out>() == size_of::<Lane>()` (the only caller, [`take_avx2`], picks `Lane` from
/// `size_of::<Out>()`). The gather moves `size_of::<Lane>()` raw bytes per element, which only
/// yields a valid `Out` because `Out` is plain-old-data (every bit pattern is a valid `Out`).
/// Pointers into the `Out`-typed slices are cast to `*const Lane`/`*mut Lane`; gather tolerates
/// the (possibly weaker) `Out` alignment.
#[inline(always)]
fn exec_take<Value, Idx, Gather>(values: &[Value], indices: &[Idx]) -> Buffer<Value>
fn exec_take<Out, Lane, Idx, Gather>(values: &[Out], indices: &[Idx]) -> Buffer<Out>
where
Value: Copy,
Out: Copy,
Idx: UnsignedPType,
Gather: GatherFn<Idx, Value>,
Gather: GatherFn<Idx, Lane>,
{
debug_assert_eq!(
size_of::<Out>(),
size_of::<Lane>(),
"gather lane and output element must have the same size"
);

let indices_len = indices.len();
let max_index = Idx::from(values.len()).unwrap_or_else(|| Idx::max_value());
let mut buffer =
BufferMut::<Value>::with_capacity_aligned(indices_len, Alignment::of::<__m256i>());
BufferMut::<Out>::with_capacity_aligned(indices_len, Alignment::of::<__m256i>());
let buf_uninit = buffer.spare_capacity_mut();

let mut offset = 0;
// Loop terminates STRIDE elements before end of the indices array because the `GatherFn`
// might read up to STRIDE src elements at a time, even though it only advances WIDTH elements
// in the dst.
while offset + Gather::STRIDE < indices_len {
// SAFETY: `gather_simd` preconditions satisfied:
// SAFETY: `gather` preconditions satisfied:
// 1. `(indices + offset)..(indices + offset + STRIDE)` is in-bounds for indices
// allocation.
// 2. `buffer` has same len as indices so `buffer + offset + STRIDE` is always valid.
// 2. `buffer` has same len as indices so `buffer + offset + WIDTH` is always valid.
// 3. `size_of::<Out>() == size_of::<Lane>()` (asserted above), so the `Lane`-typed
// pointers address the same bytes as the `Out`-typed `values`/`buffer` allocations.
unsafe {
Gather::gather(
indices.as_ptr().add(offset),
max_index,
values.as_ptr(),
buf_uninit.as_mut_ptr().add(offset).cast(),
values.as_ptr().cast::<Lane>(),
buf_uninit.as_mut_ptr().add(offset).cast::<Lane>(),
)
};
offset += Gather::WIDTH;
Expand All @@ -509,11 +408,11 @@ where
// SAFETY: All elements have been initialized.
unsafe { buffer.set_len(indices_len) };

// Reset the buffer alignment to the Value type.
// Reset the buffer alignment to the output type.
// NOTE: if we don't do this, we pass back a Buffer which is over-aligned to the SIMD
// register width. The caller expects that this memory should be aligned to the value type
// so that we can slice it at value boundaries.
buffer = buffer.aligned(Alignment::of::<Value>());
buffer = buffer.aligned(Alignment::of::<Out>());

buffer.freeze()
}
Expand Down Expand Up @@ -601,4 +500,37 @@ mod avx2_tests {
let result = unsafe { take_avx2(&values, &indices) };
assert_eq!(&vec![65535; indices.len()], result.as_slice());
}

/// A `[u8; 4]` is a 4-byte `Copy` POD that is not a `NativePType`. This proves the kernel
/// gathers an arbitrary 4-byte value type through the `u32` SIMD lane.
#[test]
fn test_avx2_take_simd_array_u8x4() {
let values: Vec<[u8; 4]> = (1u32..=200).map(u32::to_le_bytes).collect();
let indices: Vec<u32> = (0..200).collect();

let result = unsafe { take_avx2(&values, &indices) };
assert_eq!(values.as_slice(), result.as_slice());
}

/// 2-byte values have no AVX2 gather, so they take the scalar fallback path and must still be
/// correct.
#[test]
fn test_avx2_take_scalar_fallback_u16() {
let values: Vec<u16> = (1..=300).collect();
let indices: Vec<u32> = (0..300).collect();

let result = unsafe { take_avx2(&values, &indices) };
assert_eq!(values.as_slice(), result.as_slice());
}

/// Values wider than 8 bytes (e.g. `i128`/decimal backing) exceed the gather lane and fall
/// back to the scalar kernel.
#[test]
fn test_avx2_take_scalar_fallback_array_u8x16() {
let values: Vec<[u8; 16]> = (0u128..200).map(u128::to_le_bytes).collect();
let indices: Vec<u32> = (0..200).collect();

let result = unsafe { take_avx2(&values, &indices) };
assert_eq!(values.as_slice(), result.as_slice());
}
}
6 changes: 1 addition & 5 deletions vortex-array/src/arrays/primitive/compute/take/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use crate::arrays::dict::TakeExecute;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::IntegerPType;
use crate::dtype::NativePType;
use crate::executor::ExecutionCtx;
use crate::match_each_integer_ptype;
use crate::match_each_native_ptype;
Expand Down Expand Up @@ -106,10 +105,7 @@ impl TakeExecute for Primitive {

// Compiler may see this as unused based on enabled features
#[inline(always)]
fn take_primitive_scalar<T: NativePType, I: IntegerPType>(
buffer: &[T],
indices: &[I],
) -> Buffer<T> {
fn take_primitive_scalar<T: Copy, I: IntegerPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
// NB: The simpler `indices.iter().map(|idx| buffer[idx.as_()]).collect()` generates suboptimal
// assembly where the buffer length is repeatedly loaded from the stack on each iteration.

Expand Down
Loading