diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 4b94159127b..b87c04ee204 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -11,7 +11,6 @@ use vortex_error::vortex_err; use vortex_error::vortex_panic; use vortex_mask::Mask; -use crate::AnyCanonical; use crate::ArrayRef; use crate::Canonical; use crate::Columnar; @@ -30,14 +29,152 @@ use crate::arrays::fixed_size_list::FixedSizeListArrayExt; use crate::arrays::listview::ListViewArrayExt; use crate::builders::builder_with_capacity; use crate::builtins::ArrayBuiltins; +use crate::columnar::AnyColumnar; use crate::dtype::DType; -use crate::dtype::IntegerPType; use crate::executor::max_iterations; use crate::match_each_integer_ptype; /// Reference-counted type-erased grouped accumulator. pub type GroupedAccumulatorRef = Box; +/// A batch of grouped values to aggregate. +/// +/// Each outer list value is one group, and the inner element array is shared by all groups. +/// Aggregate implementations can inspect the concrete grouped representation directly, or ask for +/// derived ranges when their algorithm is expressed in terms of `(offset, size)` pairs. +pub enum GroupedArray { + /// Groups represented as a list-view array with per-group offsets and sizes. + ListView(ListViewArray), + /// Groups represented as a fixed-size list array. + FixedSizeList(FixedSizeListArray), +} + +impl From for GroupedArray { + fn from(groups: ListViewArray) -> Self { + Self::ListView(groups) + } +} + +impl From for GroupedArray { + fn from(groups: FixedSizeListArray) -> Self { + Self::FixedSizeList(groups) + } +} + +impl GroupedArray { + /// The inner element array shared by all groups. + pub fn elements(&self) -> &ArrayRef { + match self { + Self::ListView(groups) => groups.elements(), + Self::FixedSizeList(groups) => groups.elements(), + } + } + + /// Return the `(offset, size)` ranges describing each group in `elements`. + pub fn group_ranges(&self, ctx: &mut ExecutionCtx) -> VortexResult { + match self { + Self::ListView(groups) => list_view_group_ranges(groups, ctx), + Self::FixedSizeList(groups) => Ok(fixed_size_list_group_ranges(groups)), + } + } + + /// Return the per-group validity mask. + pub fn group_validity(&self, ctx: &mut ExecutionCtx) -> VortexResult { + match self { + Self::ListView(groups) => groups.validity()?.execute_mask(groups.len(), ctx), + Self::FixedSizeList(groups) => groups.validity()?.execute_mask(groups.len(), ctx), + } + } + + /// The number of groups in this batch. + pub fn len(&self) -> usize { + match self { + Self::ListView(groups) => groups.len(), + Self::FixedSizeList(groups) => groups.len(), + } + } + + /// Returns true when this batch contains no groups. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns true when every group is valid. + pub fn all_groups_valid(&self, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(self.group_validity(ctx)?.all_true()) + } + + unsafe fn with_elements_unchecked(&self, elements: ArrayRef) -> VortexResult { + Ok(match self { + Self::ListView(groups) => unsafe { + ListViewArray::new_unchecked( + elements, + groups.offsets().clone(), + groups.sizes().clone(), + groups.validity()?, + ) + } + .into(), + Self::FixedSizeList(groups) => unsafe { + FixedSizeListArray::new_unchecked( + elements, + groups.list_size(), + groups.validity()?, + groups.len(), + ) + } + .into(), + }) + } +} + +/// The physical ranges of a grouped array. +pub enum GroupRanges { + /// Explicit ranges extracted from a list-view array. + ListView { + /// The `(offset, size)` ranges. + ranges: Vec<(usize, usize)>, + }, + /// Uniform ranges derived from a fixed-size list array. + FixedSizeList { + /// The number of groups. + len: usize, + /// The number of elements in each group. + size: usize, + }, +} + +impl GroupRanges { + /// The number of groups described by these ranges. + pub fn len(&self) -> usize { + match self { + Self::ListView { ranges } => ranges.len(), + Self::FixedSizeList { len, .. } => *len, + } + } + + /// Returns true when there are no groups. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Return the `(offset, size)` range for the group at `index`. + fn range(&self, index: usize) -> (usize, usize) { + match self { + Self::ListView { ranges } => ranges[index], + Self::FixedSizeList { len, size } => { + assert!(index < *len, "range index out of bounds"); + (index * size, *size) + } + } + } + + /// Iterate over all `(offset, size)` group ranges. + pub fn iter(&self) -> impl Iterator + '_ { + (0..self.len()).map(|index| self.range(index)) + } +} + /// An accumulator used for computing grouped aggregates. /// /// Note that the groups must be processed in order, and the accumulator does not support random @@ -128,8 +265,8 @@ impl DynGroupedAccumulator for GroupedAccumulator { Columnar::Constant(c) => c.into_array().execute::(ctx)?, }; match canonical { - Canonical::List(groups) => self.accumulate_list_view(&groups, ctx), - Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx), + Canonical::List(groups) => self.accumulate_grouped_array(groups.into(), ctx), + Canonical::FixedSizeList(groups) => self.accumulate_grouped_array(groups.into(), ctx), _ => vortex_panic!("We checked the DType above, so this should never happen"), } } @@ -155,165 +292,83 @@ impl DynGroupedAccumulator for GroupedAccumulator { } impl GroupedAccumulator { - fn accumulate_list_view( + fn accumulate_grouped_array( &mut self, - groups: &ListViewArray, + groups: GroupedArray, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let mut elements = groups.elements().clone(); - let groups_validity = groups.validity()?; let session = ctx.session().clone(); for _ in 0..max_iterations() { - if elements.is::() { - break; + // Try a registered grouped kernel for the current element encoding. + if let Some(kernel) = session + .aggregate_fns() + .find_grouped_encoding_kernel(elements.encoding_id(), self.aggregate_fn.id()) + { + // SAFETY: we assume that elements execution is safe + let kernel_groups = unsafe { groups.with_elements_unchecked(elements.clone())? }; + if let Some(result) = + kernel.grouped_aggregate(&self.aggregate_fn, &kernel_groups, ctx)? + { + return self.push_result(result); + } } - if let Some(result) = session + // Try a grouped kernel for the current aggregate regardless of element encoding. + if let Some(kernel) = session .aggregate_fns() - .find_grouped_kernel(elements.encoding_id(), self.aggregate_fn.id()) - .and_then(|kernel| { - // SAFETY: we assume that elements execution is safe - let groups = unsafe { - ListViewArray::new_unchecked( - elements.clone(), - groups.offsets().clone(), - groups.sizes().clone(), - groups_validity.clone(), - ) - }; - kernel - .grouped_aggregate(&self.aggregate_fn, &groups) - .transpose() - }) - .transpose()? + .find_grouped_kernel(self.aggregate_fn.id()) { - return self.push_result(result); + // SAFETY: we preserve the grouped shape and validity while replacing the + // elements with another representation of the same logical array. + let kernel_groups = unsafe { groups.with_elements_unchecked(elements.clone())? }; + if let Some(result) = + kernel.grouped_aggregate(&self.aggregate_fn, &kernel_groups, ctx)? + { + return self.push_result(result); + } + } + + if elements.is::() { + break; } // Execute one step and try again elements = elements.execute(ctx)?; } - // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. let elements = elements.execute::(ctx)?.into_array(); - let offsets = groups.offsets(); - let sizes = groups.sizes().cast(offsets.dtype().clone())?; - let validity = groups_validity.execute_mask(offsets.len(), ctx)?; - - match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { - let offsets = offsets.clone().execute::>(ctx)?; - let sizes = sizes.execute::>(ctx)?; - self.accumulate_list_view_typed( - &elements, - offsets.as_ref(), - sizes.as_ref(), - &validity, - ctx, - ) - }) - } - - fn accumulate_list_view_typed( - &mut self, - elements: &ArrayRef, - offsets: &[O], - sizes: &[O], - validity: &Mask, - ctx: &mut ExecutionCtx, - ) -> VortexResult<()> { - let mut accumulator = Accumulator::try_new( - self.vtable.clone(), - self.options.clone(), - self.dtype.clone(), - )?; - let mut states = builder_with_capacity(&self.partial_dtype, offsets.len()); - - // `validity` is the per-group list-view validity, so it is zipped element-wise with the - // offsets and sizes (one entry per group). - for ((offset, size), valid) in offsets.iter().zip(sizes.iter()).zip(validity.iter()) { - let offset = offset.to_usize().vortex_expect("Offset value is not usize"); - let size = size.to_usize().vortex_expect("Size value is not usize"); - - if valid { - let group = elements.slice(offset..offset + size)?; - accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.flush()?)?; - } else { - states.append_null() - } - } + // SAFETY: we preserve the grouped shape and validity while replacing the elements with an + // executed form of the same logical array. + let grouped = unsafe { groups.with_elements_unchecked(elements)? }; - self.push_result(states.finish()) + // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. + self.accumulate_grouped_fallback(&grouped, ctx) } - fn accumulate_fixed_size_list( + fn accumulate_grouped_fallback( &mut self, - groups: &FixedSizeListArray, + grouped: &GroupedArray, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let mut elements = groups.elements().clone(); - let groups_validity = groups.validity()?; - let session = ctx.session().clone(); - - for _ in 0..64 { - if elements.is::() { - break; - } - - if let Some(result) = session - .aggregate_fns() - .find_grouped_kernel(elements.encoding_id(), self.aggregate_fn.id()) - .and_then(|kernel| { - // SAFETY: we assume that elements execution is safe - let groups = unsafe { - FixedSizeListArray::new_unchecked( - elements.clone(), - groups.list_size(), - groups_validity.clone(), - groups.len(), - ) - }; - - kernel - .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups) - .transpose() - }) - .transpose()? - { - return self.push_result(result); - } - - // Execute one step and try again - elements = elements.execute(ctx)?; - } - - // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. - let elements = elements.execute::(ctx)?.into_array(); - let validity = groups_validity.execute_mask(groups.len(), ctx)?; - let mut accumulator = Accumulator::try_new( self.vtable.clone(), self.options.clone(), self.dtype.clone(), )?; - let mut states = builder_with_capacity(&self.partial_dtype, groups.len()); - - let mut offset = 0; - let size = groups - .list_size() - .to_usize() - .vortex_expect("List size is not usize"); + let mut states = builder_with_capacity(&self.partial_dtype, grouped.len()); + let group_ranges = grouped.group_ranges(ctx)?; + let group_validity = grouped.group_validity(ctx)?; - for valid in validity.iter() { + for ((offset, size), valid) in group_ranges.iter().zip(group_validity.iter()) { if valid { - let group = elements.slice(offset..offset + size)?; + let group = grouped.elements().slice(offset..offset + size)?; accumulator.accumulate(&group, ctx)?; states.append_scalar(&accumulator.flush()?)?; } else { states.append_null() } - offset += size; } self.push_result(states.finish()) @@ -330,3 +385,35 @@ impl GroupedAccumulator { Ok(()) } } +fn list_view_group_ranges( + groups: &ListViewArray, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let offsets = groups.offsets(); + let sizes = groups.sizes().cast(offsets.dtype().clone())?; + + let ranges = match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { + let offsets = offsets.clone().execute::>(ctx)?; + let sizes = sizes.execute::>(ctx)?; + offsets + .as_ref() + .iter() + .zip(sizes.as_ref().iter()) + .map(|(offset, size)| { + ( + offset.to_usize().vortex_expect("Offset value is not usize"), + size.to_usize().vortex_expect("Size value is not usize"), + ) + }) + .collect::>() + }); + + Ok(GroupRanges::ListView { ranges }) +} + +fn fixed_size_list_group_ranges(groups: &FixedSizeListArray) -> GroupRanges { + GroupRanges::FixedSizeList { + len: groups.len(), + size: groups.list_size() as usize, + } +} diff --git a/vortex-array/src/aggregate_fn/fns/count/grouped.rs b/vortex-array/src/aggregate_fn/fns/count/grouped.rs new file mode 100644 index 00000000000..fb94489dde0 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/count/grouped.rs @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_buffer::Buffer; +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use super::Count; +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::GroupRanges; +use crate::aggregate_fn::GroupedArray; +use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; +use crate::arrays::PrimitiveArray; +use crate::validity::Validity; + +/// Encoding-independent grouped [`Count`] kernel. +#[derive(Debug)] +pub(crate) struct CountGroupedKernel; + +impl DynGroupedAggregateKernel for CountGroupedKernel { + fn grouped_aggregate( + &self, + aggregate_fn: &AggregateFnRef, + groups: &GroupedArray, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + if !aggregate_fn.is::() { + return Ok(None); + } + try_grouped_count(groups, ctx) + } +} + +/// Count each valid group from the element validity mask. +/// +/// The [`Count`] partial dtype is non-nullable `U64`, so a null outer group cannot be represented +/// as a partial state. If any outer group is invalid, this returns `Ok(None)` and lets the caller +/// use the existing fallback behavior. +pub(super) fn try_grouped_count( + groups: &GroupedArray, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + if !groups.all_groups_valid(ctx)? { + return Ok(None); + } + let group_ranges = groups.group_ranges(ctx)?; + + Ok(Some(grouped_count(groups.elements(), &group_ranges, ctx)?)) +} + +/// Count the valid elements of each group described by `group_ranges` (element `(offset, size)` +/// pairs) into a non-nullable `U64` array, one entry per group. +fn grouped_count( + elements: &ArrayRef, + group_ranges: &GroupRanges, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let elem_mask = elements.validity()?.execute_mask(elements.len(), ctx)?; + + let counts: Buffer = if elem_mask.all_true() { + group_ranges.iter().map(|(_, size)| size as u64).collect() + } else { + group_ranges + .iter() + .map(|(offset, size)| valid_count(&elem_mask, offset, size) as u64) + .collect() + }; + + Ok(PrimitiveArray::new(counts, Validity::NonNullable).into_array()) +} + +/// Number of valid elements in the `[offset, offset + size)` range of the element mask. +fn valid_count(elem_mask: &Mask, offset: usize, size: usize) -> usize { + elem_mask.slice(offset..offset + size).true_count() +} + +#[cfg(test)] +mod tests { + #![allow(clippy::cast_possible_truncation)] + + use vortex_buffer::Buffer; + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::ArrayRef; + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::DynGroupedAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupedAccumulator; + use crate::aggregate_fn::fns::count::Count; + use crate::arrays::FixedSizeListArray; + use crate::arrays::ListViewArray; + use crate::arrays::PrimitiveArray; + use crate::arrays::VarBinViewArray; + use crate::assert_arrays_eq; + use crate::dtype::DType; + use crate::dtype::Nullability::NonNullable; + use crate::dtype::Nullability::Nullable; + use crate::dtype::PType; + use crate::validity::Validity; + + /// Run a grouped count through the accumulator. + fn grouped_count_actual(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { + let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, elem_dtype.clone())?; + acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; + acc.finish() + } + + /// Reference valid-counts (non-nullable `U64`), one per group. + fn grouped_count_reference( + elements: &ArrayRef, + ranges: &[(usize, usize)], + ) -> VortexResult { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let counts: Buffer = ranges + .iter() + .map(|&(offset, size)| { + Ok(elements + .slice(offset..offset + size)? + .valid_count(&mut ctx)? as u64) + }) + .collect::>()?; + Ok(PrimitiveArray::new(counts, Validity::NonNullable).into_array()) + } + + fn listview(elements: ArrayRef, ranges: &[(usize, usize)]) -> VortexResult { + let offsets = PrimitiveArray::from_iter(ranges.iter().map(|&(o, _)| o as i32)); + let sizes = PrimitiveArray::from_iter(ranges.iter().map(|&(_, s)| s as i32)); + Ok(ListViewArray::try_new( + elements, + offsets.into_array(), + sizes.into_array(), + Validity::NonNullable, + )? + .into_array()) + } + + #[test] + fn listview_counts_all_valid() -> VortexResult<()> { + let elements = + PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); + let elem_dtype = DType::Primitive(PType::I32, NonNullable); + let ranges = [(0, 2), (2, 1), (3, 3), (6, 0)]; + + let groups = listview(elements.clone(), &ranges)?; + let actual = grouped_count_actual(&groups, &elem_dtype)?; + let expected = grouped_count_reference(&elements, &ranges)?; + + let direct = + PrimitiveArray::new(buffer![2u64, 1, 3, 0], Validity::NonNullable).into_array(); + assert_arrays_eq!(&actual, &direct); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_counts_with_nulls() -> VortexResult<()> { + let elements = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, None, Some(9)]) + .into_array(); + let elem_dtype = DType::Primitive(PType::I32, Nullable); + let ranges = [(0, 3), (3, 2), (5, 1)]; + + let groups = listview(elements.clone(), &ranges)?; + let actual = grouped_count_actual(&groups, &elem_dtype)?; + let expected = grouped_count_reference(&elements, &ranges)?; + + // Group 0: {1, null, 3} -> 2. Group 1: {null, null} -> 0. Group 2: {9} -> 1. + let direct = PrimitiveArray::new(buffer![2u64, 0, 1], Validity::NonNullable).into_array(); + assert_arrays_eq!(&actual, &direct); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_counts_varbinview_with_nulls() -> VortexResult<()> { + let elements = VarBinViewArray::from_iter_nullable_str([ + Some("a"), + None, + Some("bbb"), + None, + Some("cc"), + ]) + .into_array(); + let elem_dtype = elements.dtype().clone(); + let ranges = [(0, 2), (2, 2), (4, 1)]; + + let groups = listview(elements.clone(), &ranges)?; + let actual = grouped_count_actual(&groups, &elem_dtype)?; + let expected = grouped_count_reference(&elements, &ranges)?; + + let direct = PrimitiveArray::new(buffer![1u64, 1, 1], Validity::NonNullable).into_array(); + assert_arrays_eq!(&actual, &direct); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn fixed_size_counts_with_nulls() -> VortexResult<()> { + let elements = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4)]).into_array(); + let elem_dtype = DType::Primitive(PType::I32, Nullable); + let groups = + FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?.into_array(); + + let actual = grouped_count_actual(&groups, &elem_dtype)?; + let direct = PrimitiveArray::new(buffer![1u64, 2], Validity::NonNullable).into_array(); + assert_arrays_eq!(&actual, &direct); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index e25c42e0845..1fe984fb099 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +mod grouped; +pub(crate) use grouped::CountGroupedKernel; use vortex_error::VortexExpect; use vortex_error::VortexResult; diff --git a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs new file mode 100644 index 00000000000..6f00cce7fdb --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; +use vortex_mask::AllOr; +use vortex_mask::Mask; + +use super::Sum; +use super::primitive::sum_float_all; +use super::primitive::sum_signed_all; +use super::primitive::sum_unsigned_all; +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::GroupRanges; +use crate::aggregate_fn::GroupedArray; +use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; +use crate::arrays::Primitive; +use crate::arrays::PrimitiveArray; +use crate::dtype::NativePType; +use crate::match_each_native_ptype; + +/// Encoding-specific grouped [`Sum`] kernel for primitive element arrays. +#[derive(Debug)] +pub(crate) struct PrimitiveGroupedSumEncodingKernel; + +impl DynGroupedAggregateKernel for PrimitiveGroupedSumEncodingKernel { + fn grouped_aggregate( + &self, + aggregate_fn: &AggregateFnRef, + groups: &GroupedArray, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + if !aggregate_fn.is::() { + return Ok(None); + } + try_grouped_sum(groups, ctx) + } +} + +/// Grouped [`Sum`] implementation for canonical primitive elements. +/// +/// Reuses the scalar primitive-sum reductions ([`sum_unsigned_all`]/[`sum_signed_all`]/ +/// [`sum_float_all`]) so the per-group semantics match scalar `sum` exactly (overflow saturates to +/// a null sum, NaNs are skipped). The element validity mask is materialized once and sliced per +/// group, rather than the per-group accumulator setup of the generic fallback path. +pub(super) fn try_grouped_sum( + groups: &GroupedArray, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + if !groups.elements().is::() { + return Ok(None); + } + let elements = groups.elements().clone().downcast::(); + let group_ranges = groups.group_ranges(ctx)?; + let group_validity = groups.group_validity(ctx)?; + + Ok(Some(grouped_sum( + &elements, + &group_ranges, + &group_validity, + ctx, + )?)) +} + +/// Sum each group described by `group_ranges` (element `(offset, size)` pairs), one sum per group. +fn grouped_sum( + elements: &PrimitiveArray, + group_ranges: &GroupRanges, + group_validity: &Mask, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let elem_mask = elements + .as_ref() + .validity()? + .execute_mask(elements.as_ref().len(), ctx)?; + let all_valid = matches!(elem_mask.slices(), AllOr::All); + + let result = match_each_native_ptype!(elements.ptype(), + unsigned: |T| { + let values = elements.as_slice::(); + collect_sums::(values, group_ranges, group_validity, &elem_mask, all_valid, + sum_unsigned_all) + }, + signed: |T| { + let values = elements.as_slice::(); + collect_sums::(values, group_ranges, group_validity, &elem_mask, all_valid, + sum_signed_all) + }, + floating: |T| { + let values = elements.as_slice::(); + collect_sums::(values, group_ranges, group_validity, &elem_mask, all_valid, + |acc, slice| { sum_float_all(acc, slice); false }) + } + ); + + Ok(result.into_array()) +} + +/// Reduce each group's element slice into a nullable sum. A group is null when the group +/// itself is invalid, or when summing it overflows (`sum_run` returns `true`). +fn collect_sums( + values: &[T], + group_ranges: &GroupRanges, + group_validity: &Mask, + elem_mask: &Mask, + all_valid: bool, + sum_run: impl Fn(&mut A, &[T]) -> bool, +) -> PrimitiveArray { + let sums = group_ranges.iter().enumerate().map(|(i, (offset, size))| { + if !group_validity.value(i) { + return None; + } + let mut acc = A::default(); + let overflow = if all_valid { + sum_run(&mut acc, &values[offset..offset + size]) + } else { + sum_masked_group(&mut acc, values, offset, size, elem_mask, &sum_run) + }; + (!overflow).then_some(acc) + }); + PrimitiveArray::from_option_iter(sums) +} + +/// Sum the valid elements of a single group, using the contiguous valid runs of the element mask +/// intersected with the group's `[offset, offset + size)` range. +fn sum_masked_group( + acc: &mut A, + values: &[T], + offset: usize, + size: usize, + elem_mask: &Mask, + sum_run: &impl Fn(&mut A, &[T]) -> bool, +) -> bool { + match elem_mask.slice(offset..offset + size).slices() { + AllOr::All => sum_run(acc, &values[offset..offset + size]), + AllOr::None => false, + AllOr::Some(runs) => { + for &(start, end) in runs { + if sum_run(acc, &values[offset + start..offset + end]) { + return true; + } + } + false + } + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::cast_possible_truncation)] + + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::ArrayRef; + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::DynGroupedAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupedAccumulator; + use crate::aggregate_fn::fns::sum::Sum; + use crate::aggregate_fn::fns::sum::sum; + use crate::arrays::FixedSizeListArray; + use crate::arrays::ListViewArray; + use crate::arrays::PrimitiveArray; + use crate::assert_arrays_eq; + use crate::builders::builder_with_capacity; + use crate::dtype::DType; + use crate::dtype::Nullability::NonNullable; + use crate::dtype::Nullability::Nullable; + use crate::dtype::PType; + use crate::validity::Validity; + + /// Run a grouped sum through the accumulator. + fn grouped_sum_actual(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { + let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?; + acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; + acc.finish() + } + + /// Reference sums computed exactly like the generic slow path: per-group scalar [`sum`] for + /// valid groups, a null sum for invalid groups. + fn grouped_sum_reference( + elements: &ArrayRef, + ranges: &[(usize, usize)], + group_valid: &[bool], + elem_dtype: &DType, + ) -> VortexResult { + use crate::aggregate_fn::AggregateFnVTable; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let sum_dtype = Sum + .partial_dtype(&EmptyOptions, elem_dtype) + .expect("sum partial dtype"); + let mut builder = builder_with_capacity(&sum_dtype, ranges.len()); + for (i, &(offset, size)) in ranges.iter().enumerate() { + if group_valid[i] { + let slice = elements.slice(offset..offset + size)?; + builder.append_scalar(&sum(&slice, &mut ctx)?)?; + } else { + builder.append_null(); + } + } + Ok(builder.finish()) + } + + fn offsets_sizes(ranges: &[(usize, usize)]) -> (ArrayRef, ArrayRef) { + let offsets = PrimitiveArray::from_iter(ranges.iter().map(|&(o, _)| o as i32)); + let sizes = PrimitiveArray::from_iter(ranges.iter().map(|&(_, s)| s as i32)); + (offsets.into_array(), sizes.into_array()) + } + + fn listview( + elements: ArrayRef, + ranges: &[(usize, usize)], + group_valid: &[bool], + ) -> VortexResult { + let (offsets, sizes) = offsets_sizes(ranges); + let validity = if group_valid.iter().all(|&v| v) { + Validity::NonNullable + } else { + Validity::from_iter(group_valid.iter().copied()) + }; + Ok(ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array()) + } + + #[test] + fn listview_matches_reference_unsigned() -> VortexResult<()> { + let elements = + PrimitiveArray::new(buffer![1u32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); + let elem_dtype = DType::Primitive(PType::U32, NonNullable); + let ranges = [(0, 2), (2, 1), (3, 3)]; + let valid = [true, true, true]; + + let groups = listview(elements.clone(), &ranges, &valid)?; + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; + + // Unsigned input sums to U64. + let direct = PrimitiveArray::from_option_iter([Some(3u64), Some(3u64), Some(15u64)]); + assert_arrays_eq!(&actual, &direct.into_array()); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_out_of_order_offsets_with_null_group() -> VortexResult<()> { + // Offsets are not in group order and a group is null: the group validity must be indexed by + // group index, not by element offset. + let elements = + PrimitiveArray::new(buffer![10i32, 20, 30, 40, 50, 60], Validity::NonNullable) + .into_array(); + let elem_dtype = DType::Primitive(PType::I32, NonNullable); + let ranges = [(4, 2), (0, 2), (2, 2)]; + let valid = [true, false, true]; + + let groups = listview(elements.clone(), &ranges, &valid)?; + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; + + let direct = PrimitiveArray::from_option_iter([Some(110i64), None, Some(70i64)]); + assert_arrays_eq!(&actual, &direct.into_array()); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_interior_and_full_nulls() -> VortexResult<()> { + // Group 1 has an interior null, group 2 is entirely null, group 3 is empty. + let elements = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, None, Some(9)]) + .into_array(); + let elem_dtype = DType::Primitive(PType::I32, Nullable); + let ranges = [(0, 3), (3, 2), (5, 0), (5, 1)]; + let valid = [true, true, true, true]; + + let groups = listview(elements.clone(), &ranges, &valid)?; + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; + + let direct = + PrimitiveArray::from_option_iter([Some(4i64), Some(0i64), Some(0i64), Some(9i64)]); + assert_arrays_eq!(&actual, &direct.into_array()); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_overflow_group_is_null() -> VortexResult<()> { + let elements = + PrimitiveArray::new(buffer![i64::MAX, 1, 2, 3], Validity::NonNullable).into_array(); + let elem_dtype = DType::Primitive(PType::I64, NonNullable); + let ranges = [(0, 2), (2, 2)]; + let valid = [true, true]; + + let groups = listview(elements.clone(), &ranges, &valid)?; + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; + + // First group overflows -> null sum; second group sums normally. + let direct = PrimitiveArray::from_option_iter([None, Some(5i64)]); + assert_arrays_eq!(&actual, &direct.into_array()); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_float_nan_and_inf() -> VortexResult<()> { + let elements = PrimitiveArray::new( + buffer![1.0f64, f64::NAN, 2.0, f64::INFINITY, f64::NEG_INFINITY, 4.0], + Validity::NonNullable, + ) + .into_array(); + let elem_dtype = DType::Primitive(PType::F64, NonNullable); + let ranges = [(0, 3), (3, 3)]; + let valid = [true, true]; + + let groups = listview(elements.clone(), &ranges, &valid)?; + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + + // Group 0: NaN skipped -> 3.0. Group 1: INF + -INF = NaN. (Avoid array equality here since + // NaN != NaN; compare element scalars against the reference path instead.) + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; + let g0 = actual.execute_scalar(0, &mut ctx)?; + assert_eq!(g0.as_primitive().typed_value::(), Some(3.0)); + assert_eq!( + g0.as_primitive().typed_value::(), + expected + .execute_scalar(0, &mut ctx)? + .as_primitive() + .typed_value::() + ); + let g1 = actual.execute_scalar(1, &mut ctx)?; + assert!(g1.as_primitive().typed_value::().unwrap().is_nan()); + assert!( + expected + .execute_scalar(1, &mut ctx)? + .as_primitive() + .typed_value::() + .unwrap() + .is_nan() + ); + Ok(()) + } + + #[test] + fn fixed_size_overflow_and_nan() -> VortexResult<()> { + // FixedSize path: first group overflows -> null sum, second sums normally. + let elements = + PrimitiveArray::new(buffer![i64::MAX, 1, 2, 3], Validity::NonNullable).into_array(); + let elem_dtype = DType::Primitive(PType::I64, NonNullable); + let groups = FixedSizeListArray::try_new(elements.clone(), 2, Validity::NonNullable, 2)? + .into_array(); + + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + let expected = + grouped_sum_reference(&elements, &[(0, 2), (2, 2)], &[true, true], &elem_dtype)?; + let direct = PrimitiveArray::from_option_iter([None, Some(5i64)]); + assert_arrays_eq!(&actual, &direct.into_array()); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 24799570ff7..9d525bec742 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -4,8 +4,9 @@ mod bool; mod constant; mod decimal; +mod grouped; mod primitive; - +pub(crate) use grouped::PrimitiveGroupedSumEncodingKernel; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; diff --git a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs index 44418fb5628..df7d929d896 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs @@ -50,11 +50,7 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR unsigned: |_T| { vortex_panic!("float sum state with unsigned input") }, signed: |_T| { vortex_panic!("float sum state with signed input") }, floating: |T| { - for &v in p.as_slice::() { - if !v.is_nan() { - *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); - } - } + sum_float_all(acc, p.as_slice::()); Ok(false) } ), @@ -62,11 +58,21 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR } } +/// Sum the non-NaN values of a float slice into an `f64` accumulator. NaNs are skipped to match the +/// scalar `sum` semantics. Floats cannot overflow the accumulator, so this never reports saturation. +pub(super) fn sum_float_all(acc: &mut f64, slice: &[T]) { + for &v in slice { + if !v.is_nan() { + *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); + } + } +} + /// Sum all values into a `u64` accumulator. For types narrower than 64 bits, values are summed in /// chunks of [`SUM_CHUNK`] with a single checked add per chunk, which lets the inner loop vectorize /// to packed widening adds. `u64` input keeps a per-element checked add since a chunk of `u64`s /// could itself overflow. Returns `true` on overflow. -fn sum_unsigned_all(acc: &mut u64, slice: &[T]) -> bool +pub(super) fn sum_unsigned_all(acc: &mut u64, slice: &[T]) -> bool where T: NativePType + AsPrimitive, { @@ -88,7 +94,7 @@ where } /// Signed counterpart of [`sum_unsigned_all`]. -fn sum_signed_all(acc: &mut i64, slice: &[T]) -> bool +pub(super) fn sum_signed_all(acc: &mut i64, slice: &[T]) -> bool where T: NativePType + AsPrimitive, { @@ -150,11 +156,7 @@ fn accumulate_primitive_valid( floating: |T| { let values = p.as_slice::(); for &(start, end) in slices { - for &v in &values[start..end] { - if !v.is_nan() { - *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); - } - } + sum_float_all(acc, &values[start..end]); } Ok(false) } diff --git a/vortex-array/src/aggregate_fn/kernels.rs b/vortex-array/src/aggregate_fn/kernels.rs index d806b18d84d..c5af0902cbb 100644 --- a/vortex-array/src/aggregate_fn/kernels.rs +++ b/vortex-array/src/aggregate_fn/kernels.rs @@ -11,8 +11,7 @@ use vortex_error::VortexResult; use crate::ArrayRef; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnRef; -use crate::arrays::FixedSizeListArray; -use crate::arrays::ListViewArray; +use crate::aggregate_fn::GroupedArray; use crate::scalar::Scalar; /// A pluggable kernel for an aggregate function. @@ -30,34 +29,24 @@ pub trait DynAggregateKernel: 'static + Send + Sync + Debug { /// A pluggable kernel for batch aggregation of many groups. /// -/// The kernel is matched on the encoding of the _elements_ array, which is the inner array of the -/// provided `ListViewArray`. This is more pragmatic than having every kernel match on the outer -/// list encoding and having to deal with the possibility of multiple list encodings. +/// A kernel can be registered either for an aggregate function regardless of the element encoding, +/// or for a specific aggregate function and element encoding. Element-encoding kernels are matched +/// on the inner array of the provided grouped array, not on the outer list encoding. This is more +/// pragmatic than having every kernel match on the outer list encoding and having to deal with the +/// possibility of multiple list encodings. /// -/// Each element of the list array represents a group and the result of the grouped aggregate +/// Each value in the grouped array represents a group and the result of the grouped aggregate /// should be an array of the same length, where each element is the aggregate state of the /// corresponding group. /// /// Return `Ok(None)` if the kernel cannot be applied to the given aggregate function. pub trait DynGroupedAggregateKernel: 'static + Send + Sync + Debug { - /// Aggregate each group in the provided `ListViewArray` and return an array of the - /// aggregate states. + /// Aggregate each group in the provided grouped array and return an array of the aggregate + /// states. fn grouped_aggregate( &self, aggregate_fn: &AggregateFnRef, - groups: &ListViewArray, + groups: &GroupedArray, + ctx: &mut ExecutionCtx, ) -> VortexResult>; - - /// Aggregate each group in the provided `FixedSizeListArray` and return an array of the - /// aggregate states. - fn grouped_aggregate_fixed_size( - &self, - aggregate_fn: &AggregateFnRef, - groups: &FixedSizeListArray, - ) -> VortexResult> { - // TODO(ngates): we could automatically delegate to `grouped_aggregate` if SequenceArray - // was in the vortex-array crate - let _ = (aggregate_fn, groups); - Ok(None) - } } diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index edbafdf386f..c6d7542a687 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -18,6 +18,8 @@ use crate::aggregate_fn::fns::all_non_null::AllNonNull; use crate::aggregate_fn::fns::all_null::AllNull; use crate::aggregate_fn::fns::bounded_max::BoundedMax; use crate::aggregate_fn::fns::bounded_min::BoundedMin; +use crate::aggregate_fn::fns::count::Count; +use crate::aggregate_fn::fns::count::CountGroupedKernel; use crate::aggregate_fn::fns::first::First; use crate::aggregate_fn::fns::is_constant::IsConstant; use crate::aggregate_fn::fns::is_sorted::IsSorted; @@ -27,6 +29,7 @@ use crate::aggregate_fn::fns::min::Min; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::nan_count::NanCount; use crate::aggregate_fn::fns::null_count::NullCount; +use crate::aggregate_fn::fns::sum::PrimitiveGroupedSumEncodingKernel; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes; use crate::aggregate_fn::kernels::DynAggregateKernel; @@ -36,6 +39,7 @@ use crate::array::ArrayId; use crate::array::VTable; use crate::arrays::Chunked; use crate::arrays::Dict; +use crate::arrays::Primitive; use crate::arrays::chunked::compute::aggregate::ChunkedArrayAggregate; use crate::arrays::dict::compute::is_constant::DictIsConstantKernel; use crate::arrays::dict::compute::is_sorted::DictIsSortedKernel; @@ -50,8 +54,10 @@ use crate::arrays::dict::compute::min_max::DictMinMaxKernel; pub struct AggregateFnSession { registry: ArcSwapMap, - kernels: ArcSwapMap, - grouped_kernels: ArcSwapMap, + kernels: ArcSwapMap, + grouped_kernels: ArcSwapMap, + grouped_encoding_kernels: + ArcSwapMap, } impl SessionVar for AggregateFnSession { @@ -64,7 +70,8 @@ impl SessionVar for AggregateFnSession { } } -type KernelKey = (ArrayId, Option); +type AggregateKernelKey = (ArrayId, Option); +type GroupedEncodingKernelKey = (ArrayId, AggregateFnId); impl Default for AggregateFnSession { fn default() -> Self { @@ -72,6 +79,7 @@ impl Default for AggregateFnSession { registry: ArcSwapMap::default(), kernels: ArcSwapMap::default(), grouped_kernels: ArcSwapMap::default(), + grouped_encoding_kernels: ArcSwapMap::default(), }; // Register the built-in aggregate functions @@ -100,6 +108,14 @@ impl Default for AggregateFnSession { this.register_aggregate_kernel(Dict.id(), Some(IsConstant.id()), &DictIsConstantKernel); this.register_aggregate_kernel(Dict.id(), Some(IsSorted.id()), &DictIsSortedKernel); + // Register the built-in grouped aggregate kernels. + this.register_grouped_kernel(Count.id(), &CountGroupedKernel); + this.register_grouped_encoding_kernel( + Primitive.id(), + Sum.id(), + &PrimitiveGroupedSumEncodingKernel, + ); + this } } @@ -152,27 +168,46 @@ impl AggregateFnSession { self.kernels.insert(id, kernel); } - /// Returns the grouped aggregate kernel registered for `array_id` and `agg_fn_id`, if any. + /// Returns the grouped aggregate kernel registered for `agg_fn_id`, if any. /// - /// Lookup first checks for a kernel registered for the exact aggregate function, then falls - /// back to a kernel registered for all aggregate functions on the same array encoding. + /// These kernels are independent of the element encoding and are checked for each element + /// representation, after any kernel registered for the current element encoding. pub fn find_grouped_kernel( + &self, + agg_fn_id: impl Into, + ) -> Option<&'static dyn DynGroupedAggregateKernel> { + let fn_id = agg_fn_id.into(); + self.grouped_kernels + .read(|kernels| kernels.get(&fn_id).copied()) + } + + /// Registers a grouped aggregate kernel for an aggregate function. + pub fn register_grouped_kernel( + &self, + agg_fn_id: impl Into, + kernel: &'static dyn DynGroupedAggregateKernel, + ) { + let fn_id = agg_fn_id.into(); + self.grouped_kernels.insert(fn_id, kernel) + } + + /// Returns the grouped aggregate kernel registered for `array_id` and `agg_fn_id`, if any. + /// + /// These kernels are matched against each intermediate element encoding while the grouped + /// accumulator executes the element array. + pub fn find_grouped_encoding_kernel( &self, array_id: impl Into, agg_fn_id: impl Into, ) -> Option<&'static dyn DynGroupedAggregateKernel> { let id = array_id.into(); let fn_id = agg_fn_id.into(); - self.grouped_kernels.read(|kernels| { - kernels - .get(&(id, Some(fn_id))) - .or_else(|| kernels.get(&(id, None))) - .copied() - }) + self.grouped_encoding_kernels + .read(|kernels| kernels.get(&(id, fn_id)).copied()) } /// Registers a grouped aggregate kernel for a specific aggregate function and array encoding. - pub fn register_grouped_kernel( + pub fn register_grouped_encoding_kernel( &self, array_id: impl Into, agg_fn_id: impl Into, @@ -180,7 +215,7 @@ impl AggregateFnSession { ) { let id = array_id.into(); let fn_id = agg_fn_id.into(); - self.grouped_kernels.insert((id, Some(fn_id)), kernel) + self.grouped_encoding_kernels.insert((id, fn_id), kernel) } }