diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index ee253e5d7afdd..3113e55c5a1b0 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -35,6 +35,7 @@ pub use row::GroupValuesRows; mod single_group_by; use datafusion_physical_expr::binary_map::OutputType; use multi_group_by::GroupValuesColumn; +use multi_group_by::dict::{GroupDictionaryColumn, supported_dictionary_schema}; pub(crate) use single_group_by::primitive::HashValue; @@ -200,6 +201,11 @@ pub fn new_group_values( } } + // Route 2+ all-dictionary columns to the specialised implementation. + if schema.fields().len() >= 2 && supported_dictionary_schema(schema.as_ref()) { + return Ok(Box::new(GroupDictionaryColumn::new(schema)?)); + } + if multi_group_by::supported_schema(schema.as_ref()) { if matches!(group_ordering, GroupOrdering::None) { Ok(Box::new(GroupValuesColumn::::try_new(schema)?)) diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs new file mode 100644 index 0000000000000..773823d3e9d45 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs @@ -0,0 +1,948 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, DictionaryArray, Int8Array, Int16Array, Int32Array, + Int64Array, ListBuilder, NullArray, StringBuilder, UInt8Array, UInt16Array, + UInt32Array, UInt64Array, new_empty_array, +}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, Schema, + SchemaRef, UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; +use arrow::downcast_dictionary_array; +use datafusion_common::hash_utils::{RandomState, combine_hashes, create_hashes}; +use datafusion_common::{Result, internal_datafusion_err}; +use datafusion_execution::memory_pool::proxy::HashTableAllocExt; +use datafusion_expr::EmitTo; +use hashbrown::HashMap as HashbrownMap; +use hashbrown::hash_table::HashTable; + +use crate::aggregates::group_values::GroupValues; + +/// Caches the hashes for one dictionary column's values array. +/// Rebuilt only when the `Arc` pointer changes (i.e. a new values array arrives). +struct ColumnCache { + ///compared with `Arc::ptr_eq` to detect staleness. + values: ArrayRef, + /// `value_hashes[i]` = hash of the value at dictionary index `i`. + value_hashes: Vec, +} + +impl ColumnCache { + fn empty() -> Self { + Self { + values: Arc::new(NullArray::new(0)), + value_hashes: vec![], + } + } + + fn update(&mut self, new_values: ArrayRef, random_state: &RandomState) -> Result<()> { + if Arc::ptr_eq(&new_values, &self.values) { + return Ok(()); + } + let num_values = new_values.len(); + // Reuse the allocation; only grows capacity when a larger values array arrives. + self.value_hashes.clear(); + self.value_hashes.resize(num_values, 0u64); + create_hashes( + std::slice::from_ref(&new_values), + random_state, + &mut self.value_hashes, + )?; + self.values = new_values; + Ok(()) + } + + fn size(&self) -> usize { + self.value_hashes.capacity() * size_of::() + + self.values.get_array_memory_size() + } + + fn clear_shrink(&mut self, shrink_to: usize) { + self.values = Arc::new(new_empty_array(self.values.data_type())); + self.value_hashes.clear(); + self.value_hashes.shrink_to(shrink_to); + } +} + +/// [`GroupValues`] for GROUP BY over **two or more** dictionary-typed columns. +pub struct GroupDictionaryColumn { + schema: SchemaRef, + /// Per-column dictionary caches, one entry per GROUP BY column. + col_caches: Vec, + /// `(row_hash, group_id)`. Multiple entries may share the same hash value; + /// byte-level comparison is used to resolve collisions. + map: HashTable<(u64, usize)>, + /// Tracked allocation size of `map` in bytes, updated on every insert and shrink. + map_size: usize, + /// All group rows packed back-to-back into a single contiguous buffer. + /// + /// CSR-style layout: `row_offsets[g]` is the start of group `g` and + /// `row_offsets[g+1]` is its end. The last group has no `g+1` entry; its + /// end is `row_buffer.len()`. + row_buffer: Vec, + /// `row_offsets[g]` = start byte of group `g` inside `row_buffer`. + row_offsets: Vec, + /// Reused scratch buffer for encoding the current row. + row_scratch: Vec, + /// Converts row-encoded bytes back into Arrow arrays for [`GroupValues::emit`]. + row_decoder: RowSetDecoder, + random_state: RandomState, + /// Maps `(k1, k2, ..., kN)` dictionary key tuples to their group id. + /// Cleared at the start of every `intern` call (not stable across calls). + /// Bypassed entirely once it exceeds 10 000 entries. + key_tuple_cache: HashbrownMap]>, usize>, + /// Scratch buffer for building the current row's key tuple; reused each row. + key_tuple_scratch: Vec>, +} + +pub fn supported_dictionary_schema(schema: &Schema) -> bool { + schema.fields().iter().all(|field| { + if let DataType::Dictionary(_, value_type) = field.data_type() { + matches!(value_type.as_ref(), DataType::Utf8) + || matches!(value_type.as_ref(), DataType::List(list_field) if list_field.data_type() == &DataType::Utf8) + } else { + false + } + }) +} + +impl GroupDictionaryColumn { + pub fn new(schema: SchemaRef) -> Result { + let n_cols = schema.fields().len(); + if n_cols < 2 { + return Err(internal_datafusion_err!( + "GroupDictionaryColumn requires at least 2 columns, got {}", + n_cols + )); + } + let row_decoder = RowSetDecoder::new(&schema); + Ok(Self { + schema, + col_caches: (0..n_cols).map(|_| ColumnCache::empty()).collect(), + map: HashTable::with_capacity(128), + map_size: 0, + row_buffer: Vec::new(), + row_offsets: Vec::new(), + row_scratch: Vec::new(), + row_decoder, + random_state: crate::aggregates::AGGREGATION_HASH_SEED, + key_tuple_cache: HashbrownMap::new(), + key_tuple_scratch: Vec::with_capacity(n_cols), + }) + } +} + +fn dict_values_array(col: &dyn Array) -> ArrayRef { + downcast_dictionary_array!( + col => Arc::clone(col.values()), + _ => unreachable!("schema validated in GroupDictionaryColumn::new") + ) +} + +fn fill_keys(col: &dyn Array) -> Box> + '_> { + downcast_dictionary_array!( + col => { + let keys = col.keys(); + Box::new((0..keys.len()).map(move |row_idx| { + if keys.is_valid(row_idx) { + Some(keys.value(row_idx).as_usize()) + } else { + None + } + })) + }, + _ => unreachable!("schema validated in GroupDictionaryColumn::new") + ) +} + +impl GroupValues for GroupDictionaryColumn { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + debug_assert_eq!(cols.len(), self.schema.fields().len()); + groups.clear(); + + if cols.is_empty() || cols[0].is_empty() { + return Ok(()); + } + let n_rows = cols[0].len(); + + // Downcast once per column; advance with .next() per row to avoid per-row downcast. + let mut key_iters: Vec<_> = cols + .iter() + .enumerate() + .map(|(col_idx, col)| { + self.col_caches[col_idx] + .update(dict_values_array(col.as_ref()), &self.random_state)?; + Ok(fill_keys(col.as_ref())) + }) + .collect::>()?; + self.key_tuple_cache.clear(); + self.key_tuple_scratch.clear(); + + groups.try_reserve(n_rows).map_err(|e| { + datafusion_common::DataFusionError::ArrowError( + Box::new(arrow::error::ArrowError::MemoryError(e.to_string())), + None, + ) + })?; + for _row in 0..n_rows { + self.key_tuple_scratch.clear(); + for key_iter in key_iters.iter_mut() { + self.key_tuple_scratch.push(key_iter.next().unwrap()); + } + + // Fast path: key-tuple cache lookup (bypassed once the cache exceeds the threshold). + let use_cache = self.key_tuple_cache.len() <= 1000; + if use_cache + && let Some(&group_id) = + self.key_tuple_cache.get(self.key_tuple_scratch.as_slice()) + { + groups.push(group_id); + continue; + } + + let mut combined_hash = 0u64; + self.row_scratch.clear(); + for (col_idx, &key) in self.key_tuple_scratch.iter().enumerate() { + let cache = &self.col_caches[col_idx]; + let value_hash = key.map_or(0, |key_idx| cache.value_hashes[key_idx]); + combined_hash = combine_hashes(combined_hash, value_hash); + encode_value(key, cache.values.as_ref(), &mut self.row_scratch); + } + + let found = { + let row_scratch = self.row_scratch.as_slice(); + let row_buffer = self.row_buffer.as_slice(); + let row_offsets = self.row_offsets.as_slice(); + self.map + .find(combined_hash, |&(stored_hash, stored_group_id)| { + stored_hash == combined_hash && { + let end = row_offsets + .get(stored_group_id + 1) + .copied() + .unwrap_or(row_buffer.len()); + row_buffer[row_offsets[stored_group_id]..end] == *row_scratch + } + }) + .map(|&(_, stored_group_id)| stored_group_id) + }; + + let group_id = match found { + Some(existing_id) => existing_id, + None => { + let new_id = self.row_offsets.len(); + self.row_offsets.push(self.row_buffer.len()); + self.row_buffer.extend_from_slice(&self.row_scratch); + self.map.insert_accounted( + (combined_hash, new_id), + |(stored_hash, _)| *stored_hash, + &mut self.map_size, + ); + new_id + } + }; + + if use_cache { + self.key_tuple_cache + .insert(Box::from(self.key_tuple_scratch.as_slice()), group_id); + } + + groups.push(group_id); + } + + self.key_tuple_cache.clear(); + self.key_tuple_scratch.clear(); + Ok(()) + } + + fn size(&self) -> usize { + let cache_bytes: usize = self + .col_caches + .iter() + .map(|col_cache| col_cache.size()) + .sum(); + self.map_size + + self.row_buffer.capacity() + + self.row_offsets.capacity() * size_of::() + + self.row_scratch.capacity() + + self.key_tuple_scratch.capacity() * size_of::>() + + self.key_tuple_cache.capacity() + * (size_of::]>>() + size_of::()) + + cache_bytes + + self.row_decoder.size() + } + + fn is_empty(&self) -> bool { + self.row_offsets.is_empty() + } + + fn len(&self) -> usize { + self.row_offsets.len() + } + + // the entire emit path is not optimized + /// groups + /// 0: (a, x) + /// 1: (a, y) + /// 2: (b, x) + /// 3: (b, y) + /// + /// logically colun 1 :[a, a, b, b] + /// currently we do + /// keys: [0, 1, 2, 3] + /// values: [a, a, b, b] + /// + /// which is correct but we should do + /// keys: [0, 0, 1, 1] + /// values: [a, b] + fn emit(&mut self, emit_to: EmitTo) -> Result> { + let n_total = self.row_offsets.len(); + if n_total == 0 { + return Ok(self.row_decoder.finish()); + } + let n_emit = match emit_to { + EmitTo::All => n_total, + EmitTo::First(first_n) => first_n.min(n_total), + }; + + for row_idx in 0..n_emit { + let start = self.row_offsets[row_idx]; + let end = self + .row_offsets + .get(row_idx + 1) + .copied() + .unwrap_or(self.row_buffer.len()); + self.row_decoder.decode(&self.row_buffer[start..end]); + } + let inner = self.row_decoder.finish(); + let arrays: Vec = inner + .into_iter() + .zip(self.schema.fields()) + .map(|(values, field)| match field.data_type() { + DataType::Dictionary(key_type, _) => Ok(wrap_as_dictionary( + values, + &make_sequential_keys(n_emit, key_type)?, + key_type, + )), + _ => unreachable!("schema validated in GroupDictionaryColumn::new"), + }) + .collect::>()?; + + if n_emit == n_total { + self.row_buffer.clear(); + self.row_offsets.clear(); + self.map.clear(); + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); + } else { + let retain_start = self.row_offsets[n_emit]; + self.row_offsets.drain(0..n_emit); + self.row_offsets + .iter_mut() + .for_each(|offset| *offset -= retain_start); + self.row_buffer.drain(0..retain_start); + // Shift remaining group ids in-place; retain gives &mut access so no rehashing occurs. + self.map.retain(|(_, group_id)| { + if *group_id < n_emit { + return false; + } + *group_id -= n_emit; + true + }); + } + + Ok(arrays) + } + + fn clear_shrink(&mut self, num_rows: usize) { + self.map.clear(); + self.map.shrink_to(num_rows, |_| 0); + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); + self.row_buffer.clear(); + self.row_buffer.shrink_to(num_rows); + self.row_offsets.clear(); + self.row_offsets.shrink_to(num_rows); + self.row_scratch.clear(); + self.row_scratch.shrink_to(0); + self.key_tuple_cache.clear(); + self.key_tuple_scratch.clear(); + self.row_decoder.finish(); + for cache in &mut self.col_caches { + cache.clear_shrink(num_rows); + } + } +} + +/// Wire format per column — scalars are the primitive unit; non-scalars repeat them with metadata: +/// +/// **Scalar** (`Utf8`): +/// - null: `[0x00]` +/// - non-null: `[0x01][len: u64 LE][utf8_bytes…]` +/// +/// **Non-scalar** (`List`): a length-prefixed sequence of scalars: +/// - null: `[0x00]` +/// - non-null: `[0x01][content_len: u64 LE][n: u64 LE][scalar…]` +/// where each `scalar` follows the scalar encoding above +fn encode_value(key: Option, values: &dyn Array, buf: &mut Vec) { + let key_idx = match key { + None => { + buf.push(0); + return; + } + Some(key_index) => key_index, + }; + if values.is_null(key_idx) { + buf.push(0); + return; + } + buf.push(1); + match values.data_type() { + DataType::Utf8 => { + let bytes = values.as_string::().value(key_idx).as_bytes(); + buf.extend_from_slice(&(bytes.len() as u64).to_le_bytes()); + buf.extend_from_slice(bytes); + } + DataType::List(_) => { + // Back-fill content_len after encoding all elements. + let len_pos = buf.len(); + buf.extend_from_slice(&[0u8; 8]); + let content_start = buf.len(); + + let list_element = values.as_list::().value(key_idx); + let str_array = list_element.as_string::(); + buf.extend_from_slice(&(str_array.len() as u64).to_le_bytes()); + for elem_idx in 0..str_array.len() { + if str_array.is_null(elem_idx) { + buf.push(0); + } else { + let elem_bytes = str_array.value(elem_idx).as_bytes(); + buf.push(1); + buf.extend_from_slice(&(elem_bytes.len() as u64).to_le_bytes()); + buf.extend_from_slice(elem_bytes); + } + } + let content_len = (buf.len() - content_start) as u64; + buf[len_pos..len_pos + 8].copy_from_slice(&content_len.to_le_bytes()); + } + unsupported_type => { + panic!("unsupported dictionary value type: {unsupported_type}") + } + } +} + +#[derive(Debug)] +enum ColumnBuilder { + Utf8(StringBuilder), + ListUtf8(ListBuilder), +} + +impl ColumnBuilder { + fn from_value_type(value_type: &DataType) -> Self { + match value_type { + DataType::Utf8 => Self::Utf8(StringBuilder::new()), + DataType::List(_) => Self::ListUtf8(ListBuilder::new(StringBuilder::new())), + _ => unreachable!("value type validated in GroupDictionaryColumn::new"), + } + } + + fn append_null(&mut self) { + match self { + Self::Utf8(builder) => builder.append_null(), + Self::ListUtf8(builder) => builder.append_null(), + } + } + + fn append_bytes(&mut self, bytes: &[u8]) { + match self { + // SAFETY: bytes come from Arrow string arrays, always valid UTF-8. + Self::Utf8(builder) => { + builder.append_value(unsafe { std::str::from_utf8_unchecked(bytes) }) + } + Self::ListUtf8(builder) => { + let mut cursor = 0; + let n_elements = + u64::from_le_bytes(bytes[cursor..cursor + 8].try_into().unwrap()) + as usize; + cursor += 8; + for _ in 0..n_elements { + match bytes[cursor] { + 0 => { + builder.values().append_null(); + cursor += 1; + } + _ => { + cursor += 1; + let elem_len = u64::from_le_bytes( + bytes[cursor..cursor + 8].try_into().unwrap(), + ) as usize; + cursor += 8; + // SAFETY: bytes come from Arrow string arrays, always valid UTF-8. + builder.values().append_value(unsafe { + std::str::from_utf8_unchecked( + &bytes[cursor..cursor + elem_len], + ) + }); + cursor += elem_len; + } + } + } + builder.append(true); + } + } + } + + fn finish(&mut self) -> ArrayRef { + match self { + Self::Utf8(builder) => Arc::new(builder.finish()), + Self::ListUtf8(builder) => Arc::new(builder.finish()), + } + } +} + +/// Accumulates encoded row slices and reconstructs them into Arrow arrays on `finish`. +#[derive(Debug)] +struct RowSetDecoder { + builders: Vec, +} + +impl RowSetDecoder { + fn new(schema: &Schema) -> Self { + let builders = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Dictionary(_, value_type) => { + ColumnBuilder::from_value_type(value_type) + } + _ => unreachable!("schema validated in GroupDictionaryColumn::new"), + }) + .collect(); + Self { builders } + } + + /// Inverse of [`encode_value`]; see its doc comment for the wire format. + fn decode(&mut self, encoded: &[u8]) { + let mut cursor = 0; + for builder in &mut self.builders { + match encoded[cursor] { + 0 => { + builder.append_null(); + cursor += 1; + } + _ => { + cursor += 1; + let len = u64::from_le_bytes( + encoded[cursor..cursor + 8].try_into().unwrap(), + ) as usize; + cursor += 8; + builder.append_bytes(&encoded[cursor..cursor + len]); + cursor += len; + } + } + } + } + + fn finish(&mut self) -> Vec { + self.builders + .iter_mut() + .map(|builder| builder.finish()) + .collect() + } + + fn size(&self) -> usize { + self.builders.capacity() * size_of::() + } +} + +/// Build sequential keys `[0, 1, ..., n-1]` for the given key type. +macro_rules! make_keys { + ($n:expr, $ArrayType:ty, $max:expr) => {{ + if $n > $max { + return Err(internal_datafusion_err!( + "too many groups ({}) for dictionary key type with max capacity {}", + $n, + $max + )); + } + Ok( + Arc::new(<$ArrayType>::from_iter_values((0..$n).map(|idx| idx as _))) + as ArrayRef, + ) + }}; +} + +fn make_sequential_keys(n: usize, key_type: &DataType) -> Result { + match key_type { + DataType::Int8 => make_keys!(n, Int8Array, i8::MAX as usize), + DataType::Int16 => make_keys!(n, Int16Array, i16::MAX as usize), + DataType::Int32 => make_keys!(n, Int32Array, i32::MAX as usize), + DataType::Int64 => make_keys!(n, Int64Array, i64::MAX as usize), + DataType::UInt8 => make_keys!(n, UInt8Array, u8::MAX as usize), + DataType::UInt16 => make_keys!(n, UInt16Array, u16::MAX as usize), + DataType::UInt32 => make_keys!(n, UInt32Array, u32::MAX as usize), + DataType::UInt64 => make_keys!(n, UInt64Array, usize::MAX), + _ => unreachable!("schema validated in GroupDictionaryColumn::new"), + } +} + +fn wrap_as_dictionary( + values: ArrayRef, + keys: &ArrayRef, + key_type: &DataType, +) -> ArrayRef { + match key_type { + DataType::Int8 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::Int16 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::Int32 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::Int64 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::UInt8 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::UInt16 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::UInt32 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::UInt64 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + _ => unreachable!("schema validated in GroupDictionaryColumn::new"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::aggregates::group_values::GroupValuesRows; + use arrow::array::StringArray; + use arrow::datatypes::{Field, Schema}; + + fn make_dict(values: &[&str], keys: &[Option]) -> ArrayRef { + let vals = Arc::new(StringArray::from(values.to_vec())); + let keys_arr = Int32Array::from(keys.to_vec()); + Arc::new(DictionaryArray::::try_new(keys_arr, vals).unwrap()) + } + + fn dict_schema() -> SchemaRef { + let dt = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + Arc::new(Schema::new(vec![ + Field::new("a", dt.clone(), true), + Field::new("b", dt, true), + ])) + } + + #[test] + fn test_basic_dedup() { + let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); + let col_a = make_dict(&["x", "y"], &[Some(0), Some(1), Some(0)]); + let col_b = make_dict(&["p", "q"], &[Some(0), Some(1), Some(0)]); + let mut groups = vec![]; + gv.intern(&[col_a, col_b], &mut groups).unwrap(); + assert_eq!(groups, vec![0, 1, 0]); + assert_eq!(gv.len(), 2); + } + + /// Null keys collapse to the same group regardless of which column is null. + #[test] + fn test_null_keys() { + let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); + // Row 0: (null, "p") — key for col_a is null + // Row 1: (null, "p") — same as row 0 + // Row 2: ("x", "p") — distinct from row 0 + let col_a = make_dict(&["x"], &[None, None, Some(0)]); + let col_b = make_dict(&["p"], &[Some(0), Some(0), Some(0)]); + let mut groups = vec![]; + gv.intern(&[col_a, col_b], &mut groups).unwrap(); + assert_eq!( + groups[0], groups[1], + "both null-key rows must be the same group" + ); + assert_ne!( + groups[0], groups[2], + "non-null row must be a different group" + ); + } + + /// When the values array changes between batches (different Arc), the key + /// space re-translates correctly: logical values still deduplicate. + #[test] + fn test_values_array_change_between_batches() { + let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); + + // Batch 1: col_a values = ["a", "b"], col_b values = ["x"] + let col_a1 = make_dict(&["a", "b"], &[Some(0), Some(1)]); + let col_b1 = make_dict(&["x"], &[Some(0), Some(0)]); + let mut groups1 = vec![]; + gv.intern(&[col_a1, col_b1], &mut groups1).unwrap(); + assert_eq!(groups1, vec![0, 1]); + + // Batch 2: same logical rows but DIFFERENT Arc (values order swapped). + // key=0 now means "b", key=1 means "a" — opposite of batch 1. + let col_a2 = make_dict(&["b", "a"], &[Some(0), Some(1)]); + let col_b2 = make_dict(&["x"], &[Some(0), Some(0)]); + let mut groups2 = vec![]; + gv.intern(&[col_a2, col_b2], &mut groups2).unwrap(); + + // ("b", "x") was group 1 in batch 1; ("a", "x") was group 0. + assert_eq!( + groups2[0], 1, + "('b','x') should map to the existing group for 'b'" + ); + assert_eq!( + groups2[1], 0, + "('a','x') should map to the existing group for 'a'" + ); + assert_eq!(gv.len(), 2, "no new groups should have been created"); + } + + fn list_dict_schema() -> SchemaRef { + let item_field = Arc::new(Field::new("item", DataType::Utf8, true)); + let dt = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List(item_field)), + ); + Arc::new(Schema::new(vec![ + Field::new("a", dt.clone(), true), + Field::new("b", dt, true), + ])) + } + + fn make_list_dict(lists: &[Option>], keys: &[Option]) -> ArrayRef { + let mut builder = ListBuilder::new(StringBuilder::new()); + for list_opt in lists { + match list_opt { + None => builder.append_null(), + Some(strings) => { + for s in strings { + builder.values().append_value(s); + } + builder.append(true); + } + } + } + let values = Arc::new(builder.finish()); + let keys_arr = Int32Array::from(keys.to_vec()); + Arc::new(DictionaryArray::::try_new(keys_arr, values).unwrap()) + } + + fn dict_str(arr: &ArrayRef, group: usize) -> &str { + let d = arr.as_dictionary::(); + // Keys are sequential (0..n), so key value == group index. + let val_idx = d.keys().value(group) as usize; + d.values().as_string::().value(val_idx) + } + + fn dict_list(arr: &ArrayRef, group: usize) -> ArrayRef { + let d = arr.as_dictionary::(); + let val_idx = d.keys().value(group) as usize; + d.values().as_list::().value(val_idx) + } + + /// emit(All) reconstructs the correct string values for each group. + #[test] + fn test_emit_all_utf8() { + let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); + let col_a = make_dict(&["x", "y"], &[Some(0), Some(1), Some(0)]); + let col_b = make_dict(&["p", "q"], &[Some(0), Some(1), Some(0)]); + gv.intern(&[col_a, col_b], &mut vec![]).unwrap(); + + let arrays = gv.emit(EmitTo::All).unwrap(); + assert_eq!(arrays[0].len(), 2); + assert_eq!(dict_str(&arrays[0], 0), "x"); + assert_eq!(dict_str(&arrays[0], 1), "y"); + assert_eq!(dict_str(&arrays[1], 0), "p"); + assert_eq!(dict_str(&arrays[1], 1), "q"); + assert_eq!(gv.len(), 0); + } + + /// emit(First n) emits exactly the first n groups and keeps the rest accessible. + #[test] + fn test_emit_first_n_utf8() { + let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); + let col_a = make_dict(&["a", "b", "c"], &[Some(0), Some(1), Some(2)]); + let col_b = make_dict(&["x", "y", "z"], &[Some(0), Some(1), Some(2)]); + gv.intern(&[col_a, col_b], &mut vec![]).unwrap(); + + let arrays = gv.emit(EmitTo::First(2)).unwrap(); + assert_eq!(arrays[0].len(), 2); + assert_eq!(dict_str(&arrays[0], 0), "a"); + assert_eq!(dict_str(&arrays[0], 1), "b"); + + assert_eq!(gv.len(), 1); + let mut groups = vec![]; + let col_a2 = make_dict(&["c"], &[Some(0)]); + let col_b2 = make_dict(&["z"], &[Some(0)]); + gv.intern(&[col_a2, col_b2], &mut groups).unwrap(); + assert_eq!( + groups, + vec![0], + "retained group must map to id 0 after shift" + ); + } + + /// emit(All) correctly reconstructs List values for each group. + #[test] + fn test_emit_list_utf8() { + let mut gv = GroupDictionaryColumn::new(list_dict_schema()).unwrap(); + let col_a = make_list_dict( + &[Some(vec!["hello", "world"]), Some(vec!["foo"])], + &[Some(0), Some(0), Some(1)], + ); + let col_b = make_list_dict(&[Some(vec!["foo"])], &[Some(0), Some(0), Some(0)]); + let mut groups = vec![]; + gv.intern(&[col_a, col_b], &mut groups).unwrap(); + assert_eq!(groups, vec![0, 0, 1]); + + let arrays = gv.emit(EmitTo::All).unwrap(); + assert_eq!(arrays[0].len(), 2); + let g0 = dict_list(&arrays[0], 0); + let g0_strs = g0.as_string::(); + assert_eq!(g0_strs.value(0), "hello"); + assert_eq!(g0_strs.value(1), "world"); + assert_eq!(dict_list(&arrays[0], 1).as_string::().value(0), "foo"); + } + + /// List with null elements inside the list round-trips correctly. + #[test] + fn test_emit_list_with_null_elements() { + let mut gv = GroupDictionaryColumn::new(list_dict_schema()).unwrap(); + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.values().append_value("a"); + builder.values().append_null(); + builder.values().append_value("b"); + builder.append(true); + let list_values = Arc::new(builder.finish()); + let keys_arr = Int32Array::from(vec![Some(0i32)]); + let col = Arc::new( + DictionaryArray::::try_new(keys_arr, list_values).unwrap(), + ) as ArrayRef; + + gv.intern(&[Arc::clone(&col), col], &mut vec![]).unwrap(); + let arrays = gv.emit(EmitTo::All).unwrap(); + let elems = dict_list(&arrays[0], 0); + let strs = elems.as_string::(); + assert_eq!(strs.len(), 3); + assert_eq!(strs.value(0), "a"); + assert!(strs.is_null(1)); + assert_eq!(strs.value(2), "b"); + } + + /// emit(First n) on List: keeps remaining groups intact after the shift. + #[test] + fn test_emit_first_n_list_utf8() { + let mut gv = GroupDictionaryColumn::new(list_dict_schema()).unwrap(); + // Three distinct groups: 0=(["a"],["x"]), 1=(["b"],["y"]), 2=(["c"],["z"]) + let col_a = make_list_dict( + &[Some(vec!["a"]), Some(vec!["b"]), Some(vec!["c"])], + &[Some(0), Some(1), Some(2)], + ); + let col_b = make_list_dict( + &[Some(vec!["x"]), Some(vec!["y"]), Some(vec!["z"])], + &[Some(0), Some(1), Some(2)], + ); + gv.intern(&[col_a, col_b], &mut vec![]).unwrap(); + + let arrays = gv.emit(EmitTo::First(2)).unwrap(); + assert_eq!(arrays[0].len(), 2); + assert_eq!(dict_list(&arrays[0], 0).as_string::().value(0), "a"); + assert_eq!(dict_list(&arrays[0], 1).as_string::().value(0), "b"); + + // Group 2 must survive as group 0 after the shift. + assert_eq!(gv.len(), 1); + let col_a2 = make_list_dict(&[Some(vec!["c"])], &[Some(0)]); + let col_b2 = make_list_dict(&[Some(vec!["z"])], &[Some(0)]); + let mut groups = vec![]; + gv.intern(&[col_a2, col_b2], &mut groups).unwrap(); + assert_eq!(groups, vec![0]); + } + + /// Resolve the logical string at position `i` in a Dictionary array. + /// Works for any key ordering, not just sequential. + fn logical_str(arr: &ArrayRef, i: usize) -> &str { + let d = arr.as_dictionary::(); + let key = d.keys().value(i) as usize; + d.values().as_string::().value(key) + } + + /// GroupDictionaryColumn and GroupValuesRows must assign identical group IDs + /// and produce identical emitted values for the same inputs. + #[test] + fn test_parity_with_group_values_rows_utf8() { + let schema = dict_schema(); + let mut gdc = GroupDictionaryColumn::new(Arc::clone(&schema)).unwrap(); + let mut gvr = GroupValuesRows::try_new(Arc::clone(&schema)).unwrap(); + + let batches: Vec<[ArrayRef; 2]> = vec![ + [ + make_dict(&["a", "b"], &[Some(0), Some(1), Some(0)]), + make_dict(&["x", "y"], &[Some(0), Some(1), Some(0)]), + ], + // second batch reuses same logical values via a fresh values Arc + [ + make_dict(&["b", "c"], &[Some(0), Some(1), Some(0)]), + make_dict(&["y", "z"], &[Some(0), Some(1), Some(0)]), + ], + ]; + + for [col_a, col_b] in &batches { + let mut gdc_groups = vec![]; + let mut gvr_groups = vec![]; + gdc.intern(&[Arc::clone(col_a), Arc::clone(col_b)], &mut gdc_groups) + .unwrap(); + gvr.intern(&[Arc::clone(col_a), Arc::clone(col_b)], &mut gvr_groups) + .unwrap(); + assert_eq!(gdc_groups, gvr_groups, "group IDs must agree"); + } + + let gdc_out = gdc.emit(EmitTo::All).unwrap(); + let gvr_out = gvr.emit(EmitTo::All).unwrap(); + + assert_eq!(gdc_out.len(), gvr_out.len()); + for col_idx in 0..gdc_out.len() { + let n = gdc_out[col_idx].len(); + assert_eq!(n, gvr_out[col_idx].len(), "col {col_idx} length mismatch"); + for i in 0..n { + assert_eq!( + logical_str(&gdc_out[col_idx], i), + logical_str(&gvr_out[col_idx], i), + "col {col_idx} group {i} value mismatch" + ); + } + } + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index f275d777c3279..838d07f9bb55e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -20,6 +20,7 @@ mod boolean; mod bytes; pub mod bytes_view; +pub mod dict; pub mod primitive; use std::mem::{self, size_of};