From d5f63d59e3bf79aae100c31bf48350857f4bae01 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 18 Jun 2026 13:11:06 -0400 Subject: [PATCH] generalize SIMD take to Copy values An AVX2 gather just moves bytes, so dispatch the take kernel on size_of::() rather than NativePType: 4-byte values gather through a u32 lane, 8-byte through u64, others fall back to scalar. This halves the gather impls (i32/u32/f32 share one, i64/u64/f64 the other) and lets any POD Copy type use the path. exec_take casts only raw pointers, keeping align-1 types like [u8; 4] sound; take_primitive_scalar is relaxed to T: Copy. Signed-off-by: Connor Tsui --- .../src/arrays/primitive/compute/take/avx2.rs | 248 +++++++----------- .../src/arrays/primitive/compute/take/mod.rs | 6 +- 2 files changed, 91 insertions(+), 163 deletions(-) diff --git a/vortex-array/src/arrays/primitive/compute/take/avx2.rs b/vortex-array/src/arrays/primitive/compute/take/avx2.rs index e92304dc34b..fb64e6748b0 100644 --- a/vortex-array/src/arrays/primitive/compute/take/avx2.rs +++ b/vortex-array/src/arrays/primitive/compute/take/avx2.rs @@ -46,7 +46,6 @@ use crate::arrays::primitive::compute::take::TakeImpl; use crate::arrays::primitive::compute::take::take_primitive_scalar; use crate::arrays::primitive::vtable::Primitive; use crate::dtype::NativePType; -use crate::dtype::PType; use crate::dtype::UnsignedPType; use crate::match_each_native_ptype; use crate::match_each_unsigned_integer_ptype; @@ -112,6 +111,16 @@ where /// Takes the specified indices into a new [`Buffer`] using AVX2 SIMD. /// +/// An AVX2 gather only moves raw bytes, so signedness and float-ness are irrelevant — only the +/// byte width of `V` matters. Any 4-byte value rides the gather through the `u32` lane and any +/// 8-byte value through the `u64` lane, regardless of its actual type. Values 1 or 2 bytes wide +/// (AVX2 has no sub-32-bit gather) and wider than 8 bytes (`i128`, decimals) fall back to the +/// scalar kernel. +/// +/// This treats `V` as plain-old-data: reinterpreting the gathered bytes as `V` is only sound +/// because every bit pattern is a valid `V`. All primitive and decimal-backing types satisfy +/// this, as does any `Copy` POD type the caller supplies. +/// /// # Panics /// /// This function panics if any of the provided `indices` are out of bounds for `values`. @@ -121,60 +130,29 @@ where /// The caller must ensure the `avx2` feature is enabled. #[target_feature(enable = "avx2")] #[doc(hidden)] -unsafe fn take_avx2(buffer: &[V], indices: &[I]) -> Buffer { - macro_rules! dispatch_avx2 { - ($indices:ty, $values:ty) => { - { let result = dispatch_avx2!($indices, $values, cast: $values); result } - }; - ($indices:ty, $values:ty, cast: $cast:ty) => {{ - let indices = unsafe { std::mem::transmute::<&[I], &[$indices]>(indices) }; - let values = unsafe { std::mem::transmute::<&[V], &[$cast]>(buffer) }; - - let result = exec_take::<$cast, $indices, AVX2Gather>(values, indices); - unsafe { result.transmute::() } - }}; - } - +unsafe fn take_avx2(buffer: &[V], indices: &[I]) -> Buffer { if buffer.is_empty() { return Buffer::zeroed(indices.len()); } - match (I::PTYPE, V::PTYPE) { - // Int value types. Only 32 and 64 bit types are supported. - (PType::U8, PType::I32) => dispatch_avx2!(u8, i32), - (PType::U8, PType::U32) => dispatch_avx2!(u8, u32), - (PType::U8, PType::I64) => dispatch_avx2!(u8, i64), - (PType::U8, PType::U64) => dispatch_avx2!(u8, u64), - (PType::U16, PType::I32) => dispatch_avx2!(u16, i32), - (PType::U16, PType::U32) => dispatch_avx2!(u16, u32), - (PType::U16, PType::I64) => dispatch_avx2!(u16, i64), - (PType::U16, PType::U64) => dispatch_avx2!(u16, u64), - (PType::U32, PType::I32) => dispatch_avx2!(u32, i32), - (PType::U32, PType::U32) => dispatch_avx2!(u32, u32), - (PType::U32, PType::I64) => dispatch_avx2!(u32, i64), - (PType::U32, PType::U64) => dispatch_avx2!(u32, u64), - - // Float value types, treat them as if they were corresponding int types. - (PType::U8, PType::F32) => dispatch_avx2!(u8, f32, cast: u32), - (PType::U16, PType::F32) => dispatch_avx2!(u16, f32, cast: u32), - (PType::U32, PType::F32) => dispatch_avx2!(u32, f32, cast: u32), - (PType::U64, PType::F32) => dispatch_avx2!(u64, f32, cast: u32), - - (PType::U8, PType::F64) => dispatch_avx2!(u8, f64, cast: u64), - (PType::U16, PType::F64) => dispatch_avx2!(u16, f64, cast: u64), - (PType::U32, PType::F64) => dispatch_avx2!(u32, f64, cast: u64), - (PType::U64, PType::F64) => dispatch_avx2!(u64, f64, cast: u64), - - // Scalar fallback for unsupported value types. - _ => { - tracing::trace!( - "take AVX2 kernel missing for indices {} values {}, falling back to scalar", - I::PTYPE, - V::PTYPE - ); - - take_primitive_scalar(buffer, indices) - } + // Dispatch on the gather lane width. The index type must still be concretized to select the + // right `GatherFn` impl, so re-dispatch it with `match_each_unsigned_integer_ptype!`. + macro_rules! dispatch { + ($lane:ty) => {{ + match_each_unsigned_integer_ptype!(I::PTYPE, |Idx| { + // SAFETY: `Idx` has the same `PTYPE` as `I`, so this is a no-op reinterpret of the + // index slice into the concrete type the gather impl is keyed on. + let indices = unsafe { std::mem::transmute::<&[I], &[Idx]>(indices) }; + exec_take::(buffer, indices) + }) + }}; + } + + match size_of::() { + 4 => dispatch!(u32), + 8 => dispatch!(u64), + // 1/2-byte and >8-byte values have no AVX2 gather lane, so fall back to scalar. + _ => take_primitive_scalar(buffer, indices), } } @@ -259,17 +237,6 @@ impl_gather!(u8, store: _mm256_storeu_si256, WIDTH = 8, STRIDE = 16 }, - { i32 => - load: _mm_loadu_si128, - extend: _mm256_cvtepu8_epi32, - splat: _mm256_set1_epi32, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi32, - mask_cvt: |x| { x }, - gather: _mm256_mask_i32gather_epi32, - store: _mm256_storeu_si256, - WIDTH = 8, STRIDE = 16 - }, // 64-bit values, loaded 4 at a time { u64 => @@ -282,17 +249,6 @@ impl_gather!(u8, gather: _mm256_mask_i64gather_epi64, store: _mm256_storeu_si256, WIDTH = 4, STRIDE = 16 - }, - { i64 => - load: _mm_loadu_si128, - extend: _mm256_cvtepu8_epi64, - splat: _mm256_set1_epi64x, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi64, - mask_cvt: |x| { x }, - gather: _mm256_mask_i64gather_epi64, - store: _mm256_storeu_si256, - WIDTH = 4, STRIDE = 16 } ); @@ -310,17 +266,6 @@ impl_gather!(u16, store: _mm256_storeu_si256, WIDTH = 8, STRIDE = 8 }, - { i32 => - load: _mm_loadu_si128, - extend: _mm256_cvtepu16_epi32, - splat: _mm256_set1_epi32, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi32, - mask_cvt: |x| { x }, - gather: _mm256_mask_i32gather_epi32, - store: _mm256_storeu_si256, - WIDTH = 8, STRIDE = 8 - }, // 64-bit values. 8x indices loaded at a time and 4x values loaded at a time. { u64 => @@ -333,17 +278,6 @@ impl_gather!(u16, gather: _mm256_mask_i64gather_epi64, store: _mm256_storeu_si256, WIDTH = 4, STRIDE = 8 - }, - { i64 => - load: _mm_loadu_si128, - extend: _mm256_cvtepu16_epi64, - splat: _mm256_set1_epi64x, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi64, - mask_cvt: |x| { x }, - gather: _mm256_mask_i64gather_epi64, - store: _mm256_storeu_si256, - WIDTH = 4, STRIDE = 8 } ); @@ -361,17 +295,6 @@ impl_gather!(u32, store: _mm256_storeu_si256, WIDTH = 8, STRIDE = 8 }, - { i32 => - load: _mm256_loadu_si256, - extend: identity, - splat: _mm256_set1_epi32, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi32, - mask_cvt: |x| { x }, - gather: _mm256_mask_i32gather_epi32, - store: _mm256_storeu_si256, - WIDTH = 8, STRIDE = 8 - }, // 64-bit values. { u64 => @@ -384,17 +307,6 @@ impl_gather!(u32, gather: _mm256_mask_i64gather_epi64, store: _mm256_storeu_si256, WIDTH = 4, STRIDE = 4 - }, - { i64 => - load: _mm_loadu_si128, - extend: _mm256_cvtepu32_epi64, - splat: _mm256_set1_epi64x, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi64, - mask_cvt: |x| { x }, - gather: _mm256_mask_i64gather_epi64, - store: _mm256_storeu_si256, - WIDTH = 4, STRIDE = 4 } ); @@ -419,25 +331,6 @@ impl_gather!(u64, store: _mm_storeu_si128, WIDTH = 4, STRIDE = 4 }, - { i32 => - load: _mm256_loadu_si256, - extend: identity, - splat: _mm256_set1_epi64x, - zero_vec: _mm_setzero_si128, - mask_indices: _mm256_cmpgt_epi64, - mask_cvt: |m| { - unsafe { - let lo_bits = _mm256_extracti128_si256::<0>(m); // lower half - let hi_bits = _mm256_extracti128_si256::<1>(m); // upper half - let lo_packed = _mm_shuffle_epi32::<0b01_01_01_01>(lo_bits); - let hi_packed = _mm_shuffle_epi32::<0b01_01_01_01>(hi_bits); - _mm_unpacklo_epi64(lo_packed, hi_packed) - } - }, - gather: _mm256_mask_i64gather_epi32, - store: _mm_storeu_si128, - WIDTH = 4, STRIDE = 4 - }, // 64-bit values. { u64 => @@ -450,32 +343,36 @@ impl_gather!(u64, gather: _mm256_mask_i64gather_epi64, store: _mm256_storeu_si256, WIDTH = 4, STRIDE = 4 - }, - { i64 => - load: _mm256_loadu_si256, - extend: identity, - splat: _mm256_set1_epi64x, - zero_vec: _mm256_setzero_si256, - mask_indices: _mm256_cmpgt_epi64, - mask_cvt: |x| { x }, - gather: _mm256_mask_i64gather_epi64, - store: _mm256_storeu_si256, - WIDTH = 4, STRIDE = 4 } ); -/// AVX2 core inner loop for certain `Idx` and `Value` type. +/// AVX2 core inner loop for a given index type `Idx`, output element type `Out`, and gather +/// `Lane` type. +/// +/// `Out` is the element type written to the output buffer; `Lane` (`u32` or `u64`) is the +/// integer type the gather intrinsics operate on. The caller must pair them so that +/// `size_of::() == size_of::()` (the only caller, [`take_avx2`], picks `Lane` from +/// `size_of::()`). The gather moves `size_of::()` raw bytes per element, which only +/// yields a valid `Out` because `Out` is plain-old-data (every bit pattern is a valid `Out`). +/// Pointers into the `Out`-typed slices are cast to `*const Lane`/`*mut Lane`; gather tolerates +/// the (possibly weaker) `Out` alignment. #[inline(always)] -fn exec_take(values: &[Value], indices: &[Idx]) -> Buffer +fn exec_take(values: &[Out], indices: &[Idx]) -> Buffer where - Value: Copy, + Out: Copy, Idx: UnsignedPType, - Gather: GatherFn, + Gather: GatherFn, { + debug_assert_eq!( + size_of::(), + size_of::(), + "gather lane and output element must have the same size" + ); + let indices_len = indices.len(); let max_index = Idx::from(values.len()).unwrap_or_else(|| Idx::max_value()); let mut buffer = - BufferMut::::with_capacity_aligned(indices_len, Alignment::of::<__m256i>()); + BufferMut::::with_capacity_aligned(indices_len, Alignment::of::<__m256i>()); let buf_uninit = buffer.spare_capacity_mut(); let mut offset = 0; @@ -483,16 +380,18 @@ where // might read up to STRIDE src elements at a time, even though it only advances WIDTH elements // in the dst. while offset + Gather::STRIDE < indices_len { - // SAFETY: `gather_simd` preconditions satisfied: + // SAFETY: `gather` preconditions satisfied: // 1. `(indices + offset)..(indices + offset + STRIDE)` is in-bounds for indices // allocation. - // 2. `buffer` has same len as indices so `buffer + offset + STRIDE` is always valid. + // 2. `buffer` has same len as indices so `buffer + offset + WIDTH` is always valid. + // 3. `size_of::() == size_of::()` (asserted above), so the `Lane`-typed + // pointers address the same bytes as the `Out`-typed `values`/`buffer` allocations. unsafe { Gather::gather( indices.as_ptr().add(offset), max_index, - values.as_ptr(), - buf_uninit.as_mut_ptr().add(offset).cast(), + values.as_ptr().cast::(), + buf_uninit.as_mut_ptr().add(offset).cast::(), ) }; offset += Gather::WIDTH; @@ -509,11 +408,11 @@ where // SAFETY: All elements have been initialized. unsafe { buffer.set_len(indices_len) }; - // Reset the buffer alignment to the Value type. + // Reset the buffer alignment to the output type. // NOTE: if we don't do this, we pass back a Buffer which is over-aligned to the SIMD // register width. The caller expects that this memory should be aligned to the value type // so that we can slice it at value boundaries. - buffer = buffer.aligned(Alignment::of::()); + buffer = buffer.aligned(Alignment::of::()); buffer.freeze() } @@ -601,4 +500,37 @@ mod avx2_tests { let result = unsafe { take_avx2(&values, &indices) }; assert_eq!(&vec![65535; indices.len()], result.as_slice()); } + + /// A `[u8; 4]` is a 4-byte `Copy` POD that is not a `NativePType`. This proves the kernel + /// gathers an arbitrary 4-byte value type through the `u32` SIMD lane. + #[test] + fn test_avx2_take_simd_array_u8x4() { + let values: Vec<[u8; 4]> = (1u32..=200).map(u32::to_le_bytes).collect(); + let indices: Vec = (0..200).collect(); + + let result = unsafe { take_avx2(&values, &indices) }; + assert_eq!(values.as_slice(), result.as_slice()); + } + + /// 2-byte values have no AVX2 gather, so they take the scalar fallback path and must still be + /// correct. + #[test] + fn test_avx2_take_scalar_fallback_u16() { + let values: Vec = (1..=300).collect(); + let indices: Vec = (0..300).collect(); + + let result = unsafe { take_avx2(&values, &indices) }; + assert_eq!(values.as_slice(), result.as_slice()); + } + + /// Values wider than 8 bytes (e.g. `i128`/decimal backing) exceed the gather lane and fall + /// back to the scalar kernel. + #[test] + fn test_avx2_take_scalar_fallback_array_u8x16() { + let values: Vec<[u8; 16]> = (0u128..200).map(u128::to_le_bytes).collect(); + let indices: Vec = (0..200).collect(); + + let result = unsafe { take_avx2(&values, &indices) }; + assert_eq!(values.as_slice(), result.as_slice()); + } } diff --git a/vortex-array/src/arrays/primitive/compute/take/mod.rs b/vortex-array/src/arrays/primitive/compute/take/mod.rs index 4023991c65d..21516731800 100644 --- a/vortex-array/src/arrays/primitive/compute/take/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/take/mod.rs @@ -20,7 +20,6 @@ use crate::arrays::dict::TakeExecute; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::IntegerPType; -use crate::dtype::NativePType; use crate::executor::ExecutionCtx; use crate::match_each_integer_ptype; use crate::match_each_native_ptype; @@ -106,10 +105,7 @@ impl TakeExecute for Primitive { // Compiler may see this as unused based on enabled features #[inline(always)] -fn take_primitive_scalar( - buffer: &[T], - indices: &[I], -) -> Buffer { +fn take_primitive_scalar(buffer: &[T], indices: &[I]) -> Buffer { // NB: The simpler `indices.iter().map(|idx| buffer[idx.as_()]).collect()` generates suboptimal // assembly where the buffer length is repeatedly loaded from the stack on each iteration.