diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index e86d742db3d45..ddaca918ad1cc 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -29,7 +29,6 @@ use datafusion_common::{ use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, function::{AccumulatorArgs, StateFieldsArgs}, - type_coercion::aggregates::NUMERICS, utils::format_state_name, }; use datafusion_functions_aggregate_common::stats::StatsType; @@ -94,7 +93,7 @@ impl CovarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("covar")], - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::numeric(2, Volatility::Immutable), } } } @@ -188,7 +187,7 @@ impl Default for CovariancePopulation { impl CovariancePopulation { pub fn new() -> Self { Self { - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::numeric(2, Volatility::Immutable), } } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a3fab065dc097..8b754febd8e15 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -562,6 +562,58 @@ from data ---- 1 0.666666666667 +# covariance_decimal_1 +statement ok +create table t_covar_decimal (c1 decimal(10,2), c2 decimal(10,2)) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t_covar_decimal; +---- +0.666666666667 Float64 + +query RT +select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t_covar_decimal; +---- +1 Float64 + +statement ok +drop table t_covar_decimal; + +# covariance_decimal_with_nulls +statement ok +create table t_covar_decimal_nulls (f decimal(10,2), b decimal(10,2)) as values + (1, 4), + (null, 99), + (2, 5), + (98, null), + (3, 6), + (null, null); + +query RR +select covar_samp(f, b), covar_pop(f, b) from t_covar_decimal_nulls; +---- +1 0.666666666667 + +statement ok +drop table t_covar_decimal_nulls; + +# covariance_mixed_decimal_float +statement ok +create table t_covar_mixed (x decimal(10,2), y double) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_pop(x, y), arrow_typeof(covar_pop(x, y)) from t_covar_mixed; +---- +0.666666666667 Float64 + +query RT +select covar_samp(x, y), arrow_typeof(covar_samp(x, y)) from t_covar_mixed; +---- +1 Float64 + +statement ok +drop table t_covar_mixed; + # csv_query_correlation query R SELECT corr(c2, c12) FROM aggregate_test_100