Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix tests
  • Loading branch information
dharanad committed Jun 20, 2024
commit b205315e1d5d4b56a198e6746c7d57878347bbb4
1 change: 1 addition & 0 deletions datafusion-examples/examples/dataframe_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use arrow_schema::DataType;
use std::sync::Arc;

use datafusion::error::Result;
use datafusion::functions_aggregate::average::avg;
use datafusion::prelude::*;
use datafusion::test_util::arrow_test_data;
use datafusion_common::ScalarValue;
Expand Down
36 changes: 16 additions & 20 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};

use arrow_schema::{Field, Schema};

use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
use datafusion::error::Result;
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion::{assert_batches_eq, prelude::*};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::test::function_stub::avg_udaf;
use datafusion_expr::{
expr::{AggregateFunction, AggregateFunctionDefinition},
function::AccumulatorArgs,
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
expr::AggregateFunction, function::AccumulatorArgs, Accumulator, AggregateUDF,
AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
Expand Down Expand Up @@ -92,18 +92,14 @@ impl AggregateUDFImpl for BetterAvgUdaf {
// with build-in aggregate function to illustrate the use
let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction,
_: &dyn SimplifyInfo| {
Ok(Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
datafusion_expr::aggregate_function::AggregateFunction::Avg,
),
args: aggregate_function.args,
distinct: aggregate_function.distinct,
filter: aggregate_function.filter,
order_by: aggregate_function.order_by,
null_treatment: aggregate_function.null_treatment,
}))
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
avg_udaf(),
vec![],
false,
None,
None,
None,
)))
};

Some(Box::new(simplify))
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simplify_udwf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
_: &dyn SimplifyInfo| {
Ok(Expr::WindowFunction(WindowFunction {
fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction(
AggregateFunction::Avg,
AggregateFunction::Max,
),
args: window_function.args,
partition_by: window_function.partition_by,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1816,7 +1816,7 @@ mod tests {

assert_batches_sorted_eq!(
["+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT aggregate_test_100.c12) |",
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT aggregate_test_100.c12) |",
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |",
"| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |",
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col,
placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{count, sum};
use datafusion_functions_aggregate::expr_fn::{avg, count, sum};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
use datafusion_functions_aggregate::average::AvgAccumulator;

/// Test to show the contents of the setup
#[tokio::test]
async fn test_setup() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> {
let actual = plan_and_collect(&ctx, sql).await.unwrap();
let expected = [
"+------------------------------------------+",
"| AVG(custom_sqrt(aggregate_test_100.c11)) |",
"| avg(custom_sqrt(aggregate_test_100.c11)) |",
"+------------------------------------------+",
"| 0.6584408483418835 |",
"+------------------------------------------+",
Expand All @@ -69,7 +69,7 @@ async fn csv_query_avg_sqrt() -> Result<()> {
let actual = plan_and_collect(&ctx, sql).await.unwrap();
let expected = [
"+------------------------------------------+",
"| AVG(custom_sqrt(aggregate_test_100.c12)) |",
"| avg(custom_sqrt(aggregate_test_100.c12)) |",
"+------------------------------------------+",
"| 0.6706002946036459 |",
"+------------------------------------------+",
Expand Down
1 change: 0 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2238,7 +2238,6 @@ mod test {
"nth_value",
"min",
"max",
"avg",
];
for name in names {
let fun = find_df_window_func(name).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/expr_rewriter/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ mod test {
expected: sort(col("c1") + col("MIN(t.c2)")),
},
TestCase {
desc: r#"avg(c3) --> "AVG(t.c3)" as average (column *named* "AVG(t.c3)", aliased)"#,
desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
input: sort(avg(col("c3"))),
expected: sort(col("AVG(t.c3)").alias("average")),
expected: sort(col("avg(t.c3)").alias("average")),
},
];

Expand Down
10 changes: 5 additions & 5 deletions datafusion/expr/src/test/function_stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ create_func!(Average, avg_udaf);

pub fn avg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new_udf(
crate::test::function_stub::avg_udaf(),
avg_udaf(),
vec![expr],
false,
None,
Expand Down Expand Up @@ -317,7 +317,7 @@ impl AggregateUDFImpl for Average {
}

fn name(&self) -> &str {
"average"
"avg"
}

fn signature(&self) -> &Signature {
Expand All @@ -332,10 +332,10 @@ impl AggregateUDFImpl for Average {
not_impl_err!("no impl for stub")
}

fn aliases(&self) -> &[String] {
&self.aliases
}
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
not_impl_err!("no impl for stub")
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}
35 changes: 32 additions & 3 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ use arrow::array::{
};
use arrow::compute::sum;
use arrow::datatypes::{
i256, ArrowNativeType, Decimal128Type, Decimal256Type, DecimalType, Float64Type,
UInt64Type,
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
Float64Type, UInt64Type,
};
use arrow_schema::{DataType, Field};
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
Expand Down Expand Up @@ -593,3 +592,33 @@ pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
other => exec_err!("AVG does not support {other:?}"),
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_avg_return_type() -> Result<()> {
let observed = Average::default().return_type(&[DataType::Float32])?;
assert_eq!(DataType::Float64, observed);

let observed = Average::default().return_type(&[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);

let observed = Average::default().return_type(&[DataType::Int32])?;
assert_eq!(DataType::Float64, observed);

let observed = Average::default().return_type(&[DataType::Decimal128(10, 6)])?;
assert_eq!(DataType::Decimal128(14, 10), observed);

let observed = Average::default().return_type(&[DataType::Decimal128(36, 6)])?;
assert_eq!(DataType::Decimal128(38, 10), observed);
Ok(())
}

#[test]
fn test_avg_no_utf8() {
let observed = Average::default().return_type(&[DataType::Utf8]);
assert!(observed.is_err());
}
}
34 changes: 18 additions & 16 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -815,13 +815,14 @@ mod test {
use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue};
use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
use datafusion_expr::logical_plan::{EmptyRelation, Projection};
use datafusion_expr::test::function_stub::avg_udaf;
use datafusion_expr::{
cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction,
AggregateFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr,
ExprSchemable, Filter, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl,
Signature, SimpleAggregateUDF, Subquery, Volatility,
cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
Operator, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery,
Volatility,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
use datafusion_functions_aggregate::average::AvgAccumulator;

use crate::analyzer::type_coercion::{
coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
Expand Down Expand Up @@ -1000,44 +1001,45 @@ mod test {
Ok(())
}

#[ignore]
#[test]
fn agg_function_case() -> Result<()> {
// FIXME
let empty = empty();
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
avg_udaf(),
vec![lit(12i64)],
false,
None,
None,
None,
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
let expected = "Projection: AVG(CAST(Int64(12) AS Float64))\n EmptyRelation";
let expected = "Projection: avg(CAST(Int64(12) AS Float64))\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;

let empty = empty_with_type(DataType::Int32);
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
avg_udaf(),
vec![col("a")],
false,
None,
None,
None,
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
let expected = "Projection: AVG(CAST(a AS Float64))\n EmptyRelation";
let expected = "Projection: avg(CAST(a AS Float64))\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
Ok(())
}

#[ignore]
#[test]
fn agg_function_invalid_input_avg() -> Result<()> {
// FIXME
let empty = empty();
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
avg_udaf(),
vec![lit("1")],
false,
None,
Expand Down
Loading