From 4c8059c3dae12d095b19a12567c6da9d7478af59 Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov Date: Thu, 11 Jun 2026 16:29:53 +0100 Subject: [PATCH 1/3] Add f64 sum benchmarks Add f64 cases to the aggregate_sum and aggregate_grouped benchmarks to establish a baseline for upcoming float summation changes. Signed-off-by: Dimitar Dimitrov --- vortex-array/benches/aggregate_grouped.rs | 46 +++++++++++++++++++++++ vortex-array/benches/aggregate_sum.rs | 36 ++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/vortex-array/benches/aggregate_grouped.rs b/vortex-array/benches/aggregate_grouped.rs index aa1f895f294..b067314c1d9 100644 --- a/vortex-array/benches/aggregate_grouped.rs +++ b/vortex-array/benches/aggregate_grouped.rs @@ -98,6 +98,36 @@ fn i32_clustered_nulls_input() -> ArrayRef { ) } +fn f64_all_valid_input() -> ArrayRef { + let group_sizes = random_group_sizes(); + let element_count = total_element_count(&group_sizes); + let mut rng = StdRng::seed_from_u64(GROUP_SIZE_SEED); + let values: Buffer = (0..element_count) + .map(|_| rng.random_range(-1000.0..1000.0)) + .collect(); + contiguous_list_view( + PrimitiveArray::new(values, Validity::NonNullable).into_array(), + &group_sizes, + ) +} + +fn f64_clustered_nulls_input() -> ArrayRef { + let group_sizes = random_group_sizes(); + let element_count = total_element_count(&group_sizes); + let mut rng = StdRng::seed_from_u64(GROUP_SIZE_SEED); + let values = (0..element_count).map(|i| { + if (i / 16) % 8 == 0 { + None + } else { + Some(rng.random_range(-1000.0f64..1000.0)) + } + }); + contiguous_list_view( + PrimitiveArray::from_option_iter(values).into_array(), + &group_sizes, + ) +} + fn varbinview_input() -> ArrayRef { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); @@ -144,6 +174,22 @@ fn sum_i32_clustered_nulls(bencher: Bencher) { .bench_refs(|input| grouped_accumulator(input, Sum)); } +#[divan::bench] +fn sum_f64_all_valid(bencher: Bencher) { + let input = f64_all_valid_input(); + bencher + .with_inputs(|| &input) + .bench_refs(|input| grouped_accumulator(input, Sum)); +} + +#[divan::bench] +fn sum_f64_clustered_nulls(bencher: Bencher) { + let input = f64_clustered_nulls_input(); + bencher + .with_inputs(|| &input) + .bench_refs(|input| grouped_accumulator(input, Sum)); +} + #[divan::bench] fn count_i32_clustered_nulls(bencher: Bencher) { let input = i32_clustered_nulls_input(); diff --git a/vortex-array/benches/aggregate_sum.rs b/vortex-array/benches/aggregate_sum.rs index 7d0e249c1ff..b076bb80b5f 100644 --- a/vortex-array/benches/aggregate_sum.rs +++ b/vortex-array/benches/aggregate_sum.rs @@ -63,6 +63,42 @@ fn sum_i64(bencher: Bencher) { .bench_refs(|(a, ctx)| a.statistics().compute_as::(Stat::Sum, ctx)); } +#[divan::bench] +fn sum_f64(bencher: Bencher) { + let mut rng = StdRng::seed_from_u64(6); + let data: Vec = (0..N).map(|_| rng.random_range(-1000.0..1000.0)).collect(); + bencher + .with_inputs(|| { + ( + PrimitiveArray::from_iter(data.iter().copied()).into_array(), + SESSION.create_execution_ctx(), + ) + }) + .bench_refs(|(a, ctx)| a.statistics().compute_as::(Stat::Sum, ctx)); +} + +#[divan::bench] +fn sum_f64_nulls_clustered(bencher: Bencher) { + let mut rng = StdRng::seed_from_u64(7); + let data: Vec> = (0..N) + .map(|i| { + if (i / 64) % 10 == 0 { + None + } else { + Some(rng.random_range(-1000.0..1000.0)) + } + }) + .collect(); + bencher + .with_inputs(|| { + ( + PrimitiveArray::from_option_iter(data.iter().copied()).into_array(), + SESSION.create_execution_ctx(), + ) + }) + .bench_refs(|(a, ctx)| a.statistics().compute_as::(Stat::Sum, ctx)); +} + // Clustered nulls: long runs of valid values broken up by occasional null blocks. This is the // case the run-based valid path is expected to accelerate. #[divan::bench] From 49b8dbcab22c571696592dca15b389cccd4d46c5 Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov Date: Thu, 11 Jun 2026 16:35:25 +0100 Subject: [PATCH 2/3] Use Kahan (Neumaier) compensated summation for float sums SumState::Float now carries a KahanSum (running sum plus compensation term) instead of a bare f64, so float sum aggregates accumulate with Neumaier's variant of Kahan summation. The compensation update is skipped when the addition produces a non-finite value, keeping the existing overflow-to-infinity and inf + -inf => NaN (saturated) semantics. Signed-off-by: Dimitar Dimitrov --- vortex-array/src/aggregate_fn/fns/sum/mod.rs | 47 +++++++++++++++++-- .../src/aggregate_fn/fns/sum/primitive.rs | 44 ++++++++++++++++- 2 files changed, 84 insertions(+), 7 deletions(-) diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 24799570ff7..7855ef24e9e 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -161,7 +161,7 @@ impl AggregateFnVTable for Sum { .as_primitive() .typed_value::() .vortex_expect("checked non-null"); - *acc += val; + acc.add(val); false } SumState::Decimal { value, dtype } => { @@ -189,7 +189,7 @@ impl AggregateFnVTable for Sum { None => Scalar::null(partial.return_dtype.as_nullable()), Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable), Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable), - Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable), + Some(SumState::Float(v)) => Scalar::primitive(v.value(), Nullability::Nullable), Some(SumState::Decimal { value, .. }) => { let decimal_dtype = *partial .return_dtype @@ -208,7 +208,7 @@ impl AggregateFnVTable for Sum { fn is_saturated(&self, partial: &Self::Partial) -> bool { match partial.current.as_ref() { None => true, - Some(SumState::Float(v)) => v.is_nan(), + Some(SumState::Float(v)) => v.value().is_nan(), Some(_) => false, } } @@ -277,19 +277,56 @@ pub struct SumPartial { pub enum SumState { Unsigned(u64), Signed(i64), - Float(f64), + Float(KahanSum), Decimal { value: DecimalValue, dtype: DecimalDType, }, } +/// Floating point sum state using the Neumaier variant of Kahan compensated summation. +/// +/// A running compensation term captures the low-order bits that are lost when adding values of +/// differing magnitude, greatly reducing the accumulated rounding error of naive recursive +/// summation. +#[derive(Clone, Copy, Debug, Default)] +pub struct KahanSum { + sum: f64, + compensation: f64, +} + +impl KahanSum { + /// Add a value to the running sum, folding the rounding error of the addition into the + /// compensation term. + #[inline] + pub fn add(&mut self, value: f64) { + let t = self.sum + value; + // When `t` is non-finite (overflow to infinity, or inf + -inf = NaN) the error term + // below would itself be NaN and poison the compensation. The non-finite result is + // sticky in `sum` instead, so skip the compensation update. + if t.is_finite() { + self.compensation += if self.sum.abs() >= value.abs() { + (self.sum - t) + value + } else { + (value - t) + self.sum + }; + } + self.sum = t; + } + + /// The compensated value of the sum. + #[inline] + pub fn value(&self) -> f64 { + self.sum + self.compensation + } +} + fn make_zero_state(return_dtype: &DType) -> SumState { match return_dtype { DType::Primitive(ptype, _) => match ptype { PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0), PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0), - PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0), + PType::F16 | PType::F32 | PType::F64 => SumState::Float(KahanSum::default()), }, DType::Decimal(decimal, _) => SumState::Decimal { value: DecimalValue::zero(decimal), diff --git a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs index 44418fb5628..52a8ea3b0fe 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs @@ -52,7 +52,7 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR floating: |T| { for &v in p.as_slice::() { if !v.is_nan() { - *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); + acc.add(ToPrimitive::to_f64(&v).vortex_expect("float to f64")); } } Ok(false) @@ -152,7 +152,7 @@ fn accumulate_primitive_valid( for &(start, end) in slices { for &v in &values[start..end] { if !v.is_nan() { - *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); + acc.add(ToPrimitive::to_f64(&v).vortex_expect("float to f64")); } } } @@ -326,6 +326,46 @@ mod tests { Ok(()) } + #[test] + fn sum_f64_kahan_compensation() -> VortexResult<()> { + // Naive recursive summation loses the two 1.0s next to 1e100 and returns 0.0. + let arr = PrimitiveArray::new(buffer![1.0f64, 1e100, 1.0, -1e100], Validity::NonNullable) + .into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!(result.as_primitive().typed_value::(), Some(2.0)); + Ok(()) + } + + #[test] + fn sum_f64_kahan_compensation_across_batches() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + + let batch1 = + PrimitiveArray::new(buffer![1.0f64, 1e100], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + let batch2 = + PrimitiveArray::new(buffer![1.0f64, -1e100], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + + let result = acc.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(2.0)); + Ok(()) + } + + #[test] + fn sum_f64_overflow_to_infinity() -> VortexResult<()> { + let arr = PrimitiveArray::new(buffer![f64::MAX, f64::MAX, 1.0], Validity::NonNullable) + .into_array(); + let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; + assert_eq!( + result.as_primitive().typed_value::(), + Some(f64::INFINITY) + ); + Ok(()) + } + #[test] fn sum_f64_with_infinity() -> VortexResult<()> { let batch = PrimitiveArray::new( From 86623b17e5b01c65d7881e68dd3019f7e0ccce08 Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov Date: Thu, 11 Jun 2026 17:12:59 +0100 Subject: [PATCH 3/3] Carry the Kahan compensation term in float sum partials Float sum partials are now a struct {sum: f64, compensation: f64} instead of a plain f64, so compensated-summation precision survives partial flush/combine boundaries (chunked arrays, kernel partials, accumulator merges). - Sum::partial_dtype diverges from return_dtype for floats; to_scalar emits the struct, finalize/finalize_scalar collapse it back to sum + compensation. - Sum::combine_partials accepts both the struct partial and a plain f64 (constant-multiply path). - The legacy stats bridge in Accumulator::accumulate now falls through on a partial dtype mismatch instead of erroring, and Sum::try_accumulate consumes the cached f64 Stat::Sum for floats. Signed-off-by: Dimitar Dimitrov --- vortex-array/src/aggregate_fn/accumulator.rs | 85 +++++--- vortex-array/src/aggregate_fn/fns/sum/mod.rs | 197 ++++++++++++++++++- 2 files changed, 244 insertions(+), 38 deletions(-) diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index c89418e67a6..985aad9806c 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -120,21 +120,17 @@ impl DynAccumulator for Accumulator { ); // 0. Legacy stats bridge: if this aggregate is still cached under a legacy Stat slot, - // consume that exact stat before kernel dispatch or decode. + // consume that exact stat before kernel dispatch or decode. When the stat dtype is + // incompatible with the partial dtype (e.g. float Sum partials carry a compensation + // term the f64 stat lacks), fall through to regular dispatch, where the vtable may + // still consume the stat itself (see `Sum::try_accumulate`). if let Some(stat) = Stat::from_aggregate_fn(&self.aggregate_fn) && let Precision::Exact(partial) = batch.statistics().get(stat) + && partial.dtype().eq_ignore_nullability(&self.partial_dtype) { let partial = if partial.dtype() == &self.partial_dtype { partial } else { - vortex_ensure!( - partial.dtype().eq_ignore_nullability(&self.partial_dtype), - "Aggregate {} read legacy stat {} with dtype {}, expected {}", - self.aggregate_fn, - stat, - partial.dtype(), - self.partial_dtype, - ); partial.cast(&self.partial_dtype)? }; self.vtable.combine_partials(&mut self.partial, partial)?; @@ -332,10 +328,38 @@ mod tests { _batch: &ArrayRef, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable))) + Ok(Some(sum_partial(42.0))) } } + /// Build a float Sum partial scalar `{sum: value, compensation: 0.0}`. + fn sum_partial(value: f64) -> Scalar { + let dtype = Sum + .partial_dtype( + &EmptyOptions, + &DType::Primitive(PType::F64, Nullability::NonNullable), + ) + .expect("sum supports f64"); + Scalar::struct_( + dtype, + vec![ + Scalar::primitive(value, Nullability::Nullable), + Scalar::primitive(0.0f64, Nullability::Nullable), + ], + ) + } + + /// Read the `sum` value out of a mean partial `{sum: {sum, compensation}, count}`. + fn partial_sum_value(partial: &Scalar) -> Option { + partial + .as_struct() + .field("sum")? + .as_struct() + .field("sum")? + .as_primitive() + .as_::() + } + fn fresh_session() -> VortexSession { VortexSession::empty().with::() } @@ -357,7 +381,7 @@ mod tests { fn sentinel_partial() -> Scalar { let acc = mean_f64_accumulator().expect("build accumulator"); - let sum = Scalar::primitive(42.0f64, Nullability::Nullable); + let sum = sum_partial(42.0); let count = Scalar::primitive(1u64, Nullability::NonNullable); Scalar::struct_(acc.partial_dtype, vec![sum, count]) } @@ -377,13 +401,14 @@ mod tests { acc.accumulate(&dict_of_seven(), &mut ctx)?; let partial = acc.flush()?; - let s = partial.as_struct(); - assert_eq!( - s.field("sum").unwrap().as_primitive().as_::(), - Some(42.0) - ); + assert_eq!(partial_sum_value(&partial), Some(42.0)); assert_eq!( - s.field("count").unwrap().as_primitive().as_::(), + partial + .as_struct() + .field("count") + .unwrap() + .as_primitive() + .as_::(), Some(1) ); Ok(()) @@ -404,13 +429,14 @@ mod tests { acc.accumulate(&dict_of_seven(), &mut ctx)?; let partial = acc.flush()?; - let s = partial.as_struct(); + assert_eq!(partial_sum_value(&partial), Some(7.0)); assert_eq!( - s.field("sum").unwrap().as_primitive().as_::(), - Some(7.0) - ); - assert_eq!( - s.field("count").unwrap().as_primitive().as_::(), + partial + .as_struct() + .field("count") + .unwrap() + .as_primitive() + .as_::(), Some(1) ); Ok(()) @@ -432,16 +458,17 @@ mod tests { acc.accumulate(&dict_of_seven(), &mut ctx)?; let partial = acc.flush()?; - let s = partial.as_struct(); // `Sum` child returned the sentinel 42.0 — proves the (Dict, Sum) kernel fired // via `Combined`'s fan-out. `Count`'s native `try_accumulate` reads the // batch's valid_count, so count is the real 1. + assert_eq!(partial_sum_value(&partial), Some(42.0)); assert_eq!( - s.field("sum").unwrap().as_primitive().as_::(), - Some(42.0) - ); - assert_eq!( - s.field("count").unwrap().as_primitive().as_::(), + partial + .as_struct() + .field("count") + .unwrap() + .as_primitive() + .as_::(), Some(1) ); Ok(()) diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 7855ef24e9e..cac6b8ce76a 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -25,16 +25,21 @@ use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; +use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::DecimalDType; +use crate::dtype::FieldName; +use crate::dtype::FieldNames; use crate::dtype::MAX_PRECISION; use crate::dtype::Nullability; use crate::dtype::PType; +use crate::dtype::StructFields; use crate::expr::stats::Precision; use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; use crate::scalar::DecimalValue; use crate::scalar::Scalar; +use crate::scalar_fn::fns::operators::Operator; /// Return the sum of an array. /// @@ -113,7 +118,13 @@ impl AggregateFnVTable for Sum { } fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { - self.return_dtype(options, input_dtype) + let return_dtype = self.return_dtype(options, input_dtype)?; + Some(match &return_dtype { + // Float partials carry the Kahan compensation term so that precision is preserved + // when partial sums are flushed and re-combined. + DType::Primitive(ptype, _) if ptype.is_float() => float_partial_dtype(), + _ => return_dtype, + }) } fn empty_partial( @@ -124,10 +135,14 @@ impl AggregateFnVTable for Sum { let return_dtype = self .return_dtype(options, input_dtype) .ok_or_else(|| vortex_err!("Unsupported sum dtype: {}", input_dtype))?; + let partial_dtype = self + .partial_dtype(options, input_dtype) + .ok_or_else(|| vortex_err!("Unsupported sum dtype: {}", input_dtype))?; let initial = make_zero_state(&return_dtype); Ok(SumPartial { return_dtype, + partial_dtype, current: Some(initial), }) } @@ -157,11 +172,27 @@ impl AggregateFnVTable for Sum { checked_add_i64(acc, val) } SumState::Float(acc) => { - let val = other - .as_primitive() - .typed_value::() - .vortex_expect("checked non-null"); - acc.add(val); + // Partials produced by `to_scalar` and kernels carry the compensation term as a + // struct; the constant-multiply path feeds a plain f64 scalar. + if matches!(other.dtype(), DType::Struct(..)) { + let s = other.as_struct(); + let sum = s + .field("sum") + .and_then(|f| f.as_primitive().typed_value::()) + .vortex_expect("checked non-null"); + let compensation = s + .field("compensation") + .and_then(|f| f.as_primitive().typed_value::()) + .vortex_expect("checked non-null"); + acc.add(sum); + acc.add(compensation); + } else { + let val = other + .as_primitive() + .typed_value::() + .vortex_expect("checked non-null"); + acc.add(val); + } false } SumState::Decimal { value, dtype } => { @@ -186,10 +217,16 @@ impl AggregateFnVTable for Sum { fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { Ok(match &partial.current { - None => Scalar::null(partial.return_dtype.as_nullable()), + None => Scalar::null(partial.partial_dtype.as_nullable()), Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable), Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable), - Some(SumState::Float(v)) => Scalar::primitive(v.value(), Nullability::Nullable), + Some(SumState::Float(v)) => Scalar::struct_( + partial.partial_dtype.clone(), + vec![ + Scalar::primitive(v.sum, Nullability::Nullable), + Scalar::primitive(v.compensation, Nullability::Nullable), + ], + ), Some(SumState::Decimal { value, .. }) => { let decimal_dtype = *partial .return_dtype @@ -213,6 +250,32 @@ impl AggregateFnVTable for Sum { } } + fn try_accumulate( + &self, + partial: &mut Self::Partial, + batch: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + // The generic legacy-stat bridge in `Accumulator::accumulate` only fires when the cached + // stat dtype matches the partial dtype. Float partials carry a compensation term, so the + // cached f64 `Stat::Sum` no longer matches; consume it here instead. + let Some(SumState::Float(acc)) = partial.current.as_mut() else { + return Ok(false); + }; + let Precision::Exact(stat) = batch.statistics().get(Stat::Sum) else { + return Ok(false); + }; + if !stat.dtype().eq_ignore_nullability(&partial.return_dtype) { + return Ok(false); + } + match stat.as_primitive().typed_value::() { + Some(v) => acc.add(v), + // A null cached sum means the sum saturated. + None => partial.current = None, + } + Ok(true) + } + fn accumulate( &self, partial: &mut Self::Partial, @@ -254,11 +317,23 @@ impl AggregateFnVTable for Sum { } fn finalize(&self, partials: ArrayRef) -> VortexResult { + // Float partials carry the compensation term as a struct; collapse them to + // `sum + compensation`. The struct-level validity (null groups, saturation) propagates + // through `get_item` into both fields. + if matches!(partials.dtype(), DType::Struct(..)) { + let sum = partials.get_item(FieldName::from("sum"))?; + let compensation = partials.get_item(FieldName::from("compensation"))?; + return sum.binary(compensation, Operator::Add); + } Ok(partials) } fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult { - self.to_scalar(partial) + match &partial.current { + None => Ok(Scalar::null(partial.return_dtype.as_nullable())), + Some(SumState::Float(v)) => Ok(Scalar::primitive(v.value(), Nullability::Nullable)), + Some(_) => self.to_scalar(partial), + } } } @@ -266,10 +341,30 @@ impl AggregateFnVTable for Sum { /// needed for reset/result without external context. pub struct SumPartial { return_dtype: DType, + /// The DType of the partial state. Differs from `return_dtype` for floats, where partials + /// are a struct carrying the Kahan compensation term alongside the sum. + partial_dtype: DType, /// The current accumulated state, or `None` if saturated (checked overflow). current: Option, } +/// The partial dtype for float sums: a struct carrying the running sum and the Kahan +/// compensation term, so that precision survives partial flush/combine boundaries. +fn float_partial_dtype() -> DType { + // The fields are nullable so that projecting them out of the (nullable) struct in + // `finalize` keeps the field dtype unchanged. + DType::Struct( + StructFields::new( + FieldNames::from_iter([FieldName::from("sum"), FieldName::from("compensation")]), + vec![ + DType::Primitive(PType::F64, Nullability::Nullable), + DType::Primitive(PType::F64, Nullability::Nullable), + ], + ), + Nullability::Nullable, + ) +} + /// The accumulated sum value. /// // TODO(ngates): instead of an enum, we should use a Box to avoid dispatcher over the @@ -511,6 +606,57 @@ mod tests { Ok(()) } + #[test] + fn sum_partial_preserves_compensation() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + + // Each sub-accumulator's partial carries a compensation of 1.0 next to a sum of + // ±1e100. With plain f64 partials the merged sum would collapse to 0.0. + let mut acc1 = Accumulator::try_new(Sum, EmptyOptions, dtype.clone())?; + let batch1 = + PrimitiveArray::new(buffer![1.0f64, 1e100], Validity::NonNullable).into_array(); + acc1.accumulate(&batch1, &mut ctx)?; + + let mut acc2 = Accumulator::try_new(Sum, EmptyOptions, dtype.clone())?; + let batch2 = + PrimitiveArray::new(buffer![1.0f64, -1e100], Validity::NonNullable).into_array(); + acc2.accumulate(&batch2, &mut ctx)?; + + let mut merged = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + merged.combine_partials(acc1.flush()?)?; + merged.combine_partials(acc2.flush()?)?; + + let result = merged.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(2.0)); + Ok(()) + } + + #[test] + fn sum_f64_consumes_cached_stat() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let array = PrimitiveArray::new(buffer![1.0f64, 2.0], Validity::NonNullable).into_array(); + // Seed a deliberately wrong cached sum so we can observe it being consumed instead of + // the data being recomputed. + array.statistics().set( + Stat::Sum, + Precision::Exact( + Scalar::primitive(42.0f64, Nullable) + .value() + .cloned() + .vortex_expect("non-null"), + ), + ); + + let mut acc = Accumulator::try_new(Sum, EmptyOptions, array.dtype().clone())?; + acc.accumulate(&array, &mut ctx)?; + assert_eq!( + acc.finish()?.as_primitive().typed_value::(), + Some(42.0) + ); + Ok(()) + } + // Stats caching test #[test] @@ -615,6 +761,39 @@ mod tests { Ok(()) } + #[test] + fn grouped_sum_f64_kahan() -> VortexResult<()> { + // Group 0 sums [1.0, 1e100, -1e100]: naive summation loses the 1.0 and returns 0.0. + let elements = PrimitiveArray::new( + buffer![1.0f64, 1e100, -1e100, 1.0, 2.0, 3.0], + Validity::NonNullable, + ) + .into_array(); + let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?; + + let elem_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + + let expected = PrimitiveArray::from_option_iter([Some(1.0f64), Some(6.0)]).into_array(); + assert_arrays_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn grouped_sum_f64_with_null_group() -> VortexResult<()> { + let elements = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0], Validity::NonNullable).into_array(); + let validity = Validity::from_iter([true, false]); + let groups = FixedSizeListArray::try_new(elements, 2, validity, 2)?; + + let elem_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + + let expected = PrimitiveArray::from_option_iter([Some(3.0f64), None]).into_array(); + assert_arrays_eq!(&result, &expected); + Ok(()) + } + #[test] fn grouped_sum_bool() -> VortexResult<()> { let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();