From 81f1e38e1e7dbef57602ca77add69ed0365fc44f Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Thu, 7 Aug 2025 10:04:55 +0200 Subject: [PATCH 1/3] Require Eq to use udf_equals_hash The UDF comparison is expected to be reflexive. Require `Eq` for any uses of `udf_equals_hash` short-cut. --- .../user_defined/user_defined_scalar_functions.rs | 9 +++++---- .../user_defined/user_defined_window_functions.rs | 2 +- datafusion/expr/src/async_udf.rs | 1 + datafusion/expr/src/expr_fn.rs | 4 ++-- datafusion/expr/src/ptr_eq.rs | 3 ++- datafusion/expr/src/udf.rs | 1 + datafusion/expr/src/udwf.rs | 2 +- datafusion/expr/src/utils.rs | 10 ++++++---- datafusion/ffi/src/udf/mod.rs | 1 + datafusion/ffi/src/udwf/mod.rs | 1 + datafusion/functions-window/src/lead_lag.rs | 2 +- datafusion/functions-window/src/nth_value.rs | 2 +- datafusion/functions-window/src/rank.rs | 2 +- .../src/simplify_expressions/expr_simplifier.rs | 2 +- datafusion/proto/tests/cases/mod.rs | 2 +- datafusion/sql/tests/sql_integration.rs | 2 +- 16 files changed, 27 insertions(+), 19 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 32c2f1d302b40..bf7f58d51b0b9 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -181,7 +181,7 @@ async fn scalar_udf() -> Result<()> { Ok(()) } -#[derive(PartialEq, Hash)] +#[derive(PartialEq, Eq, Hash)] struct Simple0ArgsScalarUDF { name: String, signature: Signature, @@ -492,7 +492,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { } /// Volatile UDF that should append a different value to each row -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AddIndexToStringVolatileScalarUDF { name: String, signature: Signature, @@ -941,7 +941,7 @@ impl FunctionFactory for CustomFunctionFactory { // // it also defines custom [ScalarUDFImpl::simplify()] // to replace ScalarUDF expression with one instance contains. -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] struct ScalarFunctionWrapper { name: String, expr: Expr, @@ -1221,6 +1221,7 @@ impl PartialEq for MyRegexUdf { signature == &other.signature && regex.as_str() == other.regex.as_str() } } +impl Eq for MyRegexUdf {} impl Hash for MyRegexUdf { fn hash(&self, state: &mut H) { @@ -1380,7 +1381,7 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result) { - #[derive(Debug, Clone, PartialEq, Hash)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimpleWindowUDF { signature: Signature, test_state: PtrEq>, diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 4b5a55d90cc62..ad07d0690e56e 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -69,6 +69,7 @@ impl PartialEq for AsyncScalarUDF { arc_ptr_eq(inner, &other.inner) } } +impl Eq for AsyncScalarUDF {} impl Hash for AsyncScalarUDF { fn hash(&self, state: &mut H) { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 9e2285f7c0547..b91036360d057 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -403,7 +403,7 @@ pub fn create_udf( /// Implements [`ScalarUDFImpl`] for functions that have a single signature and /// return type. -#[derive(PartialEq, Hash)] +#[derive(PartialEq, Eq, Hash)] pub struct SimpleScalarUDF { name: String, signature: Signature, @@ -661,7 +661,7 @@ pub fn create_udwf( /// Implements [`WindowUDFImpl`] for functions that have a single signature and /// return type. -#[derive(PartialEq, Hash)] +#[derive(PartialEq, Eq, Hash)] pub struct SimpleWindowUDF { name: String, signature: Signature, diff --git a/datafusion/expr/src/ptr_eq.rs b/datafusion/expr/src/ptr_eq.rs index 5a177b266dc27..c85b3d9950cd9 100644 --- a/datafusion/expr/src/ptr_eq.rs +++ b/datafusion/expr/src/ptr_eq.rs @@ -34,7 +34,7 @@ pub fn arc_ptr_hash(a: &Arc, hasher: &mut impl Hasher) { std::ptr::hash(Arc::as_ptr(a), hasher) } -/// A wrapper around a pointer that implements `PartialEq` and `Hash` comparing +/// A wrapper around a pointer that implements `Eq` and `Hash` comparing /// the underlying pointer address. #[derive(Clone)] #[allow(private_bounds)] // This is so that PtrEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse. @@ -48,6 +48,7 @@ where arc_ptr_eq(&self.0, &other.0) } } +impl Eq for PtrEq> where T: ?Sized {} impl Hash for PtrEq> where diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 40e0da2678ebd..56f4867b562ba 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -755,6 +755,7 @@ impl PartialEq for AliasedScalarUDFImpl { inner.equals(other.inner.as_ref()) && aliases == &other.aliases } } +impl Eq for AliasedScalarUDFImpl {} impl Hash for AliasedScalarUDFImpl { fn hash(&self, state: &mut H) { diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 4be58c3a4add6..000c759985c0a 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -479,7 +479,7 @@ impl PartialOrd for dyn WindowUDFImpl { /// WindowUDF that adds an alias to the underlying function. It is better to /// implement [`WindowUDFImpl`], which supports aliases, directly if possible. -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AliasedWindowUDFImpl { inner: PtrEq>, aliases: Vec, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 874811665c096..80ad0f87846a7 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1276,9 +1276,9 @@ pub fn collect_subquery_cols( /// # use datafusion_expr_common::signature::Signature; /// # use std::any::Any; /// -/// // Implementing PartialEq & Hash is a prerequisite for using this macro, +/// // Implementing Eq & Hash is a prerequisite for using this macro, /// // but the implementation can be derived. -/// #[derive(Debug, PartialEq, Hash)] +/// #[derive(Debug, PartialEq, Eq, Hash)] /// struct VarcharToTimestampTz { /// safe: bool, /// } @@ -1322,11 +1322,13 @@ macro_rules! udf_equals_hash { ($udf_type:tt) => { fn equals(&self, other: &dyn $udf_type) -> bool { use ::core::any::Any; - use ::core::cmp::PartialEq; + use ::core::cmp::{Eq, PartialEq}; let Some(other) = ::downcast_ref::(other.as_any()) else { return false; }; + fn assert_self_impls_eq() {} + assert_self_impls_eq::(); PartialEq::eq(self, other) } @@ -1804,7 +1806,7 @@ mod tests { } } - #[derive(Debug, PartialEq, Hash)] + #[derive(Debug, PartialEq, Eq, Hash)] struct StatefulFunctionWithEqHashWithUdfEqualsHash { signature: Signature, state: bool, diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 4d634e0be2583..8f877d44f87f3 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -304,6 +304,7 @@ impl PartialEq for ForeignScalarUDF { && signature == &other.signature } } +impl Eq for ForeignScalarUDF {} impl Hash for ForeignScalarUDF { fn hash(&self, state: &mut H) { diff --git a/datafusion/ffi/src/udwf/mod.rs b/datafusion/ffi/src/udwf/mod.rs index ec1b6698f5803..a5e18cdf1e7e1 100644 --- a/datafusion/ffi/src/udwf/mod.rs +++ b/datafusion/ffi/src/udwf/mod.rs @@ -261,6 +261,7 @@ impl PartialEq for ForeignWindowUDF { std::ptr::eq(self, other) } } +impl Eq for ForeignWindowUDF {} impl Hash for ForeignWindowUDF { fn hash(&self, state: &mut H) { std::ptr::hash(self, state) diff --git a/datafusion/functions-window/src/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs index 140b7975140b1..8f9a1a7a72c02 100644 --- a/datafusion/functions-window/src/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -120,7 +120,7 @@ impl WindowShiftKind { } /// window shift expression -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct WindowShift { signature: Signature, kind: WindowShiftKind, diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index 783e6e5652ce9..2da2fae23d615 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -94,7 +94,7 @@ impl NthValueKind { } } -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct NthValue { signature: Signature, kind: NthValueKind, diff --git a/datafusion/functions-window/src/rank.rs b/datafusion/functions-window/src/rank.rs index 4b29b6dac8af8..e026bdf5949cc 100644 --- a/datafusion/functions-window/src/rank.rs +++ b/datafusion/functions-window/src/rank.rs @@ -64,7 +64,7 @@ define_udwf_and_expr!( ); /// Rank calculates the rank in the window function with order by -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Rank { name: String, signature: Signature, diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index dfe841fe36ffb..e05a28dcdef99 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -4443,7 +4443,7 @@ mod tests { /// A Mock UDWF which defines `simplify` to be used in tests /// related to UDWF simplification - #[derive(Debug, Clone, PartialEq, Hash)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimplifyMockUdwf { simplify: bool, } diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index eba227a84a04a..67254f741e4ce 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -151,7 +151,7 @@ pub struct MyAggregateUdfNode { pub result: String, } -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(in crate::cases) struct CustomUDWF { signature: Signature, payload: String, diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 25144042504f8..751254ff201a1 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3312,7 +3312,7 @@ fn make_udf(name: &'static str, args: Vec, return_type: DataType) -> S } /// Mocked UDF -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] struct DummyUDF { name: &'static str, signature: Signature, From b51310383853c4de568609061d280b88f6e84230 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Thu, 7 Aug 2025 10:31:39 +0200 Subject: [PATCH 2/3] Add UdfEq wrapper around Arc to UDF impl The wrapper implements PartialEq, Eq, Hash by forwarding to UDF impl equals and hash_value functions. --- datafusion/expr/src/lib.rs | 1 + datafusion/expr/src/udaf.rs | 24 ++--- datafusion/expr/src/udf.rs | 26 ++--- datafusion/expr/src/udf_eq.rs | 181 ++++++++++++++++++++++++++++++++++ datafusion/expr/src/udwf.rs | 4 +- 5 files changed, 204 insertions(+), 32 deletions(-) create mode 100644 datafusion/expr/src/udf_eq.rs diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 636d7aac59a0a..b4ad8387215ec 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -71,6 +71,7 @@ pub mod ptr_eq; pub mod test; pub mod tree_node; pub mod type_coercion; +pub mod udf_eq; pub mod utils; pub mod var_provider; pub mod window_frame; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 984c21d581660..6d7d815d90686 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -38,6 +38,7 @@ use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, }; use crate::groups_accumulator::GroupsAccumulator; +use crate::udf_eq::UdfEq; use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; use crate::{expr_vec_fmt, Accumulator, Expr}; @@ -1037,9 +1038,9 @@ pub enum ReversedUDAF { /// AggregateUDF that adds an alias to the underlying function. It is better to /// implement [`AggregateUDFImpl`], which supports aliases, directly if possible. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AliasedAggregateUDFImpl { - inner: Arc, + inner: UdfEq>, aliases: Vec, } @@ -1051,7 +1052,10 @@ impl AliasedAggregateUDFImpl { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } + Self { + inner: inner.into(), + aliases, + } } } @@ -1111,7 +1115,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { .map(|udf| { udf.map(|udf| { Arc::new(AliasedAggregateUDFImpl { - inner: udf, + inner: udf.into(), aliases: self.aliases.clone(), }) as Arc }) @@ -1135,17 +1139,15 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { } fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases - } else { - false - } + let Some(other) = other.as_any().downcast_ref::() else { + return false; + }; + self == other } fn hash_value(&self) -> u64 { let hasher = &mut DefaultHasher::new(); - self.inner.hash_value().hash(hasher); - self.aliases.hash(hasher); + self.hash(hasher); hasher.finish() } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 56f4867b562ba..272e131a83a76 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -21,6 +21,7 @@ use crate::async_udf::AsyncScalarUDF; use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; +use crate::udf_eq::UdfEq; use crate::{udf_equals_hash, ColumnarValue, Documentation, Expr, Signature}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::config::ConfigOptions; @@ -743,28 +744,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// ScalarUDF that adds an alias to the underlying function. It is better to /// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AliasedScalarUDFImpl { - inner: Arc, + inner: UdfEq>, aliases: Vec, } -impl PartialEq for AliasedScalarUDFImpl { - fn eq(&self, other: &Self) -> bool { - let Self { inner, aliases } = self; - inner.equals(other.inner.as_ref()) && aliases == &other.aliases - } -} -impl Eq for AliasedScalarUDFImpl {} - -impl Hash for AliasedScalarUDFImpl { - fn hash(&self, state: &mut H) { - let Self { inner, aliases } = self; - inner.hash_value().hash(state); - aliases.hash(state); - } -} - impl AliasedScalarUDFImpl { pub fn new( inner: Arc, @@ -772,7 +757,10 @@ impl AliasedScalarUDFImpl { ) -> Self { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } + Self { + inner: inner.into(), + aliases, + } } } diff --git a/datafusion/expr/src/udf_eq.rs b/datafusion/expr/src/udf_eq.rs new file mode 100644 index 0000000000000..1871aab3fd932 --- /dev/null +++ b/datafusion/expr/src/udf_eq.rs @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl}; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; +use std::sync::Arc; + +/// A wrapper around a pointer to UDF that implements `Eq` and `Hash` delegating to +/// corresponding methods on the UDF trait. +#[derive(Clone)] +#[allow(private_bounds)] // This is so that UdfEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse. +pub struct UdfEq(Ptr); + +impl PartialEq for UdfEq +where + Ptr: UdfPointer, +{ + fn eq(&self, other: &Self) -> bool { + self.0.equals(&other.0) + } +} +impl Eq for UdfEq where Ptr: UdfPointer {} +impl Hash for UdfEq +where + Ptr: UdfPointer, +{ + fn hash(&self, state: &mut H) { + self.0.hash_value().hash(state); + } +} + +impl From for UdfEq +where + Ptr: UdfPointer, +{ + fn from(ptr: Ptr) -> Self { + UdfEq(ptr) + } +} + +impl Debug for UdfEq +where + Ptr: UdfPointer + Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl Deref for UdfEq +where + Ptr: UdfPointer, +{ + type Target = Ptr; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +trait UdfPointer: Deref { + fn equals(&self, other: &Self::Target) -> bool; + fn hash_value(&self) -> u64; +} + +macro_rules! impl_for_udf_eq { + ($udf:ty) => { + impl UdfPointer for Arc<$udf> { + fn equals(&self, other: &$udf) -> bool { + self.as_ref().equals(other) + } + + fn hash_value(&self) -> u64 { + self.as_ref().hash_value() + } + } + }; +} + +impl_for_udf_eq!(dyn AggregateUDFImpl + '_); +impl_for_udf_eq!(dyn ScalarUDFImpl + '_); +impl_for_udf_eq!(dyn WindowUDFImpl + '_); + +#[cfg(test)] +mod tests { + use super::*; + use crate::ScalarFunctionArgs; + use arrow::datatypes::DataType; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::signature::{Signature, Volatility}; + use std::any::Any; + use std::hash::DefaultHasher; + + #[derive(Debug)] + struct TestScalarUDF { + signature: Signature, + name: &'static str, + } + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result { + unimplemented!() + } + + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + unimplemented!() + } + } + + #[test] + pub fn test_eq_eq_wrapper() { + let signature = Signature::any(1, Volatility::Immutable); + + let a1: Arc = Arc::new(TestScalarUDF { + signature: signature.clone(), + name: "a", + }); + let a2: Arc = Arc::new(TestScalarUDF { + signature: signature.clone(), + name: "a", + }); + let b: Arc = Arc::new(TestScalarUDF { + signature: signature.clone(), + name: "b", + }); + + // Reflexivity + let wrapper = UdfEq(Arc::clone(&a1)); + assert_eq!(wrapper, wrapper); + + // Two wrappers around equal pointer + assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a1))); + assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a1)))); + + // Two wrappers around different pointers but equal in ScalarUDFImpl::equals sense + assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a2))); + assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a2)))); + + // different functions (not equal) + assert_ne!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&b))); + } + + fn hash(value: T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } +} diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 000c759985c0a..032daa5c25bc0 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -29,7 +29,7 @@ use std::{ use arrow::datatypes::{DataType, FieldRef}; use crate::expr::WindowFunction; -use crate::ptr_eq::PtrEq; +use crate::udf_eq::UdfEq; use crate::{ function::WindowFunctionSimplification, udf_equals_hash, Expr, PartitionEvaluator, Signature, @@ -481,7 +481,7 @@ impl PartialOrd for dyn WindowUDFImpl { /// implement [`WindowUDFImpl`], which supports aliases, directly if possible. #[derive(Debug, PartialEq, Eq, Hash)] struct AliasedWindowUDFImpl { - inner: PtrEq>, + inner: UdfEq>, aliases: Vec, } From b855599f8607b64f5cc40cc0d371e4ddfd3140a6 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Thu, 7 Aug 2025 11:21:45 +0200 Subject: [PATCH 3/3] Derive UDAF equality from Eq, Hash Reduce boilerplate in cases where implementation of `AggregateUDFImpl::{equals,hash_value}` can be derived using standard `Eq` and `Hash` traits. --- .../user_defined/user_defined_aggregates.rs | 74 ++++++++----------- datafusion/doc/src/lib.rs | 4 +- datafusion/expr/src/expr_fn.rs | 45 ++--------- datafusion/expr/src/udaf.rs | 15 +--- datafusion/ffi/src/udaf/mod.rs | 44 ++++------- .../src/approx_percentile_cont.rs | 7 +- .../src/approx_percentile_cont_with_weight.rs | 30 ++------ .../functions-aggregate/src/array_agg.rs | 6 +- .../functions-aggregate/src/bit_and_or_xor.rs | 40 ++-------- .../functions-aggregate/src/first_last.rs | 56 ++------------ datafusion/functions-aggregate/src/regr.rs | 33 +-------- datafusion/functions-aggregate/src/stddev.rs | 24 ++---- .../functions-aggregate/src/string_agg.rs | 30 ++------ .../simplify_expressions/expr_simplifier.rs | 19 +---- datafusion/proto/tests/cases/mod.rs | 19 +---- 15 files changed, 102 insertions(+), 344 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 7f1a12e9cd960..cdba41a0d1bb8 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::collections::HashMap; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::mem::{size_of, size_of_val}; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -55,7 +55,7 @@ use datafusion_common::{assert_contains, exec_datafusion_err}; use datafusion_common::{cast::as_primitive_array, exec_err}; use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr, + col, create_udaf, function::AccumulatorArgs, udf_equals_hash, AggregateUDFImpl, Expr, GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -778,7 +778,7 @@ impl Accumulator for FirstSelector { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct TestGroupsAccumulator { signature: Signature, result: u64, @@ -817,20 +817,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator { Ok(Box::new(self.clone())) } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - self.result == other.result && self.signature == other.signature - } else { - false - } - } - - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.signature.hash(hasher); - self.result.hash(hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } impl Accumulator for TestGroupsAccumulator { @@ -902,6 +889,32 @@ struct MetadataBasedAggregateUdf { metadata: HashMap, } +impl PartialEq for MetadataBasedAggregateUdf { + fn eq(&self, other: &Self) -> bool { + let Self { + name, + signature, + metadata, + } = self; + name == &other.name + && signature == &other.signature + && metadata == &other.metadata + } +} +impl Eq for MetadataBasedAggregateUdf {} +impl Hash for MetadataBasedAggregateUdf { + fn hash(&self, state: &mut H) { + let Self { + name, + signature, + metadata: _, // unhashable + } = self; + std::any::type_name::().hash(state); + name.hash(state); + signature.hash(state); + } +} + impl MetadataBasedAggregateUdf { fn new(metadata: HashMap) -> Self { // The name we return must be unique. Otherwise we will not call distinct @@ -958,32 +971,7 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf { })) } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { - name, - signature, - metadata, - } = self; - name == &other.name - && signature == &other.signature - && metadata == &other.metadata - } - - fn hash_value(&self) -> u64 { - let Self { - name, - signature, - metadata: _, // unhashable - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - name.hash(&mut hasher); - signature.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[derive(Debug)] diff --git a/datafusion/doc/src/lib.rs b/datafusion/doc/src/lib.rs index ca74c3b06d6dc..c86a40ece204e 100644 --- a/datafusion/doc/src/lib.rs +++ b/datafusion/doc/src/lib.rs @@ -39,7 +39,7 @@ /// thus all text should be in English. /// /// [SQL function documentation]: https://datafusion.apache.org/user-guide/sql/index.html -#[derive(Debug, Clone, PartialEq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Documentation { /// The section in the documentation where the UDF will be documented pub doc_section: DocSection, @@ -158,7 +158,7 @@ impl Documentation { } } -#[derive(Debug, Clone, PartialEq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DocSection { /// True to include this doc section in the public /// documentation, false otherwise diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index b91036360d057..6e5cd068b3d8e 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -45,7 +45,7 @@ use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::Debug; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::ops::Not; use std::sync::Arc; @@ -511,11 +511,12 @@ pub fn create_udaf( /// Implements [`AggregateUDFImpl`] for functions that have a single signature and /// return type. +#[derive(PartialEq, Eq, Hash)] pub struct SimpleAggregateUDF { name: String, signature: Signature, return_type: DataType, - accumulator: AccumulatorFactoryFunction, + accumulator: PtrEq, state_fields: Vec, } @@ -547,7 +548,7 @@ impl SimpleAggregateUDF { name, signature, return_type, - accumulator, + accumulator: accumulator.into(), state_fields, } } @@ -566,7 +567,7 @@ impl SimpleAggregateUDF { name, signature, return_type, - accumulator, + accumulator: accumulator.into(), state_fields, } } @@ -600,41 +601,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { Ok(self.state_fields.clone()) } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { - name, - signature, - return_type, - accumulator, - state_fields, - } = self; - name == &other.name - && signature == &other.signature - && return_type == &other.return_type - && Arc::ptr_eq(accumulator, &other.accumulator) - && state_fields == &other.state_fields - } - - fn hash_value(&self) -> u64 { - let Self { - name, - signature, - return_type, - accumulator, - state_fields, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - name.hash(&mut hasher); - signature.hash(&mut hasher); - return_type.hash(&mut hasher); - Arc::as_ptr(accumulator).hash(&mut hasher); - state_fields.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } /// Creates a new UDWF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 6d7d815d90686..bd728013725d1 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -41,7 +41,7 @@ use crate::groups_accumulator::GroupsAccumulator; use crate::udf_eq::UdfEq; use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; -use crate::{expr_vec_fmt, Accumulator, Expr}; +use crate::{expr_vec_fmt, udf_equals_hash, Accumulator, Expr}; use crate::{Documentation, Signature}; /// Logical representation of a user-defined [aggregate function] (UDAF). @@ -1138,18 +1138,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.coerce_types(arg_types) } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - self == other - } - - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.hash(hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); fn is_descending(&self) -> Option { self.inner.is_descending() diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 66e1c28bb9fe2..63f7d26544f6e 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -39,7 +39,7 @@ use datafusion::{ }; use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::{ffi::c_void, sync::Arc}; use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; @@ -49,6 +49,7 @@ use crate::{ util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, volatility::FFI_Volatility, }; +use datafusion::logical_expr::udf_equals_hash; use prost::{DecodeError, Message}; mod accumulator; @@ -384,6 +385,19 @@ pub struct ForeignAggregateUDF { unsafe impl Send for ForeignAggregateUDF {} unsafe impl Sync for ForeignAggregateUDF {} +impl PartialEq for ForeignAggregateUDF { + fn eq(&self, other: &Self) -> bool { + // FFI_AggregateUDF cannot be compared, so identity equality is the best we can do. + std::ptr::eq(self, other) + } +} +impl Eq for ForeignAggregateUDF {} +impl Hash for ForeignAggregateUDF { + fn hash(&self, state: &mut H) { + std::ptr::hash(self, state) + } +} + impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { type Error = DataFusionError; @@ -554,33 +568,7 @@ impl AggregateUDFImpl for ForeignAggregateUDF { } } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { - signature, - aliases, - udaf, - } = self; - signature == &other.signature - && aliases == &other.aliases - && std::ptr::eq(udaf, &other.udaf) - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - aliases, - udaf, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - signature.hash(&mut hasher); - aliases.hash(&mut hasher); - std::ptr::hash(udaf, &mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[repr(C)] diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 863ee15d89ec4..36c005274deae 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -39,8 +39,8 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature, - TypeSignature, Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, + Signature, TypeSignature, Volatility, }; use datafusion_functions_aggregate_common::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, @@ -102,6 +102,7 @@ pub fn approx_percentile_cont( description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] +#[derive(PartialEq, Eq, Hash)] pub struct ApproxPercentileCont { signature: Signature, } @@ -336,6 +337,8 @@ impl AggregateUDFImpl for ApproxPercentileCont { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + udf_equals_hash!(AggregateUDFImpl); } #[derive(Debug)] diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index d30ea624cae90..9a19f43a52551 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::size_of_val; use std::sync::Arc; @@ -30,7 +30,8 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, + TypeSignature, }; use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest}; use datafusion_macros::user_doc; @@ -100,6 +101,7 @@ pub fn approx_percentile_cont_with_weight( description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] +#[derive(PartialEq, Eq, Hash)] pub struct ApproxPercentileContWithWeight { signature: Signature, approx_percentile_cont: ApproxPercentileCont, @@ -237,29 +239,7 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { self.doc() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { - signature, - approx_percentile_cont, - } = self; - signature == &other.signature - && approx_percentile_cont.equals(&other.approx_percentile_cont) - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - approx_percentile_cont, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - signature.hash(&mut hasher); - hasher.write_u64(approx_percentile_cont.hash_value()); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[derive(Debug)] diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 3ada331040d08..3d195738f8f67 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -36,7 +36,7 @@ use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; @@ -75,7 +75,7 @@ This aggregation function can only mix DISTINCT and ORDER BY if the ordering exp "#, standard_argument(name = "expression",) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] /// ARRAY_AGG aggregate expression pub struct ArrayAgg { signature: Signature, @@ -227,6 +227,8 @@ impl AggregateUDFImpl for ArrayAgg { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + udf_equals_hash!(AggregateUDFImpl); } #[derive(Debug)] diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 8ca5d992a7fea..8d573580d4235 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::collections::HashSet; use std::fmt::{Display, Formatter}; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::{size_of, size_of_val}; use ahash::RandomState; @@ -36,8 +36,8 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::INTEGERS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, - Signature, Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, + ReversedUDAF, Signature, Volatility, }; use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; @@ -211,7 +211,7 @@ impl Display for BitwiseOperationType { } /// [BitwiseOperation] struct encapsulates information about a bitwise operation. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct BitwiseOperation { signature: Signature, /// `operation` indicates the type of bitwise operation to be performed. @@ -314,37 +314,7 @@ impl AggregateUDFImpl for BitwiseOperation { Some(self.documentation) } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { - signature, - operation, - func_name, - documentation, - } = self; - signature == &other.signature - && operation == &other.operation - && func_name == &other.func_name - && documentation == &other.documentation - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - operation, - func_name, - documentation, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - signature.hash(&mut hasher); - operation.hash(&mut hasher); - func_name.hash(&mut hasher); - documentation.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } struct BitAndAccumulator { diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 0856237d08cb5..87f14ae634c35 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::fmt::Debug; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::size_of_val; use std::sync::Arc; @@ -45,8 +45,8 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt, - GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, + ExprFunctionExt, GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_macros::user_doc; @@ -89,6 +89,7 @@ pub fn last_value(expression: Expr, order_by: Vec) -> Expr { ```"#, standard_argument(name = "expression",) )] +#[derive(PartialEq, Eq, Hash)] pub struct FirstValue { signature: Signature, is_input_pre_ordered: bool, @@ -294,29 +295,7 @@ impl AggregateUDFImpl for FirstValue { self.doc() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { - signature, - is_input_pre_ordered, - } = self; - signature == &other.signature - && is_input_pre_ordered == &other.is_input_pre_ordered - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - is_input_pre_ordered, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - signature.hash(&mut hasher); - is_input_pre_ordered.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } // TODO: rename to PrimitiveGroupsAccumulator @@ -1029,6 +1008,7 @@ impl Accumulator for FirstValueAccumulator { ```"#, standard_argument(name = "expression",) )] +#[derive(PartialEq, Eq, Hash)] pub struct LastValue { signature: Signature, is_input_pre_ordered: bool, @@ -1238,29 +1218,7 @@ impl AggregateUDFImpl for LastValue { } } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { - signature, - is_input_pre_ordered, - } = self; - signature == &other.signature - && is_input_pre_ordered == &other.is_input_pre_ordered - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - is_input_pre_ordered, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - signature.hash(&mut hasher); - is_input_pre_ordered.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } /// This accumulator is used when there is no ordering specified for the diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index f7e0f0a104cd7..c8dde7aed6aed 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -34,11 +34,11 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; use std::any::Any; use std::fmt::Debug; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::size_of_val; use std::sync::{Arc, LazyLock}; @@ -59,6 +59,7 @@ make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); +#[derive(PartialEq, Eq, Hash)] pub struct Regr { signature: Signature, regr_type: RegrType, @@ -527,33 +528,7 @@ impl AggregateUDFImpl for Regr { self.regr_type.documentation() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { - signature, - regr_type, - func_name, - } = self; - signature == &other.signature - && regr_type == &other.regr_type - && func_name == &other.func_name - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - regr_type, - func_name, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - signature.hash(&mut hasher); - regr_type.hash(&mut hasher); - func_name.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } /// `RegrAccumulator` is used to compute linear regression aggregate functions diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 2f9f1cac84d49..d0512b38154bc 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::align_of_val; use std::sync::Arc; @@ -31,8 +31,8 @@ use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, - Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, + Signature, Volatility, }; use datafusion_functions_aggregate_common::stats::StatsType; use datafusion_macros::user_doc; @@ -62,6 +62,7 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression +#[derive(PartialEq, Eq, Hash)] pub struct Stddev { signature: Signature, alias: Vec, @@ -155,22 +156,7 @@ impl AggregateUDFImpl for Stddev { self.doc() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { signature, alias } = self; - signature == &other.signature && alias == &other.alias - } - - fn hash_value(&self) -> u64 { - let Self { signature, alias } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - signature.hash(&mut hasher); - alias.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } make_udaf_expr_and_func!( diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 56c5ee1aaa676..756457274475e 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -18,7 +18,7 @@ //! [`StringAgg`] accumulator for the `string_agg` function use std::any::Any; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::size_of_val; use crate::array_agg::ArrayAgg; @@ -29,7 +29,8 @@ use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, Signature, + TypeSignature, Volatility, }; use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs; use datafusion_macros::user_doc; @@ -82,7 +83,7 @@ This aggregation function can only mix DISTINCT and ORDER BY if the ordering exp ) )] /// STRING_AGG aggregate expression -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct StringAgg { signature: Signature, array_agg: ArrayAgg, @@ -182,28 +183,7 @@ impl AggregateUDFImpl for StringAgg { self.doc() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { - signature, - array_agg, - } = self; - signature == &other.signature && array_agg.equals(&other.array_agg) - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - array_agg, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - signature.hash(&mut hasher); - hasher.write_u64(array_agg.hash_value()); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[derive(Debug)] diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index e05a28dcdef99..d33cefb341046 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -2181,7 +2181,7 @@ mod tests { }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; - use std::hash::{DefaultHasher, Hash, Hasher}; + use std::hash::Hash; use std::{ collections::HashMap, ops::{BitAnd, BitOr, BitXor}, @@ -4347,7 +4347,7 @@ mod tests { /// A Mock UDAF which defines `simplify` to be used in tests /// related to UDAF simplification - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimplifyMockUdaf { simplify: bool, } @@ -4406,20 +4406,7 @@ mod tests { } } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { simplify } = self; - simplify == &other.simplify - } - - fn hash_value(&self) -> u64 { - let Self { simplify } = self; - let mut hasher = DefaultHasher::new(); - simplify.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[test] diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index 67254f741e4ce..ee5005fdde75c 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -27,7 +27,7 @@ use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use std::any::Any; use std::fmt::Debug; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; mod roundtrip_logical_plan; mod roundtrip_physical_plan; @@ -127,22 +127,7 @@ impl AggregateUDFImpl for MyAggregateUDF { unimplemented!() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::() else { - return false; - }; - let Self { signature, result } = self; - signature == &other.signature && result == &other.result - } - - fn hash_value(&self) -> u64 { - let Self { signature, result } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::().hash(&mut hasher); - signature.hash(&mut hasher); - result.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[derive(Clone, PartialEq, ::prost::Message)]