From 55a9cd72e7d2ecb6829ee76987727ba6042e4e59 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Thu, 25 Apr 2024 20:29:09 +1000 Subject: [PATCH 1/2] Support for median(distinct) aggregation function --- .../core/benches/aggregate_query_sql.rs | 10 + .../physical-expr/src/aggregate/build_in.rs | 8 +- .../physical-expr/src/aggregate/median.rs | 349 +++++++++++++++++- .../physical-expr/src/expressions/mod.rs | 1 + .../sqllogictest/test_files/aggregate.slt | 4 +- 5 files changed, 352 insertions(+), 20 deletions(-) diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index 3734cfbe313c1..1d8d87ada7847 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -163,6 +163,16 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + c.bench_function("aggregate_query_distinct_median", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT MEDIAN(DISTINCT u64_wide), MEDIAN(DISTINCT u64_narrow) \ + FROM t", + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index c549e62193752..066760eaba81e 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -357,9 +357,11 @@ pub fn create_aggregate_expr( name, data_type, )), - (AggregateFunction::Median, true) => { - return not_impl_err!("MEDIAN(DISTINCT) aggregations are not available"); - } + (AggregateFunction::Median, true) => Arc::new(expressions::DistinctMedian::new( + input_phy_exprs[0].clone(), + name, + data_type, + )), (AggregateFunction::FirstValue, _) => Arc::new( expressions::FirstValue::new( input_phy_exprs[0].clone(), diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index ed373ba13d5ec..1049187a529a3 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -17,7 +17,7 @@ //! # Median -use crate::aggregate::utils::down_cast_any_ref; +use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::{Array, ArrayRef}; @@ -28,6 +28,7 @@ use arrow_buffer::ArrowNativeType; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; use std::any::Any; +use std::collections::HashSet; use std::fmt::Formatter; use std::sync::Arc; @@ -172,21 +173,8 @@ impl Accumulator for MedianAccumulator { } fn evaluate(&mut self) -> Result { - let mut d = std::mem::take(&mut self.all_values); - let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); - - let len = d.len(); - let median = if len == 0 { - None - } else if len % 2 == 0 { - let (low, high, _) = d.select_nth_unstable_by(len / 2, cmp); - let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp); - let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2)); - Some(median) - } else { - let (_, median, _) = d.select_nth_unstable_by(len / 2, cmp); - Some(*median) - }; + let d = std::mem::take(&mut self.all_values); + let median = calculate_median::(d); ScalarValue::new_primitive::(median, &self.data_type) } @@ -196,6 +184,192 @@ impl Accumulator for MedianAccumulator { } } +/// MEDIAN(DISTINCT) aggregate expression. Similar to MEDIAN but computes after taking +/// all unique values. This may use a lot of memory if the cardinality is high. +#[derive(Debug)] +pub struct DistinctMedian { + name: String, + expr: Arc, + data_type: DataType, +} + +impl DistinctMedian { + /// Create a new MEDIAN(DISTINCT) aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + expr, + data_type, + } + } +} + +impl AggregateExpr for DistinctMedian { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + use arrow_array::types::*; + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(DistinctMedianAccumulator::<$t> { + data_type: $dt.clone(), + distinct_values: Default::default(), + })) + }; + } + let dt = &self.data_type; + downcast_integer! { + dt => (helper, dt), + DataType::Float16 => helper!(Float16Type, dt), + DataType::Float32 => helper!(Float32Type, dt), + DataType::Float64 => helper!(Float64Type, dt), + DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), + DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), + _ => Err(DataFusionError::NotImplemented(format!( + "DistinctMedianAccumulator not supported for {} with {}", + self.name(), + self.data_type + ))), + } + } + + fn state_fields(&self) -> Result> { + // Intermediate state is a list of the unique elements we have + // collected so far + let field = Field::new("item", self.data_type.clone(), true); + let data_type = DataType::List(Arc::new(field)); + + Ok(vec![Field::new( + format_state_name(&self.name, "distinct_median"), + data_type, + true, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for DistinctMedian { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + +/// The distinct median accumulator accumulates the raw input values +/// as `ScalarValue`s +/// +/// The intermediate state is represented as a List of scalar values updated by +/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values +/// in the final evaluation step so that we avoid expensive conversions and +/// allocations during `update_batch`. +struct DistinctMedianAccumulator { + data_type: DataType, + distinct_values: HashSet>, +} + +impl std::fmt::Debug for DistinctMedianAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctMedianAccumulator({})", self.data_type) + } +} + +impl Accumulator for DistinctMedianAccumulator { + fn state(&mut self) -> Result> { + let all_values = self + .distinct_values + .iter() + .map(|x| ScalarValue::new_primitive::(Some(x.0), &self.data_type)) + .collect::>>()?; + + let arr = ScalarValue::new_list(&all_values, &self.data_type); + Ok(vec![ScalarValue::List(arr)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.distinct_values.insert(Hashable(array.value(idx))); + } + } + None => array.values().iter().for_each(|x| { + self.distinct_values.insert(Hashable(*x)); + }), + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let array = states[0].as_list::(); + for v in array.iter().flatten() { + self.update_batch(&[v])? + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let d = std::mem::take(&mut self.distinct_values) + .into_iter() + .map(|v| v.0) + .collect::>(); + let median = calculate_median::(d); + ScalarValue::new_primitive::(median, &self.data_type) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.distinct_values.capacity() * std::mem::size_of::() + } +} + +fn calculate_median( + mut values: Vec, +) -> Option { + let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); + + let len = values.len(); + if len == 0 { + None + } else if len % 2 == 0 { + let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp); + let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp); + let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2)); + Some(median) + } else { + let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp); + Some(*median) + } +} + #[cfg(test)] mod tests { use super::*; @@ -329,4 +503,147 @@ mod tests { ])); generic_test_op!(a, DataType::Float64, Median, ScalarValue::from(3.5_f64)) } + + #[test] + fn distinct_median_decimal() -> Result<()> { + let array: ArrayRef = Arc::new( + vec![1, 1, 1, 1, 1, 1, 2, 3, 3] + .into_iter() + .map(Some) + .collect::() + .with_precision_and_scale(10, 4)?, + ); + + generic_test_op!( + array, + DataType::Decimal128(10, 4), + DistinctMedian, + ScalarValue::Decimal128(Some(2), 10, 4) + ) + } + + #[test] + fn distinct_median_decimal_with_nulls() -> Result<()> { + let array: ArrayRef = Arc::new( + vec![Some(1), Some(2), None, Some(3), Some(3), Some(3), Some(3)] + .into_iter() + .collect::() + .with_precision_and_scale(10, 4)?, + ); + generic_test_op!( + array, + DataType::Decimal128(10, 4), + DistinctMedian, + ScalarValue::Decimal128(Some(2), 10, 4) + ) + } + + #[test] + fn distinct_median_decimal_all_nulls() -> Result<()> { + let array: ArrayRef = Arc::new( + std::iter::repeat::>(None) + .take(6) + .collect::() + .with_precision_and_scale(10, 4)?, + ); + generic_test_op!( + array, + DataType::Decimal128(10, 4), + DistinctMedian, + ScalarValue::Decimal128(None, 10, 4) + ) + } + + #[test] + fn distinct_median_i32_odd() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 1, 1, 2, 3])); + generic_test_op!(a, DataType::Int32, DistinctMedian, ScalarValue::from(2_i32)) + } + + #[test] + fn distinct_median_i32_even() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 1, 1, 1, 3])); + generic_test_op!(a, DataType::Int32, DistinctMedian, ScalarValue::from(2_i32)) + } + + #[test] + fn distinct_median_i32_with_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(1), + Some(1), + Some(3), + ])); + generic_test_op!(a, DataType::Int32, DistinctMedian, ScalarValue::from(2i32)) + } + + #[test] + fn distinct_median_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + generic_test_op!(a, DataType::Int32, DistinctMedian, ScalarValue::Int32(None)) + } + + #[test] + fn distinct_median_u32_odd() -> Result<()> { + let a: ArrayRef = + Arc::new(UInt32Array::from(vec![1_u32, 1_u32, 1_u32, 2_u32, 3_u32])); + generic_test_op!(a, DataType::UInt32, DistinctMedian, ScalarValue::from(2u32)) + } + + #[test] + fn distinct_median_u32_even() -> Result<()> { + let a: ArrayRef = Arc::new(UInt32Array::from(vec![ + 1_u32, 1_u32, 1_u32, 1_u32, 3_u32, 3_u32, + ])); + generic_test_op!(a, DataType::UInt32, DistinctMedian, ScalarValue::from(2u32)) + } + + #[test] + fn distinct_median_f32_odd() -> Result<()> { + let a: ArrayRef = + Arc::new(Float32Array::from(vec![1_f32, 1_f32, 1_f32, 2_f32, 3_f32])); + generic_test_op!( + a, + DataType::Float32, + DistinctMedian, + ScalarValue::from(2_f32) + ) + } + + #[test] + fn distinct_median_f32_even() -> Result<()> { + let a: ArrayRef = + Arc::new(Float32Array::from(vec![1_f32, 1_f32, 1_f32, 1_f32, 2_f32])); + generic_test_op!( + a, + DataType::Float32, + DistinctMedian, + ScalarValue::from(1.5_f32) + ) + } + + #[test] + fn distinct_median_f64_odd() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 1_f64, 1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + DistinctMedian, + ScalarValue::from(2_f64) + ) + } + + #[test] + fn distinct_median_f64_even() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 1_f64, 1_f64, 1_f64, 2_f64])); + generic_test_op!( + a, + DataType::Float64, + DistinctMedian, + ScalarValue::from(1.5_f64) + ) + } } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 688d5ce6eabf2..d145913a5b28e 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -54,6 +54,7 @@ pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; pub use crate::aggregate::grouping::Grouping; +pub use crate::aggregate::median::DistinctMedian; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 8b5b84e766506..1c10d2c0e523f 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -472,8 +472,10 @@ SELECT median(distinct col_i8) FROM median_table ---- 100 -statement error DataFusion error: This feature is not implemented: MEDIAN\(DISTINCT\) aggregations are not available +query II SELECT median(col_i8), median(distinct col_i8) FROM median_table +---- +-14 100 # approx_distinct_median_i8 query I From b3caf36e6476dc3ba6e3469f371eb37ca1658ea2 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Fri, 26 Apr 2024 18:29:41 +1000 Subject: [PATCH 2/2] Reduce duplication --- .../physical-expr/src/aggregate/build_in.rs | 8 +- .../physical-expr/src/aggregate/median.rs | 318 ++++++++++-------- .../physical-expr/src/expressions/mod.rs | 35 +- .../physical-plan/src/aggregates/mod.rs | 1 + 4 files changed, 212 insertions(+), 150 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 066760eaba81e..57ed35b0b761a 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -352,15 +352,11 @@ pub fn create_aggregate_expr( "APPROX_MEDIAN(DISTINCT) aggregations are not available" ); } - (AggregateFunction::Median, false) => Arc::new(expressions::Median::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Median, true) => Arc::new(expressions::DistinctMedian::new( + (AggregateFunction::Median, distinct) => Arc::new(expressions::Median::new( input_phy_exprs[0].clone(), name, data_type, + distinct, )), (AggregateFunction::FirstValue, _) => Arc::new( expressions::FirstValue::new( diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index 1049187a529a3..f4f56fa46ed59 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -32,14 +32,20 @@ use std::collections::HashSet; use std::fmt::Formatter; use std::sync::Arc; -/// MEDIAN aggregate expression. This uses a lot of memory because all values need to be -/// stored in memory before a result can be computed. If an approximation is sufficient -/// then APPROX_MEDIAN provides a much more efficient solution. +/// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a +/// lot of memory because all values need to be stored in memory before a result can be +/// computed. If an approximation is sufficient then APPROX_MEDIAN provides a much more +/// efficient solution. +/// +/// If using the distinct variation, the memory usage will be similarly high if the +/// cardinality is high as it stores all distinct values in memory before computing the +/// result, but if cardinality is low then memory usage will also be lower. #[derive(Debug)] pub struct Median { name: String, expr: Arc, data_type: DataType, + distinct: bool, } impl Median { @@ -48,11 +54,13 @@ impl Median { expr: Arc, name: impl Into, data_type: DataType, + distinct: bool, ) -> Self { Self { name: name.into(), expr, data_type, + distinct, } } } @@ -71,10 +79,17 @@ impl AggregateExpr for Median { use arrow_array::types::*; macro_rules! helper { ($t:ty, $dt:expr) => { - Ok(Box::new(MedianAccumulator::<$t> { - data_type: $dt.clone(), - all_values: vec![], - })) + if self.distinct { + Ok(Box::new(DistinctMedianAccumulator::<$t> { + data_type: $dt.clone(), + distinct_values: HashSet::new(), + })) + } else { + Ok(Box::new(MedianAccumulator::<$t> { + data_type: $dt.clone(), + all_values: vec![], + })) + } }; } let dt = &self.data_type; @@ -97,9 +112,14 @@ impl AggregateExpr for Median { //Intermediate state is a list of the elements we have collected so far let field = Field::new("item", self.data_type.clone(), true); let data_type = DataType::List(Arc::new(field)); + let state_name = if self.distinct { + "distinct_median" + } else { + "median" + }; Ok(vec![Field::new( - format_state_name(&self.name, "median"), + format_state_name(&self.name, state_name), data_type, true, )]) @@ -122,6 +142,7 @@ impl PartialEq for Median { self.name == x.name && self.data_type == x.data_type && self.expr.eq(&x.expr) + && self.distinct == x.distinct }) .unwrap_or(false) } @@ -184,101 +205,6 @@ impl Accumulator for MedianAccumulator { } } -/// MEDIAN(DISTINCT) aggregate expression. Similar to MEDIAN but computes after taking -/// all unique values. This may use a lot of memory if the cardinality is high. -#[derive(Debug)] -pub struct DistinctMedian { - name: String, - expr: Arc, - data_type: DataType, -} - -impl DistinctMedian { - /// Create a new MEDIAN(DISTINCT) aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - } - } -} - -impl AggregateExpr for DistinctMedian { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - use arrow_array::types::*; - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(DistinctMedianAccumulator::<$t> { - data_type: $dt.clone(), - distinct_values: Default::default(), - })) - }; - } - let dt = &self.data_type; - downcast_integer! { - dt => (helper, dt), - DataType::Float16 => helper!(Float16Type, dt), - DataType::Float32 => helper!(Float32Type, dt), - DataType::Float64 => helper!(Float64Type, dt), - DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), - DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), - _ => Err(DataFusionError::NotImplemented(format!( - "DistinctMedianAccumulator not supported for {} with {}", - self.name(), - self.data_type - ))), - } - } - - fn state_fields(&self) -> Result> { - // Intermediate state is a list of the unique elements we have - // collected so far - let field = Field::new("item", self.data_type.clone(), true); - let data_type = DataType::List(Arc::new(field)); - - Ok(vec![Field::new( - format_state_name(&self.name, "distinct_median"), - data_type, - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctMedian { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - /// The distinct median accumulator accumulates the raw input values /// as `ScalarValue`s /// @@ -375,7 +301,7 @@ mod tests { use super::*; use crate::expressions::col; use crate::expressions::tests::aggregate; - use crate::generic_test_op; + use crate::generic_test_distinct_op; use arrow::{array::*, datatypes::*}; #[test] @@ -388,10 +314,11 @@ mod tests { .with_precision_and_scale(10, 4)?, ); - generic_test_op!( + generic_test_distinct_op!( array, DataType::Decimal128(10, 4), Median, + false, ScalarValue::Decimal128(Some(3), 10, 4) ) } @@ -404,10 +331,11 @@ mod tests { .collect::() .with_precision_and_scale(10, 4)?, ); - generic_test_op!( + generic_test_distinct_op!( array, DataType::Decimal128(10, 4), Median, + false, ScalarValue::Decimal128(Some(3), 10, 4) ) } @@ -421,10 +349,11 @@ mod tests { .collect::() .with_precision_and_scale(10, 4)?, ); - generic_test_op!( + generic_test_distinct_op!( array, DataType::Decimal128(10, 4), Median, + false, ScalarValue::Decimal128(None, 10, 4) ) } @@ -432,13 +361,25 @@ mod tests { #[test] fn median_i32_odd() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3_i32)) + generic_test_distinct_op!( + a, + DataType::Int32, + Median, + false, + ScalarValue::from(3_i32) + ) } #[test] fn median_i32_even() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3_i32)) + generic_test_distinct_op!( + a, + DataType::Int32, + Median, + false, + ScalarValue::from(3_i32) + ) } #[test] @@ -450,20 +391,38 @@ mod tests { Some(4), Some(5), ])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3i32)) + generic_test_distinct_op!( + a, + DataType::Int32, + Median, + false, + ScalarValue::from(3i32) + ) } #[test] fn median_i32_all_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::Int32(None)) + generic_test_distinct_op!( + a, + DataType::Int32, + Median, + false, + ScalarValue::Int32(None) + ) } #[test] fn median_u32_odd() -> Result<()> { let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Median, ScalarValue::from(3u32)) + generic_test_distinct_op!( + a, + DataType::UInt32, + Median, + false, + ScalarValue::from(3u32) + ) } #[test] @@ -471,14 +430,26 @@ mod tests { let a: ArrayRef = Arc::new(UInt32Array::from(vec![ 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, 6_u32, ])); - generic_test_op!(a, DataType::UInt32, Median, ScalarValue::from(3u32)) + generic_test_distinct_op!( + a, + DataType::UInt32, + Median, + false, + ScalarValue::from(3u32) + ) } #[test] fn median_f32_odd() -> Result<()> { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Median, ScalarValue::from(3_f32)) + generic_test_distinct_op!( + a, + DataType::Float32, + Median, + false, + ScalarValue::from(3_f32) + ) } #[test] @@ -486,14 +457,26 @@ mod tests { let a: ArrayRef = Arc::new(Float32Array::from(vec![ 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, 6_f32, ])); - generic_test_op!(a, DataType::Float32, Median, ScalarValue::from(3.5_f32)) + generic_test_distinct_op!( + a, + DataType::Float32, + Median, + false, + ScalarValue::from(3.5_f32) + ) } #[test] fn median_f64_odd() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Median, ScalarValue::from(3_f64)) + generic_test_distinct_op!( + a, + DataType::Float64, + Median, + false, + ScalarValue::from(3_f64) + ) } #[test] @@ -501,23 +484,30 @@ mod tests { let a: ArrayRef = Arc::new(Float64Array::from(vec![ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, 6_f64, ])); - generic_test_op!(a, DataType::Float64, Median, ScalarValue::from(3.5_f64)) + generic_test_distinct_op!( + a, + DataType::Float64, + Median, + false, + ScalarValue::from(3.5_f64) + ) } #[test] fn distinct_median_decimal() -> Result<()> { let array: ArrayRef = Arc::new( - vec![1, 1, 1, 1, 1, 1, 2, 3, 3] + vec![1, 1, 1, 1, 2, 3, 1, 1, 3] .into_iter() .map(Some) .collect::() .with_precision_and_scale(10, 4)?, ); - generic_test_op!( + generic_test_distinct_op!( array, DataType::Decimal128(10, 4), - DistinctMedian, + Median, + true, ScalarValue::Decimal128(Some(2), 10, 4) ) } @@ -525,15 +515,16 @@ mod tests { #[test] fn distinct_median_decimal_with_nulls() -> Result<()> { let array: ArrayRef = Arc::new( - vec![Some(1), Some(2), None, Some(3), Some(3), Some(3), Some(3)] + vec![Some(3), Some(1), None, Some(3), Some(2), Some(3), Some(3)] .into_iter() .collect::() .with_precision_and_scale(10, 4)?, ); - generic_test_op!( + generic_test_distinct_op!( array, DataType::Decimal128(10, 4), - DistinctMedian, + Median, + true, ScalarValue::Decimal128(Some(2), 10, 4) ) } @@ -546,24 +537,37 @@ mod tests { .collect::() .with_precision_and_scale(10, 4)?, ); - generic_test_op!( + generic_test_distinct_op!( array, DataType::Decimal128(10, 4), - DistinctMedian, + Median, + true, ScalarValue::Decimal128(None, 10, 4) ) } #[test] fn distinct_median_i32_odd() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 1, 1, 2, 3])); - generic_test_op!(a, DataType::Int32, DistinctMedian, ScalarValue::from(2_i32)) + let a: ArrayRef = Arc::new(Int32Array::from(vec![2, 1, 1, 2, 1, 3])); + generic_test_distinct_op!( + a, + DataType::Int32, + Median, + true, + ScalarValue::from(2_i32) + ) } #[test] fn distinct_median_i32_even() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 1, 1, 1, 3])); - generic_test_op!(a, DataType::Int32, DistinctMedian, ScalarValue::from(2_i32)) + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 1, 3, 1, 1])); + generic_test_distinct_op!( + a, + DataType::Int32, + Median, + true, + ScalarValue::from(2_i32) + ) } #[test] @@ -575,20 +579,38 @@ mod tests { Some(1), Some(3), ])); - generic_test_op!(a, DataType::Int32, DistinctMedian, ScalarValue::from(2i32)) + generic_test_distinct_op!( + a, + DataType::Int32, + Median, + true, + ScalarValue::from(2i32) + ) } #[test] fn distinct_median_i32_all_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, DistinctMedian, ScalarValue::Int32(None)) + generic_test_distinct_op!( + a, + DataType::Int32, + Median, + true, + ScalarValue::Int32(None) + ) } #[test] fn distinct_median_u32_odd() -> Result<()> { let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 1_u32, 1_u32, 2_u32, 3_u32])); - generic_test_op!(a, DataType::UInt32, DistinctMedian, ScalarValue::from(2u32)) + Arc::new(UInt32Array::from(vec![1_u32, 1_u32, 2_u32, 1_u32, 3_u32])); + generic_test_distinct_op!( + a, + DataType::UInt32, + Median, + true, + ScalarValue::from(2u32) + ) } #[test] @@ -596,17 +618,24 @@ mod tests { let a: ArrayRef = Arc::new(UInt32Array::from(vec![ 1_u32, 1_u32, 1_u32, 1_u32, 3_u32, 3_u32, ])); - generic_test_op!(a, DataType::UInt32, DistinctMedian, ScalarValue::from(2u32)) + generic_test_distinct_op!( + a, + DataType::UInt32, + Median, + true, + ScalarValue::from(2u32) + ) } #[test] fn distinct_median_f32_odd() -> Result<()> { let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 1_f32, 1_f32, 2_f32, 3_f32])); - generic_test_op!( + Arc::new(Float32Array::from(vec![3_f32, 2_f32, 1_f32, 1_f32, 1_f32])); + generic_test_distinct_op!( a, DataType::Float32, - DistinctMedian, + Median, + true, ScalarValue::from(2_f32) ) } @@ -615,10 +644,11 @@ mod tests { fn distinct_median_f32_even() -> Result<()> { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 1_f32, 1_f32, 1_f32, 2_f32])); - generic_test_op!( + generic_test_distinct_op!( a, DataType::Float32, - DistinctMedian, + Median, + true, ScalarValue::from(1.5_f32) ) } @@ -627,10 +657,11 @@ mod tests { fn distinct_median_f64_odd() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 1_f64, 1_f64, 2_f64, 3_f64])); - generic_test_op!( + generic_test_distinct_op!( a, DataType::Float64, - DistinctMedian, + Median, + true, ScalarValue::from(2_f64) ) } @@ -639,10 +670,11 @@ mod tests { fn distinct_median_f64_even() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 1_f64, 1_f64, 1_f64, 2_f64])); - generic_test_op!( + generic_test_distinct_op!( a, DataType::Float64, - DistinctMedian, + Median, + true, ScalarValue::from(1.5_f64) ) } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index d145913a5b28e..55ebd9ed8c444 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -54,7 +54,6 @@ pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; pub use crate::aggregate::grouping::Grouping; -pub use crate::aggregate::median::DistinctMedian; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; @@ -135,6 +134,40 @@ pub(crate) mod tests { }}; } + /// Same as [`generic_test_op`] but with support for providing a 4th argument, usually + /// a boolean to indicate if using the distinct version of the op. + #[macro_export] + macro_rules! generic_test_distinct_op { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $DISTINCT:expr, $EXPECTED:expr) => { + generic_test_distinct_op!( + $ARRAY, + $DATATYPE, + $OP, + $DISTINCT, + $EXPECTED, + $EXPECTED.data_type() + ) + }; + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $DISTINCT:expr, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + $DISTINCT, + )); + let actual = aggregate(&batch, agg)?; + let expected = ScalarValue::from($EXPECTED); + + assert_eq!(expected, actual); + + Ok(()) as Result<(), ::datafusion_common::DataFusionError> + }}; + } + /// macro to perform an aggregation using [`crate::GroupsAccumulator`] and verify the result. /// /// The difference between this and the above `generic_test_op` is that the former checks diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 14485c8337947..25f5508365052 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1800,6 +1800,7 @@ mod tests { col("a", &input_schema)?, "MEDIAN(a)".to_string(), DataType::UInt32, + false, ))]; // use slow-path in `hash.rs`