From 3ff4ccca9c39c8fc92d961d7b9a194fd9a2a1fa1 Mon Sep 17 00:00:00 2001 From: Richard Date: Tue, 16 Jun 2026 08:20:11 -0400 Subject: [PATCH 1/5] working draft --- .../src/aggregates/group_values/mod.rs | 6 + .../group_values/multi_group_by/dict.rs | 431 ++++++++++++++++++ .../group_values/multi_group_by/mod.rs | 1 + 3 files changed, 438 insertions(+) create mode 100644 datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index ee253e5d7afdd..40c497bb98428 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -34,6 +34,7 @@ mod row; pub use row::GroupValuesRows; mod single_group_by; use datafusion_physical_expr::binary_map::OutputType; +use multi_group_by::dict::{GroupDictionaryColumn, all_dictionary_schema}; use multi_group_by::GroupValuesColumn; pub(crate) use single_group_by::primitive::HashValue; @@ -200,6 +201,11 @@ pub fn new_group_values( } } + // Route 2+ all-dictionary columns to the specialised implementation. + if schema.fields().len() >= 2 && all_dictionary_schema(schema.as_ref()) { + return Ok(Box::new(GroupDictionaryColumn::new(schema)?)); + } + if multi_group_by::supported_schema(schema.as_ref()) { if matches!(group_ordering, GroupOrdering::None) { Ok(Box::new(GroupValuesColumn::::try_new(schema)?)) diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs new file mode 100644 index 0000000000000..380897c498c9b --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs @@ -0,0 +1,431 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, GenericStringArray, NullArray, OffsetSizeTrait, +}; +use arrow::datatypes::{ArrowNativeType, DataType, SchemaRef}; +use arrow::downcast_dictionary_array; +use datafusion_common::hash_utils::{RandomState, combine_hashes, create_hashes}; +use datafusion_common::{Result, internal_datafusion_err}; +use datafusion_execution::memory_pool::proxy::HashTableAllocExt; +use datafusion_expr::EmitTo; +use hashbrown::hash_table::HashTable; + +use crate::aggregates::group_values::GroupValues; + +/// Cached hashes and raw bytes for one dictionary column's values array. +/// Rebuilt only when the `Arc` pointer changes (i.e. a new values array arrives). +struct ColumnCache { + /// Keeps the values `Arc` alive; compared with `Arc::ptr_eq` for staleness. + values: ArrayRef, + /// `value_hashes[k]` = hash of the value at dictionary index `k`. + value_hashes: Vec, + /// `value_bytes[k]` = encoded bytes for dictionary index `k`; `None` for null values. + value_bytes: Vec>>, +} + +impl ColumnCache { + fn empty() -> Self { + Self { + values: Arc::new(NullArray::new(0)), + value_hashes: vec![], + value_bytes: vec![], + } + } + + fn update(&mut self, new_values: ArrayRef, rs: &RandomState) -> Result<()> { + if Arc::ptr_eq(&new_values, &self.values) { + return Ok(()); + } + let n = new_values.len(); + let mut hashes = vec![0u64; n]; + create_hashes(&[new_values.clone()], rs, &mut hashes)?; + let bytes = (0..n) + .map(|i| get_value_bytes(new_values.as_ref(), i)) + .collect::>()?; + self.value_hashes = hashes; + self.value_bytes = bytes; + self.values = new_values; + Ok(()) + } +} + +/// [`GroupValues`] for GROUP BY over **two or more** dictionary-typed columns. +/// +/// Rather than decoding dictionary values on every row, this implementation +/// works on the compact integer keys. The cost grows with the number of +/// distinct value *combinations*, not with the number of input rows. +pub struct GroupDictionaryColumn { + schema: SchemaRef, + col_caches: Vec, + /// `(row_hash, group_id)`. Multiple entries may share the same hash value; + /// byte-level comparison is used to resolve collisions. + map: HashTable<(u64, usize)>, + map_size: usize, + /// All group rows packed into a single contiguous buffer. + row_buffer: Vec, + /// `row_offsets[g]` = start byte of group `g` inside `row_buffer`. + row_offsets: Vec, + /// Reused scratch buffer for encoding the current row during `intern`. + row_scratch: Vec, + random_state: RandomState, +} + +/// Returns `true` when every field in `schema` is `DataType::Dictionary`. +pub fn all_dictionary_schema(schema: &arrow::datatypes::Schema) -> bool { + schema + .fields() + .iter() + .all(|f| matches!(f.data_type(), DataType::Dictionary(_, _))) +} + +impl GroupDictionaryColumn { + pub fn new(schema: SchemaRef) -> Result { + if schema.fields().len() < 2 { + return Err(internal_datafusion_err!( + "GroupDictionaryColumn requires at least 2 columns, got {}", + schema.fields().len() + )); + } + for f in schema.fields() { + if !matches!(f.data_type(), DataType::Dictionary(_, _)) { + return Err(internal_datafusion_err!( + "GroupDictionaryColumn requires all columns to be Dictionary, \ + but '{}' has type {}", + f.name(), + f.data_type() + )); + } + } + let n_cols = schema.fields().len(); + Ok(Self { + schema, + col_caches: (0..n_cols).map(|_| ColumnCache::empty()).collect(), + map: HashTable::with_capacity(128), + map_size: 0, + row_buffer: Vec::new(), + row_offsets: Vec::new(), + row_scratch: Vec::new(), + random_state: crate::aggregates::AGGREGATION_HASH_SEED, + }) + } +} + +// ── value-byte extraction ───────────────────────────────────────────────────── + +/// Return the encoded bytes for `values[idx]`, or `None` if the value is null. +/// +/// **Adding new value types**: add a match arm here. The encoding only needs +/// to be injective (distinct values → distinct byte sequences); it does not +/// need to be reversible for `intern` purposes (only `emit` needs decoding). +fn get_value_bytes(values: &dyn Array, idx: usize) -> Result>> { + if values.is_null(idx) { + return Ok(None); + } + match values.data_type() { + DataType::Utf8 => Ok(Some( + values.as_string::().value(idx).as_bytes().to_vec(), + )), + DataType::LargeUtf8 => Ok(Some( + values.as_string::().value(idx).as_bytes().to_vec(), + )), + DataType::Utf8View => { + Ok(Some(values.as_string_view().value(idx).as_bytes().to_vec())) + } + DataType::List(f) if matches!(f.data_type(), DataType::Utf8) => { + let list = values.as_list::().value(idx); + Ok(Some(encode_string_list(list.as_string::()))) + } + DataType::LargeList(f) if matches!(f.data_type(), DataType::LargeUtf8) => { + let list = values.as_list::().value(idx); + Ok(Some(encode_string_list(list.as_string::()))) + } + t => Err(internal_datafusion_err!( + "GroupDictionaryColumn: unsupported dictionary value type {t}" + )), + } +} + +/// Encode a string list as `[n: u32 LE]` then for each item +/// `[0x00]` (null) or `[0x01][len: u32 LE][bytes]`. +fn encode_string_list(arr: &GenericStringArray) -> Vec { + let mut buf = Vec::new(); + buf.extend_from_slice(&(arr.len() as u32).to_le_bytes()); + for j in 0..arr.len() { + if arr.is_null(j) { + buf.push(0); + } else { + let s = arr.value(j).as_bytes(); + buf.push(1); + buf.extend_from_slice(&(s.len() as u32).to_le_bytes()); + buf.extend_from_slice(s); + } + } + buf +} + +// ── dictionary key access ───────────────────────────────────────────────────── + +fn dict_values_array(col: &dyn Array) -> ArrayRef { + downcast_dictionary_array!( + col => col.values().clone(), + _ => unreachable!("schema validated in GroupDictionaryColumn::new") + ) +} + +/// Pre-collect all row keys for one dictionary column as `Option`. +/// Doing this upfront avoids repeated macro dispatch inside the hot row loop. +fn collect_keys(col: &dyn Array) -> Vec> { + downcast_dictionary_array!( + col => col.keys().iter().map(|k| k.map(|v| v.as_usize())).collect(), + _ => unreachable!("schema validated in GroupDictionaryColumn::new") + ) +} + +// ── row encoding ────────────────────────────────────────────────────────────── + +/// Append one column's contribution to the scratch row buffer. +/// +/// Per-column wire format: +/// null (null key **or** null dictionary value): `[0x00]` +/// non-null value: `[0x01][len: u32 LE][bytes…]` +#[inline] +fn push_col_bytes(buf: &mut Vec, cache: &ColumnCache, key: Option) { + match key.and_then(|k| cache.value_bytes[k].as_deref()) { + None => buf.push(0), + Some(vb) => { + buf.push(1); + buf.extend_from_slice(&(vb.len() as u32).to_le_bytes()); + buf.extend_from_slice(vb); + } + } +} + +// ── GroupValues impl ────────────────────────────────────────────────────────── + +impl GroupValues for GroupDictionaryColumn { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + debug_assert_eq!(cols.len(), self.schema.fields().len()); + groups.clear(); + + if cols.is_empty() || cols[0].is_empty() { + return Ok(()); + } + let n_rows = cols[0].len(); + + // Refresh column caches; a cache hit (same Arc pointer) is a no-op. + for (c, col) in cols.iter().enumerate() { + self.col_caches[c] + .update(dict_values_array(col.as_ref()), &self.random_state)?; + } + + // Pre-collect keys for all columns: avoids per-row macro dispatch. + let all_keys: Vec>> = + cols.iter().map(|c| collect_keys(c.as_ref())).collect(); + + groups.reserve(n_rows); + + for row in 0..n_rows { + // 1. Combine per-column value hashes into one row hash. + let combined_hash = { + let mut h = 0u64; + for (c, cache) in self.col_caches.iter().enumerate() { + let vh = all_keys[c][row].map_or(0, |k| cache.value_hashes[k]); + h = combine_hashes(h, vh); + } + h + }; + + // 2. Encode the row into the reusable scratch buffer. + self.row_scratch.clear(); + for (c, cache) in self.col_caches.iter().enumerate() { + push_col_bytes(&mut self.row_scratch, cache, all_keys[c][row]); + } + + // 3. Look up an existing group whose bytes match. + let found = { + let scratch = self.row_scratch.as_slice(); + let rb = self.row_buffer.as_slice(); + let ro = self.row_offsets.as_slice(); + self.map + .find(combined_hash, |&(h, g)| { + h == combined_hash && { + let end = ro.get(g + 1).copied().unwrap_or(rb.len()); + rb[ro[g]..end] == *scratch + } + }) + .map(|&(_, g)| g) + }; + + // 4. Reuse existing group or create a new one. + let group_id = match found { + Some(g) => g, + None => { + let g = self.row_offsets.len(); + self.row_offsets.push(self.row_buffer.len()); + self.row_buffer.extend_from_slice(&self.row_scratch); + self.map.insert_accounted( + (combined_hash, g), + |(h, _)| *h, + &mut self.map_size, + ); + g + } + }; + + groups.push(group_id); + } + + Ok(()) + } + + fn size(&self) -> usize { + let cache_bytes: usize = self + .col_caches + .iter() + .map(|c| { + c.value_hashes.len() * size_of::() + + c.value_bytes + .iter() + .map(|b| b.as_ref().map_or(0, |v| v.len())) + .sum::() + }) + .sum(); + self.map_size + + self.row_buffer.len() + + self.row_offsets.len() * size_of::() + + cache_bytes + } + + fn is_empty(&self) -> bool { + self.row_offsets.is_empty() + } + + fn len(&self) -> usize { + self.row_offsets.len() + } + + fn emit(&mut self, _emit_to: EmitTo) -> Result> { + todo!("GroupDictionaryColumn::emit") + } + + fn clear_shrink(&mut self, num_rows: usize) { + self.map.clear(); + self.map.shrink_to(num_rows, |_| 0); + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); + self.row_buffer.clear(); + self.row_buffer.shrink_to(num_rows * 16); + self.row_offsets.clear(); + self.row_offsets.shrink_to(num_rows); + for cache in &mut self.col_caches { + *cache = ColumnCache::empty(); + } + } +} + +// ── tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{DictionaryArray, Int32Array, StringArray}; + use arrow::datatypes::{Field, Int32Type, Schema}; + + fn make_dict(values: &[&str], keys: &[Option]) -> ArrayRef { + let vals = Arc::new(StringArray::from(values.to_vec())); + let keys_arr = Int32Array::from(keys.to_vec()); + Arc::new(DictionaryArray::::try_new(keys_arr, vals).unwrap()) + } + + fn dict_schema() -> SchemaRef { + let dt = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + Arc::new(Schema::new(vec![ + Field::new("a", dt.clone(), true), + Field::new("b", dt, true), + ])) + } + + /// Same row twice in one batch → same group id; distinct rows → different ids. + #[test] + fn test_basic_dedup() { + let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); + let col_a = make_dict(&["x", "y"], &[Some(0), Some(1), Some(0)]); + let col_b = make_dict(&["p", "q"], &[Some(0), Some(1), Some(0)]); + let mut groups = vec![]; + gv.intern(&[col_a, col_b], &mut groups).unwrap(); + assert_eq!(groups, vec![0, 1, 0]); + assert_eq!(gv.len(), 2); + } + + /// Null keys collapse to the same group regardless of which column is null. + #[test] + fn test_null_keys() { + let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); + // Row 0: (null, "p") — key for col_a is null + // Row 1: (null, "p") — same as row 0 + // Row 2: ("x", "p") — distinct from row 0 + let col_a = make_dict(&["x"], &[None, None, Some(0)]); + let col_b = make_dict(&["p"], &[Some(0), Some(0), Some(0)]); + let mut groups = vec![]; + gv.intern(&[col_a, col_b], &mut groups).unwrap(); + assert_eq!( + groups[0], groups[1], + "both null-key rows must be the same group" + ); + assert_ne!( + groups[0], groups[2], + "non-null row must be a different group" + ); + } + + /// When the values array changes between batches (different Arc), the key + /// space re-translates correctly: logical values still deduplicate. + #[test] + fn test_values_array_change_between_batches() { + let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); + + // Batch 1: col_a values = ["a", "b"], col_b values = ["x"] + let a1 = make_dict(&["a", "b"], &[Some(0), Some(1)]); + let b1 = make_dict(&["x"], &[Some(0), Some(0)]); + let mut g1 = vec![]; + gv.intern(&[a1, b1], &mut g1).unwrap(); + assert_eq!(g1, vec![0, 1]); + + // Batch 2: same logical rows but DIFFERENT Arc (values order swapped). + // key=0 now means "b", key=1 means "a" — opposite of batch 1. + let a2 = make_dict(&["b", "a"], &[Some(0), Some(1)]); + let b2 = make_dict(&["x"], &[Some(0), Some(0)]); + let mut g2 = vec![]; + gv.intern(&[a2, b2], &mut g2).unwrap(); + + // ("b", "x") was group 1 in batch 1; ("a", "x") was group 0. + assert_eq!( + g2[0], 1, + "('b','x') should map to the existing group for 'b'" + ); + assert_eq!( + g2[1], 0, + "('a','x') should map to the existing group for 'a'" + ); + assert_eq!(gv.len(), 2, "no new groups should have been created"); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index f275d777c3279..838d07f9bb55e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -20,6 +20,7 @@ mod boolean; mod bytes; pub mod bytes_view; +pub mod dict; pub mod primitive; use std::mem::{self, size_of}; From 16f74f52085556861186e70d62b0499e25296386 Mon Sep 17 00:00:00 2001 From: Richard Date: Tue, 16 Jun 2026 14:43:26 -0400 Subject: [PATCH 2/5] inital draft --- .../src/aggregates/group_values/mod.rs | 2 +- .../group_values/multi_group_by/dict.rs | 826 ++++++++++++++---- 2 files changed, 649 insertions(+), 179 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 40c497bb98428..807bc9029522b 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -34,8 +34,8 @@ mod row; pub use row::GroupValuesRows; mod single_group_by; use datafusion_physical_expr::binary_map::OutputType; -use multi_group_by::dict::{GroupDictionaryColumn, all_dictionary_schema}; use multi_group_by::GroupValuesColumn; +use multi_group_by::dict::{GroupDictionaryColumn, all_dictionary_schema}; pub(crate) use single_group_by::primitive::HashValue; diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs index 380897c498c9b..96a47e52eb928 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs @@ -19,9 +19,14 @@ use std::mem::size_of; use std::sync::Arc; use arrow::array::{ - Array, ArrayRef, AsArray, GenericStringArray, NullArray, OffsetSizeTrait, + Array, ArrayRef, AsArray, DictionaryArray, Int8Array, Int16Array, Int32Array, + Int64Array, ListBuilder, NullArray, StringBuilder, UInt8Array, UInt16Array, + UInt32Array, UInt64Array, +}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, Schema, + SchemaRef, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; -use arrow::datatypes::{ArrowNativeType, DataType, SchemaRef}; use arrow::downcast_dictionary_array; use datafusion_common::hash_utils::{RandomState, combine_hashes, create_hashes}; use datafusion_common::{Result, internal_datafusion_err}; @@ -31,15 +36,13 @@ use hashbrown::hash_table::HashTable; use crate::aggregates::group_values::GroupValues; -/// Cached hashes and raw bytes for one dictionary column's values array. +/// Caches the hashes for one dictionary column's values array. /// Rebuilt only when the `Arc` pointer changes (i.e. a new values array arrives). struct ColumnCache { - /// Keeps the values `Arc` alive; compared with `Arc::ptr_eq` for staleness. + /// Keeps the values `Arc` alive and is compared with `Arc::ptr_eq` to detect staleness. values: ArrayRef, /// `value_hashes[k]` = hash of the value at dictionary index `k`. value_hashes: Vec, - /// `value_bytes[k]` = encoded bytes for dictionary index `k`; `None` for null values. - value_bytes: Vec>>, } impl ColumnCache { @@ -47,54 +50,71 @@ impl ColumnCache { Self { values: Arc::new(NullArray::new(0)), value_hashes: vec![], - value_bytes: vec![], } } - fn update(&mut self, new_values: ArrayRef, rs: &RandomState) -> Result<()> { + fn update(&mut self, new_values: ArrayRef, random_state: &RandomState) -> Result<()> { if Arc::ptr_eq(&new_values, &self.values) { return Ok(()); } - let n = new_values.len(); - let mut hashes = vec![0u64; n]; - create_hashes(&[new_values.clone()], rs, &mut hashes)?; - let bytes = (0..n) - .map(|i| get_value_bytes(new_values.as_ref(), i)) - .collect::>()?; - self.value_hashes = hashes; - self.value_bytes = bytes; + let num_values = new_values.len(); + // Reuse the allocation; only grows capacity when a larger values array arrives. + self.value_hashes.clear(); + self.value_hashes.resize(num_values, 0u64); + create_hashes( + std::slice::from_ref(&new_values), + random_state, + &mut self.value_hashes, + )?; self.values = new_values; Ok(()) } + + fn size(&self) -> usize { + self.value_hashes.len() * size_of::() + } + + fn clear_shrink(&mut self, shrink_to: usize) { + self.values = Arc::new(NullArray::new(0)); + self.value_hashes.clear(); + self.value_hashes.shrink_to(shrink_to); + } } /// [`GroupValues`] for GROUP BY over **two or more** dictionary-typed columns. -/// -/// Rather than decoding dictionary values on every row, this implementation -/// works on the compact integer keys. The cost grows with the number of -/// distinct value *combinations*, not with the number of input rows. pub struct GroupDictionaryColumn { schema: SchemaRef, col_caches: Vec, /// `(row_hash, group_id)`. Multiple entries may share the same hash value; /// byte-level comparison is used to resolve collisions. map: HashTable<(u64, usize)>, + /// Tracked allocation size of `map` in bytes, updated on every insert and shrink. map_size: usize, - /// All group rows packed into a single contiguous buffer. + /// All group rows packed back-to-back into a single contiguous buffer. + /// + /// CSR-style layout: `row_offsets[g]` is the start of group `g` and + /// `row_offsets[g+1]` is its end. The last group has no `g+1` entry; its + /// end is `row_buffer.len()`. row_buffer: Vec, /// `row_offsets[g]` = start byte of group `g` inside `row_buffer`. row_offsets: Vec, - /// Reused scratch buffer for encoding the current row during `intern`. + /// Reused scratch buffer for encoding the current row. row_scratch: Vec, + row_decoder: RowSetDecoder, random_state: RandomState, } /// Returns `true` when every field in `schema` is `DataType::Dictionary`. -pub fn all_dictionary_schema(schema: &arrow::datatypes::Schema) -> bool { +pub fn all_dictionary_schema(schema: &Schema) -> bool { schema .fields() .iter() - .all(|f| matches!(f.data_type(), DataType::Dictionary(_, _))) + .all(|field| matches!(field.data_type(), DataType::Dictionary(_, _))) +} + +fn is_supported_value_type(data_type: &DataType) -> bool { + matches!(data_type, DataType::Utf8) + || matches!(data_type, DataType::List(f) if f.data_type() == &DataType::Utf8) } impl GroupDictionaryColumn { @@ -105,17 +125,30 @@ impl GroupDictionaryColumn { schema.fields().len() )); } - for f in schema.fields() { - if !matches!(f.data_type(), DataType::Dictionary(_, _)) { - return Err(internal_datafusion_err!( - "GroupDictionaryColumn requires all columns to be Dictionary, \ - but '{}' has type {}", - f.name(), - f.data_type() - )); + for field in schema.fields() { + match field.data_type() { + DataType::Dictionary(_, value_type) => { + if !is_supported_value_type(value_type) { + return Err(internal_datafusion_err!( + "GroupDictionaryColumn: unsupported dictionary value type \ + '{}' in column '{}'", + value_type, + field.name() + )); + } + } + _ => { + return Err(internal_datafusion_err!( + "GroupDictionaryColumn requires all columns to be Dictionary, \ + but '{}' has type {}", + field.name(), + field.data_type() + )); + } } } let n_cols = schema.fields().len(); + let row_decoder = RowSetDecoder::new(&schema); Ok(Self { schema, col_caches: (0..n_cols).map(|_| ColumnCache::empty()).collect(), @@ -124,103 +157,36 @@ impl GroupDictionaryColumn { row_buffer: Vec::new(), row_offsets: Vec::new(), row_scratch: Vec::new(), + row_decoder, random_state: crate::aggregates::AGGREGATION_HASH_SEED, }) } } -// ── value-byte extraction ───────────────────────────────────────────────────── - -/// Return the encoded bytes for `values[idx]`, or `None` if the value is null. -/// -/// **Adding new value types**: add a match arm here. The encoding only needs -/// to be injective (distinct values → distinct byte sequences); it does not -/// need to be reversible for `intern` purposes (only `emit` needs decoding). -fn get_value_bytes(values: &dyn Array, idx: usize) -> Result>> { - if values.is_null(idx) { - return Ok(None); - } - match values.data_type() { - DataType::Utf8 => Ok(Some( - values.as_string::().value(idx).as_bytes().to_vec(), - )), - DataType::LargeUtf8 => Ok(Some( - values.as_string::().value(idx).as_bytes().to_vec(), - )), - DataType::Utf8View => { - Ok(Some(values.as_string_view().value(idx).as_bytes().to_vec())) - } - DataType::List(f) if matches!(f.data_type(), DataType::Utf8) => { - let list = values.as_list::().value(idx); - Ok(Some(encode_string_list(list.as_string::()))) - } - DataType::LargeList(f) if matches!(f.data_type(), DataType::LargeUtf8) => { - let list = values.as_list::().value(idx); - Ok(Some(encode_string_list(list.as_string::()))) - } - t => Err(internal_datafusion_err!( - "GroupDictionaryColumn: unsupported dictionary value type {t}" - )), - } -} - -/// Encode a string list as `[n: u32 LE]` then for each item -/// `[0x00]` (null) or `[0x01][len: u32 LE][bytes]`. -fn encode_string_list(arr: &GenericStringArray) -> Vec { - let mut buf = Vec::new(); - buf.extend_from_slice(&(arr.len() as u32).to_le_bytes()); - for j in 0..arr.len() { - if arr.is_null(j) { - buf.push(0); - } else { - let s = arr.value(j).as_bytes(); - buf.push(1); - buf.extend_from_slice(&(s.len() as u32).to_le_bytes()); - buf.extend_from_slice(s); - } - } - buf -} - -// ── dictionary key access ───────────────────────────────────────────────────── - fn dict_values_array(col: &dyn Array) -> ArrayRef { downcast_dictionary_array!( - col => col.values().clone(), + col => Arc::clone(col.values()), _ => unreachable!("schema validated in GroupDictionaryColumn::new") ) } -/// Pre-collect all row keys for one dictionary column as `Option`. -/// Doing this upfront avoids repeated macro dispatch inside the hot row loop. -fn collect_keys(col: &dyn Array) -> Vec> { +// Box is required: different key widths (Int8/Int16/Int32/Int64) produce different concrete iterator types. +fn fill_keys(col: &dyn Array) -> Box> + '_> { downcast_dictionary_array!( - col => col.keys().iter().map(|k| k.map(|v| v.as_usize())).collect(), + col => { + let keys = col.keys(); + Box::new((0..keys.len()).map(move |row_idx| { + if keys.is_valid(row_idx) { + Some(keys.value(row_idx).as_usize()) + } else { + None + } + })) + }, _ => unreachable!("schema validated in GroupDictionaryColumn::new") ) } -// ── row encoding ────────────────────────────────────────────────────────────── - -/// Append one column's contribution to the scratch row buffer. -/// -/// Per-column wire format: -/// null (null key **or** null dictionary value): `[0x00]` -/// non-null value: `[0x01][len: u32 LE][bytes…]` -#[inline] -fn push_col_bytes(buf: &mut Vec, cache: &ColumnCache, key: Option) { - match key.and_then(|k| cache.value_bytes[k].as_deref()) { - None => buf.push(0), - Some(vb) => { - buf.push(1); - buf.extend_from_slice(&(vb.len() as u32).to_le_bytes()); - buf.extend_from_slice(vb); - } - } -} - -// ── GroupValues impl ────────────────────────────────────────────────────────── - impl GroupValues for GroupDictionaryColumn { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { debug_assert_eq!(cols.len(), self.schema.fields().len()); @@ -231,63 +197,58 @@ impl GroupValues for GroupDictionaryColumn { } let n_rows = cols[0].len(); - // Refresh column caches; a cache hit (same Arc pointer) is a no-op. - for (c, col) in cols.iter().enumerate() { - self.col_caches[c] + for (col_idx, col) in cols.iter().enumerate() { + self.col_caches[col_idx] .update(dict_values_array(col.as_ref()), &self.random_state)?; } - // Pre-collect keys for all columns: avoids per-row macro dispatch. - let all_keys: Vec>> = - cols.iter().map(|c| collect_keys(c.as_ref())).collect(); - - groups.reserve(n_rows); - - for row in 0..n_rows { - // 1. Combine per-column value hashes into one row hash. - let combined_hash = { - let mut h = 0u64; - for (c, cache) in self.col_caches.iter().enumerate() { - let vh = all_keys[c][row].map_or(0, |k| cache.value_hashes[k]); - h = combine_hashes(h, vh); - } - h - }; + // Downcast once per column; advance with .next() per row to avoid per-row downcast. + let mut key_iters: Vec<_> = + cols.iter().map(|col| fill_keys(col.as_ref())).collect(); - // 2. Encode the row into the reusable scratch buffer. + let _ = groups.try_reserve(n_rows); + for _row in 0..n_rows { + let mut hash = 0u64; self.row_scratch.clear(); - for (c, cache) in self.col_caches.iter().enumerate() { - push_col_bytes(&mut self.row_scratch, cache, all_keys[c][row]); + + for (col_idx, key_iter) in key_iters.iter_mut().enumerate() { + let key = key_iter.next().unwrap(); + let cache = &self.col_caches[col_idx]; + let value_hash = key.map_or(0, |key_idx| cache.value_hashes[key_idx]); + hash = combine_hashes(hash, value_hash); + encode_value(key, cache.values.as_ref(), &mut self.row_scratch); } - // 3. Look up an existing group whose bytes match. + let combined_hash = hash; let found = { - let scratch = self.row_scratch.as_slice(); - let rb = self.row_buffer.as_slice(); - let ro = self.row_offsets.as_slice(); + let row_scratch = self.row_scratch.as_slice(); + let row_buffer = self.row_buffer.as_slice(); + let row_offsets = self.row_offsets.as_slice(); self.map - .find(combined_hash, |&(h, g)| { - h == combined_hash && { - let end = ro.get(g + 1).copied().unwrap_or(rb.len()); - rb[ro[g]..end] == *scratch + .find(combined_hash, |&(stored_hash, group_id)| { + stored_hash == combined_hash && { + let end = row_offsets + .get(group_id + 1) + .copied() + .unwrap_or(row_buffer.len()); // last group has no g+1 entry + row_buffer[row_offsets[group_id]..end] == *row_scratch } }) - .map(|&(_, g)| g) + .map(|&(_, group_id)| group_id) }; - // 4. Reuse existing group or create a new one. let group_id = match found { - Some(g) => g, + Some(existing_id) => existing_id, None => { - let g = self.row_offsets.len(); + let new_id = self.row_offsets.len(); self.row_offsets.push(self.row_buffer.len()); self.row_buffer.extend_from_slice(&self.row_scratch); self.map.insert_accounted( - (combined_hash, g), - |(h, _)| *h, + (combined_hash, new_id), + |(stored_hash, _)| *stored_hash, &mut self.map_size, ); - g + new_id } }; @@ -298,20 +259,11 @@ impl GroupValues for GroupDictionaryColumn { } fn size(&self) -> usize { - let cache_bytes: usize = self - .col_caches - .iter() - .map(|c| { - c.value_hashes.len() * size_of::() - + c.value_bytes - .iter() - .map(|b| b.as_ref().map_or(0, |v| v.len())) - .sum::() - }) - .sum(); + let cache_bytes: usize = self.col_caches.iter().map(|c| c.size()).sum(); self.map_size + self.row_buffer.len() + self.row_offsets.len() * size_of::() + + self.row_scratch.capacity() + cache_bytes } @@ -323,8 +275,64 @@ impl GroupValues for GroupDictionaryColumn { self.row_offsets.len() } - fn emit(&mut self, _emit_to: EmitTo) -> Result> { - todo!("GroupDictionaryColumn::emit") + fn emit(&mut self, emit_to: EmitTo) -> Result> { + let n_total = self.row_offsets.len(); + if n_total == 0 { + return Ok(self.row_decoder.finish()); + } + let n_emit = match emit_to { + EmitTo::All => n_total, + EmitTo::First(n) => n.min(n_total), + }; + + for row_idx in 0..n_emit { + let start = self.row_offsets[row_idx]; + let end = self + .row_offsets + .get(row_idx + 1) + .copied() + .unwrap_or(self.row_buffer.len()); + self.row_decoder.decode(&self.row_buffer[start..end]); + } + let inner = self.row_decoder.finish(); + let arrays: Vec = inner + .into_iter() + .zip(self.schema.fields()) + .map(|(values, field)| match field.data_type() { + DataType::Dictionary(key_type, _) => wrap_as_dictionary( + values, + &make_sequential_keys(n_emit, key_type), + key_type, + ), + _ => unreachable!("schema validated in GroupDictionaryColumn::new"), + }) + .collect(); + + if n_emit == n_total { + self.row_buffer.clear(); + self.row_offsets.clear(); + self.map.clear(); + self.map_size = 0; + } else { + let retain_start = self.row_offsets[n_emit]; + self.row_offsets.drain(0..n_emit); + for offset in &mut self.row_offsets { + *offset -= retain_start; + } + self.row_buffer.drain(0..retain_start); + // avoiding this somehow would be nice. worse case this runs once + // VecDeque? + // Shift remaining group ids in-place; retain gives &mut access so no rehashing occurs. + self.map.retain(|(_, gid)| { + if *gid < n_emit { + return false; + } + *gid -= n_emit; + true + }); + } + + Ok(arrays) } fn clear_shrink(&mut self, num_rows: usize) { @@ -332,13 +340,258 @@ impl GroupValues for GroupDictionaryColumn { self.map.shrink_to(num_rows, |_| 0); self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.row_buffer.clear(); - self.row_buffer.shrink_to(num_rows * 16); self.row_offsets.clear(); self.row_offsets.shrink_to(num_rows); for cache in &mut self.col_caches { - *cache = ColumnCache::empty(); + cache.clear_shrink(num_rows); + } + } +} + +// ── encoding / decoding ─────────────────────────────────────────────────────── + +/// Wire format per column: +/// null: `[0x00]` +/// non-null scalar: `[0x01][len: u32 LE][utf8_bytes…]` +/// non-null list: `[0x01][content_len: u32 LE][n: u32 LE][elem…]` +/// where each elem is `[0x00]` (null) or `[0x01][len: u32 LE][utf8_bytes…]` +fn encode_value(key: Option, values: &dyn Array, buf: &mut Vec) { + let key_idx = match key { + None => { + buf.push(0); + return; + } + Some(k) => k, + }; + if values.is_null(key_idx) { + buf.push(0); + return; + } + buf.push(1); + match values.data_type() { + DataType::Utf8 => { + let bytes = values.as_string::().value(key_idx).as_bytes(); + buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + buf.extend_from_slice(bytes); + } + DataType::List(_) => { + // Back-fill content_len after encoding all elements. + let len_pos = buf.len(); + buf.extend_from_slice(&[0u8; 4]); + let content_start = buf.len(); + + let list_element = values.as_list::().value(key_idx); + let str_array = list_element.as_string::(); + buf.extend_from_slice(&(str_array.len() as u32).to_le_bytes()); + for elem_idx in 0..str_array.len() { + if str_array.is_null(elem_idx) { + buf.push(0); + } else { + let elem_bytes = str_array.value(elem_idx).as_bytes(); + buf.push(1); + buf.extend_from_slice(&(elem_bytes.len() as u32).to_le_bytes()); + buf.extend_from_slice(elem_bytes); + } + } + let content_len = (buf.len() - content_start) as u32; + buf[len_pos..len_pos + 4].copy_from_slice(&content_len.to_le_bytes()); + } + dt => panic!("unsupported dictionary value type: {dt}"), + } +} + +#[derive(Debug)] +enum ColumnBuilder { + Utf8(StringBuilder), + ListUtf8(ListBuilder), +} + +macro_rules! dispatch_builder { + ($self:expr, $b:ident => $body:expr) => { + match $self { + ColumnBuilder::Utf8($b) => $body, + ColumnBuilder::ListUtf8($b) => $body, + } + }; +} + +impl ColumnBuilder { + fn from_value_type(value_type: &DataType) -> Self { + match value_type { + DataType::Utf8 => Self::Utf8(StringBuilder::new()), + DataType::List(_) => Self::ListUtf8(ListBuilder::new(StringBuilder::new())), + _ => unreachable!("value type validated in GroupDictionaryColumn::new"), + } + } + + fn append_null(&mut self) { + dispatch_builder!(self, b => b.append_null()) + } + + fn append_bytes(&mut self, bytes: &[u8]) { + match self { + // SAFETY: bytes come from Arrow string arrays, always valid UTF-8. + Self::Utf8(b) => { + b.append_value(unsafe { std::str::from_utf8_unchecked(bytes) }) + } + Self::ListUtf8(b) => { + let mut cursor = 0; + let n = u32::from_le_bytes(bytes[cursor..cursor + 4].try_into().unwrap()) + as usize; + cursor += 4; + for _ in 0..n { + match bytes[cursor] { + 0 => { + b.values().append_null(); + cursor += 1; + } + _ => { + cursor += 1; + let len = u32::from_le_bytes( + bytes[cursor..cursor + 4].try_into().unwrap(), + ) as usize; + cursor += 4; + // SAFETY: bytes come from Arrow string arrays, always valid UTF-8. + b.values().append_value(unsafe { + std::str::from_utf8_unchecked( + &bytes[cursor..cursor + len], + ) + }); + cursor += len; + } + } + } + b.append(true); + } + } + } + + fn finish(&mut self) -> ArrayRef { + dispatch_builder!(self, b => Arc::new(b.finish())) + } +} + +/// Accumulates encoded row slices and reconstructs them into Arrow arrays on `finish`. +#[derive(Debug)] +struct RowSetDecoder { + builders: Vec, +} + +impl RowSetDecoder { + fn new(schema: &Schema) -> Self { + let builders = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Dictionary(_, value_type) => { + ColumnBuilder::from_value_type(value_type) + } + _ => unreachable!("schema validated in GroupDictionaryColumn::new"), + }) + .collect(); + Self { builders } + } + + /// Expected format: one column entry per schema field, in schema order. + /// Each column: `[0x00]` (null) or `[0x01][content_len: u32 LE][content…]`. + /// Utf8 content: raw UTF-8 bytes. + /// List content: `[n: u32 LE][elem…]` where each elem is `[0x00]` or `[0x01][len: u32 LE][utf8_bytes…]`. + fn decode(&mut self, encoded: &[u8]) { + let mut cursor = 0; + for builder in &mut self.builders { + match encoded[cursor] { + 0 => { + builder.append_null(); + cursor += 1; + } + _ => { + cursor += 1; + let len = u32::from_le_bytes( + encoded[cursor..cursor + 4].try_into().unwrap(), + ) as usize; + cursor += 4; + builder.append_bytes(&encoded[cursor..cursor + len]); + cursor += len; + } + } } } + + fn finish(&mut self) -> Vec { + self.builders.iter_mut().map(|b| b.finish()).collect() + } +} + +/// Build sequential keys `[0, 1, ..., n-1]` with the key type taken from the schema. +/// All columns share the same key type, so callers should build this once and clone the Arc. +fn make_sequential_keys(n: usize, key_type: &DataType) -> ArrayRef { + match key_type { + DataType::Int8 => Arc::new(Int8Array::from_iter_values((0..n).map(|i| i as i8))), + DataType::Int16 => { + Arc::new(Int16Array::from_iter_values((0..n).map(|i| i as i16))) + } + DataType::Int32 => { + Arc::new(Int32Array::from_iter_values((0..n).map(|i| i as i32))) + } + DataType::Int64 => { + Arc::new(Int64Array::from_iter_values((0..n).map(|i| i as i64))) + } + DataType::UInt8 => { + Arc::new(UInt8Array::from_iter_values((0..n).map(|i| i as u8))) + } + DataType::UInt16 => { + Arc::new(UInt16Array::from_iter_values((0..n).map(|i| i as u16))) + } + DataType::UInt32 => { + Arc::new(UInt32Array::from_iter_values((0..n).map(|i| i as u32))) + } + DataType::UInt64 => { + Arc::new(UInt64Array::from_iter_values((0..n).map(|i| i as u64))) + } + _ => unreachable!("schema validated in GroupDictionaryColumn::new"), + } +} + +fn wrap_as_dictionary( + values: ArrayRef, + keys: &ArrayRef, + key_type: &DataType, +) -> ArrayRef { + match key_type { + DataType::Int8 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::Int16 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::Int32 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::Int64 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::UInt8 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::UInt16 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::UInt32 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + DataType::UInt64 => Arc::new(DictionaryArray::::new( + keys.as_primitive::().clone(), + values, + )), + _ => unreachable!("schema validated in GroupDictionaryColumn::new"), + } } // ── tests ───────────────────────────────────────────────────────────────────── @@ -346,8 +599,9 @@ impl GroupValues for GroupDictionaryColumn { #[cfg(test)] mod tests { use super::*; - use arrow::array::{DictionaryArray, Int32Array, StringArray}; - use arrow::datatypes::{Field, Int32Type, Schema}; + use crate::aggregates::group_values::GroupValuesRows; + use arrow::array::StringArray; + use arrow::datatypes::{Field, Schema}; fn make_dict(values: &[&str], keys: &[Option]) -> ArrayRef { let vals = Arc::new(StringArray::from(values.to_vec())); @@ -404,28 +658,244 @@ mod tests { let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); // Batch 1: col_a values = ["a", "b"], col_b values = ["x"] - let a1 = make_dict(&["a", "b"], &[Some(0), Some(1)]); - let b1 = make_dict(&["x"], &[Some(0), Some(0)]); - let mut g1 = vec![]; - gv.intern(&[a1, b1], &mut g1).unwrap(); - assert_eq!(g1, vec![0, 1]); + let col_a1 = make_dict(&["a", "b"], &[Some(0), Some(1)]); + let col_b1 = make_dict(&["x"], &[Some(0), Some(0)]); + let mut groups1 = vec![]; + gv.intern(&[col_a1, col_b1], &mut groups1).unwrap(); + assert_eq!(groups1, vec![0, 1]); // Batch 2: same logical rows but DIFFERENT Arc (values order swapped). // key=0 now means "b", key=1 means "a" — opposite of batch 1. - let a2 = make_dict(&["b", "a"], &[Some(0), Some(1)]); - let b2 = make_dict(&["x"], &[Some(0), Some(0)]); - let mut g2 = vec![]; - gv.intern(&[a2, b2], &mut g2).unwrap(); + let col_a2 = make_dict(&["b", "a"], &[Some(0), Some(1)]); + let col_b2 = make_dict(&["x"], &[Some(0), Some(0)]); + let mut groups2 = vec![]; + gv.intern(&[col_a2, col_b2], &mut groups2).unwrap(); // ("b", "x") was group 1 in batch 1; ("a", "x") was group 0. assert_eq!( - g2[0], 1, + groups2[0], 1, "('b','x') should map to the existing group for 'b'" ); assert_eq!( - g2[1], 0, + groups2[1], 0, "('a','x') should map to the existing group for 'a'" ); assert_eq!(gv.len(), 2, "no new groups should have been created"); } + + fn list_dict_schema() -> SchemaRef { + let item_field = Arc::new(Field::new("item", DataType::Utf8, true)); + let dt = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List(item_field)), + ); + Arc::new(Schema::new(vec![ + Field::new("a", dt.clone(), true), + Field::new("b", dt, true), + ])) + } + + fn make_list_dict(lists: &[Option>], keys: &[Option]) -> ArrayRef { + let mut builder = ListBuilder::new(StringBuilder::new()); + for list_opt in lists { + match list_opt { + None => builder.append_null(), + Some(strings) => { + for s in strings { + builder.values().append_value(s); + } + builder.append(true); + } + } + } + let values = Arc::new(builder.finish()); + let keys_arr = Int32Array::from(keys.to_vec()); + Arc::new(DictionaryArray::::try_new(keys_arr, values).unwrap()) + } + + fn dict_str(arr: &ArrayRef, group: usize) -> &str { + let d = arr.as_dictionary::(); + // Keys are sequential (0..n), so key value == group index. + let val_idx = d.keys().value(group) as usize; + d.values().as_string::().value(val_idx) + } + + fn dict_list(arr: &ArrayRef, group: usize) -> ArrayRef { + let d = arr.as_dictionary::(); + let val_idx = d.keys().value(group) as usize; + d.values().as_list::().value(val_idx) + } + + /// emit(All) reconstructs the correct string values for each group. + #[test] + fn test_emit_all_utf8() { + let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); + let col_a = make_dict(&["x", "y"], &[Some(0), Some(1), Some(0)]); + let col_b = make_dict(&["p", "q"], &[Some(0), Some(1), Some(0)]); + gv.intern(&[col_a, col_b], &mut vec![]).unwrap(); + + let arrays = gv.emit(EmitTo::All).unwrap(); + assert_eq!(arrays[0].len(), 2); + assert_eq!(dict_str(&arrays[0], 0), "x"); + assert_eq!(dict_str(&arrays[0], 1), "y"); + assert_eq!(dict_str(&arrays[1], 0), "p"); + assert_eq!(dict_str(&arrays[1], 1), "q"); + assert_eq!(gv.len(), 0); + } + + /// emit(First n) emits exactly the first n groups and keeps the rest accessible. + #[test] + fn test_emit_first_n_utf8() { + let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); + let col_a = make_dict(&["a", "b", "c"], &[Some(0), Some(1), Some(2)]); + let col_b = make_dict(&["x", "y", "z"], &[Some(0), Some(1), Some(2)]); + gv.intern(&[col_a, col_b], &mut vec![]).unwrap(); + + let arrays = gv.emit(EmitTo::First(2)).unwrap(); + assert_eq!(arrays[0].len(), 2); + assert_eq!(dict_str(&arrays[0], 0), "a"); + assert_eq!(dict_str(&arrays[0], 1), "b"); + + assert_eq!(gv.len(), 1); + let mut groups = vec![]; + let col_a2 = make_dict(&["c"], &[Some(0)]); + let col_b2 = make_dict(&["z"], &[Some(0)]); + gv.intern(&[col_a2, col_b2], &mut groups).unwrap(); + assert_eq!( + groups, + vec![0], + "retained group must map to id 0 after shift" + ); + } + + /// emit(All) correctly reconstructs List values for each group. + #[test] + fn test_emit_list_utf8() { + let mut gv = GroupDictionaryColumn::new(list_dict_schema()).unwrap(); + let col_a = make_list_dict( + &[Some(vec!["hello", "world"]), Some(vec!["foo"])], + &[Some(0), Some(0), Some(1)], + ); + let col_b = make_list_dict(&[Some(vec!["foo"])], &[Some(0), Some(0), Some(0)]); + let mut groups = vec![]; + gv.intern(&[col_a, col_b], &mut groups).unwrap(); + assert_eq!(groups, vec![0, 0, 1]); + + let arrays = gv.emit(EmitTo::All).unwrap(); + assert_eq!(arrays[0].len(), 2); + let g0 = dict_list(&arrays[0], 0); + let g0_strs = g0.as_string::(); + assert_eq!(g0_strs.value(0), "hello"); + assert_eq!(g0_strs.value(1), "world"); + assert_eq!(dict_list(&arrays[0], 1).as_string::().value(0), "foo"); + } + + /// List with null elements inside the list round-trips correctly. + #[test] + fn test_emit_list_with_null_elements() { + let mut gv = GroupDictionaryColumn::new(list_dict_schema()).unwrap(); + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.values().append_value("a"); + builder.values().append_null(); + builder.values().append_value("b"); + builder.append(true); + let list_values = Arc::new(builder.finish()); + let keys_arr = Int32Array::from(vec![Some(0i32)]); + let col = Arc::new( + DictionaryArray::::try_new(keys_arr, list_values).unwrap(), + ) as ArrayRef; + + gv.intern(&[Arc::clone(&col), col], &mut vec![]).unwrap(); + let arrays = gv.emit(EmitTo::All).unwrap(); + let elems = dict_list(&arrays[0], 0); + let strs = elems.as_string::(); + assert_eq!(strs.len(), 3); + assert_eq!(strs.value(0), "a"); + assert!(strs.is_null(1)); + assert_eq!(strs.value(2), "b"); + } + + /// emit(First n) on List: keeps remaining groups intact after the shift. + #[test] + fn test_emit_first_n_list_utf8() { + let mut gv = GroupDictionaryColumn::new(list_dict_schema()).unwrap(); + // Three distinct groups: 0=(["a"],["x"]), 1=(["b"],["y"]), 2=(["c"],["z"]) + let col_a = make_list_dict( + &[Some(vec!["a"]), Some(vec!["b"]), Some(vec!["c"])], + &[Some(0), Some(1), Some(2)], + ); + let col_b = make_list_dict( + &[Some(vec!["x"]), Some(vec!["y"]), Some(vec!["z"])], + &[Some(0), Some(1), Some(2)], + ); + gv.intern(&[col_a, col_b], &mut vec![]).unwrap(); + + let arrays = gv.emit(EmitTo::First(2)).unwrap(); + assert_eq!(arrays[0].len(), 2); + assert_eq!(dict_list(&arrays[0], 0).as_string::().value(0), "a"); + assert_eq!(dict_list(&arrays[0], 1).as_string::().value(0), "b"); + + // Group 2 must survive as group 0 after the shift. + assert_eq!(gv.len(), 1); + let col_a2 = make_list_dict(&[Some(vec!["c"])], &[Some(0)]); + let col_b2 = make_list_dict(&[Some(vec!["z"])], &[Some(0)]); + let mut groups = vec![]; + gv.intern(&[col_a2, col_b2], &mut groups).unwrap(); + assert_eq!(groups, vec![0]); + } + + /// Resolve the logical string at position `i` in a Dictionary array. + /// Works for any key ordering, not just sequential. + fn logical_str(arr: &ArrayRef, i: usize) -> &str { + let d = arr.as_dictionary::(); + let key = d.keys().value(i) as usize; + d.values().as_string::().value(key) + } + + /// GroupDictionaryColumn and GroupValuesRows must assign identical group IDs + /// and produce identical emitted values for the same inputs. + #[test] + fn test_parity_with_group_values_rows_utf8() { + let schema = dict_schema(); + let mut gdc = GroupDictionaryColumn::new(Arc::clone(&schema)).unwrap(); + let mut gvr = GroupValuesRows::try_new(Arc::clone(&schema)).unwrap(); + + let batches: Vec<[ArrayRef; 2]> = vec![ + [ + make_dict(&["a", "b"], &[Some(0), Some(1), Some(0)]), + make_dict(&["x", "y"], &[Some(0), Some(1), Some(0)]), + ], + // second batch reuses same logical values via a fresh values Arc + [ + make_dict(&["b", "c"], &[Some(0), Some(1), Some(0)]), + make_dict(&["y", "z"], &[Some(0), Some(1), Some(0)]), + ], + ]; + + for [col_a, col_b] in &batches { + let mut gdc_groups = vec![]; + let mut gvr_groups = vec![]; + gdc.intern(&[Arc::clone(col_a), Arc::clone(col_b)], &mut gdc_groups) + .unwrap(); + gvr.intern(&[Arc::clone(col_a), Arc::clone(col_b)], &mut gvr_groups) + .unwrap(); + assert_eq!(gdc_groups, gvr_groups, "group IDs must agree"); + } + + let gdc_out = gdc.emit(EmitTo::All).unwrap(); + let gvr_out = gvr.emit(EmitTo::All).unwrap(); + + assert_eq!(gdc_out.len(), gvr_out.len()); + for col_idx in 0..gdc_out.len() { + let n = gdc_out[col_idx].len(); + assert_eq!(n, gvr_out[col_idx].len(), "col {col_idx} length mismatch"); + for i in 0..n { + assert_eq!( + logical_str(&gdc_out[col_idx], i), + logical_str(&gvr_out[col_idx], i), + "col {col_idx} group {i} value mismatch" + ); + } + } + } } From dde196b26609c72ad7ee5fd948a9dddea056c6f8 Mon Sep 17 00:00:00 2001 From: Richard Date: Thu, 18 Jun 2026 15:08:44 -0400 Subject: [PATCH 3/5] revised first draft --- .../group_values/multi_group_by/dict.rs | 207 +++++++++--------- 1 file changed, 105 insertions(+), 102 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs index 96a47e52eb928..249fa111978b7 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use arrow::array::{ Array, ArrayRef, AsArray, DictionaryArray, Int8Array, Int16Array, Int32Array, Int64Array, ListBuilder, NullArray, StringBuilder, UInt8Array, UInt16Array, - UInt32Array, UInt64Array, + UInt32Array, UInt64Array, new_empty_array, }; use arrow::datatypes::{ ArrowNativeType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, Schema, @@ -39,9 +39,9 @@ use crate::aggregates::group_values::GroupValues; /// Caches the hashes for one dictionary column's values array. /// Rebuilt only when the `Arc` pointer changes (i.e. a new values array arrives). struct ColumnCache { - /// Keeps the values `Arc` alive and is compared with `Arc::ptr_eq` to detect staleness. + ///compared with `Arc::ptr_eq` to detect staleness. values: ArrayRef, - /// `value_hashes[k]` = hash of the value at dictionary index `k`. + /// `value_hashes[i]` = hash of the value at dictionary index `i`. value_hashes: Vec, } @@ -75,7 +75,7 @@ impl ColumnCache { } fn clear_shrink(&mut self, shrink_to: usize) { - self.values = Arc::new(NullArray::new(0)); + self.values = Arc::new(new_empty_array(self.values.data_type())); self.value_hashes.clear(); self.value_hashes.shrink_to(shrink_to); } @@ -84,6 +84,7 @@ impl ColumnCache { /// [`GroupValues`] for GROUP BY over **two or more** dictionary-typed columns. pub struct GroupDictionaryColumn { schema: SchemaRef, + /// Per-column dictionary caches, one entry per GROUP BY column. col_caches: Vec, /// `(row_hash, group_id)`. Multiple entries may share the same hash value; /// byte-level comparison is used to resolve collisions. @@ -100,11 +101,11 @@ pub struct GroupDictionaryColumn { row_offsets: Vec, /// Reused scratch buffer for encoding the current row. row_scratch: Vec, + /// Converts row-encoded bytes back into Arrow arrays for [`GroupValues::emit`]. row_decoder: RowSetDecoder, random_state: RandomState, } -/// Returns `true` when every field in `schema` is `DataType::Dictionary`. pub fn all_dictionary_schema(schema: &Schema) -> bool { schema .fields() @@ -119,35 +120,35 @@ fn is_supported_value_type(data_type: &DataType) -> bool { impl GroupDictionaryColumn { pub fn new(schema: SchemaRef) -> Result { - if schema.fields().len() < 2 { + let n_cols = schema.fields().len(); + if n_cols < 2 { return Err(internal_datafusion_err!( "GroupDictionaryColumn requires at least 2 columns, got {}", - schema.fields().len() + n_cols )); } - for field in schema.fields() { - match field.data_type() { - DataType::Dictionary(_, value_type) => { - if !is_supported_value_type(value_type) { - return Err(internal_datafusion_err!( - "GroupDictionaryColumn: unsupported dictionary value type \ - '{}' in column '{}'", - value_type, - field.name() - )); - } - } - _ => { - return Err(internal_datafusion_err!( - "GroupDictionaryColumn requires all columns to be Dictionary, \ - but '{}' has type {}", - field.name(), - field.data_type() - )); + schema + .fields() + .iter() + .try_for_each(|field| match field.data_type() { + DataType::Dictionary(_, value_type) + if is_supported_value_type(value_type) => + { + Ok(()) } - } - } - let n_cols = schema.fields().len(); + DataType::Dictionary(_, value_type) => Err(internal_datafusion_err!( + "GroupDictionaryColumn: unsupported dictionary value type \ + '{}' in column '{}'", + value_type, + field.name() + )), + _ => Err(internal_datafusion_err!( + "GroupDictionaryColumn requires all columns to be Dictionary, \ + but '{}' has type {}", + field.name(), + field.data_type() + )), + })?; let row_decoder = RowSetDecoder::new(&schema); Ok(Self { schema, @@ -170,7 +171,6 @@ fn dict_values_array(col: &dyn Array) -> ArrayRef { ) } -// Box is required: different key widths (Int8/Int16/Int32/Int64) produce different concrete iterator types. fn fill_keys(col: &dyn Array) -> Box> + '_> { downcast_dictionary_array!( col => { @@ -197,29 +197,31 @@ impl GroupValues for GroupDictionaryColumn { } let n_rows = cols[0].len(); - for (col_idx, col) in cols.iter().enumerate() { - self.col_caches[col_idx] - .update(dict_values_array(col.as_ref()), &self.random_state)?; - } - // Downcast once per column; advance with .next() per row to avoid per-row downcast. - let mut key_iters: Vec<_> = - cols.iter().map(|col| fill_keys(col.as_ref())).collect(); + let mut key_iters: Vec<_> = cols + .iter() + .enumerate() + .map(|(col_idx, col)| { + // update hash cache for each column + self.col_caches[col_idx] + .update(dict_values_array(col.as_ref()), &self.random_state)?; + Ok(fill_keys(col.as_ref())) + }) + .collect::>()?; let _ = groups.try_reserve(n_rows); for _row in 0..n_rows { - let mut hash = 0u64; + let mut combined_hash = 0u64; self.row_scratch.clear(); for (col_idx, key_iter) in key_iters.iter_mut().enumerate() { let key = key_iter.next().unwrap(); let cache = &self.col_caches[col_idx]; let value_hash = key.map_or(0, |key_idx| cache.value_hashes[key_idx]); - hash = combine_hashes(hash, value_hash); + combined_hash = combine_hashes(combined_hash, value_hash); encode_value(key, cache.values.as_ref(), &mut self.row_scratch); } - let combined_hash = hash; let found = { let row_scratch = self.row_scratch.as_slice(); let row_buffer = self.row_buffer.as_slice(); @@ -265,6 +267,7 @@ impl GroupValues for GroupDictionaryColumn { + self.row_offsets.len() * size_of::() + self.row_scratch.capacity() + cache_bytes + + self.row_decoder.size() } fn is_empty(&self) -> bool { @@ -299,29 +302,25 @@ impl GroupValues for GroupDictionaryColumn { .into_iter() .zip(self.schema.fields()) .map(|(values, field)| match field.data_type() { - DataType::Dictionary(key_type, _) => wrap_as_dictionary( + DataType::Dictionary(key_type, _) => Ok(wrap_as_dictionary( values, - &make_sequential_keys(n_emit, key_type), + &make_sequential_keys(n_emit, key_type)?, key_type, - ), + )), _ => unreachable!("schema validated in GroupDictionaryColumn::new"), }) - .collect(); + .collect::>()?; if n_emit == n_total { self.row_buffer.clear(); self.row_offsets.clear(); self.map.clear(); - self.map_size = 0; + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); } else { let retain_start = self.row_offsets[n_emit]; self.row_offsets.drain(0..n_emit); - for offset in &mut self.row_offsets { - *offset -= retain_start; - } + self.row_offsets.iter_mut().for_each(|o| *o -= retain_start); self.row_buffer.drain(0..retain_start); - // avoiding this somehow would be nice. worse case this runs once - // VecDeque? // Shift remaining group ids in-place; retain gives &mut access so no rehashing occurs. self.map.retain(|(_, gid)| { if *gid < n_emit { @@ -342,19 +341,25 @@ impl GroupValues for GroupDictionaryColumn { self.row_buffer.clear(); self.row_offsets.clear(); self.row_offsets.shrink_to(num_rows); + self.row_scratch.clear(); + self.row_scratch.shrink_to(0); + self.row_decoder.finish(); for cache in &mut self.col_caches { cache.clear_shrink(num_rows); } } } -// ── encoding / decoding ─────────────────────────────────────────────────────── - -/// Wire format per column: -/// null: `[0x00]` -/// non-null scalar: `[0x01][len: u32 LE][utf8_bytes…]` -/// non-null list: `[0x01][content_len: u32 LE][n: u32 LE][elem…]` -/// where each elem is `[0x00]` (null) or `[0x01][len: u32 LE][utf8_bytes…]` +/// Wire format per column — scalars are the primitive unit; non-scalars repeat them with metadata: +/// +/// **Scalar** (`Utf8`): +/// - null: `[0x00]` +/// - non-null: `[0x01][len: u64 LE][utf8_bytes…]` +/// +/// **Non-scalar** (`List`): a length-prefixed sequence of scalars: +/// - null: `[0x00]` +/// - non-null: `[0x01][content_len: u64 LE][n: u64 LE][scalar…]` +/// where each `scalar` follows the scalar encoding above fn encode_value(key: Option, values: &dyn Array, buf: &mut Vec) { let key_idx = match key { None => { @@ -371,30 +376,30 @@ fn encode_value(key: Option, values: &dyn Array, buf: &mut Vec) { match values.data_type() { DataType::Utf8 => { let bytes = values.as_string::().value(key_idx).as_bytes(); - buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes()); + buf.extend_from_slice(&(bytes.len() as u64).to_le_bytes()); buf.extend_from_slice(bytes); } DataType::List(_) => { // Back-fill content_len after encoding all elements. let len_pos = buf.len(); - buf.extend_from_slice(&[0u8; 4]); + buf.extend_from_slice(&[0u8; 8]); let content_start = buf.len(); let list_element = values.as_list::().value(key_idx); let str_array = list_element.as_string::(); - buf.extend_from_slice(&(str_array.len() as u32).to_le_bytes()); + buf.extend_from_slice(&(str_array.len() as u64).to_le_bytes()); for elem_idx in 0..str_array.len() { if str_array.is_null(elem_idx) { buf.push(0); } else { let elem_bytes = str_array.value(elem_idx).as_bytes(); buf.push(1); - buf.extend_from_slice(&(elem_bytes.len() as u32).to_le_bytes()); + buf.extend_from_slice(&(elem_bytes.len() as u64).to_le_bytes()); buf.extend_from_slice(elem_bytes); } } - let content_len = (buf.len() - content_start) as u32; - buf[len_pos..len_pos + 4].copy_from_slice(&content_len.to_le_bytes()); + let content_len = (buf.len() - content_start) as u64; + buf[len_pos..len_pos + 8].copy_from_slice(&content_len.to_le_bytes()); } dt => panic!("unsupported dictionary value type: {dt}"), } @@ -436,9 +441,9 @@ impl ColumnBuilder { } Self::ListUtf8(b) => { let mut cursor = 0; - let n = u32::from_le_bytes(bytes[cursor..cursor + 4].try_into().unwrap()) + let n = u64::from_le_bytes(bytes[cursor..cursor + 8].try_into().unwrap()) as usize; - cursor += 4; + cursor += 8; for _ in 0..n { match bytes[cursor] { 0 => { @@ -447,10 +452,10 @@ impl ColumnBuilder { } _ => { cursor += 1; - let len = u32::from_le_bytes( - bytes[cursor..cursor + 4].try_into().unwrap(), + let len = u64::from_le_bytes( + bytes[cursor..cursor + 8].try_into().unwrap(), ) as usize; - cursor += 4; + cursor += 8; // SAFETY: bytes come from Arrow string arrays, always valid UTF-8. b.values().append_value(unsafe { std::str::from_utf8_unchecked( @@ -492,10 +497,7 @@ impl RowSetDecoder { Self { builders } } - /// Expected format: one column entry per schema field, in schema order. - /// Each column: `[0x00]` (null) or `[0x01][content_len: u32 LE][content…]`. - /// Utf8 content: raw UTF-8 bytes. - /// List content: `[n: u32 LE][elem…]` where each elem is `[0x00]` or `[0x01][len: u32 LE][utf8_bytes…]`. + /// Inverse of [`encode_value`]; see its doc comment for the wire format. fn decode(&mut self, encoded: &[u8]) { let mut cursor = 0; for builder in &mut self.builders { @@ -506,10 +508,10 @@ impl RowSetDecoder { } _ => { cursor += 1; - let len = u32::from_le_bytes( - encoded[cursor..cursor + 4].try_into().unwrap(), + let len = u64::from_le_bytes( + encoded[cursor..cursor + 8].try_into().unwrap(), ) as usize; - cursor += 4; + cursor += 8; builder.append_bytes(&encoded[cursor..cursor + len]); cursor += len; } @@ -520,34 +522,36 @@ impl RowSetDecoder { fn finish(&mut self) -> Vec { self.builders.iter_mut().map(|b| b.finish()).collect() } + + fn size(&self) -> usize { + self.builders.capacity() * size_of::() + } } -/// Build sequential keys `[0, 1, ..., n-1]` with the key type taken from the schema. -/// All columns share the same key type, so callers should build this once and clone the Arc. -fn make_sequential_keys(n: usize, key_type: &DataType) -> ArrayRef { - match key_type { - DataType::Int8 => Arc::new(Int8Array::from_iter_values((0..n).map(|i| i as i8))), - DataType::Int16 => { - Arc::new(Int16Array::from_iter_values((0..n).map(|i| i as i16))) - } - DataType::Int32 => { - Arc::new(Int32Array::from_iter_values((0..n).map(|i| i as i32))) - } - DataType::Int64 => { - Arc::new(Int64Array::from_iter_values((0..n).map(|i| i as i64))) - } - DataType::UInt8 => { - Arc::new(UInt8Array::from_iter_values((0..n).map(|i| i as u8))) - } - DataType::UInt16 => { - Arc::new(UInt16Array::from_iter_values((0..n).map(|i| i as u16))) - } - DataType::UInt32 => { - Arc::new(UInt32Array::from_iter_values((0..n).map(|i| i as u32))) - } - DataType::UInt64 => { - Arc::new(UInt64Array::from_iter_values((0..n).map(|i| i as u64))) +/// Build sequential keys `[0, 1, ..., n-1]` for the given key type. +macro_rules! make_keys { + ($n:expr, $ArrayType:ty, $max:expr) => {{ + if $n > $max { + return Err(internal_datafusion_err!( + "too many groups ({}) for dictionary key type with max capacity {}", + $n, + $max + )); } + Ok(Arc::new(<$ArrayType>::from_iter_values((0..$n).map(|i| i as _))) as ArrayRef) + }}; +} + +fn make_sequential_keys(n: usize, key_type: &DataType) -> Result { + match key_type { + DataType::Int8 => make_keys!(n, Int8Array, i8::MAX as usize), + DataType::Int16 => make_keys!(n, Int16Array, i16::MAX as usize), + DataType::Int32 => make_keys!(n, Int32Array, i32::MAX as usize), + DataType::Int64 => make_keys!(n, Int64Array, i64::MAX as usize), + DataType::UInt8 => make_keys!(n, UInt8Array, u8::MAX as usize), + DataType::UInt16 => make_keys!(n, UInt16Array, u16::MAX as usize), + DataType::UInt32 => make_keys!(n, UInt32Array, u32::MAX as usize), + DataType::UInt64 => make_keys!(n, UInt64Array, usize::MAX), _ => unreachable!("schema validated in GroupDictionaryColumn::new"), } } @@ -594,8 +598,7 @@ fn wrap_as_dictionary( } } -// ── tests ───────────────────────────────────────────────────────────────────── - +// in depth test exist on https://github.com/apache/datafusion/pull/22888 . these are mostly for correness and sanity check. --- IGNORE --- #[cfg(test)] mod tests { use super::*; From ef4ba19d4e7b39de3a18d58f7212632a1d601fde Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 22 Jun 2026 13:11:00 -0400 Subject: [PATCH 4/5] introduce key caching --- .../src/aggregates/group_values/mod.rs | 4 +- .../group_values/multi_group_by/dict.rs | 144 +++++++++++------- 2 files changed, 89 insertions(+), 59 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 807bc9029522b..3113e55c5a1b0 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -35,7 +35,7 @@ pub use row::GroupValuesRows; mod single_group_by; use datafusion_physical_expr::binary_map::OutputType; use multi_group_by::GroupValuesColumn; -use multi_group_by::dict::{GroupDictionaryColumn, all_dictionary_schema}; +use multi_group_by::dict::{GroupDictionaryColumn, supported_dictionary_schema}; pub(crate) use single_group_by::primitive::HashValue; @@ -202,7 +202,7 @@ pub fn new_group_values( } // Route 2+ all-dictionary columns to the specialised implementation. - if schema.fields().len() >= 2 && all_dictionary_schema(schema.as_ref()) { + if schema.fields().len() >= 2 && supported_dictionary_schema(schema.as_ref()) { return Ok(Box::new(GroupDictionaryColumn::new(schema)?)); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs index 249fa111978b7..0c4cad3b22156 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs @@ -32,6 +32,7 @@ use datafusion_common::hash_utils::{RandomState, combine_hashes, create_hashes}; use datafusion_common::{Result, internal_datafusion_err}; use datafusion_execution::memory_pool::proxy::HashTableAllocExt; use datafusion_expr::EmitTo; +use hashbrown::HashMap as HashbrownMap; use hashbrown::hash_table::HashTable; use crate::aggregates::group_values::GroupValues; @@ -71,7 +72,8 @@ impl ColumnCache { } fn size(&self) -> usize { - self.value_hashes.len() * size_of::() + self.value_hashes.capacity() * size_of::() + + self.values.get_array_memory_size() } fn clear_shrink(&mut self, shrink_to: usize) { @@ -104,18 +106,23 @@ pub struct GroupDictionaryColumn { /// Converts row-encoded bytes back into Arrow arrays for [`GroupValues::emit`]. row_decoder: RowSetDecoder, random_state: RandomState, + /// Maps `(k1, k2, ..., kN)` dictionary key tuples to their group id. + /// Cleared at the start of every `intern` call (not stable across calls). + /// Bypassed entirely once it exceeds 10 000 entries. + key_tuple_cache: HashbrownMap]>, usize>, + /// Scratch buffer for building the current row's key tuple; reused each row. + key_tuple_scratch: Vec>, } -pub fn all_dictionary_schema(schema: &Schema) -> bool { - schema - .fields() - .iter() - .all(|field| matches!(field.data_type(), DataType::Dictionary(_, _))) -} - -fn is_supported_value_type(data_type: &DataType) -> bool { - matches!(data_type, DataType::Utf8) - || matches!(data_type, DataType::List(f) if f.data_type() == &DataType::Utf8) +pub fn supported_dictionary_schema(schema: &Schema) -> bool { + schema.fields().iter().all(|field| { + if let DataType::Dictionary(_, value_type) = field.data_type() { + matches!(value_type.as_ref(), DataType::Utf8) + || matches!(value_type.as_ref(), DataType::List(f) if f.data_type() == &DataType::Utf8) + } else { + false + } + }) } impl GroupDictionaryColumn { @@ -127,28 +134,6 @@ impl GroupDictionaryColumn { n_cols )); } - schema - .fields() - .iter() - .try_for_each(|field| match field.data_type() { - DataType::Dictionary(_, value_type) - if is_supported_value_type(value_type) => - { - Ok(()) - } - DataType::Dictionary(_, value_type) => Err(internal_datafusion_err!( - "GroupDictionaryColumn: unsupported dictionary value type \ - '{}' in column '{}'", - value_type, - field.name() - )), - _ => Err(internal_datafusion_err!( - "GroupDictionaryColumn requires all columns to be Dictionary, \ - but '{}' has type {}", - field.name(), - field.data_type() - )), - })?; let row_decoder = RowSetDecoder::new(&schema); Ok(Self { schema, @@ -160,6 +145,8 @@ impl GroupDictionaryColumn { row_scratch: Vec::new(), row_decoder, random_state: crate::aggregates::AGGREGATION_HASH_SEED, + key_tuple_cache: HashbrownMap::new(), + key_tuple_scratch: Vec::with_capacity(n_cols), }) } } @@ -209,13 +196,34 @@ impl GroupValues for GroupDictionaryColumn { }) .collect::>()?; - let _ = groups.try_reserve(n_rows); + // keys are not stable across batches + self.key_tuple_cache.clear(); + + groups.reserve(n_rows); for _row in 0..n_rows { + // Collect the raw dictionary key indices for this row. + self.key_tuple_scratch.clear(); + for key_iter in key_iters.iter_mut() { + self.key_tuple_scratch.push(key_iter.next().unwrap()); + } + + // Fast path: key-tuple cache lookup + // TODO: expose this threshold via SessionConfig so users can tune it for their + // workload (e.g. raise it for high-cardinality dictionaries, lower it to save memory). + let use_cache = self.key_tuple_cache.len() <= 10_000; + if use_cache { + if let Some(&group_id) = + self.key_tuple_cache.get(self.key_tuple_scratch.as_slice()) + { + groups.push(group_id); + continue; + } + } + + // Cache miss: compute combined hash and encode the row from the collected tuple. let mut combined_hash = 0u64; self.row_scratch.clear(); - - for (col_idx, key_iter) in key_iters.iter_mut().enumerate() { - let key = key_iter.next().unwrap(); + for (col_idx, &key) in self.key_tuple_scratch.iter().enumerate() { let cache = &self.col_caches[col_idx]; let value_hash = key.map_or(0, |key_idx| cache.value_hashes[key_idx]); combined_hash = combine_hashes(combined_hash, value_hash); @@ -227,16 +235,16 @@ impl GroupValues for GroupDictionaryColumn { let row_buffer = self.row_buffer.as_slice(); let row_offsets = self.row_offsets.as_slice(); self.map - .find(combined_hash, |&(stored_hash, group_id)| { + .find(combined_hash, |&(stored_hash, stored_group_id)| { stored_hash == combined_hash && { let end = row_offsets - .get(group_id + 1) + .get(stored_group_id + 1) .copied() - .unwrap_or(row_buffer.len()); // last group has no g+1 entry - row_buffer[row_offsets[group_id]..end] == *row_scratch + .unwrap_or(row_buffer.len()); + row_buffer[row_offsets[stored_group_id]..end] == *row_scratch } }) - .map(|&(_, group_id)| group_id) + .map(|&(_, stored_group_id)| stored_group_id) }; let group_id = match found { @@ -254,18 +262,26 @@ impl GroupValues for GroupDictionaryColumn { } }; + // Write back to the key-tuple cache. + if use_cache { + self.key_tuple_cache + .insert(Box::from(self.key_tuple_scratch.as_slice()), group_id); + } + groups.push(group_id); } + println!("cache: {:?}", self.key_tuple_cache); Ok(()) } fn size(&self) -> usize { let cache_bytes: usize = self.col_caches.iter().map(|c| c.size()).sum(); self.map_size - + self.row_buffer.len() - + self.row_offsets.len() * size_of::() + + self.row_buffer.capacity() + + self.row_offsets.capacity() * size_of::() + self.row_scratch.capacity() + + self.key_tuple_scratch.capacity() * size_of::>() + cache_bytes + self.row_decoder.size() } @@ -278,6 +294,21 @@ impl GroupValues for GroupDictionaryColumn { self.row_offsets.len() } + // the entire emit path is not optimized + /// groups + /// 0: (a, x) + /// 1: (a, y) + /// 2: (b, x) + /// 3: (b, y) + /// + /// logically colun 1 :[a, a, b, b] + /// currently we do + /// keys: [0, 1, 2, 3] + /// values: [a, a, b, b] + /// + /// which is correct but we should do + /// keys: [0, 0, 1, 1] + /// values: [a, b] fn emit(&mut self, emit_to: EmitTo) -> Result> { let n_total = self.row_offsets.len(); if n_total == 0 { @@ -339,10 +370,13 @@ impl GroupValues for GroupDictionaryColumn { self.map.shrink_to(num_rows, |_| 0); self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.row_buffer.clear(); + self.row_buffer.shrink_to(num_rows); self.row_offsets.clear(); self.row_offsets.shrink_to(num_rows); self.row_scratch.clear(); self.row_scratch.shrink_to(0); + self.key_tuple_cache.clear(); + self.key_tuple_scratch.clear(); self.row_decoder.finish(); for cache in &mut self.col_caches { cache.clear_shrink(num_rows); @@ -411,15 +445,6 @@ enum ColumnBuilder { ListUtf8(ListBuilder), } -macro_rules! dispatch_builder { - ($self:expr, $b:ident => $body:expr) => { - match $self { - ColumnBuilder::Utf8($b) => $body, - ColumnBuilder::ListUtf8($b) => $body, - } - }; -} - impl ColumnBuilder { fn from_value_type(value_type: &DataType) -> Self { match value_type { @@ -430,7 +455,10 @@ impl ColumnBuilder { } fn append_null(&mut self) { - dispatch_builder!(self, b => b.append_null()) + match self { + Self::Utf8(b) => b.append_null(), + Self::ListUtf8(b) => b.append_null(), + } } fn append_bytes(&mut self, bytes: &[u8]) { @@ -472,7 +500,10 @@ impl ColumnBuilder { } fn finish(&mut self) -> ArrayRef { - dispatch_builder!(self, b => Arc::new(b.finish())) + match self { + Self::Utf8(b) => Arc::new(b.finish()), + Self::ListUtf8(b) => Arc::new(b.finish()), + } } } @@ -598,7 +629,7 @@ fn wrap_as_dictionary( } } -// in depth test exist on https://github.com/apache/datafusion/pull/22888 . these are mostly for correness and sanity check. --- IGNORE --- +// in depth test exist on https://github.com/apache/datafusion/pull/22888 . these are mostly for correness and sanity check. #[cfg(test)] mod tests { use super::*; @@ -621,7 +652,6 @@ mod tests { ])) } - /// Same row twice in one batch → same group id; distinct rows → different ids. #[test] fn test_basic_dedup() { let mut gv = GroupDictionaryColumn::new(dict_schema()).unwrap(); From 17decbd09a5a913790ef5256e4e17bd1e6ae5f39 Mon Sep 17 00:00:00 2001 From: Richard Date: Mon, 22 Jun 2026 13:56:25 -0400 Subject: [PATCH 5/5] add better naming conventions as well as account for memory taken up for key-tuple cache --- .../group_values/multi_group_by/dict.rs | 106 ++++++++++-------- 1 file changed, 60 insertions(+), 46 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs index 0c4cad3b22156..773823d3e9d45 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/dict.rs @@ -118,7 +118,7 @@ pub fn supported_dictionary_schema(schema: &Schema) -> bool { schema.fields().iter().all(|field| { if let DataType::Dictionary(_, value_type) = field.data_type() { matches!(value_type.as_ref(), DataType::Utf8) - || matches!(value_type.as_ref(), DataType::List(f) if f.data_type() == &DataType::Utf8) + || matches!(value_type.as_ref(), DataType::List(list_field) if list_field.data_type() == &DataType::Utf8) } else { false } @@ -189,38 +189,36 @@ impl GroupValues for GroupDictionaryColumn { .iter() .enumerate() .map(|(col_idx, col)| { - // update hash cache for each column self.col_caches[col_idx] .update(dict_values_array(col.as_ref()), &self.random_state)?; Ok(fill_keys(col.as_ref())) }) .collect::>()?; - - // keys are not stable across batches self.key_tuple_cache.clear(); + self.key_tuple_scratch.clear(); - groups.reserve(n_rows); + groups.try_reserve(n_rows).map_err(|e| { + datafusion_common::DataFusionError::ArrowError( + Box::new(arrow::error::ArrowError::MemoryError(e.to_string())), + None, + ) + })?; for _row in 0..n_rows { - // Collect the raw dictionary key indices for this row. self.key_tuple_scratch.clear(); for key_iter in key_iters.iter_mut() { self.key_tuple_scratch.push(key_iter.next().unwrap()); } - // Fast path: key-tuple cache lookup - // TODO: expose this threshold via SessionConfig so users can tune it for their - // workload (e.g. raise it for high-cardinality dictionaries, lower it to save memory). - let use_cache = self.key_tuple_cache.len() <= 10_000; - if use_cache { - if let Some(&group_id) = + // Fast path: key-tuple cache lookup (bypassed once the cache exceeds the threshold). + let use_cache = self.key_tuple_cache.len() <= 1000; + if use_cache + && let Some(&group_id) = self.key_tuple_cache.get(self.key_tuple_scratch.as_slice()) - { - groups.push(group_id); - continue; - } + { + groups.push(group_id); + continue; } - // Cache miss: compute combined hash and encode the row from the collected tuple. let mut combined_hash = 0u64; self.row_scratch.clear(); for (col_idx, &key) in self.key_tuple_scratch.iter().enumerate() { @@ -262,7 +260,6 @@ impl GroupValues for GroupDictionaryColumn { } }; - // Write back to the key-tuple cache. if use_cache { self.key_tuple_cache .insert(Box::from(self.key_tuple_scratch.as_slice()), group_id); @@ -271,17 +268,24 @@ impl GroupValues for GroupDictionaryColumn { groups.push(group_id); } - println!("cache: {:?}", self.key_tuple_cache); + self.key_tuple_cache.clear(); + self.key_tuple_scratch.clear(); Ok(()) } fn size(&self) -> usize { - let cache_bytes: usize = self.col_caches.iter().map(|c| c.size()).sum(); + let cache_bytes: usize = self + .col_caches + .iter() + .map(|col_cache| col_cache.size()) + .sum(); self.map_size + self.row_buffer.capacity() + self.row_offsets.capacity() * size_of::() + self.row_scratch.capacity() + self.key_tuple_scratch.capacity() * size_of::>() + + self.key_tuple_cache.capacity() + * (size_of::]>>() + size_of::()) + cache_bytes + self.row_decoder.size() } @@ -316,7 +320,7 @@ impl GroupValues for GroupDictionaryColumn { } let n_emit = match emit_to { EmitTo::All => n_total, - EmitTo::First(n) => n.min(n_total), + EmitTo::First(first_n) => first_n.min(n_total), }; for row_idx in 0..n_emit { @@ -350,14 +354,16 @@ impl GroupValues for GroupDictionaryColumn { } else { let retain_start = self.row_offsets[n_emit]; self.row_offsets.drain(0..n_emit); - self.row_offsets.iter_mut().for_each(|o| *o -= retain_start); + self.row_offsets + .iter_mut() + .for_each(|offset| *offset -= retain_start); self.row_buffer.drain(0..retain_start); // Shift remaining group ids in-place; retain gives &mut access so no rehashing occurs. - self.map.retain(|(_, gid)| { - if *gid < n_emit { + self.map.retain(|(_, group_id)| { + if *group_id < n_emit { return false; } - *gid -= n_emit; + *group_id -= n_emit; true }); } @@ -400,7 +406,7 @@ fn encode_value(key: Option, values: &dyn Array, buf: &mut Vec) { buf.push(0); return; } - Some(k) => k, + Some(key_index) => key_index, }; if values.is_null(key_idx) { buf.push(0); @@ -435,7 +441,9 @@ fn encode_value(key: Option, values: &dyn Array, buf: &mut Vec) { let content_len = (buf.len() - content_start) as u64; buf[len_pos..len_pos + 8].copy_from_slice(&content_len.to_le_bytes()); } - dt => panic!("unsupported dictionary value type: {dt}"), + unsupported_type => { + panic!("unsupported dictionary value type: {unsupported_type}") + } } } @@ -456,53 +464,54 @@ impl ColumnBuilder { fn append_null(&mut self) { match self { - Self::Utf8(b) => b.append_null(), - Self::ListUtf8(b) => b.append_null(), + Self::Utf8(builder) => builder.append_null(), + Self::ListUtf8(builder) => builder.append_null(), } } fn append_bytes(&mut self, bytes: &[u8]) { match self { // SAFETY: bytes come from Arrow string arrays, always valid UTF-8. - Self::Utf8(b) => { - b.append_value(unsafe { std::str::from_utf8_unchecked(bytes) }) + Self::Utf8(builder) => { + builder.append_value(unsafe { std::str::from_utf8_unchecked(bytes) }) } - Self::ListUtf8(b) => { + Self::ListUtf8(builder) => { let mut cursor = 0; - let n = u64::from_le_bytes(bytes[cursor..cursor + 8].try_into().unwrap()) - as usize; + let n_elements = + u64::from_le_bytes(bytes[cursor..cursor + 8].try_into().unwrap()) + as usize; cursor += 8; - for _ in 0..n { + for _ in 0..n_elements { match bytes[cursor] { 0 => { - b.values().append_null(); + builder.values().append_null(); cursor += 1; } _ => { cursor += 1; - let len = u64::from_le_bytes( + let elem_len = u64::from_le_bytes( bytes[cursor..cursor + 8].try_into().unwrap(), ) as usize; cursor += 8; // SAFETY: bytes come from Arrow string arrays, always valid UTF-8. - b.values().append_value(unsafe { + builder.values().append_value(unsafe { std::str::from_utf8_unchecked( - &bytes[cursor..cursor + len], + &bytes[cursor..cursor + elem_len], ) }); - cursor += len; + cursor += elem_len; } } } - b.append(true); + builder.append(true); } } } fn finish(&mut self) -> ArrayRef { match self { - Self::Utf8(b) => Arc::new(b.finish()), - Self::ListUtf8(b) => Arc::new(b.finish()), + Self::Utf8(builder) => Arc::new(builder.finish()), + Self::ListUtf8(builder) => Arc::new(builder.finish()), } } } @@ -551,7 +560,10 @@ impl RowSetDecoder { } fn finish(&mut self) -> Vec { - self.builders.iter_mut().map(|b| b.finish()).collect() + self.builders + .iter_mut() + .map(|builder| builder.finish()) + .collect() } fn size(&self) -> usize { @@ -569,7 +581,10 @@ macro_rules! make_keys { $max )); } - Ok(Arc::new(<$ArrayType>::from_iter_values((0..$n).map(|i| i as _))) as ArrayRef) + Ok( + Arc::new(<$ArrayType>::from_iter_values((0..$n).map(|idx| idx as _))) + as ArrayRef, + ) }}; } @@ -629,7 +644,6 @@ fn wrap_as_dictionary( } } -// in depth test exist on https://github.com/apache/datafusion/pull/22888 . these are mostly for correness and sanity check. #[cfg(test)] mod tests { use super::*;