From e594f4bbfbefd7f07248b21dccb6826079dc9d56 Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Wed, 17 Jun 2026 14:55:10 +0200 Subject: [PATCH 1/7] implement map_agg --- datafusion/functions-aggregate/src/lib.rs | 3 + datafusion/functions-aggregate/src/map_agg.rs | 725 ++++++++++++++++++ datafusion/functions-aggregate/src/utils.rs | 41 +- .../sqllogictest/test_files/map_agg.slt | 70 ++ 4 files changed, 837 insertions(+), 2 deletions(-) create mode 100644 datafusion/functions-aggregate/src/map_agg.rs create mode 100644 datafusion/sqllogictest/test_files/map_agg.slt diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 1b9996220d882..f1656e449e006 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -79,6 +79,7 @@ pub mod covariance; pub mod first_last; pub mod grouping; pub mod hyperloglog; +pub mod map_agg; pub mod median; pub mod min_max; pub mod nth_value; @@ -122,6 +123,7 @@ pub mod expr_fn { pub use super::first_last::first_value; pub use super::first_last::last_value; pub use super::grouping::grouping; + pub use super::map_agg::map_agg; pub use super::median::median; pub use super::min_max::max; pub use super::min_max::min; @@ -156,6 +158,7 @@ pub fn all_default_aggregate_functions() -> Vec> { sum::sum_udaf(), min_max::max_udaf(), min_max::min_udaf(), + map_agg::map_agg_udaf(), median::median_udaf(), count::count_udaf(), regr::regr_slope_udaf(), diff --git a/datafusion/functions-aggregate/src/map_agg.rs b/datafusion/functions-aggregate/src/map_agg.rs new file mode 100644 index 0000000000000..8a32954768a3a --- /dev/null +++ b/datafusion/functions-aggregate/src/map_agg.rs @@ -0,0 +1,725 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `MAP_AGG` aggregate implementation: [`MapAgg`] + +use std::collections::VecDeque; +use std::mem::{size_of, size_of_val, take}; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, AsArray, MapArray, StructArray}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; + +use datafusion_common::utils::{compare_rows, get_row_at_idx}; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; +use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; +use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; +use datafusion_functions_aggregate_common::utils::ordering_fields; +use datafusion_macros::user_doc; +use datafusion_physical_expr_common::sort_expr::LexOrdering; + +use crate::utils::{map_row_to_scalars, struct_to_rows}; + +make_udaf_expr_and_func!( + MapAgg, + map_agg, + "Aggregate key-value pairs into a map", + map_agg_udaf +); + +#[user_doc( + doc_section(label = "General Functions"), + description = "Aggregate key-value pairs from two columns into a single map per group. Pairs with a NULL key are skipped; NULL values are retained. On a duplicate key the first value wins; use ORDER BY to make which value wins deterministic.", + syntax_example = "map_agg(key, value [ORDER BY expression])", + sql_example = r#" +```sql +> SELECT map_agg(name, score) FROM scores GROUP BY department; ++-------------------------------+ +| map_agg(name, score) | ++-------------------------------+ +| {Alice: 95, Bob: 87} | ++-------------------------------+ +``` +"#, + standard_argument(name = "key",), + standard_argument(name = "value",) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct MapAgg { + signature: Signature, +} + +impl Default for MapAgg { + fn default() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for MapAgg { + fn name(&self) -> &str { + "map_agg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(map_type(&arg_types[0], &arg_types[1])) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let key_type = args.input_fields[0].data_type(); + let value_type = args.input_fields[1].data_type(); + + let mut fields = vec![Arc::new(Field::new( + format_state_name(args.name, "map_agg"), + map_type(key_type, value_type), + true, + ))]; + + if !args.ordering_fields.is_empty() { + fields.push(Arc::new(Field::new_list( + format_state_name(args.name, "map_agg_orderings"), + Field::new_list_field( + DataType::Struct(Fields::from(args.ordering_fields.to_vec())), + true, + ), + false, + ))); + } + + Ok(fields) + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + // Order decides which value wins on a duplicate key, so the optimizer + // must satisfy it (inserts a SortExec). + // TODO: handle pre-sorted input like `array_agg` to skip the + // redundant sort. + AggregateOrderSensitivity::HardRequirement + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let key_type = acc_args.expr_fields[0].data_type().clone(); + let value_type = acc_args.expr_fields[1].data_type().clone(); + + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return Ok(Box::new(MapAggAccumulator::new(key_type, value_type))); + }; + + let ordering_dtypes = ordering + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; + + Ok(Box::new(OrderSensitiveMapAggAccumulator::new( + key_type, + value_type, + ordering_dtypes, + ordering, + ))) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn map_type(key_type: &DataType, value_type: &DataType) -> DataType { + let key_field = Arc::new(Field::new("key", key_type.clone(), false)); + let value_field = Arc::new(Field::new("value", value_type.clone(), true)); + let entries_field = Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![key_field, value_field])), + false, + )); + DataType::Map(entries_field, false) +} + +fn build_single_map( + keys: Vec, + values: Vec, + key_type: &DataType, + value_type: &DataType, +) -> Result { + debug_assert_eq!(keys.len(), values.len()); + + let key_field = Arc::new(Field::new("key", key_type.clone(), false)); + let value_field = Arc::new(Field::new("value", value_type.clone(), true)); + let entries_field = Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Arc::clone(&key_field), + Arc::clone(&value_field), + ])), + false, + )); + + let len = keys.len(); + let key_array = if len == 0 { + arrow::array::new_empty_array(key_type) + } else { + ScalarValue::iter_to_array(keys)? + }; + let value_array = if len == 0 { + arrow::array::new_empty_array(value_type) + } else { + ScalarValue::iter_to_array(values)? + }; + + let entries = StructArray::try_new( + Fields::from(vec![key_field, value_field]), + vec![key_array, value_array], + None, + )?; + + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0i32, len as i32])); + Ok(Arc::new(MapArray::try_new( + entries_field, + offsets, + entries, + None, + false, + )?)) +} + +/// De-duplicates parallel key/value vectors keeping the first value seen for +/// each key. +fn dedup_first_wins( + keys: Vec, + values: Vec, +) -> (Vec, Vec) { + use std::collections::HashSet; + + let mut seen: HashSet = HashSet::with_capacity(keys.len()); + let mut out_keys: Vec = Vec::with_capacity(keys.len()); + let mut out_vals: Vec = Vec::with_capacity(keys.len()); + + for (k, v) in keys.into_iter().zip(values) { + // Keep only the first occurrence of each key; later ones are dropped. + if seen.insert(k.clone()) { + out_keys.push(k); + out_vals.push(v); + } + } + + (out_keys, out_vals) +} + +/// Plain accumulator used when there is no `ORDER BY`. +#[derive(Debug)] +pub struct MapAggAccumulator { + key_type: DataType, + value_type: DataType, + keys: Vec, + values: Vec, +} + +impl MapAggAccumulator { + pub fn new(key_type: DataType, value_type: DataType) -> Self { + Self { + key_type, + value_type, + keys: Vec::new(), + values: Vec::new(), + } + } +} + +impl Accumulator for MapAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.len() != 2 { + return exec_err!("map_agg expects 2 columns, got {}", values.len()); + } + let keys = &values[0]; + let vals = &values[1]; + + for i in 0..keys.len() { + // NULL keys cannot exist in a map; skip the whole pair. + if keys.is_null(i) { + continue; + } + self.keys + .push(ScalarValue::try_from_array(keys, i)?.compacted()); + self.values + .push(ScalarValue::try_from_array(vals, i)?.compacted()); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let map_array = states[0].as_map(); + for row in 0..map_array.len() { + if map_array.is_null(row) { + continue; + } + let (keys, values) = map_row_to_scalars(map_array, row)?; + self.keys.extend(keys); + self.values.extend(values); + } + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&mut self) -> Result { + let (keys, values) = dedup_first_wins(self.keys.clone(), self.values.clone()); + let map_array = build_single_map(keys, values, &self.key_type, &self.value_type)?; + ScalarValue::try_from_array(&map_array, 0) + } + + fn size(&self) -> usize { + size_of_val(self) + ScalarValue::size_of_vec(&self.keys) - size_of_val(&self.keys) + + ScalarValue::size_of_vec(&self.values) + - size_of_val(&self.values) + + self.key_type.size() + - size_of_val(&self.key_type) + + self.value_type.size() + - size_of_val(&self.value_type) + } +} + +/// Accumulator used when `map_agg` has an `ORDER BY`. Stores the ordering column +/// values alongside each pair so the input can be globally sorted (across +/// partitions). +#[derive(Debug)] +pub struct OrderSensitiveMapAggAccumulator { + key_type: DataType, + value_type: DataType, + keys: Vec, + values: Vec, + /// Ordering-expression values for each pair, parallel to `keys`/`values`. + ordering_values: Vec>, + ordering_dtypes: Vec, + ordering_req: LexOrdering, +} + +impl OrderSensitiveMapAggAccumulator { + pub fn new( + key_type: DataType, + value_type: DataType, + ordering_dtypes: Vec, + ordering_req: LexOrdering, + ) -> Self { + Self { + key_type, + value_type, + keys: Vec::new(), + values: Vec::new(), + ordering_values: Vec::new(), + ordering_dtypes, + ordering_req, + } + } + + fn sort_options(&self) -> Vec { + self.ordering_req.iter().map(|s| s.options).collect() + } + + /// Sorts the accumulated pairs by their ordering values, then applies + /// first-wins de-duplication. + fn sorted_deduped(&self) -> Result<(Vec, Vec)> { + let sort_options = self.sort_options(); + let mut rows: Vec = (0..self.keys.len()).collect(); + let mut cmp_err = Ok(()); + rows.sort_by(|&a, &b| { + compare_rows( + &self.ordering_values[a], + &self.ordering_values[b], + &sort_options, + ) + .unwrap_or_else(|e| { + cmp_err = Err(e); + std::cmp::Ordering::Equal + }) + }); + cmp_err?; + + let keys = rows.iter().map(|&i| self.keys[i].clone()).collect(); + let values = rows.iter().map(|&i| self.values[i].clone()).collect(); + Ok(dedup_first_wins(keys, values)) + } + + /// Builds the `List>` state column carrying ordering + /// values for every accumulated pair. + fn evaluate_orderings(&self) -> Result { + let fields = ordering_fields(&self.ordering_req, &self.ordering_dtypes); + let num_rows = self.ordering_values.len(); + let struct_field = Fields::from(fields.clone()); + + let mut column_wise: Vec = Vec::with_capacity(fields.len()); + for (col_idx, field) in fields.iter().enumerate() { + if num_rows == 0 { + column_wise.push(arrow::array::new_empty_array(field.data_type())); + } else { + let col_vals = + self.ordering_values.iter().map(|row| row[col_idx].clone()); + column_wise.push(ScalarValue::iter_to_array(col_vals)?); + } + } + + let struct_array = StructArray::try_new(struct_field, column_wise, None)?; + Ok( + datafusion_common::utils::SingleRowListArrayBuilder::new(Arc::new( + struct_array, + )) + .build_list_scalar(), + ) + } +} + +impl Accumulator for OrderSensitiveMapAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.len() < 2 { + return exec_err!("map_agg expects at least 2 columns, got {}", values.len()); + } + let keys = &values[0]; + let vals = &values[1]; + let ordering_cols = &values[2..]; + + for i in 0..keys.len() { + // NULL keys cannot exist in a map; skip the whole pair. + if keys.is_null(i) { + continue; + } + self.keys + .push(ScalarValue::try_from_array(keys, i)?.compacted()); + self.values + .push(ScalarValue::try_from_array(vals, i)?.compacted()); + self.ordering_values.push( + get_row_at_idx(ordering_cols, i)? + .into_iter() + .map(|v| v.compacted()) + .collect(), + ); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + if states.len() != 2 { + return exec_err!( + "map_agg ordered merge expects 2 state columns, got {}", + states.len() + ); + } + + let map_array = states[0].as_map(); + let orderings = states[1].as_list::(); + + // Each partition contributes one map row plus a parallel list of + // ordering-value structs. Collect them, then merge by ordering. + let mut partition_keys: Vec> = + vec![take(&mut self.keys).into()]; + let mut partition_orderings: Vec>> = + vec![take(&mut self.ordering_values).into()]; + // Values are carried inside the ordering merge by pairing them with keys + // through a side table; simpler to merge keys+values together below. + let mut partition_values: Vec> = + vec![take(&mut self.values).into()]; + + for row in 0..map_array.len() { + if map_array.is_null(row) { + continue; + } + let (keys, values) = map_row_to_scalars(map_array, row)?; + let ord_vals = struct_to_rows(orderings.value(row).as_struct())?; + + partition_keys.push(keys.into()); + partition_values.push(values.into()); + partition_orderings.push(ord_vals.into()); + } + + // Merge keys and values along the ordering. `merge_ordered_arrays` + // merges a single value stream; run it once for keys and once for + // values using the same ordering inputs so they stay aligned. + let sort_options = self.sort_options(); + + let (merged_keys, merged_orderings) = merge_ordered_arrays( + &mut partition_keys, + &mut partition_orderings.clone(), + &sort_options, + )?; + let (merged_values, _) = merge_ordered_arrays( + &mut partition_values, + &mut partition_orderings, + &sort_options, + )?; + + self.keys = merged_keys; + self.values = merged_values; + self.ordering_values = merged_orderings; + Ok(()) + } + + fn state(&mut self) -> Result> { + // Ship the de-duplicated map plus ordering values for re-sorting. + let (keys, values) = self.sorted_deduped()?; + let map_array = build_single_map(keys, values, &self.key_type, &self.value_type)?; + Ok(vec![ + ScalarValue::try_from_array(&map_array, 0)?, + self.evaluate_orderings()?, + ]) + } + + fn evaluate(&mut self) -> Result { + let (keys, values) = self.sorted_deduped()?; + let map_array = build_single_map(keys, values, &self.key_type, &self.value_type)?; + ScalarValue::try_from_array(&map_array, 0) + } + + fn size(&self) -> usize { + let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.keys) + - size_of_val(&self.keys) + + ScalarValue::size_of_vec(&self.values) + - size_of_val(&self.values); + total += size_of::>() * self.ordering_values.capacity(); + for row in &self.ordering_values { + total += ScalarValue::size_of_vec(row) - size_of_val(row); + } + total += size_of::() * self.ordering_dtypes.capacity(); + total + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{Int32Type, Schema}; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + + fn make_acc() -> MapAggAccumulator { + MapAggAccumulator::new(DataType::Utf8, DataType::Int32) + } + + fn str_arr(vals: &[&str]) -> ArrayRef { + Arc::new(StringArray::from(vals.to_vec())) + } + + fn int_arr(vals: &[i32]) -> ArrayRef { + Arc::new(Int32Array::from(vals.to_vec())) + } + + fn str_arr_nullable(vals: Vec>) -> ArrayRef { + Arc::new(StringArray::from(vals)) + } + + fn extract_map(sv: ScalarValue) -> Vec<(String, Option)> { + let ScalarValue::Map(arr) = sv else { + panic!("expected ScalarValue::Map, got {sv:?}"); + }; + let entries = arr.value(0); + let entries = entries.as_any().downcast_ref::().unwrap(); + let keys = entries.column(0).as_string::(); + let vals = entries.column(1).as_primitive::(); + (0..keys.len()) + .map(|i| { + let v = if vals.is_null(i) { + None + } else { + Some(vals.value(i)) + }; + (keys.value(i).to_string(), v) + }) + .collect() + } + + #[test] + fn collects_distinct_pairs_in_order() -> Result<()> { + let mut acc = make_acc(); + acc.update_batch(&[str_arr(&["a", "b", "c"]), int_arr(&[1, 2, 3])])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!( + pairs, + vec![ + ("a".into(), Some(1)), + ("b".into(), Some(2)), + ("c".into(), Some(3)), + ] + ); + Ok(()) + } + + #[test] + fn null_key_skipped() -> Result<()> { + let mut acc = make_acc(); + // A null-key pair is dropped (matches Trino map_agg). + acc.update_batch(&[ + str_arr_nullable(vec![Some("a"), None, Some("c")]), + int_arr(&[1, 2, 3]), + ])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!(pairs.len(), 2); + assert_eq!(pairs[0].0, "a"); + assert_eq!(pairs[1].0, "c"); + Ok(()) + } + + #[test] + fn null_value_retained() -> Result<()> { + let mut acc = make_acc(); + let vals: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); + acc.update_batch(&[str_arr(&["a", "b", "c"]), vals])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!(pairs[1], ("b".into(), None)); + Ok(()) + } + + #[test] + fn duplicate_key_first_wins() -> Result<()> { + let mut acc = make_acc(); + // a appears twice; the first value (1) wins, later one is dropped. + acc.update_batch(&[str_arr(&["a", "b", "a"]), int_arr(&[1, 2, 3])])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!(pairs, vec![("a".into(), Some(1)), ("b".into(), Some(2))]); + Ok(()) + } + + #[test] + fn empty_produces_empty_map() -> Result<()> { + let mut acc = make_acc(); + let pairs = extract_map(acc.evaluate()?); + assert!(pairs.is_empty()); + Ok(()) + } + + #[test] + fn merge_two_partitions() -> Result<()> { + let mut acc1 = make_acc(); + let mut acc2 = make_acc(); + + acc1.update_batch(&[str_arr(&["a", "b"]), int_arr(&[1, 2])])?; + acc2.update_batch(&[str_arr(&["c", "d"]), int_arr(&[3, 4])])?; + + let state2 = acc2 + .state()? + .into_iter() + .map(|sv| sv.to_array()) + .collect::>>()?; + acc1.merge_batch(&state2)?; + + let pairs = extract_map(acc1.evaluate()?); + assert_eq!(pairs.len(), 4); + Ok(()) + } + + #[test] + fn merge_duplicate_key_across_partitions_first_wins() -> Result<()> { + let mut acc1 = make_acc(); + let mut acc2 = make_acc(); + + acc1.update_batch(&[str_arr(&["a"]), int_arr(&[1])])?; + acc2.update_batch(&[str_arr(&["a"]), int_arr(&[2])])?; + + let state2 = acc2 + .state()? + .into_iter() + .map(|sv| sv.to_array()) + .collect::>>()?; + acc1.merge_batch(&state2)?; + + let pairs = extract_map(acc1.evaluate()?); + // acc1's pair comes first in the merged stream, so value 1 wins. + assert_eq!(pairs, vec![("a".into(), Some(1))]); + Ok(()) + } + + /// Builds an order-sensitive accumulator that orders by a single Int32 + /// column `ord` (column index 2 in the update batch). + fn make_ordered_acc(descending: bool) -> OrderSensitiveMapAggAccumulator { + let schema = Schema::new(vec![ + Field::new("k", DataType::Utf8, true), + Field::new("v", DataType::Int32, true), + Field::new("ord", DataType::Int32, true), + ]); + let ord_expr = Arc::new(Column::new_with_schema("ord", &schema).unwrap()); + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new( + ord_expr, + SortOptions::new(descending, false), + )]) + .unwrap(); + OrderSensitiveMapAggAccumulator::new( + DataType::Utf8, + DataType::Int32, + vec![DataType::Int32], + ordering, + ) + } + + #[test] + fn ordered_dup_key_asc_first_wins() -> Result<()> { + let mut acc = make_ordered_acc(false); + // key "a" twice: ord=10 -> v=1, ord=20 -> v=2. ASC sort puts v=1 + // (ord=10) first, so first-wins keeps v=1. + acc.update_batch(&[str_arr(&["a", "a"]), int_arr(&[1, 2]), int_arr(&[10, 20])])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!(pairs, vec![("a".into(), Some(1))]); + Ok(()) + } + + #[test] + fn ordered_dup_key_desc_flips_winner() -> Result<()> { + let mut acc = make_ordered_acc(true); + // Same input, DESC sort puts v=2 (ord=20) first, so first-wins keeps v=2. + acc.update_batch(&[str_arr(&["a", "a"]), int_arr(&[1, 2]), int_arr(&[10, 20])])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!(pairs, vec![("a".into(), Some(2))]); + Ok(()) + } + + #[test] + fn ordered_merge_two_partitions() -> Result<()> { + // Partition 1 sees ord=20, partition 2 sees ord=10 for the same key. + // After merge + ASC sort, the ord=10 row is first, so its value wins. + let mut acc1 = make_ordered_acc(false); + let mut acc2 = make_ordered_acc(false); + + acc1.update_batch(&[str_arr(&["a"]), int_arr(&[2]), int_arr(&[20])])?; + acc2.update_batch(&[str_arr(&["a"]), int_arr(&[1]), int_arr(&[10])])?; + + let state2 = acc2 + .state()? + .into_iter() + .map(|sv| sv.to_array()) + .collect::>>()?; + acc1.merge_batch(&state2)?; + + let pairs = extract_map(acc1.evaluate()?); + assert_eq!(pairs, vec![("a".into(), Some(1))]); + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/utils.rs b/datafusion/functions-aggregate/src/utils.rs index 6d816e54bdaf2..fd5fd58fa70d8 100644 --- a/datafusion/functions-aggregate/src/utils.rs +++ b/datafusion/functions-aggregate/src/utils.rs @@ -17,9 +17,11 @@ use std::sync::Arc; -use arrow::array::RecordBatch; +use arrow::array::{Array, MapArray, RecordBatch, StructArray}; use arrow::datatypes::Schema; -use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err, plan_err}; +use datafusion_common::{ + DataFusionError, Result, ScalarValue, internal_datafusion_err, internal_err, plan_err, +}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -75,3 +77,38 @@ pub(crate) fn validate_percentile_expr( } Ok(percentile) } + +/// Reads the key/value scalars out of one row of a `MapArray`. +pub(crate) fn map_row_to_scalars( + map_array: &MapArray, + row: usize, +) -> Result<(Vec, Vec)> { + let entries = map_array.value(row); + let entries = entries + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("map entries must be a StructArray"))?; + let key_col = entries.column(0); + let val_col = entries.column(1); + + let mut keys = Vec::with_capacity(key_col.len()); + let mut values = Vec::with_capacity(val_col.len()); + for i in 0..key_col.len() { + keys.push(ScalarValue::try_from_array(key_col, i)?); + values.push(ScalarValue::try_from_array(val_col, i)?); + } + Ok((keys, values)) +} + +/// Converts a `StructArray` into per-row scalar tuples: `rows[i][c]` is column +/// `c` of row `i`. +pub(crate) fn struct_to_rows(s: &StructArray) -> Result>> { + (0..s.len()) + .map(|i| { + s.columns() + .iter() + .map(|col| ScalarValue::try_from_array(col, i)) + .collect() + }) + .collect() +} diff --git a/datafusion/sqllogictest/test_files/map_agg.slt b/datafusion/sqllogictest/test_files/map_agg.slt new file mode 100644 index 0000000000000..705f647a6ca73 --- /dev/null +++ b/datafusion/sqllogictest/test_files/map_agg.slt @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE kv (g INT, k VARCHAR, v INT, ts INT) AS VALUES + (1, 'a', 1, 10), + (1, 'b', 2, 20), + (2, 'c', 3, 30), + (2, 'a', 4, 40); + +# Basic grouped map_agg +query I? +SELECT g, map_agg(k, v) FROM kv GROUP BY g ORDER BY g; +---- +1 {a: 1, b: 2} +2 {c: 3, a: 4} + +# Duplicate key within a group: first-wins (matches Trino map_agg). ORDER BY +# decides which row is first, hence which value survives. +statement ok +CREATE TABLE dup (k VARCHAR, v INT, ts INT) AS VALUES + ('a', 1, 10), + ('a', 2, 20); + +# ORDER BY ts ASC -> v=1 (ts=10) is first -> wins +query ? +SELECT map_agg(k, v ORDER BY ts ASC) FROM dup; +---- +{a: 1} + +# ORDER BY ts DESC -> v=2 (ts=20) is first -> wins +query ? +SELECT map_agg(k, v ORDER BY ts DESC) FROM dup; +---- +{a: 2} + +# Pair with a NULL key is skipped (matches Trino map_agg) +statement ok +CREATE TABLE nullkey (k VARCHAR, v INT) AS VALUES + ('a', 1), + (NULL, 2), + ('c', 3); + +query ? +SELECT map_agg(k, v) FROM nullkey; +---- +{a: 1, c: 3} + +statement ok +DROP TABLE kv; + +statement ok +DROP TABLE dup; + +statement ok +DROP TABLE nullkey; From 0c2a55adf0c4f899d5195db9cd8dfb7ced9babaa Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Thu, 18 Jun 2026 11:10:35 +0200 Subject: [PATCH 2/7] Add doc --- datafusion/functions-aggregate/src/map_agg.rs | 4 +-- .../sqllogictest/test_files/map_agg.slt | 20 +++++++++++++++ .../user-guide/sql/aggregate_functions.md | 25 +++++++++++++++++++ 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/map_agg.rs b/datafusion/functions-aggregate/src/map_agg.rs index 8a32954768a3a..65403e14a8d4b 100644 --- a/datafusion/functions-aggregate/src/map_agg.rs +++ b/datafusion/functions-aggregate/src/map_agg.rs @@ -445,11 +445,10 @@ impl Accumulator for OrderSensitiveMapAggAccumulator { vec![take(&mut self.keys).into()]; let mut partition_orderings: Vec>> = vec![take(&mut self.ordering_values).into()]; - // Values are carried inside the ordering merge by pairing them with keys - // through a side table; simpler to merge keys+values together below. let mut partition_values: Vec> = vec![take(&mut self.values).into()]; + // Push keys and values from each partition's state into the merge buffers. for row in 0..map_array.len() { if map_array.is_null(row) { continue; @@ -485,7 +484,6 @@ impl Accumulator for OrderSensitiveMapAggAccumulator { } fn state(&mut self) -> Result> { - // Ship the de-duplicated map plus ordering values for re-sorting. let (keys, values) = self.sorted_deduped()?; let map_array = build_single_map(keys, values, &self.key_type, &self.value_type)?; Ok(vec![ diff --git a/datafusion/sqllogictest/test_files/map_agg.slt b/datafusion/sqllogictest/test_files/map_agg.slt index 705f647a6ca73..2a0ce7dd17630 100644 --- a/datafusion/sqllogictest/test_files/map_agg.slt +++ b/datafusion/sqllogictest/test_files/map_agg.slt @@ -60,6 +60,26 @@ SELECT map_agg(k, v) FROM nullkey; ---- {a: 1, c: 3} +# Window usage: a cumulative frame calls evaluate() once per row, so the map +# grows as the frame expands. Guards against optimizations that consume the +# accumulator state between evaluate() calls. +statement ok +CREATE TABLE win (k VARCHAR, v INT, ts INT) AS VALUES + ('a', 1, 10), + ('b', 2, 20), + ('c', 3, 30); + +query ? +SELECT map_agg(k, v) OVER (ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +FROM win; +---- +{a: 1} +{a: 1, b: 2} +{a: 1, b: 2, c: 3} + +statement ok +DROP TABLE win; + statement ok DROP TABLE kv; diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index ba9c6ae12477b..1aa8e9a971534 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -91,6 +91,7 @@ SELECT SUM(x) WITHIN GROUP (ORDER BY x) FROM t; - [first_value](#first_value) - [grouping](#grouping) - [last_value](#last_value) +- [map_agg](#map_agg) - [max](#max) - [mean](#mean) - [median](#median) @@ -347,6 +348,30 @@ last_value(expression [ORDER BY expression]) +-----------------------------------------------+ ``` +### `map_agg` + +Aggregate key-value pairs from two columns into a single map per group. Pairs with a NULL key are skipped; NULL values are retained. On a duplicate key the first value wins; use ORDER BY to make which value wins deterministic. + +```sql +map_agg(key, value [ORDER BY expression]) +``` + +#### Arguments + +- **key**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **value**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT map_agg(name, score) FROM scores GROUP BY department; ++-------------------------------+ +| map_agg(name, score) | ++-------------------------------+ +| {Alice: 95, Bob: 87} | ++-------------------------------+ +``` + ### `max` Returns the maximum value in the specified column. From d9292129dc3e6d49bad44a90a8ce9522a8790c04 Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Thu, 18 Jun 2026 11:31:31 +0200 Subject: [PATCH 3/7] avoid allocation in dedup_first_wins --- datafusion/functions-aggregate/src/map_agg.rs | 73 ++++++++----------- 1 file changed, 30 insertions(+), 43 deletions(-) diff --git a/datafusion/functions-aggregate/src/map_agg.rs b/datafusion/functions-aggregate/src/map_agg.rs index 65403e14a8d4b..339efed75e472 100644 --- a/datafusion/functions-aggregate/src/map_agg.rs +++ b/datafusion/functions-aggregate/src/map_agg.rs @@ -17,7 +17,7 @@ //! `MAP_AGG` aggregate implementation: [`MapAgg`] -use std::collections::VecDeque; +use std::collections::{HashSet, VecDeque}; use std::mem::{size_of, size_of_val, take}; use std::sync::Arc; @@ -26,7 +26,7 @@ use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, FieldRef, Fields}; -use datafusion_common::utils::{compare_rows, get_row_at_idx}; +use datafusion_common::utils::{SingleRowListArrayBuilder, compare_rows, get_row_at_idx}; use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; @@ -150,13 +150,11 @@ impl AggregateUDFImpl for MapAgg { } fn map_type(key_type: &DataType, value_type: &DataType) -> DataType { - let key_field = Arc::new(Field::new("key", key_type.clone(), false)); - let value_field = Arc::new(Field::new("value", value_type.clone(), true)); - let entries_field = Arc::new(Field::new( - "entries", - DataType::Struct(Fields::from(vec![key_field, value_field])), - false, - )); + let fields = Fields::from(vec![ + Field::new("key", key_type.clone(), false), + Field::new("value", value_type.clone(), true), + ]); + let entries_field = Arc::new(Field::new("entries", DataType::Struct(fields), false)); DataType::Map(entries_field, false) } @@ -168,14 +166,13 @@ fn build_single_map( ) -> Result { debug_assert_eq!(keys.len(), values.len()); - let key_field = Arc::new(Field::new("key", key_type.clone(), false)); - let value_field = Arc::new(Field::new("value", value_type.clone(), true)); + let fields = Fields::from(vec![ + Field::new("key", key_type.clone(), false), + Field::new("value", value_type.clone(), true), + ]); let entries_field = Arc::new(Field::new( "entries", - DataType::Struct(Fields::from(vec![ - Arc::clone(&key_field), - Arc::clone(&value_field), - ])), + DataType::Struct(fields.clone()), false, )); @@ -191,11 +188,7 @@ fn build_single_map( ScalarValue::iter_to_array(values)? }; - let entries = StructArray::try_new( - Fields::from(vec![key_field, value_field]), - vec![key_array, value_array], - None, - )?; + let entries = StructArray::try_new(fields, vec![key_array, value_array], None)?; let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0i32, len as i32])); Ok(Arc::new(MapArray::try_new( @@ -207,27 +200,26 @@ fn build_single_map( )?)) } -/// De-duplicates parallel key/value vectors keeping the first value seen for -/// each key. fn dedup_first_wins( keys: Vec, values: Vec, ) -> (Vec, Vec) { - use std::collections::HashSet; - - let mut seen: HashSet = HashSet::with_capacity(keys.len()); - let mut out_keys: Vec = Vec::with_capacity(keys.len()); - let mut out_vals: Vec = Vec::with_capacity(keys.len()); - - for (k, v) in keys.into_iter().zip(values) { - // Keep only the first occurrence of each key; later ones are dropped. - if seen.insert(k.clone()) { - out_keys.push(k); - out_vals.push(v); - } - } - - (out_keys, out_vals) + // First pass: mark each position that is the first occurrence of its key. + let mut seen = HashSet::with_capacity(keys.len()); + let keep: Vec = keys.iter().map(|k| seen.insert(k)).collect(); + + // Second pass: keep only the first-occurrence positions. + let out_keys = keys + .into_iter() + .zip(&keep) + .filter_map(|(k, &keep)| keep.then_some(k)) + .collect(); + let out_values = values + .into_iter() + .zip(&keep) + .filter_map(|(v, &keep)| keep.then_some(v)) + .collect(); + (out_keys, out_values) } /// Plain accumulator used when there is no `ORDER BY`. @@ -388,12 +380,7 @@ impl OrderSensitiveMapAggAccumulator { } let struct_array = StructArray::try_new(struct_field, column_wise, None)?; - Ok( - datafusion_common::utils::SingleRowListArrayBuilder::new(Arc::new( - struct_array, - )) - .build_list_scalar(), - ) + Ok(SingleRowListArrayBuilder::new(Arc::new(struct_array)).build_list_scalar()) } } From 3aa09db75dbcadf335c7712a785b28058cc559a7 Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Fri, 19 Jun 2026 11:12:53 +0200 Subject: [PATCH 4/7] fix dup keys ordering values not being aligned --- datafusion/functions-aggregate/src/map_agg.rs | 154 +++++++++++++----- .../sqllogictest/test_files/map_agg.slt | 26 ++- 2 files changed, 134 insertions(+), 46 deletions(-) diff --git a/datafusion/functions-aggregate/src/map_agg.rs b/datafusion/functions-aggregate/src/map_agg.rs index 339efed75e472..dcc823f5e926a 100644 --- a/datafusion/functions-aggregate/src/map_agg.rs +++ b/datafusion/functions-aggregate/src/map_agg.rs @@ -200,28 +200,6 @@ fn build_single_map( )?)) } -fn dedup_first_wins( - keys: Vec, - values: Vec, -) -> (Vec, Vec) { - // First pass: mark each position that is the first occurrence of its key. - let mut seen = HashSet::with_capacity(keys.len()); - let keep: Vec = keys.iter().map(|k| seen.insert(k)).collect(); - - // Second pass: keep only the first-occurrence positions. - let out_keys = keys - .into_iter() - .zip(&keep) - .filter_map(|(k, &keep)| keep.then_some(k)) - .collect(); - let out_values = values - .into_iter() - .zip(&keep) - .filter_map(|(v, &keep)| keep.then_some(v)) - .collect(); - (out_keys, out_values) -} - /// Plain accumulator used when there is no `ORDER BY`. #[derive(Debug)] pub struct MapAggAccumulator { @@ -240,6 +218,30 @@ impl MapAggAccumulator { values: Vec::new(), } } + + /// De-duplicates parallel key/value vectors, keeping the first value seen + /// for each key. Surviving keys retain their first-seen order. + fn dedup_first_wins( + keys: Vec, + values: Vec, + ) -> (Vec, Vec) { + // First pass: mark each position that is the first occurrence of its key. + let mut seen = HashSet::with_capacity(keys.len()); + let keep: Vec = keys.iter().map(|k| seen.insert(k)).collect(); + + // Second pass: keep only the first-occurrence positions. + let out_keys = keys + .into_iter() + .zip(&keep) + .filter_map(|(k, &keep)| keep.then_some(k)) + .collect(); + let out_values = values + .into_iter() + .zip(&keep) + .filter_map(|(v, &keep)| keep.then_some(v)) + .collect(); + (out_keys, out_values) + } } impl Accumulator for MapAggAccumulator { @@ -284,7 +286,8 @@ impl Accumulator for MapAggAccumulator { } fn evaluate(&mut self) -> Result { - let (keys, values) = dedup_first_wins(self.keys.clone(), self.values.clone()); + let (keys, values) = + Self::dedup_first_wins(self.keys.clone(), self.values.clone()); let map_array = build_single_map(keys, values, &self.key_type, &self.value_type)?; ScalarValue::try_from_array(&map_array, 0) } @@ -300,6 +303,12 @@ impl Accumulator for MapAggAccumulator { } } +struct OrderSensitiveMapAggRows { + keys: Vec, + values: Vec, + ordering_values: Vec>, +} + /// Accumulator used when `map_agg` has an `ORDER BY`. Stores the ordering column /// values alongside each pair so the input can be globally sorted (across /// partitions). @@ -338,8 +347,9 @@ impl OrderSensitiveMapAggAccumulator { } /// Sorts the accumulated pairs by their ordering values, then applies - /// first-wins de-duplication. - fn sorted_deduped(&self) -> Result<(Vec, Vec)> { + /// first-wins de-duplication. Returns the surviving keys, values, and + /// ordering values, all aligned so they describe the same rows. + fn sorted_deduped(&self) -> Result { let sort_options = self.sort_options(); let mut rows: Vec = (0..self.keys.len()).collect(); let mut cmp_err = Ok(()); @@ -356,25 +366,44 @@ impl OrderSensitiveMapAggAccumulator { }); cmp_err?; - let keys = rows.iter().map(|&i| self.keys[i].clone()).collect(); - let values = rows.iter().map(|&i| self.values[i].clone()).collect(); - Ok(dedup_first_wins(keys, values)) + // Keep the first occurrence of each key in sorted order, and project the + // keys, values, and ordering values through the same surviving indices + // so all three stay aligned. + let mut seen = HashSet::with_capacity(rows.len()); + let mut keys = Vec::new(); + let mut values = Vec::new(); + let mut ordering_values = Vec::new(); + for &i in &rows { + if seen.insert(&self.keys[i]) { + keys.push(self.keys[i].clone()); + values.push(self.values[i].clone()); + ordering_values.push(self.ordering_values[i].clone()); + } + } + + Ok(OrderSensitiveMapAggRows { + keys, + values, + ordering_values, + }) } - /// Builds the `List>` state column carrying ordering - /// values for every accumulated pair. - fn evaluate_orderings(&self) -> Result { + /// Builds the `List>` state column from the given + /// ordering values. These must be the de-duplicated ordering values that + /// align with the map state, so both pieces of state describe the same rows. + fn evaluate_orderings( + &self, + ordering_values: &[Vec], + ) -> Result { let fields = ordering_fields(&self.ordering_req, &self.ordering_dtypes); - let num_rows = self.ordering_values.len(); let struct_field = Fields::from(fields.clone()); let mut column_wise: Vec = Vec::with_capacity(fields.len()); for (col_idx, field) in fields.iter().enumerate() { - if num_rows == 0 { + if ordering_values.is_empty() { column_wise.push(arrow::array::new_empty_array(field.data_type())); } else { - let col_vals = - self.ordering_values.iter().map(|row| row[col_idx].clone()); + let col_vals = ordering_values.iter().map(|row| row[col_idx].clone()); column_wise.push(ScalarValue::iter_to_array(col_vals)?); } } @@ -471,17 +500,17 @@ impl Accumulator for OrderSensitiveMapAggAccumulator { } fn state(&mut self) -> Result> { - let (keys, values) = self.sorted_deduped()?; - let map_array = build_single_map(keys, values, &self.key_type, &self.value_type)?; - Ok(vec![ - ScalarValue::try_from_array(&map_array, 0)?, - self.evaluate_orderings()?, - ]) + let rows = self.sorted_deduped()?; + let orderings = self.evaluate_orderings(&rows.ordering_values)?; + let map_array = + build_single_map(rows.keys, rows.values, &self.key_type, &self.value_type)?; + Ok(vec![ScalarValue::try_from_array(&map_array, 0)?, orderings]) } fn evaluate(&mut self) -> Result { - let (keys, values) = self.sorted_deduped()?; - let map_array = build_single_map(keys, values, &self.key_type, &self.value_type)?; + let rows = self.sorted_deduped()?; + let map_array = + build_single_map(rows.keys, rows.values, &self.key_type, &self.value_type)?; ScalarValue::try_from_array(&map_array, 0) } @@ -707,4 +736,43 @@ mod tests { assert_eq!(pairs, vec![("a".into(), Some(1))]); Ok(()) } + + #[test] + fn ordered_merge_with_intra_partition_duplicates() -> Result<()> { + let mut acc1 = make_ordered_acc(false); + let mut acc2 = make_ordered_acc(false); + + // P1: a -> {1@10, 3@30}, b -> {2@20} => after dedup: a=1, b=2 + acc1.update_batch(&[ + str_arr(&["a", "b", "a"]), + int_arr(&[1, 2, 3]), + int_arr(&[10, 20, 30]), + ])?; + // P2: a -> {4@40, 5@50}, c -> {6@60} => after dedup: a=4, c=6 + acc2.update_batch(&[ + str_arr(&["a", "a", "c"]), + int_arr(&[4, 5, 6]), + int_arr(&[40, 50, 60]), + ])?; + + let state2 = acc2 + .state()? + .into_iter() + .map(|sv| sv.to_array()) + .collect::>>()?; + acc1.merge_batch(&state2)?; + + let mut pairs = extract_map(acc1.evaluate()?); + pairs.sort(); + // Global first-wins by ord: a@10=1, b@20=2, c@60=6. + assert_eq!( + pairs, + vec![ + ("a".into(), Some(1)), + ("b".into(), Some(2)), + ("c".into(), Some(6)), + ] + ); + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/map_agg.slt b/datafusion/sqllogictest/test_files/map_agg.slt index 2a0ce7dd17630..3b5d6c1b8330e 100644 --- a/datafusion/sqllogictest/test_files/map_agg.slt +++ b/datafusion/sqllogictest/test_files/map_agg.slt @@ -29,8 +29,8 @@ SELECT g, map_agg(k, v) FROM kv GROUP BY g ORDER BY g; 1 {a: 1, b: 2} 2 {c: 3, a: 4} -# Duplicate key within a group: first-wins (matches Trino map_agg). ORDER BY -# decides which row is first, hence which value survives. +# Duplicate key within a group: first value wins. ORDER BY decides which row is +# first, hence which value survives. statement ok CREATE TABLE dup (k VARCHAR, v INT, ts INT) AS VALUES ('a', 1, 10), @@ -48,7 +48,24 @@ SELECT map_agg(k, v ORDER BY ts DESC) FROM dup; ---- {a: 2} -# Pair with a NULL key is skipped (matches Trino map_agg) +# Grouped duplicate keys with ORDER BY: each group de-duplicates its map, so the +# partial ordering state must stay aligned with the de-duplicated map across the +# partial -> final merge. ASC keeps the lowest-ts value per key. +statement ok +CREATE TABLE dup_grouped (g INT, k VARCHAR, v INT, ts INT) AS VALUES + (1, 'a', 1, 10), + (1, 'a', 2, 20), + (1, 'b', 3, 30), + (2, 'c', 4, 40), + (2, 'c', 5, 50); + +query I? +SELECT g, map_agg(k, v ORDER BY ts ASC) FROM dup_grouped GROUP BY g ORDER BY g; +---- +1 {a: 1, b: 3} +2 {c: 4} + +# Pair with a NULL key is skipped statement ok CREATE TABLE nullkey (k VARCHAR, v INT) AS VALUES ('a', 1), @@ -86,5 +103,8 @@ DROP TABLE kv; statement ok DROP TABLE dup; +statement ok +DROP TABLE dup_grouped; + statement ok DROP TABLE nullkey; From 756627f4182fc310d57752247bbc109901a0f5a5 Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Fri, 19 Jun 2026 11:40:25 +0200 Subject: [PATCH 5/7] change HardRequirement -> SoftRequirement --- datafusion/functions-aggregate/src/map_agg.rs | 56 ++++++++++++------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/datafusion/functions-aggregate/src/map_agg.rs b/datafusion/functions-aggregate/src/map_agg.rs index dcc823f5e926a..11247da4f6666 100644 --- a/datafusion/functions-aggregate/src/map_agg.rs +++ b/datafusion/functions-aggregate/src/map_agg.rs @@ -68,12 +68,16 @@ make_udaf_expr_and_func!( #[derive(Debug, PartialEq, Eq, Hash)] pub struct MapAgg { signature: Signature, + /// Whether the optimizer guarantees the input is already ordered by the + /// `ORDER BY` clause, letting the accumulator skip its internal sort. + is_input_pre_ordered: bool, } impl Default for MapAgg { fn default() -> Self { Self { signature: Signature::any(2, Volatility::Immutable), + is_input_pre_ordered: false, } } } @@ -116,11 +120,17 @@ impl AggregateUDFImpl for MapAgg { } fn order_sensitivity(&self) -> AggregateOrderSensitivity { - // Order decides which value wins on a duplicate key, so the optimizer - // must satisfy it (inserts a SortExec). - // TODO: handle pre-sorted input like `array_agg` to skip the - // redundant sort. - AggregateOrderSensitivity::HardRequirement + AggregateOrderSensitivity::SoftRequirement + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + Ok(Some(Arc::new(Self { + signature: self.signature.clone(), + is_input_pre_ordered: beneficial_ordering, + }))) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -141,6 +151,7 @@ impl AggregateUDFImpl for MapAgg { value_type, ordering_dtypes, ordering, + self.is_input_pre_ordered, ))) } @@ -322,6 +333,7 @@ pub struct OrderSensitiveMapAggAccumulator { ordering_values: Vec>, ordering_dtypes: Vec, ordering_req: LexOrdering, + is_input_pre_ordered: bool, } impl OrderSensitiveMapAggAccumulator { @@ -330,6 +342,7 @@ impl OrderSensitiveMapAggAccumulator { value_type: DataType, ordering_dtypes: Vec, ordering_req: LexOrdering, + is_input_pre_ordered: bool, ) -> Self { Self { key_type, @@ -339,6 +352,7 @@ impl OrderSensitiveMapAggAccumulator { ordering_values: Vec::new(), ordering_dtypes, ordering_req, + is_input_pre_ordered, } } @@ -350,21 +364,24 @@ impl OrderSensitiveMapAggAccumulator { /// first-wins de-duplication. Returns the surviving keys, values, and /// ordering values, all aligned so they describe the same rows. fn sorted_deduped(&self) -> Result { - let sort_options = self.sort_options(); let mut rows: Vec = (0..self.keys.len()).collect(); - let mut cmp_err = Ok(()); - rows.sort_by(|&a, &b| { - compare_rows( - &self.ordering_values[a], - &self.ordering_values[b], - &sort_options, - ) - .unwrap_or_else(|e| { - cmp_err = Err(e); - std::cmp::Ordering::Equal - }) - }); - cmp_err?; + + if !self.is_input_pre_ordered { + let sort_options = self.sort_options(); + let mut cmp_err = Ok(()); + rows.sort_by(|&a, &b| { + compare_rows( + &self.ordering_values[a], + &self.ordering_values[b], + &sort_options, + ) + .unwrap_or_else(|e| { + cmp_err = Err(e); + std::cmp::Ordering::Equal + }) + }); + cmp_err?; + } // Keep the first occurrence of each key in sorted order, and project the // keys, values, and ordering values through the same surviving indices @@ -691,6 +708,7 @@ mod tests { DataType::Int32, vec![DataType::Int32], ordering, + false, ) } From a2cbaae9baf3ba537ab7f095ef389f072e4621f6 Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Fri, 19 Jun 2026 11:55:17 +0200 Subject: [PATCH 6/7] Return NULL for empty input --- datafusion/functions-aggregate/src/map_agg.rs | 16 +++++++++++----- datafusion/sqllogictest/test_files/map_agg.slt | 6 ++++++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/datafusion/functions-aggregate/src/map_agg.rs b/datafusion/functions-aggregate/src/map_agg.rs index 11247da4f6666..f5bf12ae9aef7 100644 --- a/datafusion/functions-aggregate/src/map_agg.rs +++ b/datafusion/functions-aggregate/src/map_agg.rs @@ -22,7 +22,7 @@ use std::mem::{size_of, size_of_val, take}; use std::sync::Arc; use arrow::array::{Array, ArrayRef, AsArray, MapArray, StructArray}; -use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, FieldRef, Fields}; @@ -201,12 +201,15 @@ fn build_single_map( let entries = StructArray::try_new(fields, vec![key_array, value_array], None)?; + // With no entries, emit a single NULL map + let nulls = (len == 0).then(|| NullBuffer::new_null(1)); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0i32, len as i32])); Ok(Arc::new(MapArray::try_new( entries_field, offsets, entries, - None, + nulls, false, )?)) } @@ -641,10 +644,13 @@ mod tests { } #[test] - fn empty_produces_empty_map() -> Result<()> { + fn empty_produces_null_map() -> Result<()> { let mut acc = make_acc(); - let pairs = extract_map(acc.evaluate()?); - assert!(pairs.is_empty()); + let ScalarValue::Map(arr) = acc.evaluate()? else { + panic!("expected ScalarValue::Map"); + }; + assert_eq!(arr.len(), 1); + assert!(arr.is_null(0)); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/map_agg.slt b/datafusion/sqllogictest/test_files/map_agg.slt index 3b5d6c1b8330e..ac0d4448f761c 100644 --- a/datafusion/sqllogictest/test_files/map_agg.slt +++ b/datafusion/sqllogictest/test_files/map_agg.slt @@ -77,6 +77,12 @@ SELECT map_agg(k, v) FROM nullkey; ---- {a: 1, c: 3} +# Empty input: a global aggregate over zero rows returns a single NULL map. +query ? +SELECT map_agg(k, v) FROM nullkey WHERE false; +---- +NULL + # Window usage: a cumulative frame calls evaluate() once per row, so the map # grows as the frame expands. Guards against optimizations that consume the # accumulator state between evaluate() calls. From 7468b5acf28e8812c5364ead0bb9c4acc0ce8fa0 Mon Sep 17 00:00:00 2001 From: LiaCastaneda Date: Fri, 19 Jun 2026 12:33:49 +0200 Subject: [PATCH 7/7] Account for data types in size() --- datafusion/functions-aggregate/src/map_agg.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/map_agg.rs b/datafusion/functions-aggregate/src/map_agg.rs index f5bf12ae9aef7..181eb4b1894a9 100644 --- a/datafusion/functions-aggregate/src/map_agg.rs +++ b/datafusion/functions-aggregate/src/map_agg.rs @@ -538,12 +538,24 @@ impl Accumulator for OrderSensitiveMapAggAccumulator { let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.keys) - size_of_val(&self.keys) + ScalarValue::size_of_vec(&self.values) - - size_of_val(&self.values); + - size_of_val(&self.values) + + self.key_type.size() + - size_of_val(&self.key_type) + + self.value_type.size() + - size_of_val(&self.value_type); + + // ordering_values: Vec spine plus the heap owned by each row's scalars. total += size_of::>() * self.ordering_values.capacity(); for row in &self.ordering_values { total += ScalarValue::size_of_vec(row) - size_of_val(row); } + + // ordering_dtypes: Vec spine plus the heap owned by each DataType. total += size_of::() * self.ordering_dtypes.capacity(); + for dtype in &self.ordering_dtypes { + total += dtype.size() - size_of_val(dtype); + } + total } }