diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 1b9996220d882..f1656e449e006 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -79,6 +79,7 @@ pub mod covariance; pub mod first_last; pub mod grouping; pub mod hyperloglog; +pub mod map_agg; pub mod median; pub mod min_max; pub mod nth_value; @@ -122,6 +123,7 @@ pub mod expr_fn { pub use super::first_last::first_value; pub use super::first_last::last_value; pub use super::grouping::grouping; + pub use super::map_agg::map_agg; pub use super::median::median; pub use super::min_max::max; pub use super::min_max::min; @@ -156,6 +158,7 @@ pub fn all_default_aggregate_functions() -> Vec> { sum::sum_udaf(), min_max::max_udaf(), min_max::min_udaf(), + map_agg::map_agg_udaf(), median::median_udaf(), count::count_udaf(), regr::regr_slope_udaf(), diff --git a/datafusion/functions-aggregate/src/map_agg.rs b/datafusion/functions-aggregate/src/map_agg.rs new file mode 100644 index 0000000000000..181eb4b1894a9 --- /dev/null +++ b/datafusion/functions-aggregate/src/map_agg.rs @@ -0,0 +1,814 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `MAP_AGG` aggregate implementation: [`MapAgg`] + +use std::collections::{HashSet, VecDeque}; +use std::mem::{size_of, size_of_val, take}; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, AsArray, MapArray, StructArray}; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; + +use datafusion_common::utils::{SingleRowListArrayBuilder, compare_rows, get_row_at_idx}; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; +use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; +use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; +use datafusion_functions_aggregate_common::utils::ordering_fields; +use datafusion_macros::user_doc; +use datafusion_physical_expr_common::sort_expr::LexOrdering; + +use crate::utils::{map_row_to_scalars, struct_to_rows}; + +make_udaf_expr_and_func!( + MapAgg, + map_agg, + "Aggregate key-value pairs into a map", + map_agg_udaf +); + +#[user_doc( + doc_section(label = "General Functions"), + description = "Aggregate key-value pairs from two columns into a single map per group. Pairs with a NULL key are skipped; NULL values are retained. On a duplicate key the first value wins; use ORDER BY to make which value wins deterministic.", + syntax_example = "map_agg(key, value [ORDER BY expression])", + sql_example = r#" +```sql +> SELECT map_agg(name, score) FROM scores GROUP BY department; ++-------------------------------+ +| map_agg(name, score) | ++-------------------------------+ +| {Alice: 95, Bob: 87} | ++-------------------------------+ +``` +"#, + standard_argument(name = "key",), + standard_argument(name = "value",) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct MapAgg { + signature: Signature, + /// Whether the optimizer guarantees the input is already ordered by the + /// `ORDER BY` clause, letting the accumulator skip its internal sort. + is_input_pre_ordered: bool, +} + +impl Default for MapAgg { + fn default() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + is_input_pre_ordered: false, + } + } +} + +impl AggregateUDFImpl for MapAgg { + fn name(&self) -> &str { + "map_agg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(map_type(&arg_types[0], &arg_types[1])) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let key_type = args.input_fields[0].data_type(); + let value_type = args.input_fields[1].data_type(); + + let mut fields = vec![Arc::new(Field::new( + format_state_name(args.name, "map_agg"), + map_type(key_type, value_type), + true, + ))]; + + if !args.ordering_fields.is_empty() { + fields.push(Arc::new(Field::new_list( + format_state_name(args.name, "map_agg_orderings"), + Field::new_list_field( + DataType::Struct(Fields::from(args.ordering_fields.to_vec())), + true, + ), + false, + ))); + } + + Ok(fields) + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::SoftRequirement + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + Ok(Some(Arc::new(Self { + signature: self.signature.clone(), + is_input_pre_ordered: beneficial_ordering, + }))) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let key_type = acc_args.expr_fields[0].data_type().clone(); + let value_type = acc_args.expr_fields[1].data_type().clone(); + + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return Ok(Box::new(MapAggAccumulator::new(key_type, value_type))); + }; + + let ordering_dtypes = ordering + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; + + Ok(Box::new(OrderSensitiveMapAggAccumulator::new( + key_type, + value_type, + ordering_dtypes, + ordering, + self.is_input_pre_ordered, + ))) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn map_type(key_type: &DataType, value_type: &DataType) -> DataType { + let fields = Fields::from(vec![ + Field::new("key", key_type.clone(), false), + Field::new("value", value_type.clone(), true), + ]); + let entries_field = Arc::new(Field::new("entries", DataType::Struct(fields), false)); + DataType::Map(entries_field, false) +} + +fn build_single_map( + keys: Vec, + values: Vec, + key_type: &DataType, + value_type: &DataType, +) -> Result { + debug_assert_eq!(keys.len(), values.len()); + + let fields = Fields::from(vec![ + Field::new("key", key_type.clone(), false), + Field::new("value", value_type.clone(), true), + ]); + let entries_field = Arc::new(Field::new( + "entries", + DataType::Struct(fields.clone()), + false, + )); + + let len = keys.len(); + let key_array = if len == 0 { + arrow::array::new_empty_array(key_type) + } else { + ScalarValue::iter_to_array(keys)? + }; + let value_array = if len == 0 { + arrow::array::new_empty_array(value_type) + } else { + ScalarValue::iter_to_array(values)? + }; + + let entries = StructArray::try_new(fields, vec![key_array, value_array], None)?; + + // With no entries, emit a single NULL map + let nulls = (len == 0).then(|| NullBuffer::new_null(1)); + + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0i32, len as i32])); + Ok(Arc::new(MapArray::try_new( + entries_field, + offsets, + entries, + nulls, + false, + )?)) +} + +/// Plain accumulator used when there is no `ORDER BY`. +#[derive(Debug)] +pub struct MapAggAccumulator { + key_type: DataType, + value_type: DataType, + keys: Vec, + values: Vec, +} + +impl MapAggAccumulator { + pub fn new(key_type: DataType, value_type: DataType) -> Self { + Self { + key_type, + value_type, + keys: Vec::new(), + values: Vec::new(), + } + } + + /// De-duplicates parallel key/value vectors, keeping the first value seen + /// for each key. Surviving keys retain their first-seen order. + fn dedup_first_wins( + keys: Vec, + values: Vec, + ) -> (Vec, Vec) { + // First pass: mark each position that is the first occurrence of its key. + let mut seen = HashSet::with_capacity(keys.len()); + let keep: Vec = keys.iter().map(|k| seen.insert(k)).collect(); + + // Second pass: keep only the first-occurrence positions. + let out_keys = keys + .into_iter() + .zip(&keep) + .filter_map(|(k, &keep)| keep.then_some(k)) + .collect(); + let out_values = values + .into_iter() + .zip(&keep) + .filter_map(|(v, &keep)| keep.then_some(v)) + .collect(); + (out_keys, out_values) + } +} + +impl Accumulator for MapAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.len() != 2 { + return exec_err!("map_agg expects 2 columns, got {}", values.len()); + } + let keys = &values[0]; + let vals = &values[1]; + + for i in 0..keys.len() { + // NULL keys cannot exist in a map; skip the whole pair. + if keys.is_null(i) { + continue; + } + self.keys + .push(ScalarValue::try_from_array(keys, i)?.compacted()); + self.values + .push(ScalarValue::try_from_array(vals, i)?.compacted()); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let map_array = states[0].as_map(); + for row in 0..map_array.len() { + if map_array.is_null(row) { + continue; + } + let (keys, values) = map_row_to_scalars(map_array, row)?; + self.keys.extend(keys); + self.values.extend(values); + } + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&mut self) -> Result { + let (keys, values) = + Self::dedup_first_wins(self.keys.clone(), self.values.clone()); + let map_array = build_single_map(keys, values, &self.key_type, &self.value_type)?; + ScalarValue::try_from_array(&map_array, 0) + } + + fn size(&self) -> usize { + size_of_val(self) + ScalarValue::size_of_vec(&self.keys) - size_of_val(&self.keys) + + ScalarValue::size_of_vec(&self.values) + - size_of_val(&self.values) + + self.key_type.size() + - size_of_val(&self.key_type) + + self.value_type.size() + - size_of_val(&self.value_type) + } +} + +struct OrderSensitiveMapAggRows { + keys: Vec, + values: Vec, + ordering_values: Vec>, +} + +/// Accumulator used when `map_agg` has an `ORDER BY`. Stores the ordering column +/// values alongside each pair so the input can be globally sorted (across +/// partitions). +#[derive(Debug)] +pub struct OrderSensitiveMapAggAccumulator { + key_type: DataType, + value_type: DataType, + keys: Vec, + values: Vec, + /// Ordering-expression values for each pair, parallel to `keys`/`values`. + ordering_values: Vec>, + ordering_dtypes: Vec, + ordering_req: LexOrdering, + is_input_pre_ordered: bool, +} + +impl OrderSensitiveMapAggAccumulator { + pub fn new( + key_type: DataType, + value_type: DataType, + ordering_dtypes: Vec, + ordering_req: LexOrdering, + is_input_pre_ordered: bool, + ) -> Self { + Self { + key_type, + value_type, + keys: Vec::new(), + values: Vec::new(), + ordering_values: Vec::new(), + ordering_dtypes, + ordering_req, + is_input_pre_ordered, + } + } + + fn sort_options(&self) -> Vec { + self.ordering_req.iter().map(|s| s.options).collect() + } + + /// Sorts the accumulated pairs by their ordering values, then applies + /// first-wins de-duplication. Returns the surviving keys, values, and + /// ordering values, all aligned so they describe the same rows. + fn sorted_deduped(&self) -> Result { + let mut rows: Vec = (0..self.keys.len()).collect(); + + if !self.is_input_pre_ordered { + let sort_options = self.sort_options(); + let mut cmp_err = Ok(()); + rows.sort_by(|&a, &b| { + compare_rows( + &self.ordering_values[a], + &self.ordering_values[b], + &sort_options, + ) + .unwrap_or_else(|e| { + cmp_err = Err(e); + std::cmp::Ordering::Equal + }) + }); + cmp_err?; + } + + // Keep the first occurrence of each key in sorted order, and project the + // keys, values, and ordering values through the same surviving indices + // so all three stay aligned. + let mut seen = HashSet::with_capacity(rows.len()); + let mut keys = Vec::new(); + let mut values = Vec::new(); + let mut ordering_values = Vec::new(); + for &i in &rows { + if seen.insert(&self.keys[i]) { + keys.push(self.keys[i].clone()); + values.push(self.values[i].clone()); + ordering_values.push(self.ordering_values[i].clone()); + } + } + + Ok(OrderSensitiveMapAggRows { + keys, + values, + ordering_values, + }) + } + + /// Builds the `List>` state column from the given + /// ordering values. These must be the de-duplicated ordering values that + /// align with the map state, so both pieces of state describe the same rows. + fn evaluate_orderings( + &self, + ordering_values: &[Vec], + ) -> Result { + let fields = ordering_fields(&self.ordering_req, &self.ordering_dtypes); + let struct_field = Fields::from(fields.clone()); + + let mut column_wise: Vec = Vec::with_capacity(fields.len()); + for (col_idx, field) in fields.iter().enumerate() { + if ordering_values.is_empty() { + column_wise.push(arrow::array::new_empty_array(field.data_type())); + } else { + let col_vals = ordering_values.iter().map(|row| row[col_idx].clone()); + column_wise.push(ScalarValue::iter_to_array(col_vals)?); + } + } + + let struct_array = StructArray::try_new(struct_field, column_wise, None)?; + Ok(SingleRowListArrayBuilder::new(Arc::new(struct_array)).build_list_scalar()) + } +} + +impl Accumulator for OrderSensitiveMapAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.len() < 2 { + return exec_err!("map_agg expects at least 2 columns, got {}", values.len()); + } + let keys = &values[0]; + let vals = &values[1]; + let ordering_cols = &values[2..]; + + for i in 0..keys.len() { + // NULL keys cannot exist in a map; skip the whole pair. + if keys.is_null(i) { + continue; + } + self.keys + .push(ScalarValue::try_from_array(keys, i)?.compacted()); + self.values + .push(ScalarValue::try_from_array(vals, i)?.compacted()); + self.ordering_values.push( + get_row_at_idx(ordering_cols, i)? + .into_iter() + .map(|v| v.compacted()) + .collect(), + ); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + if states.len() != 2 { + return exec_err!( + "map_agg ordered merge expects 2 state columns, got {}", + states.len() + ); + } + + let map_array = states[0].as_map(); + let orderings = states[1].as_list::(); + + // Each partition contributes one map row plus a parallel list of + // ordering-value structs. Collect them, then merge by ordering. + let mut partition_keys: Vec> = + vec![take(&mut self.keys).into()]; + let mut partition_orderings: Vec>> = + vec![take(&mut self.ordering_values).into()]; + let mut partition_values: Vec> = + vec![take(&mut self.values).into()]; + + // Push keys and values from each partition's state into the merge buffers. + for row in 0..map_array.len() { + if map_array.is_null(row) { + continue; + } + let (keys, values) = map_row_to_scalars(map_array, row)?; + let ord_vals = struct_to_rows(orderings.value(row).as_struct())?; + + partition_keys.push(keys.into()); + partition_values.push(values.into()); + partition_orderings.push(ord_vals.into()); + } + + // Merge keys and values along the ordering. `merge_ordered_arrays` + // merges a single value stream; run it once for keys and once for + // values using the same ordering inputs so they stay aligned. + let sort_options = self.sort_options(); + + let (merged_keys, merged_orderings) = merge_ordered_arrays( + &mut partition_keys, + &mut partition_orderings.clone(), + &sort_options, + )?; + let (merged_values, _) = merge_ordered_arrays( + &mut partition_values, + &mut partition_orderings, + &sort_options, + )?; + + self.keys = merged_keys; + self.values = merged_values; + self.ordering_values = merged_orderings; + Ok(()) + } + + fn state(&mut self) -> Result> { + let rows = self.sorted_deduped()?; + let orderings = self.evaluate_orderings(&rows.ordering_values)?; + let map_array = + build_single_map(rows.keys, rows.values, &self.key_type, &self.value_type)?; + Ok(vec![ScalarValue::try_from_array(&map_array, 0)?, orderings]) + } + + fn evaluate(&mut self) -> Result { + let rows = self.sorted_deduped()?; + let map_array = + build_single_map(rows.keys, rows.values, &self.key_type, &self.value_type)?; + ScalarValue::try_from_array(&map_array, 0) + } + + fn size(&self) -> usize { + let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.keys) + - size_of_val(&self.keys) + + ScalarValue::size_of_vec(&self.values) + - size_of_val(&self.values) + + self.key_type.size() + - size_of_val(&self.key_type) + + self.value_type.size() + - size_of_val(&self.value_type); + + // ordering_values: Vec spine plus the heap owned by each row's scalars. + total += size_of::>() * self.ordering_values.capacity(); + for row in &self.ordering_values { + total += ScalarValue::size_of_vec(row) - size_of_val(row); + } + + // ordering_dtypes: Vec spine plus the heap owned by each DataType. + total += size_of::() * self.ordering_dtypes.capacity(); + for dtype in &self.ordering_dtypes { + total += dtype.size() - size_of_val(dtype); + } + + total + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{Int32Type, Schema}; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + + fn make_acc() -> MapAggAccumulator { + MapAggAccumulator::new(DataType::Utf8, DataType::Int32) + } + + fn str_arr(vals: &[&str]) -> ArrayRef { + Arc::new(StringArray::from(vals.to_vec())) + } + + fn int_arr(vals: &[i32]) -> ArrayRef { + Arc::new(Int32Array::from(vals.to_vec())) + } + + fn str_arr_nullable(vals: Vec>) -> ArrayRef { + Arc::new(StringArray::from(vals)) + } + + fn extract_map(sv: ScalarValue) -> Vec<(String, Option)> { + let ScalarValue::Map(arr) = sv else { + panic!("expected ScalarValue::Map, got {sv:?}"); + }; + let entries = arr.value(0); + let entries = entries.as_any().downcast_ref::().unwrap(); + let keys = entries.column(0).as_string::(); + let vals = entries.column(1).as_primitive::(); + (0..keys.len()) + .map(|i| { + let v = if vals.is_null(i) { + None + } else { + Some(vals.value(i)) + }; + (keys.value(i).to_string(), v) + }) + .collect() + } + + #[test] + fn collects_distinct_pairs_in_order() -> Result<()> { + let mut acc = make_acc(); + acc.update_batch(&[str_arr(&["a", "b", "c"]), int_arr(&[1, 2, 3])])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!( + pairs, + vec![ + ("a".into(), Some(1)), + ("b".into(), Some(2)), + ("c".into(), Some(3)), + ] + ); + Ok(()) + } + + #[test] + fn null_key_skipped() -> Result<()> { + let mut acc = make_acc(); + // A null-key pair is dropped (matches Trino map_agg). + acc.update_batch(&[ + str_arr_nullable(vec![Some("a"), None, Some("c")]), + int_arr(&[1, 2, 3]), + ])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!(pairs.len(), 2); + assert_eq!(pairs[0].0, "a"); + assert_eq!(pairs[1].0, "c"); + Ok(()) + } + + #[test] + fn null_value_retained() -> Result<()> { + let mut acc = make_acc(); + let vals: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); + acc.update_batch(&[str_arr(&["a", "b", "c"]), vals])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!(pairs[1], ("b".into(), None)); + Ok(()) + } + + #[test] + fn duplicate_key_first_wins() -> Result<()> { + let mut acc = make_acc(); + // a appears twice; the first value (1) wins, later one is dropped. + acc.update_batch(&[str_arr(&["a", "b", "a"]), int_arr(&[1, 2, 3])])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!(pairs, vec![("a".into(), Some(1)), ("b".into(), Some(2))]); + Ok(()) + } + + #[test] + fn empty_produces_null_map() -> Result<()> { + let mut acc = make_acc(); + let ScalarValue::Map(arr) = acc.evaluate()? else { + panic!("expected ScalarValue::Map"); + }; + assert_eq!(arr.len(), 1); + assert!(arr.is_null(0)); + Ok(()) + } + + #[test] + fn merge_two_partitions() -> Result<()> { + let mut acc1 = make_acc(); + let mut acc2 = make_acc(); + + acc1.update_batch(&[str_arr(&["a", "b"]), int_arr(&[1, 2])])?; + acc2.update_batch(&[str_arr(&["c", "d"]), int_arr(&[3, 4])])?; + + let state2 = acc2 + .state()? + .into_iter() + .map(|sv| sv.to_array()) + .collect::>>()?; + acc1.merge_batch(&state2)?; + + let pairs = extract_map(acc1.evaluate()?); + assert_eq!(pairs.len(), 4); + Ok(()) + } + + #[test] + fn merge_duplicate_key_across_partitions_first_wins() -> Result<()> { + let mut acc1 = make_acc(); + let mut acc2 = make_acc(); + + acc1.update_batch(&[str_arr(&["a"]), int_arr(&[1])])?; + acc2.update_batch(&[str_arr(&["a"]), int_arr(&[2])])?; + + let state2 = acc2 + .state()? + .into_iter() + .map(|sv| sv.to_array()) + .collect::>>()?; + acc1.merge_batch(&state2)?; + + let pairs = extract_map(acc1.evaluate()?); + // acc1's pair comes first in the merged stream, so value 1 wins. + assert_eq!(pairs, vec![("a".into(), Some(1))]); + Ok(()) + } + + /// Builds an order-sensitive accumulator that orders by a single Int32 + /// column `ord` (column index 2 in the update batch). + fn make_ordered_acc(descending: bool) -> OrderSensitiveMapAggAccumulator { + let schema = Schema::new(vec![ + Field::new("k", DataType::Utf8, true), + Field::new("v", DataType::Int32, true), + Field::new("ord", DataType::Int32, true), + ]); + let ord_expr = Arc::new(Column::new_with_schema("ord", &schema).unwrap()); + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new( + ord_expr, + SortOptions::new(descending, false), + )]) + .unwrap(); + OrderSensitiveMapAggAccumulator::new( + DataType::Utf8, + DataType::Int32, + vec![DataType::Int32], + ordering, + false, + ) + } + + #[test] + fn ordered_dup_key_asc_first_wins() -> Result<()> { + let mut acc = make_ordered_acc(false); + // key "a" twice: ord=10 -> v=1, ord=20 -> v=2. ASC sort puts v=1 + // (ord=10) first, so first-wins keeps v=1. + acc.update_batch(&[str_arr(&["a", "a"]), int_arr(&[1, 2]), int_arr(&[10, 20])])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!(pairs, vec![("a".into(), Some(1))]); + Ok(()) + } + + #[test] + fn ordered_dup_key_desc_flips_winner() -> Result<()> { + let mut acc = make_ordered_acc(true); + // Same input, DESC sort puts v=2 (ord=20) first, so first-wins keeps v=2. + acc.update_batch(&[str_arr(&["a", "a"]), int_arr(&[1, 2]), int_arr(&[10, 20])])?; + let pairs = extract_map(acc.evaluate()?); + assert_eq!(pairs, vec![("a".into(), Some(2))]); + Ok(()) + } + + #[test] + fn ordered_merge_two_partitions() -> Result<()> { + // Partition 1 sees ord=20, partition 2 sees ord=10 for the same key. + // After merge + ASC sort, the ord=10 row is first, so its value wins. + let mut acc1 = make_ordered_acc(false); + let mut acc2 = make_ordered_acc(false); + + acc1.update_batch(&[str_arr(&["a"]), int_arr(&[2]), int_arr(&[20])])?; + acc2.update_batch(&[str_arr(&["a"]), int_arr(&[1]), int_arr(&[10])])?; + + let state2 = acc2 + .state()? + .into_iter() + .map(|sv| sv.to_array()) + .collect::>>()?; + acc1.merge_batch(&state2)?; + + let pairs = extract_map(acc1.evaluate()?); + assert_eq!(pairs, vec![("a".into(), Some(1))]); + Ok(()) + } + + #[test] + fn ordered_merge_with_intra_partition_duplicates() -> Result<()> { + let mut acc1 = make_ordered_acc(false); + let mut acc2 = make_ordered_acc(false); + + // P1: a -> {1@10, 3@30}, b -> {2@20} => after dedup: a=1, b=2 + acc1.update_batch(&[ + str_arr(&["a", "b", "a"]), + int_arr(&[1, 2, 3]), + int_arr(&[10, 20, 30]), + ])?; + // P2: a -> {4@40, 5@50}, c -> {6@60} => after dedup: a=4, c=6 + acc2.update_batch(&[ + str_arr(&["a", "a", "c"]), + int_arr(&[4, 5, 6]), + int_arr(&[40, 50, 60]), + ])?; + + let state2 = acc2 + .state()? + .into_iter() + .map(|sv| sv.to_array()) + .collect::>>()?; + acc1.merge_batch(&state2)?; + + let mut pairs = extract_map(acc1.evaluate()?); + pairs.sort(); + // Global first-wins by ord: a@10=1, b@20=2, c@60=6. + assert_eq!( + pairs, + vec![ + ("a".into(), Some(1)), + ("b".into(), Some(2)), + ("c".into(), Some(6)), + ] + ); + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/utils.rs b/datafusion/functions-aggregate/src/utils.rs index 6d816e54bdaf2..fd5fd58fa70d8 100644 --- a/datafusion/functions-aggregate/src/utils.rs +++ b/datafusion/functions-aggregate/src/utils.rs @@ -17,9 +17,11 @@ use std::sync::Arc; -use arrow::array::RecordBatch; +use arrow::array::{Array, MapArray, RecordBatch, StructArray}; use arrow::datatypes::Schema; -use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err, plan_err}; +use datafusion_common::{ + DataFusionError, Result, ScalarValue, internal_datafusion_err, internal_err, plan_err, +}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -75,3 +77,38 @@ pub(crate) fn validate_percentile_expr( } Ok(percentile) } + +/// Reads the key/value scalars out of one row of a `MapArray`. +pub(crate) fn map_row_to_scalars( + map_array: &MapArray, + row: usize, +) -> Result<(Vec, Vec)> { + let entries = map_array.value(row); + let entries = entries + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("map entries must be a StructArray"))?; + let key_col = entries.column(0); + let val_col = entries.column(1); + + let mut keys = Vec::with_capacity(key_col.len()); + let mut values = Vec::with_capacity(val_col.len()); + for i in 0..key_col.len() { + keys.push(ScalarValue::try_from_array(key_col, i)?); + values.push(ScalarValue::try_from_array(val_col, i)?); + } + Ok((keys, values)) +} + +/// Converts a `StructArray` into per-row scalar tuples: `rows[i][c]` is column +/// `c` of row `i`. +pub(crate) fn struct_to_rows(s: &StructArray) -> Result>> { + (0..s.len()) + .map(|i| { + s.columns() + .iter() + .map(|col| ScalarValue::try_from_array(col, i)) + .collect() + }) + .collect() +} diff --git a/datafusion/sqllogictest/test_files/map_agg.slt b/datafusion/sqllogictest/test_files/map_agg.slt new file mode 100644 index 0000000000000..ac0d4448f761c --- /dev/null +++ b/datafusion/sqllogictest/test_files/map_agg.slt @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE kv (g INT, k VARCHAR, v INT, ts INT) AS VALUES + (1, 'a', 1, 10), + (1, 'b', 2, 20), + (2, 'c', 3, 30), + (2, 'a', 4, 40); + +# Basic grouped map_agg +query I? +SELECT g, map_agg(k, v) FROM kv GROUP BY g ORDER BY g; +---- +1 {a: 1, b: 2} +2 {c: 3, a: 4} + +# Duplicate key within a group: first value wins. ORDER BY decides which row is +# first, hence which value survives. +statement ok +CREATE TABLE dup (k VARCHAR, v INT, ts INT) AS VALUES + ('a', 1, 10), + ('a', 2, 20); + +# ORDER BY ts ASC -> v=1 (ts=10) is first -> wins +query ? +SELECT map_agg(k, v ORDER BY ts ASC) FROM dup; +---- +{a: 1} + +# ORDER BY ts DESC -> v=2 (ts=20) is first -> wins +query ? +SELECT map_agg(k, v ORDER BY ts DESC) FROM dup; +---- +{a: 2} + +# Grouped duplicate keys with ORDER BY: each group de-duplicates its map, so the +# partial ordering state must stay aligned with the de-duplicated map across the +# partial -> final merge. ASC keeps the lowest-ts value per key. +statement ok +CREATE TABLE dup_grouped (g INT, k VARCHAR, v INT, ts INT) AS VALUES + (1, 'a', 1, 10), + (1, 'a', 2, 20), + (1, 'b', 3, 30), + (2, 'c', 4, 40), + (2, 'c', 5, 50); + +query I? +SELECT g, map_agg(k, v ORDER BY ts ASC) FROM dup_grouped GROUP BY g ORDER BY g; +---- +1 {a: 1, b: 3} +2 {c: 4} + +# Pair with a NULL key is skipped +statement ok +CREATE TABLE nullkey (k VARCHAR, v INT) AS VALUES + ('a', 1), + (NULL, 2), + ('c', 3); + +query ? +SELECT map_agg(k, v) FROM nullkey; +---- +{a: 1, c: 3} + +# Empty input: a global aggregate over zero rows returns a single NULL map. +query ? +SELECT map_agg(k, v) FROM nullkey WHERE false; +---- +NULL + +# Window usage: a cumulative frame calls evaluate() once per row, so the map +# grows as the frame expands. Guards against optimizations that consume the +# accumulator state between evaluate() calls. +statement ok +CREATE TABLE win (k VARCHAR, v INT, ts INT) AS VALUES + ('a', 1, 10), + ('b', 2, 20), + ('c', 3, 30); + +query ? +SELECT map_agg(k, v) OVER (ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +FROM win; +---- +{a: 1} +{a: 1, b: 2} +{a: 1, b: 2, c: 3} + +statement ok +DROP TABLE win; + +statement ok +DROP TABLE kv; + +statement ok +DROP TABLE dup; + +statement ok +DROP TABLE dup_grouped; + +statement ok +DROP TABLE nullkey; diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index ba9c6ae12477b..1aa8e9a971534 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -91,6 +91,7 @@ SELECT SUM(x) WITHIN GROUP (ORDER BY x) FROM t; - [first_value](#first_value) - [grouping](#grouping) - [last_value](#last_value) +- [map_agg](#map_agg) - [max](#max) - [mean](#mean) - [median](#median) @@ -347,6 +348,30 @@ last_value(expression [ORDER BY expression]) +-----------------------------------------------+ ``` +### `map_agg` + +Aggregate key-value pairs from two columns into a single map per group. Pairs with a NULL key are skipped; NULL values are retained. On a duplicate key the first value wins; use ORDER BY to make which value wins deterministic. + +```sql +map_agg(key, value [ORDER BY expression]) +``` + +#### Arguments + +- **key**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **value**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT map_agg(name, score) FROM scores GROUP BY department; ++-------------------------------+ +| map_agg(name, score) | ++-------------------------------+ +| {Alice: 95, Bob: 87} | ++-------------------------------+ +``` + ### `max` Returns the maximum value in the specified column.