From 7b85f05bd379b4947f45fab86d1de3e7a5f6307e Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Thu, 4 Jun 2026 14:57:35 +0100 Subject: [PATCH 1/5] aggregate fns to have grouped aggregate kernels for sum and count Signed-off-by: Onur Satici --- .../src/aggregate_fn/accumulator_grouped.rs | 324 ++++++++++------ .../src/aggregate_fn/fns/count/grouped.rs | 195 ++++++++++ .../src/aggregate_fn/fns/count/mod.rs | 12 + .../src/aggregate_fn/fns/sum/grouped.rs | 346 ++++++++++++++++++ vortex-array/src/aggregate_fn/fns/sum/mod.rs | 11 + .../src/aggregate_fn/fns/sum/primitive.rs | 26 +- vortex-array/src/aggregate_fn/kernels.rs | 29 +- vortex-array/src/aggregate_fn/vtable.rs | 15 + 8 files changed, 803 insertions(+), 155 deletions(-) create mode 100644 vortex-array/src/aggregate_fn/fns/count/grouped.rs create mode 100644 vortex-array/src/aggregate_fn/fns/sum/grouped.rs diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 4b94159127b..06b535a7686 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -31,13 +31,150 @@ use crate::arrays::listview::ListViewArrayExt; use crate::builders::builder_with_capacity; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; -use crate::dtype::IntegerPType; use crate::executor::max_iterations; use crate::match_each_integer_ptype; /// Reference-counted type-erased grouped accumulator. pub type GroupedAccumulatorRef = Box; +/// A batch of grouped values to aggregate. +/// +/// Each outer list value is one group, and the inner element array is shared by all groups. +/// Aggregate implementations can inspect the concrete grouped representation directly, or ask for +/// derived ranges when their algorithm is expressed in terms of `(offset, size)` pairs. +pub enum GroupedArray { + /// Groups represented as a list-view array with per-group offsets and sizes. + ListView(ListViewArray), + /// Groups represented as a fixed-size list array. + FixedSizeList(FixedSizeListArray), +} + +impl From for GroupedArray { + fn from(groups: ListViewArray) -> Self { + Self::ListView(groups) + } +} + +impl From for GroupedArray { + fn from(groups: FixedSizeListArray) -> Self { + Self::FixedSizeList(groups) + } +} + +impl GroupedArray { + /// The inner element array shared by all groups. + pub fn elements(&self) -> &ArrayRef { + match self { + Self::ListView(groups) => groups.elements(), + Self::FixedSizeList(groups) => groups.elements(), + } + } + + /// Return the `(offset, size)` ranges describing each group in `elements`. + pub fn group_ranges(&self, ctx: &mut ExecutionCtx) -> VortexResult { + match self { + Self::ListView(groups) => list_view_group_ranges(groups, ctx), + Self::FixedSizeList(groups) => Ok(fixed_size_list_group_ranges(groups)), + } + } + + /// Return the per-group validity mask. + pub fn group_validity(&self, ctx: &mut ExecutionCtx) -> VortexResult { + match self { + Self::ListView(groups) => groups.validity()?.execute_mask(groups.len(), ctx), + Self::FixedSizeList(groups) => groups.validity()?.execute_mask(groups.len(), ctx), + } + } + + /// The number of groups in this batch. + pub fn len(&self) -> usize { + match self { + Self::ListView(groups) => groups.len(), + Self::FixedSizeList(groups) => groups.len(), + } + } + + /// Returns true when this batch contains no groups. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns true when every group is valid. + pub fn all_groups_valid(&self, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(self.group_validity(ctx)?.all_true()) + } + + unsafe fn with_elements_unchecked(&self, elements: ArrayRef) -> VortexResult { + Ok(match self { + Self::ListView(groups) => unsafe { + ListViewArray::new_unchecked( + elements, + groups.offsets().clone(), + groups.sizes().clone(), + groups.validity()?, + ) + } + .into(), + Self::FixedSizeList(groups) => unsafe { + FixedSizeListArray::new_unchecked( + elements, + groups.list_size(), + groups.validity()?, + groups.len(), + ) + } + .into(), + }) + } +} + +/// The physical ranges of a grouped array. +pub enum GroupRanges { + /// Explicit ranges extracted from a list-view array. + ListView { + /// The `(offset, size)` ranges. + ranges: Vec<(usize, usize)>, + }, + /// Uniform ranges derived from a fixed-size list array. + FixedSizeList { + /// The number of groups. + len: usize, + /// The number of elements in each group. + size: usize, + }, +} + +impl GroupRanges { + /// The number of groups described by these ranges. + pub fn len(&self) -> usize { + match self { + Self::ListView { ranges } => ranges.len(), + Self::FixedSizeList { len, .. } => *len, + } + } + + /// Returns true when there are no groups. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Return the `(offset, size)` range for the group at `index`. + fn range(&self, index: usize) -> (usize, usize) { + match self { + Self::ListView { ranges } => ranges[index], + Self::FixedSizeList { len, size } => { + assert!(index < *len, "range index out of bounds"); + (index * size, *size) + } + } + } + + /// Iterate over all `(offset, size)` group ranges. + pub fn iter(&self) -> impl Iterator + '_ { + (0..self.len()).map(|index| self.range(index)) + } +} + /// An accumulator used for computing grouped aggregates. /// /// Note that the groups must be processed in order, and the accumulator does not support random @@ -128,8 +265,8 @@ impl DynGroupedAccumulator for GroupedAccumulator { Columnar::Constant(c) => c.into_array().execute::(ctx)?, }; match canonical { - Canonical::List(groups) => self.accumulate_list_view(&groups, ctx), - Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx), + Canonical::List(groups) => self.accumulate_grouped_array(groups.into(), ctx), + Canonical::FixedSizeList(groups) => self.accumulate_grouped_array(groups.into(), ctx), _ => vortex_panic!("We checked the DType above, so this should never happen"), } } @@ -155,13 +292,12 @@ impl DynGroupedAccumulator for GroupedAccumulator { } impl GroupedAccumulator { - fn accumulate_list_view( + fn accumulate_grouped_array( &mut self, - groups: &ListViewArray, + groups: GroupedArray, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let mut elements = groups.elements().clone(); - let groups_validity = groups.validity()?; let session = ctx.session().clone(); for _ in 0..max_iterations() { @@ -169,151 +305,62 @@ impl GroupedAccumulator { break; } - if let Some(result) = session + // Try a registered grouped kernel for the current non-canonical element encoding. + if let Some(kernel) = session .aggregate_fns() .find_grouped_kernel(elements.encoding_id(), self.aggregate_fn.id()) - .and_then(|kernel| { - // SAFETY: we assume that elements execution is safe - let groups = unsafe { - ListViewArray::new_unchecked( - elements.clone(), - groups.offsets().clone(), - groups.sizes().clone(), - groups_validity.clone(), - ) - }; - kernel - .grouped_aggregate(&self.aggregate_fn, &groups) - .transpose() - }) - .transpose()? { - return self.push_result(result); + // SAFETY: we assume that elements execution is safe + let kernel_groups = unsafe { groups.with_elements_unchecked(elements.clone())? }; + if let Some(result) = + kernel.grouped_aggregate(&self.aggregate_fn, &kernel_groups, ctx)? + { + return self.push_result(result); + } } // Execute one step and try again elements = elements.execute(ctx)?; } - // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. let elements = elements.execute::(ctx)?.into_array(); - let offsets = groups.offsets(); - let sizes = groups.sizes().cast(offsets.dtype().clone())?; - let validity = groups_validity.execute_mask(offsets.len(), ctx)?; - - match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { - let offsets = offsets.clone().execute::>(ctx)?; - let sizes = sizes.execute::>(ctx)?; - self.accumulate_list_view_typed( - &elements, - offsets.as_ref(), - sizes.as_ref(), - &validity, - ctx, - ) - }) - } - - fn accumulate_list_view_typed( - &mut self, - elements: &ArrayRef, - offsets: &[O], - sizes: &[O], - validity: &Mask, - ctx: &mut ExecutionCtx, - ) -> VortexResult<()> { - let mut accumulator = Accumulator::try_new( - self.vtable.clone(), - self.options.clone(), - self.dtype.clone(), - )?; - let mut states = builder_with_capacity(&self.partial_dtype, offsets.len()); - - // `validity` is the per-group list-view validity, so it is zipped element-wise with the - // offsets and sizes (one entry per group). - for ((offset, size), valid) in offsets.iter().zip(sizes.iter()).zip(validity.iter()) { - let offset = offset.to_usize().vortex_expect("Offset value is not usize"); - let size = size.to_usize().vortex_expect("Size value is not usize"); - - if valid { - let group = elements.slice(offset..offset + size)?; - accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.flush()?)?; - } else { - states.append_null() - } + // SAFETY: we preserve the grouped shape and validity while replacing the elements with an + // executed form of the same logical array. + let grouped = unsafe { groups.with_elements_unchecked(elements)? }; + + if let Some(result) = self + .vtable + .try_accumulate_grouped(&self.options, &grouped, ctx)? + { + return self.push_result(result); } - self.push_result(states.finish()) + // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. + self.accumulate_grouped_fallback(&grouped, ctx) } - fn accumulate_fixed_size_list( + fn accumulate_grouped_fallback( &mut self, - groups: &FixedSizeListArray, + grouped: &GroupedArray, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let mut elements = groups.elements().clone(); - let groups_validity = groups.validity()?; - let session = ctx.session().clone(); - - for _ in 0..64 { - if elements.is::() { - break; - } - - if let Some(result) = session - .aggregate_fns() - .find_grouped_kernel(elements.encoding_id(), self.aggregate_fn.id()) - .and_then(|kernel| { - // SAFETY: we assume that elements execution is safe - let groups = unsafe { - FixedSizeListArray::new_unchecked( - elements.clone(), - groups.list_size(), - groups_validity.clone(), - groups.len(), - ) - }; - - kernel - .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups) - .transpose() - }) - .transpose()? - { - return self.push_result(result); - } - - // Execute one step and try again - elements = elements.execute(ctx)?; - } - - // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. - let elements = elements.execute::(ctx)?.into_array(); - let validity = groups_validity.execute_mask(groups.len(), ctx)?; - let mut accumulator = Accumulator::try_new( self.vtable.clone(), self.options.clone(), self.dtype.clone(), )?; - let mut states = builder_with_capacity(&self.partial_dtype, groups.len()); - - let mut offset = 0; - let size = groups - .list_size() - .to_usize() - .vortex_expect("List size is not usize"); + let mut states = builder_with_capacity(&self.partial_dtype, grouped.len()); + let group_ranges = grouped.group_ranges(ctx)?; + let group_validity = grouped.group_validity(ctx)?; - for valid in validity.iter() { + for ((offset, size), valid) in group_ranges.iter().zip(group_validity.iter()) { if valid { - let group = elements.slice(offset..offset + size)?; + let group = grouped.elements().slice(offset..offset + size)?; accumulator.accumulate(&group, ctx)?; states.append_scalar(&accumulator.flush()?)?; } else { states.append_null() } - offset += size; } self.push_result(states.finish()) @@ -330,3 +377,36 @@ impl GroupedAccumulator { Ok(()) } } + +fn list_view_group_ranges( + groups: &ListViewArray, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let offsets = groups.offsets(); + let sizes = groups.sizes().cast(offsets.dtype().clone())?; + + let ranges = match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { + let offsets = offsets.clone().execute::>(ctx)?; + let sizes = sizes.execute::>(ctx)?; + offsets + .as_ref() + .iter() + .zip(sizes.as_ref().iter()) + .map(|(offset, size)| { + ( + offset.to_usize().vortex_expect("Offset value is not usize"), + size.to_usize().vortex_expect("Size value is not usize"), + ) + }) + .collect::>() + }); + + Ok(GroupRanges::ListView { ranges }) +} + +fn fixed_size_list_group_ranges(groups: &FixedSizeListArray) -> GroupRanges { + GroupRanges::FixedSizeList { + len: groups.len(), + size: groups.list_size() as usize, + } +} diff --git a/vortex-array/src/aggregate_fn/fns/count/grouped.rs b/vortex-array/src/aggregate_fn/fns/count/grouped.rs new file mode 100644 index 00000000000..37b9a24b928 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/count/grouped.rs @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_buffer::Buffer; +use vortex_error::VortexResult; +use vortex_mask::Mask; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::aggregate_fn::GroupRanges; +use crate::aggregate_fn::GroupedArray; +use crate::arrays::PrimitiveArray; +use crate::validity::Validity; + +/// Count each valid group from the element validity mask. +/// +/// The [`Count`](super::Count) partial dtype is non-nullable `U64`, so a null outer group cannot be +/// represented as a partial state. If any outer group is invalid, this returns `Ok(None)` and lets +/// the caller use the existing fallback behavior. +pub(super) fn try_grouped_count( + groups: &GroupedArray, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + if !groups.all_groups_valid(ctx)? { + return Ok(None); + } + let group_ranges = groups.group_ranges(ctx)?; + + Ok(Some(grouped_count(groups.elements(), &group_ranges, ctx)?)) +} + +/// Count the valid elements of each group described by `group_ranges` (element `(offset, size)` +/// pairs) into a non-nullable `U64` array, one entry per group. +fn grouped_count( + elements: &ArrayRef, + group_ranges: &GroupRanges, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let elem_mask = elements.validity()?.execute_mask(elements.len(), ctx)?; + + let counts: Buffer = if elem_mask.all_true() { + group_ranges.iter().map(|(_, size)| size as u64).collect() + } else { + group_ranges + .iter() + .map(|(offset, size)| valid_count(&elem_mask, offset, size) as u64) + .collect() + }; + + Ok(PrimitiveArray::new(counts, Validity::NonNullable).into_array()) +} + +/// Number of valid elements in the `[offset, offset + size)` range of the element mask. +fn valid_count(elem_mask: &Mask, offset: usize, size: usize) -> usize { + elem_mask.slice(offset..offset + size).true_count() +} + +#[cfg(test)] +mod tests { + #![allow(clippy::cast_possible_truncation)] + + use vortex_buffer::Buffer; + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::ArrayRef; + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::DynGroupedAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupedAccumulator; + use crate::aggregate_fn::fns::count::Count; + use crate::arrays::FixedSizeListArray; + use crate::arrays::ListViewArray; + use crate::arrays::PrimitiveArray; + use crate::arrays::VarBinViewArray; + use crate::assert_arrays_eq; + use crate::dtype::DType; + use crate::dtype::Nullability::NonNullable; + use crate::dtype::Nullability::Nullable; + use crate::dtype::PType; + use crate::validity::Validity; + + /// Run a grouped count through the accumulator. + fn grouped_count_actual(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { + let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, elem_dtype.clone())?; + acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; + acc.finish() + } + + /// Reference valid-counts (non-nullable `U64`), one per group. + fn grouped_count_reference( + elements: &ArrayRef, + ranges: &[(usize, usize)], + ) -> VortexResult { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let counts: Buffer = ranges + .iter() + .map(|&(offset, size)| { + Ok(elements + .slice(offset..offset + size)? + .valid_count(&mut ctx)? as u64) + }) + .collect::>()?; + Ok(PrimitiveArray::new(counts, Validity::NonNullable).into_array()) + } + + fn listview(elements: ArrayRef, ranges: &[(usize, usize)]) -> VortexResult { + let offsets = PrimitiveArray::from_iter(ranges.iter().map(|&(o, _)| o as i32)); + let sizes = PrimitiveArray::from_iter(ranges.iter().map(|&(_, s)| s as i32)); + Ok(ListViewArray::try_new( + elements, + offsets.into_array(), + sizes.into_array(), + Validity::NonNullable, + )? + .into_array()) + } + + #[test] + fn listview_counts_all_valid() -> VortexResult<()> { + let elements = + PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); + let elem_dtype = DType::Primitive(PType::I32, NonNullable); + let ranges = [(0, 2), (2, 1), (3, 3), (6, 0)]; + + let groups = listview(elements.clone(), &ranges)?; + let actual = grouped_count_actual(&groups, &elem_dtype)?; + let expected = grouped_count_reference(&elements, &ranges)?; + + let direct = + PrimitiveArray::new(buffer![2u64, 1, 3, 0], Validity::NonNullable).into_array(); + assert_arrays_eq!(&actual, &direct); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_counts_with_nulls() -> VortexResult<()> { + let elements = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, None, Some(9)]) + .into_array(); + let elem_dtype = DType::Primitive(PType::I32, Nullable); + let ranges = [(0, 3), (3, 2), (5, 1)]; + + let groups = listview(elements.clone(), &ranges)?; + let actual = grouped_count_actual(&groups, &elem_dtype)?; + let expected = grouped_count_reference(&elements, &ranges)?; + + // Group 0: {1, null, 3} -> 2. Group 1: {null, null} -> 0. Group 2: {9} -> 1. + let direct = PrimitiveArray::new(buffer![2u64, 0, 1], Validity::NonNullable).into_array(); + assert_arrays_eq!(&actual, &direct); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_counts_varbinview_with_nulls() -> VortexResult<()> { + let elements = VarBinViewArray::from_iter_nullable_str([ + Some("a"), + None, + Some("bbb"), + None, + Some("cc"), + ]) + .into_array(); + let elem_dtype = elements.dtype().clone(); + let ranges = [(0, 2), (2, 2), (4, 1)]; + + let groups = listview(elements.clone(), &ranges)?; + let actual = grouped_count_actual(&groups, &elem_dtype)?; + let expected = grouped_count_reference(&elements, &ranges)?; + + let direct = PrimitiveArray::new(buffer![1u64, 1, 1], Validity::NonNullable).into_array(); + assert_arrays_eq!(&actual, &direct); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn fixed_size_counts_with_nulls() -> VortexResult<()> { + let elements = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4)]).into_array(); + let elem_dtype = DType::Primitive(PType::I32, Nullable); + let groups = + FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?.into_array(); + + let actual = grouped_count_actual(&groups, &elem_dtype)?; + let direct = PrimitiveArray::new(buffer![1u64, 2], Validity::NonNullable).into_array(); + assert_arrays_eq!(&actual, &direct); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index e25c42e0845..fe5ff56a634 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +mod grouped; + use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -10,6 +12,7 @@ use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::GroupedArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -82,6 +85,15 @@ impl AggregateFnVTable for Count { Ok(true) } + fn try_accumulate_grouped( + &self, + _options: &Self::Options, + groups: &GroupedArray, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + grouped::try_grouped_count(groups, ctx) + } + fn accumulate( &self, _partial: &mut Self::Partial, diff --git a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs new file mode 100644 index 00000000000..a86b7c595e2 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; +use vortex_mask::AllOr; +use vortex_mask::Mask; + +use super::primitive::sum_float_all; +use super::primitive::sum_signed_all; +use super::primitive::sum_unsigned_all; +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::aggregate_fn::GroupRanges; +use crate::aggregate_fn::GroupedArray; +use crate::arrays::Primitive; +use crate::arrays::PrimitiveArray; +use crate::dtype::NativePType; +use crate::match_each_native_ptype; + +/// Grouped [`Sum`](super::Sum) implementation for canonical primitive elements. +/// +/// Reuses the scalar primitive-sum reductions ([`sum_unsigned_all`]/[`sum_signed_all`]/ +/// [`sum_float_all`]) so the per-group semantics match scalar `sum` exactly (overflow saturates to +/// a null sum, NaNs are skipped). The element validity mask is materialized once and sliced per +/// group, rather than the per-group accumulator setup of the generic fallback path. +pub(super) fn try_grouped_sum( + groups: &GroupedArray, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + if !groups.elements().is::() { + return Ok(None); + } + let elements = groups.elements().clone().downcast::(); + let group_ranges = groups.group_ranges(ctx)?; + let group_validity = groups.group_validity(ctx)?; + + Ok(Some(grouped_sum( + &elements, + &group_ranges, + &group_validity, + ctx, + )?)) +} + +/// Sum each group described by `group_ranges` (element `(offset, size)` pairs), one sum per group. +fn grouped_sum( + elements: &PrimitiveArray, + group_ranges: &GroupRanges, + group_validity: &Mask, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let elem_mask = elements + .as_ref() + .validity()? + .execute_mask(elements.as_ref().len(), ctx)?; + let all_valid = matches!(elem_mask.slices(), AllOr::All); + + let result = match_each_native_ptype!(elements.ptype(), + unsigned: |T| { + let values = elements.as_slice::(); + collect_sums::(values, group_ranges, group_validity, &elem_mask, all_valid, + sum_unsigned_all) + }, + signed: |T| { + let values = elements.as_slice::(); + collect_sums::(values, group_ranges, group_validity, &elem_mask, all_valid, + sum_signed_all) + }, + floating: |T| { + let values = elements.as_slice::(); + collect_sums::(values, group_ranges, group_validity, &elem_mask, all_valid, + |acc, slice| { sum_float_all(acc, slice); false }) + } + ); + + Ok(result.into_array()) +} + +/// Reduce each group's element slice into a nullable sum. A group is null when the group +/// itself is invalid, or when summing it overflows (`sum_run` returns `true`). +fn collect_sums( + values: &[T], + group_ranges: &GroupRanges, + group_validity: &Mask, + elem_mask: &Mask, + all_valid: bool, + sum_run: impl Fn(&mut A, &[T]) -> bool, +) -> PrimitiveArray { + let sums = group_ranges.iter().enumerate().map(|(i, (offset, size))| { + if !group_validity.value(i) { + return None; + } + let mut acc = A::default(); + let overflow = if all_valid { + sum_run(&mut acc, &values[offset..offset + size]) + } else { + sum_masked_group(&mut acc, values, offset, size, elem_mask, &sum_run) + }; + (!overflow).then_some(acc) + }); + PrimitiveArray::from_option_iter(sums) +} + +/// Sum the valid elements of a single group, using the contiguous valid runs of the element mask +/// intersected with the group's `[offset, offset + size)` range. +fn sum_masked_group( + acc: &mut A, + values: &[T], + offset: usize, + size: usize, + elem_mask: &Mask, + sum_run: &impl Fn(&mut A, &[T]) -> bool, +) -> bool { + match elem_mask.slice(offset..offset + size).slices() { + AllOr::All => sum_run(acc, &values[offset..offset + size]), + AllOr::None => false, + AllOr::Some(runs) => { + for &(start, end) in runs { + if sum_run(acc, &values[offset + start..offset + end]) { + return true; + } + } + false + } + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::cast_possible_truncation)] + + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::ArrayRef; + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::DynGroupedAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupedAccumulator; + use crate::aggregate_fn::fns::sum::Sum; + use crate::aggregate_fn::fns::sum::sum; + use crate::arrays::FixedSizeListArray; + use crate::arrays::ListViewArray; + use crate::arrays::PrimitiveArray; + use crate::assert_arrays_eq; + use crate::builders::builder_with_capacity; + use crate::dtype::DType; + use crate::dtype::Nullability::NonNullable; + use crate::dtype::Nullability::Nullable; + use crate::dtype::PType; + use crate::validity::Validity; + + /// Run a grouped sum through the accumulator. + fn grouped_sum_actual(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { + let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?; + acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; + acc.finish() + } + + /// Reference sums computed exactly like the generic slow path: per-group scalar [`sum`] for + /// valid groups, a null sum for invalid groups. + fn grouped_sum_reference( + elements: &ArrayRef, + ranges: &[(usize, usize)], + group_valid: &[bool], + elem_dtype: &DType, + ) -> VortexResult { + use crate::aggregate_fn::AggregateFnVTable; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let sum_dtype = Sum + .partial_dtype(&EmptyOptions, elem_dtype) + .expect("sum partial dtype"); + let mut builder = builder_with_capacity(&sum_dtype, ranges.len()); + for (i, &(offset, size)) in ranges.iter().enumerate() { + if group_valid[i] { + let slice = elements.slice(offset..offset + size)?; + builder.append_scalar(&sum(&slice, &mut ctx)?)?; + } else { + builder.append_null(); + } + } + Ok(builder.finish()) + } + + fn offsets_sizes(ranges: &[(usize, usize)]) -> (ArrayRef, ArrayRef) { + let offsets = PrimitiveArray::from_iter(ranges.iter().map(|&(o, _)| o as i32)); + let sizes = PrimitiveArray::from_iter(ranges.iter().map(|&(_, s)| s as i32)); + (offsets.into_array(), sizes.into_array()) + } + + fn listview( + elements: ArrayRef, + ranges: &[(usize, usize)], + group_valid: &[bool], + ) -> VortexResult { + let (offsets, sizes) = offsets_sizes(ranges); + let validity = if group_valid.iter().all(|&v| v) { + Validity::NonNullable + } else { + Validity::from_iter(group_valid.iter().copied()) + }; + Ok(ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array()) + } + + #[test] + fn listview_matches_reference_unsigned() -> VortexResult<()> { + let elements = + PrimitiveArray::new(buffer![1u32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); + let elem_dtype = DType::Primitive(PType::U32, NonNullable); + let ranges = [(0, 2), (2, 1), (3, 3)]; + let valid = [true, true, true]; + + let groups = listview(elements.clone(), &ranges, &valid)?; + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; + + // Unsigned input sums to U64. + let direct = PrimitiveArray::from_option_iter([Some(3u64), Some(3u64), Some(15u64)]); + assert_arrays_eq!(&actual, &direct.into_array()); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_out_of_order_offsets_with_null_group() -> VortexResult<()> { + // Offsets are not in group order and a group is null: the group validity must be indexed by + // group index, not by element offset. + let elements = + PrimitiveArray::new(buffer![10i32, 20, 30, 40, 50, 60], Validity::NonNullable) + .into_array(); + let elem_dtype = DType::Primitive(PType::I32, NonNullable); + let ranges = [(4, 2), (0, 2), (2, 2)]; + let valid = [true, false, true]; + + let groups = listview(elements.clone(), &ranges, &valid)?; + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; + + let direct = PrimitiveArray::from_option_iter([Some(110i64), None, Some(70i64)]); + assert_arrays_eq!(&actual, &direct.into_array()); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_interior_and_full_nulls() -> VortexResult<()> { + // Group 1 has an interior null, group 2 is entirely null, group 3 is empty. + let elements = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, None, Some(9)]) + .into_array(); + let elem_dtype = DType::Primitive(PType::I32, Nullable); + let ranges = [(0, 3), (3, 2), (5, 0), (5, 1)]; + let valid = [true, true, true, true]; + + let groups = listview(elements.clone(), &ranges, &valid)?; + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; + + let direct = + PrimitiveArray::from_option_iter([Some(4i64), Some(0i64), Some(0i64), Some(9i64)]); + assert_arrays_eq!(&actual, &direct.into_array()); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_overflow_group_is_null() -> VortexResult<()> { + let elements = + PrimitiveArray::new(buffer![i64::MAX, 1, 2, 3], Validity::NonNullable).into_array(); + let elem_dtype = DType::Primitive(PType::I64, NonNullable); + let ranges = [(0, 2), (2, 2)]; + let valid = [true, true]; + + let groups = listview(elements.clone(), &ranges, &valid)?; + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; + + // First group overflows -> null sum; second group sums normally. + let direct = PrimitiveArray::from_option_iter([None, Some(5i64)]); + assert_arrays_eq!(&actual, &direct.into_array()); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn listview_float_nan_and_inf() -> VortexResult<()> { + let elements = PrimitiveArray::new( + buffer![1.0f64, f64::NAN, 2.0, f64::INFINITY, f64::NEG_INFINITY, 4.0], + Validity::NonNullable, + ) + .into_array(); + let elem_dtype = DType::Primitive(PType::F64, NonNullable); + let ranges = [(0, 3), (3, 3)]; + let valid = [true, true]; + + let groups = listview(elements.clone(), &ranges, &valid)?; + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + + // Group 0: NaN skipped -> 3.0. Group 1: INF + -INF = NaN. (Avoid array equality here since + // NaN != NaN; compare element scalars against the reference path instead.) + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; + let g0 = actual.execute_scalar(0, &mut ctx)?; + assert_eq!(g0.as_primitive().typed_value::(), Some(3.0)); + assert_eq!( + g0.as_primitive().typed_value::(), + expected + .execute_scalar(0, &mut ctx)? + .as_primitive() + .typed_value::() + ); + let g1 = actual.execute_scalar(1, &mut ctx)?; + assert!(g1.as_primitive().typed_value::().unwrap().is_nan()); + assert!( + expected + .execute_scalar(1, &mut ctx)? + .as_primitive() + .typed_value::() + .unwrap() + .is_nan() + ); + Ok(()) + } + + #[test] + fn fixed_size_overflow_and_nan() -> VortexResult<()> { + // FixedSize path: first group overflows -> null sum, second sums normally. + let elements = + PrimitiveArray::new(buffer![i64::MAX, 1, 2, 3], Validity::NonNullable).into_array(); + let elem_dtype = DType::Primitive(PType::I64, NonNullable); + let groups = FixedSizeListArray::try_new(elements.clone(), 2, Validity::NonNullable, 2)? + .into_array(); + + let actual = grouped_sum_actual(&groups, &elem_dtype)?; + let expected = + grouped_sum_reference(&elements, &[(0, 2), (2, 2)], &[true, true], &elem_dtype)?; + let direct = PrimitiveArray::from_option_iter([None, Some(5i64)]); + assert_arrays_eq!(&actual, &direct.into_array()); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 24799570ff7..4c2de6c1a7b 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -4,6 +4,7 @@ mod bool; mod constant; mod decimal; +mod grouped; mod primitive; use vortex_error::VortexExpect; @@ -25,6 +26,7 @@ use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::GroupedArray; use crate::dtype::DType; use crate::dtype::DecimalDType; use crate::dtype::MAX_PRECISION; @@ -213,6 +215,15 @@ impl AggregateFnVTable for Sum { } } + fn try_accumulate_grouped( + &self, + _options: &Self::Options, + groups: &GroupedArray, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + grouped::try_grouped_sum(groups, ctx) + } + fn accumulate( &self, partial: &mut Self::Partial, diff --git a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs index 44418fb5628..df7d929d896 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs @@ -50,11 +50,7 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR unsigned: |_T| { vortex_panic!("float sum state with unsigned input") }, signed: |_T| { vortex_panic!("float sum state with signed input") }, floating: |T| { - for &v in p.as_slice::() { - if !v.is_nan() { - *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); - } - } + sum_float_all(acc, p.as_slice::()); Ok(false) } ), @@ -62,11 +58,21 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR } } +/// Sum the non-NaN values of a float slice into an `f64` accumulator. NaNs are skipped to match the +/// scalar `sum` semantics. Floats cannot overflow the accumulator, so this never reports saturation. +pub(super) fn sum_float_all(acc: &mut f64, slice: &[T]) { + for &v in slice { + if !v.is_nan() { + *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); + } + } +} + /// Sum all values into a `u64` accumulator. For types narrower than 64 bits, values are summed in /// chunks of [`SUM_CHUNK`] with a single checked add per chunk, which lets the inner loop vectorize /// to packed widening adds. `u64` input keeps a per-element checked add since a chunk of `u64`s /// could itself overflow. Returns `true` on overflow. -fn sum_unsigned_all(acc: &mut u64, slice: &[T]) -> bool +pub(super) fn sum_unsigned_all(acc: &mut u64, slice: &[T]) -> bool where T: NativePType + AsPrimitive, { @@ -88,7 +94,7 @@ where } /// Signed counterpart of [`sum_unsigned_all`]. -fn sum_signed_all(acc: &mut i64, slice: &[T]) -> bool +pub(super) fn sum_signed_all(acc: &mut i64, slice: &[T]) -> bool where T: NativePType + AsPrimitive, { @@ -150,11 +156,7 @@ fn accumulate_primitive_valid( floating: |T| { let values = p.as_slice::(); for &(start, end) in slices { - for &v in &values[start..end] { - if !v.is_nan() { - *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); - } - } + sum_float_all(acc, &values[start..end]); } Ok(false) } diff --git a/vortex-array/src/aggregate_fn/kernels.rs b/vortex-array/src/aggregate_fn/kernels.rs index d806b18d84d..bc62848f728 100644 --- a/vortex-array/src/aggregate_fn/kernels.rs +++ b/vortex-array/src/aggregate_fn/kernels.rs @@ -11,8 +11,7 @@ use vortex_error::VortexResult; use crate::ArrayRef; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnRef; -use crate::arrays::FixedSizeListArray; -use crate::arrays::ListViewArray; +use crate::aggregate_fn::GroupedArray; use crate::scalar::Scalar; /// A pluggable kernel for an aggregate function. @@ -31,33 +30,21 @@ pub trait DynAggregateKernel: 'static + Send + Sync + Debug { /// A pluggable kernel for batch aggregation of many groups. /// /// The kernel is matched on the encoding of the _elements_ array, which is the inner array of the -/// provided `ListViewArray`. This is more pragmatic than having every kernel match on the outer -/// list encoding and having to deal with the possibility of multiple list encodings. +/// provided grouped array. This is more pragmatic than having every kernel match on the outer list +/// encoding and having to deal with the possibility of multiple list encodings. /// -/// Each element of the list array represents a group and the result of the grouped aggregate +/// Each value in the grouped array represents a group and the result of the grouped aggregate /// should be an array of the same length, where each element is the aggregate state of the /// corresponding group. /// /// Return `Ok(None)` if the kernel cannot be applied to the given aggregate function. pub trait DynGroupedAggregateKernel: 'static + Send + Sync + Debug { - /// Aggregate each group in the provided `ListViewArray` and return an array of the - /// aggregate states. + /// Aggregate each group in the provided grouped array and return an array of the aggregate + /// states. fn grouped_aggregate( &self, aggregate_fn: &AggregateFnRef, - groups: &ListViewArray, + groups: &GroupedArray, + ctx: &mut ExecutionCtx, ) -> VortexResult>; - - /// Aggregate each group in the provided `FixedSizeListArray` and return an array of the - /// aggregate states. - fn grouped_aggregate_fixed_size( - &self, - aggregate_fn: &AggregateFnRef, - groups: &FixedSizeListArray, - ) -> VortexResult> { - // TODO(ngates): we could automatically delegate to `grouped_aggregate` if SequenceArray - // was in the vortex-array crate - let _ = (aggregate_fn, groups); - Ok(None) - } } diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 28b91d45166..7ad817747ae 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -18,6 +18,7 @@ use crate::aggregate_fn::AggregateFn; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnSatisfaction; +use crate::aggregate_fn::GroupedArray; use crate::dtype::DType; use crate::scalar::Scalar; @@ -138,6 +139,20 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { Ok(false) } + /// Try to accumulate many list groups in one batch. + /// + /// The returned array must contain one partial state per group and have the dtype returned by + /// [`Self::partial_dtype`]. Returning `Ok(None)` falls back to the default per-group accumulator + /// loop. + fn try_accumulate_grouped( + &self, + _options: &Self::Options, + _groups: &GroupedArray, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + Ok(None) + } + /// Accumulate a new canonical array into the accumulator state. fn accumulate( &self, From 34328c7ac013720f4738ee5bceeb515d9380d0dd Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Tue, 9 Jun 2026 15:12:46 +0100 Subject: [PATCH 2/5] comments Signed-off-by: Onur Satici --- .../src/aggregate_fn/accumulator_grouped.rs | 29 ++++----- .../src/aggregate_fn/fns/count/grouped.rs | 20 +++++++ .../src/aggregate_fn/fns/count/mod.rs | 12 +--- .../src/aggregate_fn/fns/sum/grouped.rs | 20 +++++++ vortex-array/src/aggregate_fn/fns/sum/mod.rs | 12 +--- vortex-array/src/aggregate_fn/kernels.rs | 8 ++- vortex-array/src/aggregate_fn/session.rs | 59 ++++++++++++++----- vortex-array/src/aggregate_fn/vtable.rs | 15 ----- 8 files changed, 107 insertions(+), 68 deletions(-) diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 06b535a7686..5dbc52b0c55 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -11,7 +11,6 @@ use vortex_error::vortex_err; use vortex_error::vortex_panic; use vortex_mask::Mask; -use crate::AnyCanonical; use crate::ArrayRef; use crate::Canonical; use crate::Columnar; @@ -30,6 +29,7 @@ use crate::arrays::fixed_size_list::FixedSizeListArrayExt; use crate::arrays::listview::ListViewArrayExt; use crate::builders::builder_with_capacity; use crate::builtins::ArrayBuiltins; +use crate::columnar::AnyColumnar; use crate::dtype::DType; use crate::executor::max_iterations; use crate::match_each_integer_ptype; @@ -300,15 +300,19 @@ impl GroupedAccumulator { let mut elements = groups.elements().clone(); let session = ctx.session().clone(); - for _ in 0..max_iterations() { - if elements.is::() { - break; - } + if let Some(kernel) = session + .aggregate_fns() + .find_grouped_kernel(self.aggregate_fn.id()) + && let Some(result) = kernel.grouped_aggregate(&self.aggregate_fn, &groups, ctx)? + { + return self.push_result(result); + } - // Try a registered grouped kernel for the current non-canonical element encoding. + for _ in 0..max_iterations() { + // Try a registered grouped kernel for the current element encoding. if let Some(kernel) = session .aggregate_fns() - .find_grouped_kernel(elements.encoding_id(), self.aggregate_fn.id()) + .find_grouped_encoding_kernel(elements.encoding_id(), self.aggregate_fn.id()) { // SAFETY: we assume that elements execution is safe let kernel_groups = unsafe { groups.with_elements_unchecked(elements.clone())? }; @@ -319,6 +323,10 @@ impl GroupedAccumulator { } } + if elements.is::() { + break; + } + // Execute one step and try again elements = elements.execute(ctx)?; } @@ -328,13 +336,6 @@ impl GroupedAccumulator { // executed form of the same logical array. let grouped = unsafe { groups.with_elements_unchecked(elements)? }; - if let Some(result) = self - .vtable - .try_accumulate_grouped(&self.options, &grouped, ctx)? - { - return self.push_result(result); - } - // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. self.accumulate_grouped_fallback(&grouped, ctx) } diff --git a/vortex-array/src/aggregate_fn/fns/count/grouped.rs b/vortex-array/src/aggregate_fn/fns/count/grouped.rs index 37b9a24b928..3ad307290a0 100644 --- a/vortex-array/src/aggregate_fn/fns/count/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/count/grouped.rs @@ -8,11 +8,31 @@ use vortex_mask::Mask; use crate::ArrayRef; use crate::ExecutionCtx; use crate::IntoArray; +use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::GroupRanges; use crate::aggregate_fn::GroupedArray; +use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; use crate::arrays::PrimitiveArray; use crate::validity::Validity; +/// Encoding-independent grouped [`Count`](super::Count) kernel. +#[derive(Debug)] +pub(crate) struct CountGroupedKernel; + +impl DynGroupedAggregateKernel for CountGroupedKernel { + fn grouped_aggregate( + &self, + aggregate_fn: &AggregateFnRef, + groups: &GroupedArray, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + if !aggregate_fn.is::() { + return Ok(None); + } + try_grouped_count(groups, ctx) + } +} + /// Count each valid group from the element validity mask. /// /// The [`Count`](super::Count) partial dtype is non-nullable `U64`, so a null outer group cannot be diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index fe5ff56a634..1fe984fb099 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -2,7 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors mod grouped; - +pub(crate) use grouped::CountGroupedKernel; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -12,7 +12,6 @@ use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::EmptyOptions; -use crate::aggregate_fn::GroupedArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -85,15 +84,6 @@ impl AggregateFnVTable for Count { Ok(true) } - fn try_accumulate_grouped( - &self, - _options: &Self::Options, - groups: &GroupedArray, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - grouped::try_grouped_count(groups, ctx) - } - fn accumulate( &self, _partial: &mut Self::Partial, diff --git a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs index a86b7c595e2..dfeacd39e42 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs @@ -11,13 +11,33 @@ use super::primitive::sum_unsigned_all; use crate::ArrayRef; use crate::ExecutionCtx; use crate::IntoArray; +use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::GroupRanges; use crate::aggregate_fn::GroupedArray; +use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; use crate::arrays::Primitive; use crate::arrays::PrimitiveArray; use crate::dtype::NativePType; use crate::match_each_native_ptype; +/// Grouped [`Sum`](super::Sum) kernel for primitive element arrays. +#[derive(Debug)] +pub(crate) struct PrimitiveGroupedSumKernel; + +impl DynGroupedAggregateKernel for PrimitiveGroupedSumKernel { + fn grouped_aggregate( + &self, + aggregate_fn: &AggregateFnRef, + groups: &GroupedArray, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + if !aggregate_fn.is::() { + return Ok(None); + } + try_grouped_sum(groups, ctx) + } +} + /// Grouped [`Sum`](super::Sum) implementation for canonical primitive elements. /// /// Reuses the scalar primitive-sum reductions ([`sum_unsigned_all`]/[`sum_signed_all`]/ diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 4c2de6c1a7b..9dcfc41b976 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -6,7 +6,7 @@ mod constant; mod decimal; mod grouped; mod primitive; - +pub(crate) use grouped::PrimitiveGroupedSumKernel; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -26,7 +26,6 @@ use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; -use crate::aggregate_fn::GroupedArray; use crate::dtype::DType; use crate::dtype::DecimalDType; use crate::dtype::MAX_PRECISION; @@ -215,15 +214,6 @@ impl AggregateFnVTable for Sum { } } - fn try_accumulate_grouped( - &self, - _options: &Self::Options, - groups: &GroupedArray, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - grouped::try_grouped_sum(groups, ctx) - } - fn accumulate( &self, partial: &mut Self::Partial, diff --git a/vortex-array/src/aggregate_fn/kernels.rs b/vortex-array/src/aggregate_fn/kernels.rs index bc62848f728..c5af0902cbb 100644 --- a/vortex-array/src/aggregate_fn/kernels.rs +++ b/vortex-array/src/aggregate_fn/kernels.rs @@ -29,9 +29,11 @@ pub trait DynAggregateKernel: 'static + Send + Sync + Debug { /// A pluggable kernel for batch aggregation of many groups. /// -/// The kernel is matched on the encoding of the _elements_ array, which is the inner array of the -/// provided grouped array. This is more pragmatic than having every kernel match on the outer list -/// encoding and having to deal with the possibility of multiple list encodings. +/// A kernel can be registered either for an aggregate function regardless of the element encoding, +/// or for a specific aggregate function and element encoding. Element-encoding kernels are matched +/// on the inner array of the provided grouped array, not on the outer list encoding. This is more +/// pragmatic than having every kernel match on the outer list encoding and having to deal with the +/// possibility of multiple list encodings. /// /// Each value in the grouped array represents a group and the result of the grouped aggregate /// should be an array of the same length, where each element is the aggregate state of the diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index edbafdf386f..4109a8fa349 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -18,6 +18,8 @@ use crate::aggregate_fn::fns::all_non_null::AllNonNull; use crate::aggregate_fn::fns::all_null::AllNull; use crate::aggregate_fn::fns::bounded_max::BoundedMax; use crate::aggregate_fn::fns::bounded_min::BoundedMin; +use crate::aggregate_fn::fns::count::Count; +use crate::aggregate_fn::fns::count::CountGroupedKernel; use crate::aggregate_fn::fns::first::First; use crate::aggregate_fn::fns::is_constant::IsConstant; use crate::aggregate_fn::fns::is_sorted::IsSorted; @@ -27,6 +29,7 @@ use crate::aggregate_fn::fns::min::Min; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::nan_count::NanCount; use crate::aggregate_fn::fns::null_count::NullCount; +use crate::aggregate_fn::fns::sum::PrimitiveGroupedSumKernel; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes; use crate::aggregate_fn::kernels::DynAggregateKernel; @@ -36,6 +39,7 @@ use crate::array::ArrayId; use crate::array::VTable; use crate::arrays::Chunked; use crate::arrays::Dict; +use crate::arrays::Primitive; use crate::arrays::chunked::compute::aggregate::ChunkedArrayAggregate; use crate::arrays::dict::compute::is_constant::DictIsConstantKernel; use crate::arrays::dict::compute::is_sorted::DictIsSortedKernel; @@ -50,8 +54,10 @@ use crate::arrays::dict::compute::min_max::DictMinMaxKernel; pub struct AggregateFnSession { registry: ArcSwapMap, - kernels: ArcSwapMap, - grouped_kernels: ArcSwapMap, + kernels: ArcSwapMap, + grouped_kernels: ArcSwapMap, + grouped_encoding_kernels: + ArcSwapMap, } impl SessionVar for AggregateFnSession { @@ -64,7 +70,8 @@ impl SessionVar for AggregateFnSession { } } -type KernelKey = (ArrayId, Option); +type AggregateKernelKey = (ArrayId, Option); +type GroupedEncodingKernelKey = (ArrayId, AggregateFnId); impl Default for AggregateFnSession { fn default() -> Self { @@ -72,6 +79,7 @@ impl Default for AggregateFnSession { registry: ArcSwapMap::default(), kernels: ArcSwapMap::default(), grouped_kernels: ArcSwapMap::default(), + grouped_encoding_kernels: ArcSwapMap::default(), }; // Register the built-in aggregate functions @@ -100,6 +108,10 @@ impl Default for AggregateFnSession { this.register_aggregate_kernel(Dict.id(), Some(IsConstant.id()), &DictIsConstantKernel); this.register_aggregate_kernel(Dict.id(), Some(IsSorted.id()), &DictIsSortedKernel); + // Register the built-in grouped aggregate kernels. + this.register_grouped_kernel(Count.id(), &CountGroupedKernel); + this.register_grouped_encoding_kernel(Primitive.id(), Sum.id(), &PrimitiveGroupedSumKernel); + this } } @@ -152,27 +164,46 @@ impl AggregateFnSession { self.kernels.insert(id, kernel); } - /// Returns the grouped aggregate kernel registered for `array_id` and `agg_fn_id`, if any. + /// Returns the grouped aggregate kernel registered for `agg_fn_id`, if any. /// - /// Lookup first checks for a kernel registered for the exact aggregate function, then falls - /// back to a kernel registered for all aggregate functions on the same array encoding. + /// These kernels are independent of the element encoding and are checked once before the + /// grouped accumulator executes the element array. pub fn find_grouped_kernel( + &self, + agg_fn_id: impl Into, + ) -> Option<&'static dyn DynGroupedAggregateKernel> { + let fn_id = agg_fn_id.into(); + self.grouped_kernels + .read(|kernels| kernels.get(&fn_id).copied()) + } + + /// Registers a grouped aggregate kernel for an aggregate function. + pub fn register_grouped_kernel( + &self, + agg_fn_id: impl Into, + kernel: &'static dyn DynGroupedAggregateKernel, + ) { + let fn_id = agg_fn_id.into(); + self.grouped_kernels.insert(fn_id, kernel) + } + + /// Returns the grouped aggregate kernel registered for `array_id` and `agg_fn_id`, if any. + /// + /// These kernels are matched against each intermediate element encoding while the grouped + /// accumulator executes the element array. + pub fn find_grouped_encoding_kernel( &self, array_id: impl Into, agg_fn_id: impl Into, ) -> Option<&'static dyn DynGroupedAggregateKernel> { let id = array_id.into(); let fn_id = agg_fn_id.into(); - self.grouped_kernels.read(|kernels| { - kernels - .get(&(id, Some(fn_id))) - .or_else(|| kernels.get(&(id, None))) - .copied() - }) + self.grouped_encoding_kernels + .read(|kernels| kernels.get(&(id, fn_id)).copied()) } /// Registers a grouped aggregate kernel for a specific aggregate function and array encoding. - pub fn register_grouped_kernel( + pub fn register_grouped_encoding_kernel( &self, array_id: impl Into, agg_fn_id: impl Into, @@ -180,7 +211,7 @@ impl AggregateFnSession { ) { let id = array_id.into(); let fn_id = agg_fn_id.into(); - self.grouped_kernels.insert((id, Some(fn_id)), kernel) + self.grouped_encoding_kernels.insert((id, fn_id), kernel) } } diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 7ad817747ae..28b91d45166 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -18,7 +18,6 @@ use crate::aggregate_fn::AggregateFn; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnSatisfaction; -use crate::aggregate_fn::GroupedArray; use crate::dtype::DType; use crate::scalar::Scalar; @@ -139,20 +138,6 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { Ok(false) } - /// Try to accumulate many list groups in one batch. - /// - /// The returned array must contain one partial state per group and have the dtype returned by - /// [`Self::partial_dtype`]. Returning `Ok(None)` falls back to the default per-group accumulator - /// loop. - fn try_accumulate_grouped( - &self, - _options: &Self::Options, - _groups: &GroupedArray, - _ctx: &mut ExecutionCtx, - ) -> VortexResult> { - Ok(None) - } - /// Accumulate a new canonical array into the accumulator state. fn accumulate( &self, From a21387c1b5e85a21ccc97cc5917f9e4fa5f05e0b Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Tue, 9 Jun 2026 16:10:03 +0100 Subject: [PATCH 3/5] proper order Signed-off-by: Onur Satici --- .../src/aggregate_fn/accumulator_grouped.rs | 24 ++++++++++++------- vortex-array/src/aggregate_fn/session.rs | 3 ++- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 5dbc52b0c55..00ff03fc69e 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -299,14 +299,7 @@ impl GroupedAccumulator { ) -> VortexResult<()> { let mut elements = groups.elements().clone(); let session = ctx.session().clone(); - - if let Some(kernel) = session - .aggregate_fns() - .find_grouped_kernel(self.aggregate_fn.id()) - && let Some(result) = kernel.grouped_aggregate(&self.aggregate_fn, &groups, ctx)? - { - return self.push_result(result); - } + let mut checked_aggregate_kernel = false; for _ in 0..max_iterations() { // Try a registered grouped kernel for the current element encoding. @@ -323,6 +316,21 @@ impl GroupedAccumulator { } } + if !checked_aggregate_kernel { + // check the aggregate function kernel once if any. This is done in this loop + // so the encoding specific kernel check above takes precedence, and we get + // to check the aggregate kernel before decompressing the world + checked_aggregate_kernel = true; + if let Some(kernel) = session + .aggregate_fns() + .find_grouped_kernel(self.aggregate_fn.id()) + && let Some(result) = + kernel.grouped_aggregate(&self.aggregate_fn, &groups, ctx)? + { + return self.push_result(result); + } + } + if elements.is::() { break; } diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index 4109a8fa349..0d15d34383e 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -167,7 +167,8 @@ impl AggregateFnSession { /// Returns the grouped aggregate kernel registered for `agg_fn_id`, if any. /// /// These kernels are independent of the element encoding and are checked once before the - /// grouped accumulator executes the element array. + /// grouped accumulator executes the element array, after any kernel registered for the initial + /// element encoding. pub fn find_grouped_kernel( &self, agg_fn_id: impl Into, From ef4ae99a350b34e55b75040793d12b3f1bad3706 Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Tue, 9 Jun 2026 16:26:06 +0100 Subject: [PATCH 4/5] check every iter Signed-off-by: Onur Satici --- .../src/aggregate_fn/accumulator_grouped.rs | 22 +++++++++---------- vortex-array/src/aggregate_fn/session.rs | 5 ++--- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 00ff03fc69e..b87c04ee204 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -299,7 +299,6 @@ impl GroupedAccumulator { ) -> VortexResult<()> { let mut elements = groups.elements().clone(); let session = ctx.session().clone(); - let mut checked_aggregate_kernel = false; for _ in 0..max_iterations() { // Try a registered grouped kernel for the current element encoding. @@ -316,16 +315,16 @@ impl GroupedAccumulator { } } - if !checked_aggregate_kernel { - // check the aggregate function kernel once if any. This is done in this loop - // so the encoding specific kernel check above takes precedence, and we get - // to check the aggregate kernel before decompressing the world - checked_aggregate_kernel = true; - if let Some(kernel) = session - .aggregate_fns() - .find_grouped_kernel(self.aggregate_fn.id()) - && let Some(result) = - kernel.grouped_aggregate(&self.aggregate_fn, &groups, ctx)? + // Try a grouped kernel for the current aggregate regardless of element encoding. + if let Some(kernel) = session + .aggregate_fns() + .find_grouped_kernel(self.aggregate_fn.id()) + { + // SAFETY: we preserve the grouped shape and validity while replacing the + // elements with another representation of the same logical array. + let kernel_groups = unsafe { groups.with_elements_unchecked(elements.clone())? }; + if let Some(result) = + kernel.grouped_aggregate(&self.aggregate_fn, &kernel_groups, ctx)? { return self.push_result(result); } @@ -386,7 +385,6 @@ impl GroupedAccumulator { Ok(()) } } - fn list_view_group_ranges( groups: &ListViewArray, ctx: &mut ExecutionCtx, diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index 0d15d34383e..b935f6ef9de 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -166,9 +166,8 @@ impl AggregateFnSession { /// Returns the grouped aggregate kernel registered for `agg_fn_id`, if any. /// - /// These kernels are independent of the element encoding and are checked once before the - /// grouped accumulator executes the element array, after any kernel registered for the initial - /// element encoding. + /// These kernels are independent of the element encoding and are checked for each element + /// representation, after any kernel registered for the current element encoding. pub fn find_grouped_kernel( &self, agg_fn_id: impl Into, From 5d009d02762e0c46e8ff66ececb87f997e4fc620 Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Thu, 11 Jun 2026 11:17:23 +0100 Subject: [PATCH 5/5] comments Signed-off-by: Onur Satici --- vortex-array/src/aggregate_fn/fns/count/grouped.rs | 11 ++++++----- vortex-array/src/aggregate_fn/fns/sum/grouped.rs | 11 ++++++----- vortex-array/src/aggregate_fn/fns/sum/mod.rs | 2 +- vortex-array/src/aggregate_fn/session.rs | 8 ++++++-- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/vortex-array/src/aggregate_fn/fns/count/grouped.rs b/vortex-array/src/aggregate_fn/fns/count/grouped.rs index 3ad307290a0..fb94489dde0 100644 --- a/vortex-array/src/aggregate_fn/fns/count/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/count/grouped.rs @@ -5,6 +5,7 @@ use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_mask::Mask; +use super::Count; use crate::ArrayRef; use crate::ExecutionCtx; use crate::IntoArray; @@ -15,7 +16,7 @@ use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; use crate::arrays::PrimitiveArray; use crate::validity::Validity; -/// Encoding-independent grouped [`Count`](super::Count) kernel. +/// Encoding-independent grouped [`Count`] kernel. #[derive(Debug)] pub(crate) struct CountGroupedKernel; @@ -26,7 +27,7 @@ impl DynGroupedAggregateKernel for CountGroupedKernel { groups: &GroupedArray, ctx: &mut ExecutionCtx, ) -> VortexResult> { - if !aggregate_fn.is::() { + if !aggregate_fn.is::() { return Ok(None); } try_grouped_count(groups, ctx) @@ -35,9 +36,9 @@ impl DynGroupedAggregateKernel for CountGroupedKernel { /// Count each valid group from the element validity mask. /// -/// The [`Count`](super::Count) partial dtype is non-nullable `U64`, so a null outer group cannot be -/// represented as a partial state. If any outer group is invalid, this returns `Ok(None)` and lets -/// the caller use the existing fallback behavior. +/// The [`Count`] partial dtype is non-nullable `U64`, so a null outer group cannot be represented +/// as a partial state. If any outer group is invalid, this returns `Ok(None)` and lets the caller +/// use the existing fallback behavior. pub(super) fn try_grouped_count( groups: &GroupedArray, ctx: &mut ExecutionCtx, diff --git a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs index dfeacd39e42..6f00cce7fdb 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs @@ -5,6 +5,7 @@ use vortex_error::VortexResult; use vortex_mask::AllOr; use vortex_mask::Mask; +use super::Sum; use super::primitive::sum_float_all; use super::primitive::sum_signed_all; use super::primitive::sum_unsigned_all; @@ -20,25 +21,25 @@ use crate::arrays::PrimitiveArray; use crate::dtype::NativePType; use crate::match_each_native_ptype; -/// Grouped [`Sum`](super::Sum) kernel for primitive element arrays. +/// Encoding-specific grouped [`Sum`] kernel for primitive element arrays. #[derive(Debug)] -pub(crate) struct PrimitiveGroupedSumKernel; +pub(crate) struct PrimitiveGroupedSumEncodingKernel; -impl DynGroupedAggregateKernel for PrimitiveGroupedSumKernel { +impl DynGroupedAggregateKernel for PrimitiveGroupedSumEncodingKernel { fn grouped_aggregate( &self, aggregate_fn: &AggregateFnRef, groups: &GroupedArray, ctx: &mut ExecutionCtx, ) -> VortexResult> { - if !aggregate_fn.is::() { + if !aggregate_fn.is::() { return Ok(None); } try_grouped_sum(groups, ctx) } } -/// Grouped [`Sum`](super::Sum) implementation for canonical primitive elements. +/// Grouped [`Sum`] implementation for canonical primitive elements. /// /// Reuses the scalar primitive-sum reductions ([`sum_unsigned_all`]/[`sum_signed_all`]/ /// [`sum_float_all`]) so the per-group semantics match scalar `sum` exactly (overflow saturates to diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 9dcfc41b976..9d525bec742 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -6,7 +6,7 @@ mod constant; mod decimal; mod grouped; mod primitive; -pub(crate) use grouped::PrimitiveGroupedSumKernel; +pub(crate) use grouped::PrimitiveGroupedSumEncodingKernel; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index b935f6ef9de..c6d7542a687 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -29,7 +29,7 @@ use crate::aggregate_fn::fns::min::Min; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::nan_count::NanCount; use crate::aggregate_fn::fns::null_count::NullCount; -use crate::aggregate_fn::fns::sum::PrimitiveGroupedSumKernel; +use crate::aggregate_fn::fns::sum::PrimitiveGroupedSumEncodingKernel; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes; use crate::aggregate_fn::kernels::DynAggregateKernel; @@ -110,7 +110,11 @@ impl Default for AggregateFnSession { // Register the built-in grouped aggregate kernels. this.register_grouped_kernel(Count.id(), &CountGroupedKernel); - this.register_grouped_encoding_kernel(Primitive.id(), Sum.id(), &PrimitiveGroupedSumKernel); + this.register_grouped_encoding_kernel( + Primitive.id(), + Sum.id(), + &PrimitiveGroupedSumEncodingKernel, + ); this }