From 571013d6300c2cef64ec052cdc02b527c5e0ae9f Mon Sep 17 00:00:00 2001 From: Devan Date: Thu, 26 Sep 2024 21:16:28 -0500 Subject: [PATCH 01/17] feat: wip: working on it --- datafusion/expr-common/src/lib.rs | 2 +- .../src/aggregate.rs | 1 + .../src/aggregate/groups_accumulator_view.rs | 240 ++++++++++++++++++ datafusion/functions-aggregate/src/min_max.rs | 5 + datafusion/physical-expr/src/aggregate.rs | 1 + datafusion/physical-expr/src/lib.rs | 2 +- .../physical-plan/src/aggregates/row_hash.rs | 3 +- .../sqllogictest/test_files/min_max.slt | 10 + 8 files changed, 260 insertions(+), 4 deletions(-) create mode 100644 datafusion/functions-aggregate-common/src/aggregate/groups_accumulator_view.rs create mode 100644 datafusion/sqllogictest/test_files/min_max.slt diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index 179dd75ace85a..f8812af502c1d 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -33,4 +33,4 @@ pub mod interval_arithmetic; pub mod operator; pub mod signature; pub mod sort_properties; -pub mod type_coercion; +pub mod type_coercion; \ No newline at end of file diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs index c9cbaa8396fc5..239e82326fa49 100644 --- a/datafusion/functions-aggregate-common/src/aggregate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -17,3 +17,4 @@ pub mod count_distinct; pub mod groups_accumulator; +pub mod groups_accumulator_view; diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator_view.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator_view.rs new file mode 100644 index 0000000000000..20a55025059f5 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator_view.rs @@ -0,0 +1,240 @@ +use std::sync::Arc; +use arrow::array::{Array, ArrayRef, AsArray, BinaryViewArray, BinaryViewBuilder, BooleanArray}; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; +use datafusion_common::{DataFusionError, Result}; + +/// An accumulator that concatenates strings for each group. +pub struct GroupsAccumulatorMin { + states: Vec, +} + +impl GroupsAccumulatorMin { + /// Creates a new `StringGroupsAccumulator`. + pub fn new() -> Self { + Self { + states: Vec::new(), + } + } +} + +impl GroupsAccumulator for GroupsAccumulatorMin { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // Ensure that self.states has capacity for total_num_groups + if self.states.len() < total_num_groups { + self.states.resize(total_num_groups, String::new()); + } + + // Assume values has one element (the input column) + let input_array = &values[0]; + + // Iterate over rows + for (i, &group_index) in group_indices.iter().enumerate() { + // Check filter + if let Some(filter) = opt_filter { + if !filter.value(i) { + continue; + } + } + + // Skip null values + if input_array.is_null(i) { + continue; + } + + // Get the binary value at index i + let value = input_array.as_binary_view().value(i); + + // Convert binary data to a string (assuming UTF-8 encoding) + let value_str = std::str::from_utf8(value).map_err(|e| { + DataFusionError::Execution(format!("Invalid UTF-8 sequence: {}", e)) + })?; + + if self.states[group_index].len() == 0 { + self.states[group_index] = value_str.to_string(); + } else { + let curr_value_bytes = self.states[group_index].as_bytes(); + if value < curr_value_bytes { + self.states[group_index] = value_str.parse().unwrap(); + } + } + } + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let num_groups = match emit_to { + EmitTo::All => self.states.len(), + EmitTo::First(n) => std::cmp::min(n, self.states.len()), + }; + + let mut builder = BinaryViewBuilder::new(); + + // Build the output array + for i in 0..num_groups { + let value = &self.states[i]; + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + let array = Arc::new(builder.finish()) as ArrayRef; + + // Handle internal state according to emit_to + match emit_to { + EmitTo::All => { + // Reset the internal state + self.states.clear(); + } + EmitTo::First(n) => { + // Remove the first n elements from self.states + self.states.drain(0..n); + } + } + Ok(array) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let num_groups = match emit_to { + EmitTo::All => self.states.len(), + EmitTo::First(n) => std::cmp::min(n, self.states.len()), + }; + + let mut builder = BinaryViewBuilder::new(); + + // Build the state array + for i in 0..num_groups { + let value = &self.states[i]; + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + let array = Arc::new(builder.finish()) as ArrayRef; + + // Handle internal state according to emit_to + match emit_to { + EmitTo::All => { + // Reset the internal state + self.states.clear(); + } + EmitTo::First(n) => { + // Remove the first n elements from self.states + self.states.drain(0..n); + } + } + Ok(vec![array]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // Ensure that self.states has capacity for total_num_groups + if self.states.len() < total_num_groups { + self.states.resize(total_num_groups, String::new()); + } + + // The values are the state arrays from other accumulators + // We expect that values[0] is an ArrayRef containing the binary data + let input_array = &values[0]; + + // Iterate over rows + for (i, &group_index) in group_indices.iter().enumerate() { + // Check filter + if let Some(filter) = opt_filter { + if !filter.value(i) { + continue; + } + } + + // Skip null values + if input_array.is_null(i) { + continue; + } + + // Get the binary value at index i + let value = input_array.as_binary_view().value(i); + + // Convert binary data to a string (assuming UTF-8 encoding) + let value_str = std::str::from_utf8(value).map_err(|e| { + DataFusionError::Execution(format!("Invalid UTF-8 sequence: {}", e)) + })?; + + // Update the state for the group + if self.states[group_index].len() == 0 { + self.states[group_index] = value_str.to_string(); + } else { + let curr_value_bytes = self.states[group_index].as_bytes(); + if value < curr_value_bytes { + self.states[group_index] = value_str.parse().unwrap(); + } + } + } + Ok(()) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // For each input value, produce an array representing the state + let input_array = &values[0]; + + // Downcast to BinaryViewArray + let binary_array = input_array + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected BinaryViewArray".to_string()))?; + + if opt_filter.is_none() { + // No filter, return the input array as is + return Ok(vec![input_array.clone()]); + } + + // Apply the filter + let filter = opt_filter.unwrap(); + + // Build a new array with filtered values + let mut builder = BinaryViewBuilder::new(); + + for i in 0..binary_array.len() { + if !filter.value(i) { + builder.append_null(); + continue; + } + + if binary_array.is_null(i) { + builder.append_null(); + } else { + let value = binary_array.value(i); + builder.append_value(value); + } + } + + let array = Arc::new(builder.finish()) as ArrayRef; + Ok(vec![array]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + // Compute the total length of the strings in self.states + self.states.iter().map(|s| s.len()).sum() + } +} \ No newline at end of file diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 961e8639604c8..fff0e9405e4aa 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -69,6 +69,7 @@ use datafusion_expr::{ }; use half::f16; use std::ops::Deref; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator_view::GroupsAccumulatorMin; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { // make sure that the input types only has one element. @@ -972,6 +973,7 @@ impl AggregateUDFImpl for Min { | Time32(_) | Time64(_) | Timestamp(_, _) + | Utf8View ) } @@ -1032,6 +1034,9 @@ impl AggregateUDFImpl for Min { Decimal256(_, _) => { instantiate_min_accumulator!(data_type, i256, Decimal256Type) } + Utf8View => { + Ok(Box::new(GroupsAccumulatorMin::new())) + } // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906 diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 866596d0b6901..f9dee12e1a4bf 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -23,6 +23,7 @@ pub(crate) mod groups_accumulator { pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{ accumulate::NullState, GroupsAccumulatorAdapter, }; + pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator_view::GroupsAccumulatorMin; } pub(crate) mod stats { pub use datafusion_functions_aggregate_common::stats::StatsType; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 46185712413ef..39cfcf1b6d211 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -44,7 +44,7 @@ pub mod execution_props { pub use datafusion_expr::var_provider::{VarProvider, VarType}; } -pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, NullState}; +pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, GroupsAccumulatorMin, NullState}; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; pub use equivalence::{calculate_union, ConstExpr, EquivalenceProperties}; pub use partitioning::{Distribution, Partitioning}; diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 60efc77112167..89f03bea60674 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -52,7 +52,6 @@ use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; - use super::order::GroupOrdering; use super::AggregateExec; @@ -575,7 +574,7 @@ impl GroupedHashAggregateStream { /// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if /// that is supported by the aggregate, or a -/// [`GroupsAccumulatorAdapter`] if not. +/// [`GroupsAccumulatorMin`] if not. pub(crate) fn create_group_accumulator( agg_expr: &AggregateFunctionExpr, ) -> Result> { diff --git a/datafusion/sqllogictest/test_files/min_max.slt b/datafusion/sqllogictest/test_files/min_max.slt new file mode 100644 index 0000000000000..9c880b183550e --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max.slt @@ -0,0 +1,10 @@ +statement ok +set datafusion.execution.parquet.schema_force_view_types = true; + +statement ok +CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION '/Users/devan/Documents/OSS/datafusion/parquet-testing/data/alltypes_plain.parquet' + +statement ok +SELECT REGEXP_REPLACE(t.string_col, '1', '0') AS k, AVG(length(t.string_col)) AS l, COUNT(*) AS c, MIN(t.string_col) +FROM t +GROUP BY k; \ No newline at end of file From f7634e1103d90e88328bf3060cff2023a48e8c73 Mon Sep 17 00:00:00 2001 From: Devan Date: Fri, 27 Sep 2024 10:32:55 -0500 Subject: [PATCH 02/17] feat: working on it --- .../src/aggregate/groups_accumulator_view.rs | 60 ++++++------------- datafusion/functions-aggregate/src/min_max.rs | 8 +-- datafusion/physical-expr/src/aggregate.rs | 2 +- datafusion/physical-expr/src/lib.rs | 2 +- 4 files changed, 25 insertions(+), 47 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator_view.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator_view.rs index 20a55025059f5..8a969f5659180 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator_view.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator_view.rs @@ -1,15 +1,19 @@ use std::sync::Arc; -use arrow::array::{Array, ArrayRef, AsArray, BinaryViewArray, BinaryViewBuilder, BooleanArray}; +use arrow::array::{Array, ArrayRef, AsArray, BinaryViewBuilder, BooleanArray}; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; use datafusion_common::{DataFusionError, Result}; -/// An accumulator that concatenates strings for each group. -pub struct GroupsAccumulatorMin { +pub struct StringGroupsAccumulatorMin { states: Vec, } -impl GroupsAccumulatorMin { - /// Creates a new `StringGroupsAccumulator`. +impl Default for StringGroupsAccumulatorMin { + fn default() -> Self { + Self::new() + } +} + +impl StringGroupsAccumulatorMin { pub fn new() -> Self { Self { states: Vec::new(), @@ -17,7 +21,7 @@ impl GroupsAccumulatorMin { } } -impl GroupsAccumulator for GroupsAccumulatorMin { +impl GroupsAccumulator for StringGroupsAccumulatorMin { fn update_batch( &mut self, values: &[ArrayRef], @@ -55,7 +59,7 @@ impl GroupsAccumulator for GroupsAccumulatorMin { DataFusionError::Execution(format!("Invalid UTF-8 sequence: {}", e)) })?; - if self.states[group_index].len() == 0 { + if self.states[group_index].is_empty() { self.states[group_index] = value_str.to_string(); } else { let curr_value_bytes = self.states[group_index].as_bytes(); @@ -75,7 +79,6 @@ impl GroupsAccumulator for GroupsAccumulatorMin { let mut builder = BinaryViewBuilder::new(); - // Build the output array for i in 0..num_groups { let value = &self.states[i]; if value.is_empty() { @@ -87,14 +90,11 @@ impl GroupsAccumulator for GroupsAccumulatorMin { let array = Arc::new(builder.finish()) as ArrayRef; - // Handle internal state according to emit_to match emit_to { EmitTo::All => { - // Reset the internal state self.states.clear(); } EmitTo::First(n) => { - // Remove the first n elements from self.states self.states.drain(0..n); } } @@ -109,7 +109,6 @@ impl GroupsAccumulator for GroupsAccumulatorMin { let mut builder = BinaryViewBuilder::new(); - // Build the state array for i in 0..num_groups { let value = &self.states[i]; if value.is_empty() { @@ -121,14 +120,11 @@ impl GroupsAccumulator for GroupsAccumulatorMin { let array = Arc::new(builder.finish()) as ArrayRef; - // Handle internal state according to emit_to match emit_to { EmitTo::All => { - // Reset the internal state self.states.clear(); } EmitTo::First(n) => { - // Remove the first n elements from self.states self.states.drain(0..n); } } @@ -142,39 +138,30 @@ impl GroupsAccumulator for GroupsAccumulatorMin { opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { - // Ensure that self.states has capacity for total_num_groups if self.states.len() < total_num_groups { self.states.resize(total_num_groups, String::new()); } - // The values are the state arrays from other accumulators - // We expect that values[0] is an ArrayRef containing the binary data let input_array = &values[0]; - // Iterate over rows for (i, &group_index) in group_indices.iter().enumerate() { - // Check filter if let Some(filter) = opt_filter { if !filter.value(i) { continue; } } - // Skip null values if input_array.is_null(i) { continue; } - // Get the binary value at index i let value = input_array.as_binary_view().value(i); - // Convert binary data to a string (assuming UTF-8 encoding) let value_str = std::str::from_utf8(value).map_err(|e| { DataFusionError::Execution(format!("Invalid UTF-8 sequence: {}", e)) })?; - // Update the state for the group - if self.states[group_index].len() == 0 { + if self.states[group_index].is_empty() { self.states[group_index] = value_str.to_string(); } else { let curr_value_bytes = self.states[group_index].as_bytes(); @@ -191,36 +178,28 @@ impl GroupsAccumulator for GroupsAccumulatorMin { values: &[ArrayRef], opt_filter: Option<&BooleanArray>, ) -> Result> { - // For each input value, produce an array representing the state let input_array = &values[0]; - // Downcast to BinaryViewArray - let binary_array = input_array - .as_any() - .downcast_ref::() - .ok_or_else(|| DataFusionError::Internal("Expected BinaryViewArray".to_string()))?; - if opt_filter.is_none() { - // No filter, return the input array as is - return Ok(vec![input_array.clone()]); + return Ok(vec![Arc::::clone(&input_array)]); } - // Apply the filter let filter = opt_filter.unwrap(); - // Build a new array with filtered values + let mut builder = BinaryViewBuilder::new(); - for i in 0..binary_array.len() { + for i in 0..values.len() { + let value = input_array.as_binary_view().value(i); + if !filter.value(i) { builder.append_null(); continue; } - if binary_array.is_null(i) { + if value.is_empty() { builder.append_null(); } else { - let value = binary_array.value(i); builder.append_value(value); } } @@ -234,7 +213,6 @@ impl GroupsAccumulator for GroupsAccumulatorMin { } fn size(&self) -> usize { - // Compute the total length of the strings in self.states self.states.iter().map(|s| s.len()).sum() } -} \ No newline at end of file +} diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index fff0e9405e4aa..c22e8eaa0e039 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -69,7 +69,7 @@ use datafusion_expr::{ }; use half::f16; use std::ops::Deref; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator_view::GroupsAccumulatorMin; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator_view::StringGroupsAccumulatorMin; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { // make sure that the input types only has one element. @@ -973,7 +973,7 @@ impl AggregateUDFImpl for Min { | Time32(_) | Time64(_) | Timestamp(_, _) - | Utf8View + | BinaryView ) } @@ -1034,8 +1034,8 @@ impl AggregateUDFImpl for Min { Decimal256(_, _) => { instantiate_min_accumulator!(data_type, i256, Decimal256Type) } - Utf8View => { - Ok(Box::new(GroupsAccumulatorMin::new())) + BinaryView => { + Ok(Box::new(StringGroupsAccumulatorMin::new())) } // It would be nice to have a fast implementation for Strings as well diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index f9dee12e1a4bf..e4274a9179a72 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -23,7 +23,7 @@ pub(crate) mod groups_accumulator { pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{ accumulate::NullState, GroupsAccumulatorAdapter, }; - pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator_view::GroupsAccumulatorMin; + pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator_view::StringGroupsAccumulatorMin; } pub(crate) mod stats { pub use datafusion_functions_aggregate_common::stats::StatsType; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 39cfcf1b6d211..252d939e8a5ce 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -44,7 +44,7 @@ pub mod execution_props { pub use datafusion_expr::var_provider::{VarProvider, VarType}; } -pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, GroupsAccumulatorMin, NullState}; +pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, StringGroupsAccumulatorMin, NullState}; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; pub use equivalence::{calculate_union, ConstExpr, EquivalenceProperties}; pub use partitioning::{Distribution, Partitioning}; From 6d522fa3ff772540e7d02184eee2b86d099d51e2 Mon Sep 17 00:00:00 2001 From: Devan Date: Sun, 29 Sep 2024 20:13:38 -0500 Subject: [PATCH 03/17] feat(udf): POC for native min max accumulators --- datafusion/expr-common/src/lib.rs | 2 +- .../src/aggregate.rs | 2 +- .../src/aggregate/min_max.rs | 2 + .../min_max/groups_accumulator_max_view.rs | 214 ++++++++++++++++++ .../groups_accumulator_min_view.rs} | 34 ++- datafusion/functions-aggregate/src/lib.rs | 1 + datafusion/functions-aggregate/src/min_max.rs | 8 +- datafusion/physical-expr/src/aggregate.rs | 2 +- datafusion/physical-expr/src/lib.rs | 4 +- .../physical-plan/src/aggregates/row_hash.rs | 4 +- 10 files changed, 246 insertions(+), 27 deletions(-) create mode 100644 datafusion/functions-aggregate-common/src/aggregate/min_max.rs create mode 100644 datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs rename datafusion/functions-aggregate-common/src/aggregate/{groups_accumulator_view.rs => min_max/groups_accumulator_min_view.rs} (88%) diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index f8812af502c1d..179dd75ace85a 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -33,4 +33,4 @@ pub mod interval_arithmetic; pub mod operator; pub mod signature; pub mod sort_properties; -pub mod type_coercion; \ No newline at end of file +pub mod type_coercion; diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs index 239e82326fa49..802087e71ac1f 100644 --- a/datafusion/functions-aggregate-common/src/aggregate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -17,4 +17,4 @@ pub mod count_distinct; pub mod groups_accumulator; -pub mod groups_accumulator_view; +pub mod min_max; diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max.rs new file mode 100644 index 0000000000000..215085d62683e --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max.rs @@ -0,0 +1,2 @@ +pub mod groups_accumulator_max_view; +pub mod groups_accumulator_min_view; diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs new file mode 100644 index 0000000000000..55aa9677248ec --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs @@ -0,0 +1,214 @@ +use arrow::array::{Array, ArrayRef, AsArray, BinaryViewBuilder, BooleanArray}; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; +use std::sync::Arc; + +pub struct GroupsAccumulatorMaxStringView { + states: Vec, +} + +impl Default for GroupsAccumulatorMaxStringView { + fn default() -> Self { + Self::new() + } +} + +impl GroupsAccumulatorMaxStringView { + pub fn new() -> Self { + Self { states: Vec::new() } + } +} + +impl GroupsAccumulator for GroupsAccumulatorMaxStringView { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + if self.states.len() < total_num_groups { + self.states.resize(total_num_groups, String::new()); + } + + let input_array = &values[0]; + + for (i, &group_index) in group_indices.iter().enumerate() { + if let Some(filter) = opt_filter { + if !filter.value(i) { + continue; + } + } + + if input_array.is_null(i) { + continue; + } + + let value = input_array.as_binary_view().value(i); + + let value_str = std::str::from_utf8(value).map_err(|e| { + DataFusionError::Execution(format!( + "could not build utf8 from binary view {}", + e + )) + })?; + + if self.states[group_index].is_empty() { + self.states[group_index] = value_str.to_string(); + } else { + let curr_value_bytes = self.states[group_index].as_bytes(); + if value > curr_value_bytes { + self.states[group_index] = value_str.parse().unwrap(); + } + } + } + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let num_groups = match emit_to { + EmitTo::All => self.states.len(), + EmitTo::First(n) => std::cmp::min(n, self.states.len()), + }; + + let mut builder = BinaryViewBuilder::new(); + + for i in 0..num_groups { + let value = &self.states[i]; + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + let array = Arc::new(builder.finish()) as ArrayRef; + + match emit_to { + EmitTo::All => { + self.states.clear(); + } + EmitTo::First(n) => { + self.states.drain(0..n); + } + } + Ok(array) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let num_groups = match emit_to { + EmitTo::All => self.states.len(), + EmitTo::First(n) => std::cmp::min(n, self.states.len()), + }; + + let mut builder = BinaryViewBuilder::new(); + + for i in 0..num_groups { + let value = &self.states[i]; + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + let array = Arc::new(builder.finish()) as ArrayRef; + + match emit_to { + EmitTo::All => { + self.states.clear(); + } + EmitTo::First(n) => { + self.states.drain(0..n); + } + } + Ok(vec![array]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + if self.states.len() < total_num_groups { + self.states.resize(total_num_groups, String::new()); + } + + let input_array = &values[0]; + + for (i, &group_index) in group_indices.iter().enumerate() { + if let Some(filter) = opt_filter { + if !filter.value(i) { + continue; + } + } + + if input_array.is_null(i) { + continue; + } + + let value = input_array.as_binary_view().value(i); + + let value_str = std::str::from_utf8(value).map_err(|e| { + DataFusionError::Execution(format!( + "could not build utf8 from binary view {}", + e + )) + })?; + + if self.states[group_index].is_empty() { + self.states[group_index] = value_str.to_string(); + } else { + let curr_value_bytes = self.states[group_index].as_bytes(); + if value > curr_value_bytes { + self.states[group_index] = value_str.parse().unwrap(); + } + } + } + Ok(()) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let input_array = &values[0]; + + if opt_filter.is_none() { + return Ok(vec![Arc::::clone(&input_array)]); + } + + let filter = opt_filter.unwrap(); + + let mut builder = BinaryViewBuilder::new(); + + for i in 0..values.len() { + let value = input_array.as_binary_view().value(i); + + if !filter.value(i) { + builder.append_null(); + continue; + } + + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value); + } + } + + let array = Arc::new(builder.finish()) as ArrayRef; + Ok(vec![array]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.states.iter().map(|s| s.len()).sum() + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs similarity index 88% rename from datafusion/functions-aggregate-common/src/aggregate/groups_accumulator_view.rs rename to datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs index 8a969f5659180..92260b4d76179 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator_view.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs @@ -1,27 +1,25 @@ -use std::sync::Arc; use arrow::array::{Array, ArrayRef, AsArray, BinaryViewBuilder, BooleanArray}; -use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; +use std::sync::Arc; -pub struct StringGroupsAccumulatorMin { +pub struct GroupsAccumulatorMinStringView { states: Vec, } -impl Default for StringGroupsAccumulatorMin { +impl Default for GroupsAccumulatorMinStringView { fn default() -> Self { Self::new() } } -impl StringGroupsAccumulatorMin { +impl GroupsAccumulatorMinStringView { pub fn new() -> Self { - Self { - states: Vec::new(), - } + Self { states: Vec::new() } } } -impl GroupsAccumulator for StringGroupsAccumulatorMin { +impl GroupsAccumulator for GroupsAccumulatorMinStringView { fn update_batch( &mut self, values: &[ArrayRef], @@ -29,34 +27,30 @@ impl GroupsAccumulator for StringGroupsAccumulatorMin { opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { - // Ensure that self.states has capacity for total_num_groups if self.states.len() < total_num_groups { self.states.resize(total_num_groups, String::new()); } - // Assume values has one element (the input column) let input_array = &values[0]; - // Iterate over rows for (i, &group_index) in group_indices.iter().enumerate() { - // Check filter if let Some(filter) = opt_filter { if !filter.value(i) { continue; } } - // Skip null values if input_array.is_null(i) { continue; } - // Get the binary value at index i let value = input_array.as_binary_view().value(i); - // Convert binary data to a string (assuming UTF-8 encoding) let value_str = std::str::from_utf8(value).map_err(|e| { - DataFusionError::Execution(format!("Invalid UTF-8 sequence: {}", e)) + DataFusionError::Execution(format!( + "could not build utf8 from binary view {}", + e + )) })?; if self.states[group_index].is_empty() { @@ -158,7 +152,10 @@ impl GroupsAccumulator for StringGroupsAccumulatorMin { let value = input_array.as_binary_view().value(i); let value_str = std::str::from_utf8(value).map_err(|e| { - DataFusionError::Execution(format!("Invalid UTF-8 sequence: {}", e)) + DataFusionError::Execution(format!( + "could not build utf8 from binary view {}", + e + )) })?; if self.states[group_index].is_empty() { @@ -186,7 +183,6 @@ impl GroupsAccumulator for StringGroupsAccumulatorMin { let filter = opt_filter.unwrap(); - let mut builder = BinaryViewBuilder::new(); for i in 0..values.len() { diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 60e2602eb6eda..b8fc4ff9734c8 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -81,6 +81,7 @@ pub mod grouping; pub mod kurtosis_pop; pub mod nth_value; pub mod string_agg; +mod min_max_group_accumulator; use crate::approx_percentile_cont::approx_percentile_cont_udaf; use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index c22e8eaa0e039..9cea056f5428f 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -53,6 +53,8 @@ use datafusion_common::{ downcast_value, exec_err, internal_err, DataFusionError, Result, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_functions_aggregate_common::aggregate::min_max::groups_accumulator_max_view::GroupsAccumulatorMaxStringView; +use datafusion_functions_aggregate_common::aggregate::min_max::groups_accumulator_min_view::GroupsAccumulatorMinStringView; use std::fmt::Debug; use arrow::datatypes::i256; @@ -69,7 +71,6 @@ use datafusion_expr::{ }; use half::f16; use std::ops::Deref; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator_view::StringGroupsAccumulatorMin; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { // make sure that the input types only has one element. @@ -254,6 +255,9 @@ impl AggregateUDFImpl for Max { Decimal256(_, _) => { instantiate_max_accumulator!(data_type, i256, Decimal256Type) } + BinaryView => { + Ok(Box::new(GroupsAccumulatorMaxStringView::default())) + } // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906 @@ -1035,7 +1039,7 @@ impl AggregateUDFImpl for Min { instantiate_min_accumulator!(data_type, i256, Decimal256Type) } BinaryView => { - Ok(Box::new(StringGroupsAccumulatorMin::new())) + Ok(Box::new(GroupsAccumulatorMinStringView::default())) } // It would be nice to have a fast implementation for Strings as well diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index e4274a9179a72..aab39c5902fa1 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -23,7 +23,7 @@ pub(crate) mod groups_accumulator { pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{ accumulate::NullState, GroupsAccumulatorAdapter, }; - pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator_view::StringGroupsAccumulatorMin; + pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator_min_view::GroupsAccumulatorMinStringView; } pub(crate) mod stats { pub use datafusion_functions_aggregate_common::stats::StatsType; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 252d939e8a5ce..718216d584e1a 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -44,7 +44,9 @@ pub mod execution_props { pub use datafusion_expr::var_provider::{VarProvider, VarType}; } -pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, StringGroupsAccumulatorMin, NullState}; +pub use aggregate::groups_accumulator::{ + GroupsAccumulatorAdapter, GroupsAccumulatorMinStringView, NullState, +}; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; pub use equivalence::{calculate_union, ConstExpr, EquivalenceProperties}; pub use partitioning::{Distribution, Partitioning}; diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 89f03bea60674..1413fee285a41 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -48,12 +48,12 @@ use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; +use super::order::GroupOrdering; +use super::AggregateExec; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; -use super::order::GroupOrdering; -use super::AggregateExec; #[derive(Debug, Clone)] /// This object tracks the aggregation phase (input/output) From bd9ea7da1654647a88d8c5b9e40a092b1548d299 Mon Sep 17 00:00:00 2001 From: Devan Date: Sun, 29 Sep 2024 20:23:40 -0500 Subject: [PATCH 04/17] feat: revert some changes --- datafusion/physical-expr/src/lib.rs | 2 +- datafusion/physical-plan/src/aggregates/row_hash.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 718216d584e1a..8564533eb8bc3 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -45,7 +45,7 @@ pub mod execution_props { } pub use aggregate::groups_accumulator::{ - GroupsAccumulatorAdapter, GroupsAccumulatorMinStringView, NullState, + GroupsAccumulatorAdapter, NullState, }; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; pub use equivalence::{calculate_union, ConstExpr, EquivalenceProperties}; diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1413fee285a41..87b9aa69459a5 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -574,7 +574,7 @@ impl GroupedHashAggregateStream { /// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if /// that is supported by the aggregate, or a -/// [`GroupsAccumulatorMin`] if not. +/// [`GroupsAccumulatorAdapter`] if not. pub(crate) fn create_group_accumulator( agg_expr: &AggregateFunctionExpr, ) -> Result> { From fe649035a11bd8674972d5e4a2f39ee7d9319242 Mon Sep 17 00:00:00 2001 From: Devan Date: Sun, 29 Sep 2024 20:25:05 -0500 Subject: [PATCH 05/17] feat: add BinaryView to groups accum supported --- datafusion/functions-aggregate/src/min_max.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 9cea056f5428f..18b231e166fe6 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -195,6 +195,7 @@ impl AggregateUDFImpl for Max { | Time32(_) | Time64(_) | Timestamp(_, _) + | BinaryView ) } From fbdf86722888601e55a048ea2ac32355fe152835 Mon Sep 17 00:00:00 2001 From: Devan Date: Sun, 29 Sep 2024 20:26:24 -0500 Subject: [PATCH 06/17] feat: revert some changes while testing --- datafusion/physical-expr/src/aggregate.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index aab39c5902fa1..866596d0b6901 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -23,7 +23,6 @@ pub(crate) mod groups_accumulator { pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{ accumulate::NullState, GroupsAccumulatorAdapter, }; - pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator_min_view::GroupsAccumulatorMinStringView; } pub(crate) mod stats { pub use datafusion_functions_aggregate_common::stats::StatsType; From 0a93803bb5a58eabf52353f633e3ba17238163b8 Mon Sep 17 00:00:00 2001 From: Devan Date: Sun, 29 Sep 2024 20:27:20 -0500 Subject: [PATCH 07/17] feat: rename file --- datafusion/sqllogictest/test_files/min_max.slt | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 datafusion/sqllogictest/test_files/min_max.slt diff --git a/datafusion/sqllogictest/test_files/min_max.slt b/datafusion/sqllogictest/test_files/min_max.slt deleted file mode 100644 index 9c880b183550e..0000000000000 --- a/datafusion/sqllogictest/test_files/min_max.slt +++ /dev/null @@ -1,10 +0,0 @@ -statement ok -set datafusion.execution.parquet.schema_force_view_types = true; - -statement ok -CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION '/Users/devan/Documents/OSS/datafusion/parquet-testing/data/alltypes_plain.parquet' - -statement ok -SELECT REGEXP_REPLACE(t.string_col, '1', '0') AS k, AVG(length(t.string_col)) AS l, COUNT(*) AS c, MIN(t.string_col) -FROM t -GROUP BY k; \ No newline at end of file From 481952199338bf1cd0efc8d133d1820ebb4dc3b2 Mon Sep 17 00:00:00 2001 From: Devan Date: Sun, 29 Sep 2024 20:32:07 -0500 Subject: [PATCH 08/17] chore: add license header --- .../src/aggregate/min_max.rs | 15 +++++++++++++++ .../min_max/groups_accumulator_max_view.rs | 15 +++++++++++++++ .../min_max/groups_accumulator_min_view.rs | 15 +++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max.rs index 215085d62683e..ef2f2b2a076dd 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max.rs @@ -1,2 +1,17 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. pub mod groups_accumulator_max_view; pub mod groups_accumulator_min_view; diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs index 55aa9677248ec..660280d16e462 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. use arrow::array::{Array, ArrayRef, AsArray, BinaryViewBuilder, BooleanArray}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs index 92260b4d76179..6df2c625c7de0 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. use arrow::array::{Array, ArrayRef, AsArray, BinaryViewBuilder, BooleanArray}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; From 6f048fdbc0b3b22726eb231daad2b2305322744a Mon Sep 17 00:00:00 2001 From: Devan Date: Sun, 29 Sep 2024 20:35:22 -0500 Subject: [PATCH 09/17] feat: clippy + fmt + check --- .../src/aggregate/min_max/groups_accumulator_max_view.rs | 2 +- .../src/aggregate/min_max/groups_accumulator_min_view.rs | 2 +- datafusion/functions-aggregate/src/lib.rs | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs index 660280d16e462..9f0eb8cfbd651 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs @@ -193,7 +193,7 @@ impl GroupsAccumulator for GroupsAccumulatorMaxStringView { let input_array = &values[0]; if opt_filter.is_none() { - return Ok(vec![Arc::::clone(&input_array)]); + return Ok(vec![Arc::::clone(input_array)]); } let filter = opt_filter.unwrap(); diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs index 6df2c625c7de0..178ce0a12018d 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs @@ -193,7 +193,7 @@ impl GroupsAccumulator for GroupsAccumulatorMinStringView { let input_array = &values[0]; if opt_filter.is_none() { - return Ok(vec![Arc::::clone(&input_array)]); + return Ok(vec![Arc::::clone(input_array)]); } let filter = opt_filter.unwrap(); diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index b8fc4ff9734c8..60e2602eb6eda 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -81,7 +81,6 @@ pub mod grouping; pub mod kurtosis_pop; pub mod nth_value; pub mod string_agg; -mod min_max_group_accumulator; use crate::approx_percentile_cont::approx_percentile_cont_udaf; use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; From 1cefad5913c921a8b76ddcbdbc1f321693a11c73 Mon Sep 17 00:00:00 2001 From: Devan Date: Sun, 29 Sep 2024 20:37:55 -0500 Subject: [PATCH 10/17] chore: fmt --- datafusion/functions-aggregate/src/min_max.rs | 8 ++------ datafusion/physical-expr/src/lib.rs | 4 +--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 18b231e166fe6..50f119d195cd6 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -256,9 +256,7 @@ impl AggregateUDFImpl for Max { Decimal256(_, _) => { instantiate_max_accumulator!(data_type, i256, Decimal256Type) } - BinaryView => { - Ok(Box::new(GroupsAccumulatorMaxStringView::default())) - } + BinaryView => Ok(Box::new(GroupsAccumulatorMaxStringView::default())), // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906 @@ -1039,9 +1037,7 @@ impl AggregateUDFImpl for Min { Decimal256(_, _) => { instantiate_min_accumulator!(data_type, i256, Decimal256Type) } - BinaryView => { - Ok(Box::new(GroupsAccumulatorMinStringView::default())) - } + BinaryView => Ok(Box::new(GroupsAccumulatorMinStringView::default())), // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906 diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 8564533eb8bc3..46185712413ef 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -44,9 +44,7 @@ pub mod execution_props { pub use datafusion_expr::var_provider::{VarProvider, VarType}; } -pub use aggregate::groups_accumulator::{ - GroupsAccumulatorAdapter, NullState, -}; +pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, NullState}; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; pub use equivalence::{calculate_union, ConstExpr, EquivalenceProperties}; pub use partitioning::{Distribution, Partitioning}; From a60d59994c964822848079ab13fcb7dc6cbbf8dd Mon Sep 17 00:00:00 2001 From: Devan Date: Sun, 29 Sep 2024 21:28:55 -0500 Subject: [PATCH 11/17] feat: fix max accum --- .../src/aggregate/min_max/groups_accumulator_max_view.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs index 9f0eb8cfbd651..38d92a4e46a2e 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs @@ -83,7 +83,7 @@ impl GroupsAccumulator for GroupsAccumulatorMaxStringView { fn evaluate(&mut self, emit_to: EmitTo) -> Result { let num_groups = match emit_to { EmitTo::All => self.states.len(), - EmitTo::First(n) => std::cmp::min(n, self.states.len()), + EmitTo::First(n) => std::cmp::max(n, self.states.len()), }; let mut builder = BinaryViewBuilder::new(); @@ -113,7 +113,7 @@ impl GroupsAccumulator for GroupsAccumulatorMaxStringView { fn state(&mut self, emit_to: EmitTo) -> Result> { let num_groups = match emit_to { EmitTo::All => self.states.len(), - EmitTo::First(n) => std::cmp::min(n, self.states.len()), + EmitTo::First(n) => std::cmp::max(n, self.states.len()), }; let mut builder = BinaryViewBuilder::new(); From 268aa91f148167d27f348858c4d2605a4cd7dbed Mon Sep 17 00:00:00 2001 From: Devan Date: Mon, 30 Sep 2024 09:18:54 -0500 Subject: [PATCH 12/17] chore: fix emit_to calls --- .../min_max/groups_accumulator_max_view.rs | 34 +++--------------- .../min_max/groups_accumulator_min_view.rs | 35 +++---------------- 2 files changed, 9 insertions(+), 60 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs index 38d92a4e46a2e..feea96581910f 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs @@ -81,15 +81,11 @@ impl GroupsAccumulator for GroupsAccumulatorMaxStringView { } fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let num_groups = match emit_to { - EmitTo::All => self.states.len(), - EmitTo::First(n) => std::cmp::max(n, self.states.len()), - }; + let states = emit_to.take_needed(&mut self.states); let mut builder = BinaryViewBuilder::new(); - for i in 0..num_groups { - let value = &self.states[i]; + for value in states { if value.is_empty() { builder.append_null(); } else { @@ -98,28 +94,15 @@ impl GroupsAccumulator for GroupsAccumulatorMaxStringView { } let array = Arc::new(builder.finish()) as ArrayRef; - - match emit_to { - EmitTo::All => { - self.states.clear(); - } - EmitTo::First(n) => { - self.states.drain(0..n); - } - } Ok(array) } fn state(&mut self, emit_to: EmitTo) -> Result> { - let num_groups = match emit_to { - EmitTo::All => self.states.len(), - EmitTo::First(n) => std::cmp::max(n, self.states.len()), - }; + let states = emit_to.take_needed(&mut self.states); let mut builder = BinaryViewBuilder::new(); - for i in 0..num_groups { - let value = &self.states[i]; + for value in states { if value.is_empty() { builder.append_null(); } else { @@ -128,15 +111,6 @@ impl GroupsAccumulator for GroupsAccumulatorMaxStringView { } let array = Arc::new(builder.finish()) as ArrayRef; - - match emit_to { - EmitTo::All => { - self.states.clear(); - } - EmitTo::First(n) => { - self.states.drain(0..n); - } - } Ok(vec![array]) } diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs index 178ce0a12018d..e52424e83c4bd 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs @@ -1,3 +1,4 @@ +use std::os::macos::raw::stat; // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -81,15 +82,11 @@ impl GroupsAccumulator for GroupsAccumulatorMinStringView { } fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let num_groups = match emit_to { - EmitTo::All => self.states.len(), - EmitTo::First(n) => std::cmp::min(n, self.states.len()), - }; + let states = emit_to.take_needed(&mut self.states); let mut builder = BinaryViewBuilder::new(); - for i in 0..num_groups { - let value = &self.states[i]; + for value in states { if value.is_empty() { builder.append_null(); } else { @@ -98,28 +95,15 @@ impl GroupsAccumulator for GroupsAccumulatorMinStringView { } let array = Arc::new(builder.finish()) as ArrayRef; - - match emit_to { - EmitTo::All => { - self.states.clear(); - } - EmitTo::First(n) => { - self.states.drain(0..n); - } - } Ok(array) } fn state(&mut self, emit_to: EmitTo) -> Result> { - let num_groups = match emit_to { - EmitTo::All => self.states.len(), - EmitTo::First(n) => std::cmp::min(n, self.states.len()), - }; + let states = emit_to.take_needed(&mut self.states); let mut builder = BinaryViewBuilder::new(); - for i in 0..num_groups { - let value = &self.states[i]; + for value in states { if value.is_empty() { builder.append_null(); } else { @@ -128,15 +112,6 @@ impl GroupsAccumulator for GroupsAccumulatorMinStringView { } let array = Arc::new(builder.finish()) as ArrayRef; - - match emit_to { - EmitTo::All => { - self.states.clear(); - } - EmitTo::First(n) => { - self.states.drain(0..n); - } - } Ok(vec![array]) } From 77b980fedb8e7a0c20e16c715c3b0697907a28ed Mon Sep 17 00:00:00 2001 From: Devan Date: Mon, 30 Sep 2024 09:19:56 -0500 Subject: [PATCH 13/17] fix: rm not needed import --- .../src/aggregate/min_max/groups_accumulator_min_view.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs index e52424e83c4bd..06db5f5063d07 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs @@ -1,4 +1,3 @@ -use std::os::macos::raw::stat; // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information From 660eaa4ed19db1abc09fe5c3bfbe7415720ed6c2 Mon Sep 17 00:00:00 2001 From: Devan Date: Thu, 3 Oct 2024 10:23:35 -0500 Subject: [PATCH 14/17] feat: moves all functionality to single primitive strings function --- .../src/aggregate.rs | 3 +- .../src/aggregate/groups_accumulator.rs | 1 + .../groups_accumulator/prim_string.rs | 237 ++++++++++++++++++ .../src/aggregate/min_max.rs | 17 -- .../min_max/groups_accumulator_max_view.rs | 203 --------------- .../min_max/groups_accumulator_min_view.rs | 203 --------------- datafusion/functions-aggregate/src/min_max.rs | 39 ++- 7 files changed, 274 insertions(+), 429 deletions(-) create mode 100644 datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_string.rs delete mode 100644 datafusion/functions-aggregate-common/src/aggregate/min_max.rs delete mode 100644 datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs delete mode 100644 datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs index 802087e71ac1f..80f18623edf02 100644 --- a/datafusion/functions-aggregate-common/src/aggregate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -16,5 +16,4 @@ // under the License. pub mod count_distinct; -pub mod groups_accumulator; -pub mod min_max; +pub mod groups_accumulator; \ No newline at end of file diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index 92dd91bd86bca..20f4a245c5c8d 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -22,6 +22,7 @@ pub mod accumulate; pub mod bool_op; pub mod nulls; pub mod prim_op; +pub mod prim_string; use arrow::{ array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_string.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_string.rs new file mode 100644 index 0000000000000..d42d5e152795b --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_string.rs @@ -0,0 +1,237 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use arrow::array::{Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray}; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; +use std::sync::Arc; +use crate::aggregate::groups_accumulator::accumulate::accumulate_indices; + +pub struct StringGroupsAccumulator { + states: Vec, + fun: F +} + +impl StringGroupsAccumulator +where + F: Fn(&[u8], &[u8]) -> bool + Send + Sync, +{ + pub fn new(s_fn: F) -> Self { + Self { + states: Vec::new(), + fun: s_fn + } + } +} + +impl GroupsAccumulator for StringGroupsAccumulator +where + F: Fn(&[u8], &[u8]) -> bool + Send + Sync, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + if self.states.len() < total_num_groups { + self.states.resize(total_num_groups, String::new()); + } + + let input_array = &values[0]; + + accumulate_indices(group_indices, input_array.logical_nulls().as_ref(), opt_filter, |group_index| { + invoke_accumulator::(self, input_array, group_index) + }); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let states = emit_to.take_needed(&mut self.states); + + let array = if VIEW { + let mut builder = BinaryViewBuilder::new(); + + for value in states { + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + Arc::new(builder.finish()) as ArrayRef + } else { + let mut builder = BinaryBuilder::new(); + + for value in states { + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + Arc::new(builder.finish()) as ArrayRef + }; + + Ok(array) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let states = emit_to.take_needed(&mut self.states); + + let array = if VIEW { + let mut builder = BinaryViewBuilder::new(); + + for value in states { + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + Arc::new(builder.finish()) as ArrayRef + } else { + let mut builder = BinaryBuilder::new(); + + for value in states { + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value.as_bytes()); + } + } + + Arc::new(builder.finish()) as ArrayRef + }; + + Ok(vec![array]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + if self.states.len() < total_num_groups { + self.states.resize(total_num_groups, String::new()); + } + + let input_array = &values[0]; + + accumulate_indices(group_indices, input_array.logical_nulls().as_ref(), opt_filter, |group_index| { + invoke_accumulator::(self, input_array, group_index) + }); + + Ok(()) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let input_array = &values[0]; + + if opt_filter.is_none() { + return Ok(vec![Arc::::clone(input_array)]); + } + + let filter = opt_filter.unwrap(); + + let array = if VIEW { + let mut builder = BinaryViewBuilder::new(); + + for i in 0..values.len() { + let value = input_array.as_binary_view().value(i); + + if !filter.value(i) { + builder.append_null(); + continue; + } + + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value); + } + } + + Arc::new(builder.finish()) as ArrayRef + } else { + let mut builder = BinaryBuilder::new(); + + for i in 0..values.len() { + let value = input_array.as_binary::().value(i); + + if !filter.value(i) { + builder.append_null(); + continue; + } + + if value.is_empty() { + builder.append_null(); + } else { + builder.append_value(value); + } + } + + Arc::new(builder.finish()) as ArrayRef + }; + + Ok(vec![array]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.states.iter().map(|s| s.len()).sum() + } +} + +fn invoke_accumulator(accumulator: &mut StringGroupsAccumulator, input_array: &ArrayRef, group_index: usize) +where + F: Fn(&[u8], &[u8]) -> bool + Send + Sync +{ + let value: &[u8] = if VIEW { + input_array.as_binary_view().value(group_index) + } else { + input_array.as_binary::().value(group_index) + }; + + let value_str = std::str::from_utf8(value).map_err(|e| { + DataFusionError::Execution(format!( + "could not build utf8 {}", + e + )) + }).expect("failed to build utf8"); + + if accumulator.states[group_index].is_empty() { + accumulator.states[group_index] = value_str.to_string(); + } else { + let curr_value_bytes = accumulator.states[group_index].as_bytes(); + if (accumulator.fun)(value, curr_value_bytes) { + accumulator.states[group_index] = value_str.parse().unwrap(); + } + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max.rs deleted file mode 100644 index ef2f2b2a076dd..0000000000000 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max.rs +++ /dev/null @@ -1,17 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -pub mod groups_accumulator_max_view; -pub mod groups_accumulator_min_view; diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs deleted file mode 100644 index feea96581910f..0000000000000 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_max_view.rs +++ /dev/null @@ -1,203 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -use arrow::array::{Array, ArrayRef, AsArray, BinaryViewBuilder, BooleanArray}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; -use std::sync::Arc; - -pub struct GroupsAccumulatorMaxStringView { - states: Vec, -} - -impl Default for GroupsAccumulatorMaxStringView { - fn default() -> Self { - Self::new() - } -} - -impl GroupsAccumulatorMaxStringView { - pub fn new() -> Self { - Self { states: Vec::new() } - } -} - -impl GroupsAccumulator for GroupsAccumulatorMaxStringView { - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - if self.states.len() < total_num_groups { - self.states.resize(total_num_groups, String::new()); - } - - let input_array = &values[0]; - - for (i, &group_index) in group_indices.iter().enumerate() { - if let Some(filter) = opt_filter { - if !filter.value(i) { - continue; - } - } - - if input_array.is_null(i) { - continue; - } - - let value = input_array.as_binary_view().value(i); - - let value_str = std::str::from_utf8(value).map_err(|e| { - DataFusionError::Execution(format!( - "could not build utf8 from binary view {}", - e - )) - })?; - - if self.states[group_index].is_empty() { - self.states[group_index] = value_str.to_string(); - } else { - let curr_value_bytes = self.states[group_index].as_bytes(); - if value > curr_value_bytes { - self.states[group_index] = value_str.parse().unwrap(); - } - } - } - Ok(()) - } - - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let states = emit_to.take_needed(&mut self.states); - - let mut builder = BinaryViewBuilder::new(); - - for value in states { - if value.is_empty() { - builder.append_null(); - } else { - builder.append_value(value.as_bytes()); - } - } - - let array = Arc::new(builder.finish()) as ArrayRef; - Ok(array) - } - - fn state(&mut self, emit_to: EmitTo) -> Result> { - let states = emit_to.take_needed(&mut self.states); - - let mut builder = BinaryViewBuilder::new(); - - for value in states { - if value.is_empty() { - builder.append_null(); - } else { - builder.append_value(value.as_bytes()); - } - } - - let array = Arc::new(builder.finish()) as ArrayRef; - Ok(vec![array]) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - if self.states.len() < total_num_groups { - self.states.resize(total_num_groups, String::new()); - } - - let input_array = &values[0]; - - for (i, &group_index) in group_indices.iter().enumerate() { - if let Some(filter) = opt_filter { - if !filter.value(i) { - continue; - } - } - - if input_array.is_null(i) { - continue; - } - - let value = input_array.as_binary_view().value(i); - - let value_str = std::str::from_utf8(value).map_err(|e| { - DataFusionError::Execution(format!( - "could not build utf8 from binary view {}", - e - )) - })?; - - if self.states[group_index].is_empty() { - self.states[group_index] = value_str.to_string(); - } else { - let curr_value_bytes = self.states[group_index].as_bytes(); - if value > curr_value_bytes { - self.states[group_index] = value_str.parse().unwrap(); - } - } - } - Ok(()) - } - - fn convert_to_state( - &self, - values: &[ArrayRef], - opt_filter: Option<&BooleanArray>, - ) -> Result> { - let input_array = &values[0]; - - if opt_filter.is_none() { - return Ok(vec![Arc::::clone(input_array)]); - } - - let filter = opt_filter.unwrap(); - - let mut builder = BinaryViewBuilder::new(); - - for i in 0..values.len() { - let value = input_array.as_binary_view().value(i); - - if !filter.value(i) { - builder.append_null(); - continue; - } - - if value.is_empty() { - builder.append_null(); - } else { - builder.append_value(value); - } - } - - let array = Arc::new(builder.finish()) as ArrayRef; - Ok(vec![array]) - } - - fn supports_convert_to_state(&self) -> bool { - true - } - - fn size(&self) -> usize { - self.states.iter().map(|s| s.len()).sum() - } -} diff --git a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs b/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs deleted file mode 100644 index 06db5f5063d07..0000000000000 --- a/datafusion/functions-aggregate-common/src/aggregate/min_max/groups_accumulator_min_view.rs +++ /dev/null @@ -1,203 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -use arrow::array::{Array, ArrayRef, AsArray, BinaryViewBuilder, BooleanArray}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; -use std::sync::Arc; - -pub struct GroupsAccumulatorMinStringView { - states: Vec, -} - -impl Default for GroupsAccumulatorMinStringView { - fn default() -> Self { - Self::new() - } -} - -impl GroupsAccumulatorMinStringView { - pub fn new() -> Self { - Self { states: Vec::new() } - } -} - -impl GroupsAccumulator for GroupsAccumulatorMinStringView { - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - if self.states.len() < total_num_groups { - self.states.resize(total_num_groups, String::new()); - } - - let input_array = &values[0]; - - for (i, &group_index) in group_indices.iter().enumerate() { - if let Some(filter) = opt_filter { - if !filter.value(i) { - continue; - } - } - - if input_array.is_null(i) { - continue; - } - - let value = input_array.as_binary_view().value(i); - - let value_str = std::str::from_utf8(value).map_err(|e| { - DataFusionError::Execution(format!( - "could not build utf8 from binary view {}", - e - )) - })?; - - if self.states[group_index].is_empty() { - self.states[group_index] = value_str.to_string(); - } else { - let curr_value_bytes = self.states[group_index].as_bytes(); - if value < curr_value_bytes { - self.states[group_index] = value_str.parse().unwrap(); - } - } - } - Ok(()) - } - - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let states = emit_to.take_needed(&mut self.states); - - let mut builder = BinaryViewBuilder::new(); - - for value in states { - if value.is_empty() { - builder.append_null(); - } else { - builder.append_value(value.as_bytes()); - } - } - - let array = Arc::new(builder.finish()) as ArrayRef; - Ok(array) - } - - fn state(&mut self, emit_to: EmitTo) -> Result> { - let states = emit_to.take_needed(&mut self.states); - - let mut builder = BinaryViewBuilder::new(); - - for value in states { - if value.is_empty() { - builder.append_null(); - } else { - builder.append_value(value.as_bytes()); - } - } - - let array = Arc::new(builder.finish()) as ArrayRef; - Ok(vec![array]) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - if self.states.len() < total_num_groups { - self.states.resize(total_num_groups, String::new()); - } - - let input_array = &values[0]; - - for (i, &group_index) in group_indices.iter().enumerate() { - if let Some(filter) = opt_filter { - if !filter.value(i) { - continue; - } - } - - if input_array.is_null(i) { - continue; - } - - let value = input_array.as_binary_view().value(i); - - let value_str = std::str::from_utf8(value).map_err(|e| { - DataFusionError::Execution(format!( - "could not build utf8 from binary view {}", - e - )) - })?; - - if self.states[group_index].is_empty() { - self.states[group_index] = value_str.to_string(); - } else { - let curr_value_bytes = self.states[group_index].as_bytes(); - if value < curr_value_bytes { - self.states[group_index] = value_str.parse().unwrap(); - } - } - } - Ok(()) - } - - fn convert_to_state( - &self, - values: &[ArrayRef], - opt_filter: Option<&BooleanArray>, - ) -> Result> { - let input_array = &values[0]; - - if opt_filter.is_none() { - return Ok(vec![Arc::::clone(input_array)]); - } - - let filter = opt_filter.unwrap(); - - let mut builder = BinaryViewBuilder::new(); - - for i in 0..values.len() { - let value = input_array.as_binary_view().value(i); - - if !filter.value(i) { - builder.append_null(); - continue; - } - - if value.is_empty() { - builder.append_null(); - } else { - builder.append_value(value); - } - } - - let array = Arc::new(builder.finish()) as ArrayRef; - Ok(vec![array]) - } - - fn supports_convert_to_state(&self) -> bool { - true - } - - fn size(&self) -> usize { - self.states.iter().map(|s| s.len()).sum() - } -} diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 50f119d195cd6..7f1385ea4dd69 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -53,8 +53,7 @@ use datafusion_common::{ downcast_value, exec_err, internal_err, DataFusionError, Result, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use datafusion_functions_aggregate_common::aggregate::min_max::groups_accumulator_max_view::GroupsAccumulatorMaxStringView; -use datafusion_functions_aggregate_common::aggregate::min_max::groups_accumulator_min_view::GroupsAccumulatorMinStringView; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_string::StringGroupsAccumulator; use std::fmt::Debug; use arrow::datatypes::i256; @@ -130,6 +129,16 @@ macro_rules! instantiate_max_accumulator { }}; } +macro_rules! instantiate_max_string_accumulator { + ($VIEW:expr) => {{ + Ok(Box::new( + StringGroupsAccumulator::<_, $VIEW>::new(|a, b| { + a > b + }) + )) + }}; +} + /// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN` /// the specified [`ArrowPrimitiveType`]. /// @@ -149,6 +158,17 @@ macro_rules! instantiate_min_accumulator { }}; } + +macro_rules! instantiate_min_string_accumulator { + ($VIEW:expr) => {{ + Ok(Box::new( + StringGroupsAccumulator::<_, $VIEW>::new(|a, b| { + a < b + }) + )) + }}; +} + impl AggregateUDFImpl for Max { fn as_any(&self) -> &dyn std::any::Any { self @@ -256,7 +276,12 @@ impl AggregateUDFImpl for Max { Decimal256(_, _) => { instantiate_max_accumulator!(data_type, i256, Decimal256Type) } - BinaryView => Ok(Box::new(GroupsAccumulatorMaxStringView::default())), + BinaryView => { + instantiate_max_string_accumulator!(true) + } + Binary => { + instantiate_max_string_accumulator!(false) + } // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906 @@ -977,6 +1002,7 @@ impl AggregateUDFImpl for Min { | Time64(_) | Timestamp(_, _) | BinaryView + | Binary ) } @@ -1037,7 +1063,12 @@ impl AggregateUDFImpl for Min { Decimal256(_, _) => { instantiate_min_accumulator!(data_type, i256, Decimal256Type) } - BinaryView => Ok(Box::new(GroupsAccumulatorMinStringView::default())), + BinaryView => { + instantiate_min_string_accumulator!(true) + }, + Binary => { + instantiate_min_string_accumulator!(false) + }, // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906 From 7205cca2becacfc79d0d2d4a88fae097b606ecf7 Mon Sep 17 00:00:00 2001 From: Devan Date: Thu, 3 Oct 2024 10:24:47 -0500 Subject: [PATCH 15/17] feat: rename to string_op --- .../src/aggregate/groups_accumulator.rs | 2 +- .../groups_accumulator/{prim_string.rs => string_op.rs} | 0 datafusion/functions-aggregate/src/min_max.rs | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/{prim_string.rs => string_op.rs} (100%) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index 20f4a245c5c8d..5c5d66aef4447 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -22,7 +22,7 @@ pub mod accumulate; pub mod bool_op; pub mod nulls; pub mod prim_op; -pub mod prim_string; +pub mod string_op; use arrow::{ array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_string.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs similarity index 100% rename from datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_string.rs rename to datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 7f1385ea4dd69..c132426854b9a 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -53,7 +53,7 @@ use datafusion_common::{ downcast_value, exec_err, internal_err, DataFusionError, Result, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_string::StringGroupsAccumulator; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::string_op::StringGroupsAccumulator; use std::fmt::Debug; use arrow::datatypes::i256; From a52cefea64ef4ed42a675c46a7b6668fb348fd3b Mon Sep 17 00:00:00 2001 From: Devan Date: Thu, 3 Oct 2024 10:55:11 -0500 Subject: [PATCH 16/17] fix: need to implement own accumulator --- .../aggregate/groups_accumulator/string_op.rs | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs index d42d5e152795b..37b9e3abf7a16 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs @@ -17,7 +17,6 @@ use arrow::array::{Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, B use datafusion_common::{DataFusionError, Result}; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; use std::sync::Arc; -use crate::aggregate::groups_accumulator::accumulate::accumulate_indices; pub struct StringGroupsAccumulator { states: Vec, @@ -53,10 +52,10 @@ where let input_array = &values[0]; - accumulate_indices(group_indices, input_array.logical_nulls().as_ref(), opt_filter, |group_index| { - invoke_accumulator::(self, input_array, group_index) - }); - + for (i, &group_index) in group_indices.iter().enumerate() { + invoke_accumulator::(self, input_array, opt_filter, group_index, i) + } + Ok(()) } @@ -137,9 +136,9 @@ where let input_array = &values[0]; - accumulate_indices(group_indices, input_array.logical_nulls().as_ref(), opt_filter, |group_index| { - invoke_accumulator::(self, input_array, group_index) - }); + for (i, &group_index) in group_indices.iter().enumerate() { + invoke_accumulator::(self, input_array, opt_filter, group_index, i) + } Ok(()) } @@ -209,14 +208,23 @@ where } } -fn invoke_accumulator(accumulator: &mut StringGroupsAccumulator, input_array: &ArrayRef, group_index: usize) +fn invoke_accumulator(accumulator: &mut StringGroupsAccumulator, input_array: &ArrayRef, opt_filter: Option<&BooleanArray>, group_index: usize, i: usize) where F: Fn(&[u8], &[u8]) -> bool + Send + Sync { + if let Some(filter) = opt_filter { + if !filter.value(i) { + return + } + } + if input_array.is_null(i) { + return + } + let value: &[u8] = if VIEW { - input_array.as_binary_view().value(group_index) + input_array.as_binary_view().value(i) } else { - input_array.as_binary::().value(group_index) + input_array.as_binary::().value(i) }; let value_str = std::str::from_utf8(value).map_err(|e| { From a0a1572f5a58ffd461e62505ad1f2f04beb7dda9 Mon Sep 17 00:00:00 2001 From: Devan Date: Thu, 3 Oct 2024 10:56:13 -0500 Subject: [PATCH 17/17] fix: fmt --- .../src/aggregate.rs | 2 +- .../aggregate/groups_accumulator/string_op.rs | 48 ++++++++++--------- datafusion/functions-aggregate/src/min_max.rs | 21 ++++---- 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs index 80f18623edf02..c9cbaa8396fc5 100644 --- a/datafusion/functions-aggregate-common/src/aggregate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -16,4 +16,4 @@ // under the License. pub mod count_distinct; -pub mod groups_accumulator; \ No newline at end of file +pub mod groups_accumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs index 37b9e3abf7a16..2b95ed1065131 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/string_op.rs @@ -13,30 +13,32 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray}; +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray, +}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; use std::sync::Arc; pub struct StringGroupsAccumulator { states: Vec, - fun: F + fun: F, } -impl StringGroupsAccumulator -where +impl StringGroupsAccumulator +where F: Fn(&[u8], &[u8]) -> bool + Send + Sync, { pub fn new(s_fn: F) -> Self { - Self { + Self { states: Vec::new(), - fun: s_fn + fun: s_fn, } } } -impl GroupsAccumulator for StringGroupsAccumulator -where +impl GroupsAccumulator for StringGroupsAccumulator +where F: Fn(&[u8], &[u8]) -> bool + Send + Sync, { fn update_batch( @@ -55,7 +57,7 @@ where for (i, &group_index) in group_indices.iter().enumerate() { invoke_accumulator::(self, input_array, opt_filter, group_index, i) } - + Ok(()) } @@ -155,7 +157,7 @@ where } let filter = opt_filter.unwrap(); - + let array = if VIEW { let mut builder = BinaryViewBuilder::new(); @@ -208,31 +210,33 @@ where } } -fn invoke_accumulator(accumulator: &mut StringGroupsAccumulator, input_array: &ArrayRef, opt_filter: Option<&BooleanArray>, group_index: usize, i: usize) -where - F: Fn(&[u8], &[u8]) -> bool + Send + Sync +fn invoke_accumulator( + accumulator: &mut StringGroupsAccumulator, + input_array: &ArrayRef, + opt_filter: Option<&BooleanArray>, + group_index: usize, + i: usize, +) where + F: Fn(&[u8], &[u8]) -> bool + Send + Sync, { if let Some(filter) = opt_filter { if !filter.value(i) { - return + return; } } if input_array.is_null(i) { - return + return; } - + let value: &[u8] = if VIEW { input_array.as_binary_view().value(i) } else { input_array.as_binary::().value(i) }; - let value_str = std::str::from_utf8(value).map_err(|e| { - DataFusionError::Execution(format!( - "could not build utf8 {}", - e - )) - }).expect("failed to build utf8"); + let value_str = std::str::from_utf8(value) + .map_err(|e| DataFusionError::Execution(format!("could not build utf8 {}", e))) + .expect("failed to build utf8"); if accumulator.states[group_index].is_empty() { accumulator.states[group_index] = value_str.to_string(); diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index c132426854b9a..605c17f9327e8 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -131,11 +131,9 @@ macro_rules! instantiate_max_accumulator { macro_rules! instantiate_max_string_accumulator { ($VIEW:expr) => {{ - Ok(Box::new( - StringGroupsAccumulator::<_, $VIEW>::new(|a, b| { - a > b - }) - )) + Ok(Box::new(StringGroupsAccumulator::<_, $VIEW>::new( + |a, b| a > b, + ))) }}; } @@ -158,14 +156,11 @@ macro_rules! instantiate_min_accumulator { }}; } - macro_rules! instantiate_min_string_accumulator { ($VIEW:expr) => {{ - Ok(Box::new( - StringGroupsAccumulator::<_, $VIEW>::new(|a, b| { - a < b - }) - )) + Ok(Box::new(StringGroupsAccumulator::<_, $VIEW>::new( + |a, b| a < b, + ))) }}; } @@ -1065,10 +1060,10 @@ impl AggregateUDFImpl for Min { } BinaryView => { instantiate_min_string_accumulator!(true) - }, + } Binary => { instantiate_min_string_accumulator!(false) - }, + } // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906