From a52386fea5d30934ec4b5322ec7db73e78e4a09a Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Thu, 10 Apr 2025 12:05:05 +0800 Subject: [PATCH 1/3] feat: support min/max for struct --- datafusion/functions-aggregate/src/min_max.rs | 74 ++++++++++++++++--- .../sqllogictest/test_files/aggregate.slt | 26 +++++++ 2 files changed, 91 insertions(+), 9 deletions(-) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index ea4cad5488031..65fc8741bdc98 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -21,15 +21,16 @@ mod min_max_bytes; use arrow::array::{ - ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray, - DurationNanosecondArray, DurationSecondArray, Float16Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, - LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray, - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Array, ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, + Date64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray, + DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, + LargeBinaryArray, LargeStringArray, StringArray, StringViewArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }; use arrow::compute; use arrow::datatypes::{ @@ -610,10 +611,57 @@ fn min_batch(values: &ArrayRef) -> Result { min_binary_view ) } + DataType::Struct(_) => min_max_batch_struct(values, Ordering::Greater)?, _ => min_max_batch!(values, min), }) } +fn min_max_batch_struct(array: &ArrayRef, ordering: Ordering) -> Result { + if array.len() == array.null_count() { + return ScalarValue::try_from(array.data_type()); + } + let mut extreme = ScalarValue::try_from_array(array, 0)?; + for i in 1..array.len() { + let current = ScalarValue::try_from_array(array, i)?; + if current.is_null() { + continue; + } + if extreme.is_null() { + extreme = current; + continue; + } + match extreme.partial_cmp(¤t) { + Some(cmp) => { + if cmp == ordering { + extreme = current; + } + } + None => { + return internal_err!("Comparison error while computing min/max"); + } + } + } + Ok(extreme) +} + +macro_rules! min_max_struct { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + if $VALUE.is_null() { + Ok($DELTA.clone()) + } else if $DELTA.is_null() { + Ok($VALUE.clone()) + } else { + match $VALUE.partial_cmp(&$DELTA) { + Some(choose_min_max!($OP)) => Ok($DELTA.clone()), + Some(_) => Ok($VALUE.clone()), + None => { + internal_err!("Comparison error while computing min/max") + } + } + } + }}; +} + /// dynamically-typed max(array) -> ScalarValue pub fn max_batch(values: &ArrayRef) -> Result { Ok(match values.data_type() { @@ -653,6 +701,7 @@ pub fn max_batch(values: &ArrayRef) -> Result { max_binary ) } + DataType::Struct(_) => min_max_batch_struct(values, Ordering::Less)?, _ => min_max_batch!(values, max), }) } @@ -923,6 +972,13 @@ macro_rules! min_max { ) => { typed_min_max!(lhs, rhs, DurationNanosecond, $OP) } + + ( + lhs @ ScalarValue::Struct(_), + rhs @ ScalarValue::Struct(_), + ) => { + min_max_struct!(lhs, rhs, $OP)? + } e => { return internal_err!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 004846bc369ee..1c155e9d7813e 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -6752,3 +6752,29 @@ select c2, count(*) from test WHERE 1 = 1 group by c2; 5 1 6 1 + +# Min/Max struct +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c FROM t) +---- +{a: 1, b: 2} {a: 10, b: 11} + +# Min/Max struct with NULL +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 2 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c FROM t) +---- +{a: 2, b: 3} {a: 10, b: 11} + +# Min/Max struct with two recordbatch +query ?? rowsort +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(1 as 'a', 2 as 'b') AS c UNION SELECT STRUCT(3 as 'a', 4 as 'b') AS c ) +---- +{a: 1, b: 2} {a: 3, b: 4} + +# Min/Max struct empty +query ?? rowsort +SELECT MIN(c), MAX(c) FROM (SELECT * FROM (SELECT STRUCT(1 as 'a', 2 as 'b') AS c) LIMIT 0) +---- +NULL NULL From c0cf133797cfdb337bff52988390062c2fdbe7a3 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 14 Apr 2025 19:52:54 +0800 Subject: [PATCH 2/3] groups aggregator --- datafusion/common/src/scalar/mod.rs | 81 ++- .../src/aggregate/groups_accumulator/nulls.rs | 10 +- datafusion/functions-aggregate/src/min_max.rs | 39 +- .../src/min_max/min_max_struct.rs | 544 ++++++++++++++++++ .../sqllogictest/test_files/aggregate.slt | 32 ++ 5 files changed, 681 insertions(+), 25 deletions(-) create mode 100644 datafusion/functions-aggregate/src/min_max/min_max_struct.rs diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index b8d9aea810f03..03ec0fe2a1d82 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -506,7 +506,7 @@ impl PartialOrd for ScalarValue { } (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Struct(struct_arr1), Struct(struct_arr2)) => { - partial_cmp_struct(struct_arr1, struct_arr2) + partial_cmp_struct(struct_arr1.as_ref(), struct_arr2.as_ref()) } (Struct(_), _) => None, (Map(map_arr1), Map(map_arr2)) => partial_cmp_map(map_arr1, map_arr2), @@ -612,7 +612,20 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { Some(Ordering::Equal) } -fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option { +fn expand_struct_columns<'a>(array: &'a StructArray, columns: &mut Vec<&'a ArrayRef>) { + for i in 0..array.num_columns() { + let column = array.column(i); + if let Some(nested_struct) = column.as_any().downcast_ref::() { + // If it's a nested struct, recursively expand + expand_struct_columns(nested_struct, columns); + } else { + // If it's a primitive type, add directly + columns.push(column); + } + } +} + +pub fn partial_cmp_struct(s1: &StructArray, s2: &StructArray) -> Option { if s1.len() != s2.len() { return None; } @@ -621,9 +634,15 @@ fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option() } + + /// Performs a deep clone of the ScalarValue, creating new copies of all nested data structures. + /// This is different from the standard `clone()` which may share data through `Arc`. + /// Aggregation functions like `max` will cost a lot of memory if the data is not cloned. + pub fn deep_clone(&self) -> Self { + match self { + // Complex types need deep clone of their contents + ScalarValue::List(array) => { + let array = copy_array_data(&array.to_data()); + let new_array = ListArray::from(array); + ScalarValue::List(Arc::new(new_array)) + } + ScalarValue::LargeList(array) => { + let array = copy_array_data(&array.to_data()); + let new_array = LargeListArray::from(array); + ScalarValue::LargeList(Arc::new(new_array)) + } + ScalarValue::FixedSizeList(arr) => { + let array = copy_array_data(&arr.to_data()); + let new_array = FixedSizeListArray::from(array); + ScalarValue::FixedSizeList(Arc::new(new_array)) + } + ScalarValue::Struct(arr) => { + let array = copy_array_data(&arr.to_data()); + let new_array = StructArray::from(array); + ScalarValue::Struct(Arc::new(new_array)) + } + ScalarValue::Map(arr) => { + let array = copy_array_data(&arr.to_data()); + let new_array = MapArray::from(array); + ScalarValue::Map(Arc::new(new_array)) + } + ScalarValue::Union(Some((type_id, value)), fields, mode) => { + let new_value = Box::new(value.deep_clone()); + ScalarValue::Union(Some((*type_id, new_value)), fields.clone(), *mode) + } + ScalarValue::Union(None, fields, mode) => { + ScalarValue::Union(None, fields.clone(), *mode) + } + ScalarValue::Dictionary(key_type, value) => { + let new_value = Box::new(value.deep_clone()); + ScalarValue::Dictionary(key_type.clone(), new_value) + } + _ => self.clone(), + } + } +} + +pub fn copy_array_data(data: &ArrayData) -> ArrayData { + let mut copy = MutableArrayData::new(vec![&data], true, data.len()); + copy.extend(0, 0, data.len()); + copy.freeze() } macro_rules! impl_scalar { diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 6a8946034cbc3..44c7f1b0c7de4 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -20,7 +20,7 @@ use arrow::array::{ Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray, BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray, - StringViewArray, + StringViewArray, StructArray, }; use arrow::buffer::NullBuffer; use arrow::datatypes::DataType; @@ -193,6 +193,14 @@ pub fn set_nulls_dyn(input: &dyn Array, nulls: Option) -> Result unsafe { + let input = input.as_struct(); + Arc::new(StructArray::new_unchecked( + input.fields().clone(), + input.columns().to_vec(), + nulls, + )) + }, _ => { return not_impl_err!("Applying nulls {:?}", input.data_type()); } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 65fc8741bdc98..025c3037c7d36 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -19,6 +19,7 @@ //! [`Min`] and [`MinAccumulator`] accumulator for the `min` function mod min_max_bytes; +mod min_max_struct; use arrow::array::{ Array, ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, @@ -56,6 +57,7 @@ use arrow::datatypes::{ }; use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; +use crate::min_max::min_max_struct::MinMaxStructAccumulator; use datafusion_common::ScalarValue; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, @@ -267,6 +269,7 @@ impl AggregateUDFImpl for Max { | LargeBinary | BinaryView | Duration(_) + | Struct(_) ) } @@ -342,7 +345,9 @@ impl AggregateUDFImpl for Max { Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone()))) } - + Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_max( + data_type.clone(), + ))), // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), } @@ -630,33 +635,26 @@ fn min_max_batch_struct(array: &ArrayRef, ordering: Ordering) -> Result { - if cmp == ordering { - extreme = current; - } - } - None => { - return internal_err!("Comparison error while computing min/max"); + if let Some(cmp) = extreme.partial_cmp(¤t) { + if cmp == ordering { + extreme = current; } } } - Ok(extreme) + // use deep_clone to free array reference + Ok(extreme.deep_clone()) } macro_rules! min_max_struct { ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ if $VALUE.is_null() { - Ok($DELTA.clone()) + $DELTA.clone() } else if $DELTA.is_null() { - Ok($VALUE.clone()) + $VALUE.clone() } else { match $VALUE.partial_cmp(&$DELTA) { - Some(choose_min_max!($OP)) => Ok($DELTA.clone()), - Some(_) => Ok($VALUE.clone()), - None => { - internal_err!("Comparison error while computing min/max") - } + Some(choose_min_max!($OP)) => $DELTA.clone(), + _ => $VALUE.clone(), } } }}; @@ -977,7 +975,7 @@ macro_rules! min_max { lhs @ ScalarValue::Struct(_), rhs @ ScalarValue::Struct(_), ) => { - min_max_struct!(lhs, rhs, $OP)? + min_max_struct!(lhs, rhs, $OP) } e => { return internal_err!( @@ -1189,6 +1187,7 @@ impl AggregateUDFImpl for Min { | LargeBinary | BinaryView | Duration(_) + | Struct(_) ) } @@ -1264,7 +1263,9 @@ impl AggregateUDFImpl for Min { Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone()))) } - + Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_min( + data_type.clone(), + ))), // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), } diff --git a/datafusion/functions-aggregate/src/min_max/min_max_struct.rs b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs new file mode 100644 index 0000000000000..8038f2f01d90c --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs @@ -0,0 +1,544 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{cmp::Ordering, sync::Arc}; + +use arrow::{ + array::{ + Array, ArrayData, ArrayRef, AsArray, BooleanArray, MutableArrayData, StructArray, + }, + datatypes::DataType, +}; +use datafusion_common::{ + internal_err, + scalar::{copy_array_data, partial_cmp_struct}, + Result, +}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; + +/// Accumulator for MIN/MAX operations on Struct data types. +/// +/// This accumulator tracks the minimum or maximum struct value encountered +/// during aggregation, depending on the `is_min` flag. +/// +/// The comparison is done based on the struct fields in order. +pub(crate) struct MinMaxStructAccumulator { + /// Inner data storage. + inner: MinMaxStructState, + /// if true, is `MIN` otherwise is `MAX` + is_min: bool, +} + +impl MinMaxStructAccumulator { + pub fn new_min(data_type: DataType) -> Self { + Self { + inner: MinMaxStructState::new(data_type), + is_min: true, + } + } + + pub fn new_max(data_type: DataType) -> Self { + Self { + inner: MinMaxStructState::new(data_type), + is_min: false, + } + } +} + +impl GroupsAccumulator for MinMaxStructAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = &values[0]; + assert_eq!(array.len(), group_indices.len()); + assert_eq!(array.data_type(), &self.inner.data_type); + // apply filter if needed + let array = apply_filter_as_nulls(array, opt_filter)?; + + fn struct_min(a: &StructArray, b: &StructArray) -> bool { + matches!(partial_cmp_struct(a, b), Some(Ordering::Less)) + } + + fn struct_max(a: &StructArray, b: &StructArray) -> bool { + matches!(partial_cmp_struct(a, b), Some(Ordering::Greater)) + } + + if self.is_min { + self.inner.update_batch( + array.as_struct(), + group_indices, + total_num_groups, + struct_min, + ) + } else { + self.inner.update_batch( + array.as_struct(), + group_indices, + total_num_groups, + struct_max, + ) + } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let (_, min_maxes) = self.inner.emit_to(emit_to); + let fields = match &self.inner.data_type { + DataType::Struct(fields) => fields, + _ => return internal_err!("Data type is not a struct"), + }; + let null_array = StructArray::new_null(fields.clone(), 1); + let min_maxes_data: Vec = min_maxes + .iter() + .map(|v| match v { + Some(v) => v.to_data(), + None => null_array.to_data(), + }) + .collect(); + let min_maxes_refs: Vec<&ArrayData> = min_maxes_data.iter().collect(); + let mut copy = MutableArrayData::new(min_maxes_refs, true, min_maxes_data.len()); + + for (i, item) in min_maxes_data.iter().enumerate() { + copy.extend(i, 0, item.len()); + } + let result = copy.freeze(); + assert_eq!(&self.inner.data_type, result.data_type()); + Ok(Arc::new(StructArray::from(result))) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + // min/max are their own states (no transition needed) + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // min/max are their own states (no transition needed) + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // Min/max do not change the values as they are their own states + // apply the filter by combining with the null mask, if any + let output = apply_filter_as_nulls(&values[0], opt_filter)?; + Ok(vec![output]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +#[derive(Debug)] +struct MinMaxStructState { + /// The minimum/maximum value for each group + min_max: Vec>, + /// The data type of the array + data_type: DataType, + /// The total bytes of the string data (for pre-allocating the final array, + /// and tracking memory usage) + total_data_bytes: usize, +} + +#[derive(Debug, Clone)] +enum MinMaxLocation { + /// the min/max value is stored in the existing `min_max` array + ExistingMinMax, + /// the min/max value is stored in the input array at the given index + Input(StructArray), +} + +/// Implement the MinMaxStructState with a comparison function +/// for comparing structs +impl MinMaxStructState { + /// Create a new MinMaxStructState + /// + /// # Arguments: + /// * `data_type`: The data type of the arrays that will be passed to this accumulator + fn new(data_type: DataType) -> Self { + Self { + min_max: vec![], + data_type, + total_data_bytes: 0, + } + } + + /// Set the specified group to the given value, updating memory usage appropriately + fn set_value(&mut self, group_index: usize, new_val: &StructArray) { + let new_val = StructArray::from(copy_array_data(&new_val.to_data())); + match self.min_max[group_index].as_mut() { + None => { + self.total_data_bytes += new_val.get_array_memory_size(); + self.min_max[group_index] = Some(new_val); + } + Some(existing_val) => { + // Copy data over to avoid re-allocating + self.total_data_bytes -= existing_val.get_array_memory_size(); + self.total_data_bytes += new_val.get_array_memory_size(); + *existing_val = new_val; + } + } + } + + /// Updates the min/max values for the given string values + /// + /// `cmp` is the comparison function to use, called like `cmp(new_val, existing_val)` + /// returns true if the `new_val` should replace `existing_val` + fn update_batch( + &mut self, + array: &StructArray, + group_indices: &[usize], + total_num_groups: usize, + mut cmp: F, + ) -> Result<()> + where + F: FnMut(&StructArray, &StructArray) -> bool + Send + Sync, + { + self.min_max.resize(total_num_groups, None); + // Minimize value copies by calculating the new min/maxes for each group + // in this batch (either the existing min/max or the new input value) + // and updating the owned values in `self.min_maxes` at most once + let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; + + // Figure out the new min value for each group + for (index, group_index) in (0..array.len()).zip(group_indices.iter()) { + let group_index = *group_index; + if array.is_null(index) { + continue; + } + let new_val = array.slice(index, 1); + + let existing_val = match &locations[group_index] { + // previous input value was the min/max, so compare it + MinMaxLocation::Input(existing_val) => existing_val, + MinMaxLocation::ExistingMinMax => { + let Some(existing_val) = self.min_max[group_index].as_ref() else { + // no existing min/max, so this is the new min/max + locations[group_index] = MinMaxLocation::Input(new_val); + continue; + }; + existing_val + } + }; + + // Compare the new value to the existing value, replacing if necessary + if cmp(&new_val, existing_val) { + locations[group_index] = MinMaxLocation::Input(new_val); + } + } + + // Update self.min_max with any new min/max values we found in the input + for (group_index, location) in locations.iter().enumerate() { + match location { + MinMaxLocation::ExistingMinMax => {} + MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val), + } + } + Ok(()) + } + + /// Emits the specified min_max values + /// + /// Returns (data_capacity, min_maxes), updating the current value of total_data_bytes + /// + /// - `data_capacity`: the total length of all strings and their contents, + /// - `min_maxes`: the actual min/max values for each group + fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec>) { + match emit_to { + EmitTo::All => { + ( + std::mem::take(&mut self.total_data_bytes), // reset total bytes and min_max + std::mem::take(&mut self.min_max), + ) + } + EmitTo::First(n) => { + let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect(); + let first_data_capacity: usize = first_min_maxes + .iter() + .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum(); + self.total_data_bytes -= first_data_capacity; + (first_data_capacity, first_min_maxes) + } + } + } + + fn size(&self) -> usize { + self.total_data_bytes + self.min_max.len() * size_of::>() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray, StructArray}; + use arrow::datatypes::{DataType, Field, Fields, Int32Type}; + use std::sync::Arc; + + fn create_test_struct_array( + int_values: Vec>, + str_values: Vec>, + ) -> StructArray { + let int_array = Int32Array::from(int_values); + let str_array = StringArray::from(str_values); + + let fields = vec![ + Field::new("int_field", DataType::Int32, true), + Field::new("str_field", DataType::Utf8, true), + ]; + + StructArray::new( + Fields::from(fields), + vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ], + None, + ) + } + + fn create_nested_struct_array( + int_values: Vec>, + str_values: Vec>, + ) -> StructArray { + let inner_struct = create_test_struct_array(int_values, str_values); + + let fields = vec![Field::new("inner", inner_struct.data_type().clone(), true)]; + + StructArray::new( + Fields::from(fields), + vec![Arc::new(inner_struct) as ArrayRef], + None, + ) + } + + #[test] + fn test_min_max_simple_struct() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_nested_struct() { + let array = create_nested_struct_array( + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let inner = min_result.column(0).as_struct(); + let int_array = inner.column(0).as_primitive::(); + let str_array = inner.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let inner = max_result.column(0).as_struct(); + let int_array = inner.column(0).as_primitive::(); + let str_array = inner.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_with_nulls() { + let array = create_test_struct_array( + vec![Some(1), None, Some(3)], + vec![Some("a"), None, Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_multiple_groups() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3), Some(4)], + vec![Some("a"), Some("b"), Some("c"), Some("d")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 1, 0, 1]; + + min_accumulator + .update_batch(&values, &group_indices, None, 2) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 2) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 2); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + assert_eq!(int_array.value(1), 2); + assert_eq!(str_array.value(1), "b"); + + assert_eq!(max_result.len(), 2); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + assert_eq!(int_array.value(1), 4); + assert_eq!(str_array.value(1), "d"); + } + + #[test] + fn test_min_max_with_filter() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3), Some(4)], + vec![Some("a"), Some("b"), Some("c"), Some("d")], + ); + + // Create a filter that only keeps even numbers + let filter = BooleanArray::from(vec![false, true, false, true]); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, Some(&filter), 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, Some(&filter), 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 2); + assert_eq!(str_array.value(0), "b"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 4); + assert_eq!(str_array.value(0), "d"); + } +} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 1c155e9d7813e..25ee2dc8908d1 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -6778,3 +6778,35 @@ query ?? rowsort SELECT MIN(c), MAX(c) FROM (SELECT * FROM (SELECT STRUCT(1 as 'a', 2 as 'b') AS c) LIMIT 0) ---- NULL NULL + +# Min/Max group struct +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 2, b: 3} {a: 10, b: 11} +1 {a: 1, b: 2} {a: 9, b: 10} + +# Min/Max group struct with NULL +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 2 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 2, b: 3} {a: 10, b: 11} +1 NULL NULL + +# Min/Max group struct with NULL +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 3 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 6, b: 7} {a: 6, b: 7} +1 {a: 3, b: 4} {a: 9, b: 10} + +# Min/Max struct empty +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c, (c1 % 2) AS key FROM t LIMIT 0) GROUP BY key +---- + + From 3dcf38d8a00472e7c7b0e921e9cba593edb4b8dc Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sun, 4 May 2025 12:47:58 +0800 Subject: [PATCH 3/3] update based on lamb's suggestion --- datafusion/common/src/scalar/mod.rs | 14 +++++++------- .../src/aggregate/groups_accumulator/nulls.rs | 18 +++++++++++------- datafusion/functions-aggregate/src/min_max.rs | 4 ++-- .../sqllogictest/test_files/aggregate.slt | 17 +++++++++++++++++ 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 8e8ac3a74835b..5c020d1f6398a 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -616,12 +616,12 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { Some(arr1.len().cmp(&arr2.len())) } -fn expand_struct_columns<'a>(array: &'a StructArray, columns: &mut Vec<&'a ArrayRef>) { +fn flatten<'a>(array: &'a StructArray, columns: &mut Vec<&'a ArrayRef>) { for i in 0..array.num_columns() { let column = array.column(i); if let Some(nested_struct) = column.as_any().downcast_ref::() { // If it's a nested struct, recursively expand - expand_struct_columns(nested_struct, columns); + flatten(nested_struct, columns); } else { // If it's a primitive type, add directly columns.push(column); @@ -641,8 +641,8 @@ pub fn partial_cmp_struct(s1: &StructArray, s2: &StructArray) -> Option Self { + pub fn force_clone(&self) -> Self { match self { // Complex types need deep clone of their contents ScalarValue::List(array) => { @@ -3467,14 +3467,14 @@ impl ScalarValue { ScalarValue::Map(Arc::new(new_array)) } ScalarValue::Union(Some((type_id, value)), fields, mode) => { - let new_value = Box::new(value.deep_clone()); + let new_value = Box::new(value.force_clone()); ScalarValue::Union(Some((*type_id, new_value)), fields.clone(), *mode) } ScalarValue::Union(None, fields, mode) => { ScalarValue::Union(None, fields.clone(), *mode) } ScalarValue::Dictionary(key_type, value) => { - let new_value = Box::new(value.deep_clone()); + let new_value = Box::new(value.force_clone()); ScalarValue::Dictionary(key_type.clone(), new_value) } _ => self.clone(), diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 44c7f1b0c7de4..c8c7736bba14f 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -193,14 +193,18 @@ pub fn set_nulls_dyn(input: &dyn Array, nulls: Option) -> Result unsafe { + DataType::Struct(_) => { let input = input.as_struct(); - Arc::new(StructArray::new_unchecked( - input.fields().clone(), - input.columns().to_vec(), - nulls, - )) - }, + // safety: values / offsets came from a valid struct array + // and we checked nulls has the same length as values + unsafe { + Arc::new(StructArray::new_unchecked( + input.fields().clone(), + input.columns().to_vec(), + nulls, + )) + } + } _ => { return not_impl_err!("Applying nulls {:?}", input.data_type()); } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index c11f0f1437fef..af178ed675284 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -645,8 +645,8 @@ fn min_max_batch_struct(array: &ArrayRef, ordering: Ordering) -> Result