Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 31 additions & 43 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

use std::any::Any;
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::hash::{Hash, Hasher};
use std::mem::{size_of, size_of_val};
use std::sync::{
atomic::{AtomicBool, Ordering},
Expand Down Expand Up @@ -55,7 +55,7 @@ use datafusion_common::{assert_contains, exec_datafusion_err};
use datafusion_common::{cast::as_primitive_array, exec_err};
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr,
col, create_udaf, function::AccumulatorArgs, udf_equals_hash, AggregateUDFImpl, Expr,
GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::average::AvgAccumulator;
Expand Down Expand Up @@ -778,7 +778,7 @@ impl Accumulator for FirstSelector {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TestGroupsAccumulator {
signature: Signature,
result: u64,
Expand Down Expand Up @@ -817,20 +817,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
Ok(Box::new(self.clone()))
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<TestGroupsAccumulator>() {
self.result == other.result && self.signature == other.signature
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.signature.hash(hasher);
self.result.hash(hasher);
hasher.finish()
}
udf_equals_hash!(AggregateUDFImpl);
}

impl Accumulator for TestGroupsAccumulator {
Expand Down Expand Up @@ -902,6 +889,32 @@ struct MetadataBasedAggregateUdf {
metadata: HashMap<String, String>,
}

impl PartialEq for MetadataBasedAggregateUdf {
fn eq(&self, other: &Self) -> bool {
let Self {
name,
signature,
metadata,
} = self;
name == &other.name
&& signature == &other.signature
&& metadata == &other.metadata
}
}
impl Eq for MetadataBasedAggregateUdf {}
impl Hash for MetadataBasedAggregateUdf {
fn hash<H: Hasher>(&self, state: &mut H) {
let Self {
name,
signature,
metadata: _, // unhashable
} = self;
std::any::type_name::<Self>().hash(state);
name.hash(state);
signature.hash(state);
}
}

impl MetadataBasedAggregateUdf {
fn new(metadata: HashMap<String, String>) -> Self {
// The name we return must be unique. Otherwise we will not call distinct
Expand Down Expand Up @@ -958,32 +971,7 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf {
}))
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
let Some(other) = other.as_any().downcast_ref::<Self>() else {
return false;
};
let Self {
name,
signature,
metadata,
} = self;
name == &other.name
&& signature == &other.signature
&& metadata == &other.metadata
}

fn hash_value(&self) -> u64 {
let Self {
name,
signature,
metadata: _, // unhashable
} = self;
let mut hasher = DefaultHasher::new();
std::any::type_name::<Self>().hash(&mut hasher);
name.hash(&mut hasher);
signature.hash(&mut hasher);
hasher.finish()
}
udf_equals_hash!(AggregateUDFImpl);
}

#[derive(Debug)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ async fn scalar_udf() -> Result<()> {
Ok(())
}

#[derive(PartialEq, Hash)]
#[derive(PartialEq, Eq, Hash)]
struct Simple0ArgsScalarUDF {
name: String,
signature: Signature,
Expand Down Expand Up @@ -492,7 +492,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
}

/// Volatile UDF that should append a different value to each row
#[derive(Debug, PartialEq, Hash)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct AddIndexToStringVolatileScalarUDF {
name: String,
signature: Signature,
Expand Down Expand Up @@ -941,7 +941,7 @@ impl FunctionFactory for CustomFunctionFactory {
//
// it also defines custom [ScalarUDFImpl::simplify()]
// to replace ScalarUDF expression with one instance contains.
#[derive(Debug, PartialEq, Hash)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct ScalarFunctionWrapper {
name: String,
expr: Expr,
Expand Down Expand Up @@ -1221,6 +1221,7 @@ impl PartialEq for MyRegexUdf {
signature == &other.signature && regex.as_str() == other.regex.as_str()
}
}
impl Eq for MyRegexUdf {}

impl Hash for MyRegexUdf {
fn hash<H: Hasher>(&self, state: &mut H) {
Expand Down Expand Up @@ -1380,7 +1381,7 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordB
ctx.sql(sql).await?.collect().await
}

#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Eq)]
struct MetadataBasedUdf {
name: String,
signature: Signature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ impl OddCounter {
}

fn register(ctx: &mut SessionContext, test_state: Arc<TestState>) {
#[derive(Debug, Clone, PartialEq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct SimpleWindowUDF {
signature: Signature,
test_state: PtrEq<Arc<TestState>>,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/doc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
/// thus all text should be in English.
///
/// [SQL function documentation]: https://datafusion.apache.org/user-guide/sql/index.html
#[derive(Debug, Clone, PartialEq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Documentation {
/// The section in the documentation where the UDF will be documented
pub doc_section: DocSection,
Expand Down Expand Up @@ -158,7 +158,7 @@ impl Documentation {
}
}

#[derive(Debug, Clone, PartialEq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DocSection {
/// True to include this doc section in the public
/// documentation, false otherwise
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/async_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ impl PartialEq for AsyncScalarUDF {
arc_ptr_eq(inner, &other.inner)
}
}
impl Eq for AsyncScalarUDF {}

impl Hash for AsyncScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
Expand Down
49 changes: 8 additions & 41 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::Debug;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::hash::Hash;
use std::ops::Not;
use std::sync::Arc;

Expand Down Expand Up @@ -403,7 +403,7 @@ pub fn create_udf(

/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
/// return type.
#[derive(PartialEq, Hash)]
#[derive(PartialEq, Eq, Hash)]
pub struct SimpleScalarUDF {
name: String,
signature: Signature,
Expand Down Expand Up @@ -511,11 +511,12 @@ pub fn create_udaf(

/// Implements [`AggregateUDFImpl`] for functions that have a single signature and
/// return type.
#[derive(PartialEq, Eq, Hash)]
pub struct SimpleAggregateUDF {
name: String,
signature: Signature,
return_type: DataType,
accumulator: AccumulatorFactoryFunction,
accumulator: PtrEq<AccumulatorFactoryFunction>,
state_fields: Vec<FieldRef>,
}

Expand Down Expand Up @@ -547,7 +548,7 @@ impl SimpleAggregateUDF {
name,
signature,
return_type,
accumulator,
accumulator: accumulator.into(),
state_fields,
}
}
Expand All @@ -566,7 +567,7 @@ impl SimpleAggregateUDF {
name,
signature,
return_type,
accumulator,
accumulator: accumulator.into(),
state_fields,
}
}
Expand Down Expand Up @@ -600,41 +601,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
Ok(self.state_fields.clone())
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
let Some(other) = other.as_any().downcast_ref::<Self>() else {
return false;
};
let Self {
name,
signature,
return_type,
accumulator,
state_fields,
} = self;
name == &other.name
&& signature == &other.signature
&& return_type == &other.return_type
&& Arc::ptr_eq(accumulator, &other.accumulator)
&& state_fields == &other.state_fields
}

fn hash_value(&self) -> u64 {
let Self {
name,
signature,
return_type,
accumulator,
state_fields,
} = self;
let mut hasher = DefaultHasher::new();
std::any::type_name::<Self>().hash(&mut hasher);
name.hash(&mut hasher);
signature.hash(&mut hasher);
return_type.hash(&mut hasher);
Arc::as_ptr(accumulator).hash(&mut hasher);
state_fields.hash(&mut hasher);
hasher.finish()
}
udf_equals_hash!(AggregateUDFImpl);
}

/// Creates a new UDWF with a specific signature, state type and return type.
Expand All @@ -661,7 +628,7 @@ pub fn create_udwf(

/// Implements [`WindowUDFImpl`] for functions that have a single signature and
/// return type.
#[derive(PartialEq, Hash)]
#[derive(PartialEq, Eq, Hash)]
pub struct SimpleWindowUDF {
name: String,
signature: Signature,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub mod ptr_eq;
pub mod test;
pub mod tree_node;
pub mod type_coercion;
pub mod udf_eq;
pub mod utils;
pub mod var_provider;
pub mod window_frame;
Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/ptr_eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub fn arc_ptr_hash<T: ?Sized>(a: &Arc<T>, hasher: &mut impl Hasher) {
std::ptr::hash(Arc::as_ptr(a), hasher)
}

/// A wrapper around a pointer that implements `PartialEq` and `Hash` comparing
/// A wrapper around a pointer that implements `Eq` and `Hash` comparing
/// the underlying pointer address.
#[derive(Clone)]
#[allow(private_bounds)] // This is so that PtrEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse.
Expand All @@ -48,6 +48,7 @@ where
arc_ptr_eq(&self.0, &other.0)
}
}
impl<T> Eq for PtrEq<Arc<T>> where T: ?Sized {}

impl<T> Hash for PtrEq<Arc<T>>
where
Expand Down
29 changes: 10 additions & 19 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
};
use crate::groups_accumulator::GroupsAccumulator;
use crate::udf_eq::UdfEq;
use crate::utils::format_state_name;
use crate::utils::AggregateOrderSensitivity;
use crate::{expr_vec_fmt, Accumulator, Expr};
use crate::{expr_vec_fmt, udf_equals_hash, Accumulator, Expr};
use crate::{Documentation, Signature};

/// Logical representation of a user-defined [aggregate function] (UDAF).
Expand Down Expand Up @@ -1037,9 +1038,9 @@ pub enum ReversedUDAF {

/// AggregateUDF that adds an alias to the underlying function. It is better to
/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct AliasedAggregateUDFImpl {
inner: Arc<dyn AggregateUDFImpl>,
inner: UdfEq<Arc<dyn AggregateUDFImpl>>,
aliases: Vec<String>,
}

Expand All @@ -1051,7 +1052,10 @@ impl AliasedAggregateUDFImpl {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));

Self { inner, aliases }
Self {
inner: inner.into(),
aliases,
}
}
}

Expand Down Expand Up @@ -1111,7 +1115,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
.map(|udf| {
udf.map(|udf| {
Arc::new(AliasedAggregateUDFImpl {
inner: udf,
inner: udf.into(),
aliases: self.aliases.clone(),
}) as Arc<dyn AggregateUDFImpl>
})
Expand All @@ -1134,20 +1138,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
self.inner.coerce_types(arg_types)
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.inner.hash_value().hash(hasher);
self.aliases.hash(hasher);
hasher.finish()
}
udf_equals_hash!(AggregateUDFImpl);

fn is_descending(&self) -> Option<bool> {
self.inner.is_descending()
Expand Down
Loading