From 2cbf0622e9aef32fb4aa4981c5338878052ee95c Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Tue, 18 Oct 2022 15:49:56 +0800 Subject: [PATCH 1/2] optimize concat_ws Signed-off-by: remzi <13716567376yh@gmail.com> --- datafusion/expr/src/expr_fn.rs | 13 +- datafusion/expr/src/literal.rs | 6 + .../optimizer/src/simplify_expressions.rs | 257 ++++++++++++------ datafusion/optimizer/src/type_coercion.rs | 2 +- .../optimizer/tests/integration-test.rs | 13 + 5 files changed, 203 insertions(+), 88 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index ba5351a850a3f..e0954963fade6 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,7 +19,7 @@ use crate::expr::{BinaryExpr, GroupingSet}; use crate::{ - aggregate_function, built_in_function, conditional_expressions::CaseBuilder, lit, + aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, @@ -134,11 +134,11 @@ pub fn concat(args: &[Expr]) -> Expr { } /// Concatenates all but the first argument, with separators. -/// The first argument is used as the separator string, and should not be NULL. -/// Other NULL arguments are ignored. -pub fn concat_ws(sep: impl Into, values: &[Expr]) -> Expr { - let mut args = vec![lit(sep.into())]; - args.extend_from_slice(values); +/// The first argument is used as the separator. +/// NULL arguments in `values` are ignored. +pub fn concat_ws(sep: Expr, values: Vec) -> Expr { + let mut args = values; + args.insert(0, sep); Expr::ScalarFunction { fun: built_in_function::BuiltinScalarFunction::ConcatWithSeparator, args, @@ -524,6 +524,7 @@ pub fn call_fn(name: impl AsRef, args: Vec) -> Result { #[cfg(test)] mod test { use super::*; + use crate::lit; #[test] fn filter_is_null_and_is_not_null() { diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index 08646b808644b..dc7412b5946c2 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -53,6 +53,12 @@ impl Literal for String { } } +impl Literal for &String { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + } +} + impl Literal for Vec { fn lit(&self) -> Expr { Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index a9970abf6b641..541ab600c8065 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -27,7 +27,7 @@ use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::{ expr::Between, - expr_fn::{and, or}, + expr_fn::{and, concat_ws, or}, expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, lit, logical_plan::LogicalPlan, @@ -256,6 +256,118 @@ fn negate_clause(expr: Expr) -> Expr { } } +/// Simplify the `concat` function by +/// 1. filtering out all `null` literals +/// 2. concatenating contiguous literal arguments +/// +/// For example: +/// `concat(col(a), 'hello ', 'world', col(b), null)` +/// will be optimized to +/// `concat(col(a), 'hello world', col(b))` +fn simpl_concat(args: Vec) -> Result { + let mut new_args = Vec::with_capacity(args.len()); + let mut contiguous_scalar = "".to_string(); + for arg in args { + match arg { + // filter out `null` args + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} + // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. + // Concatenate it with the `contiguous_scalar`. + Expr::Literal( + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), + ) => contiguous_scalar += &v, + Expr::Literal(x) => return Err(DataFusionError::Internal(format!( + "The scalar {} should be casted to string type during the type coercion.", + x + ))), + // If the arg is not a literal, we should first push the current `contiguous_scalar` + // to the `new_args` (if it is not empty) and reset it to empty string. + // Then pushing this arg to the `new_args`. + arg => { + if !contiguous_scalar.is_empty() { + new_args.push(lit(contiguous_scalar)); + contiguous_scalar = "".to_string(); + } + new_args.push(arg); + } + } + } + if !contiguous_scalar.is_empty() { + new_args.push(lit(contiguous_scalar)); + } + + Ok(Expr::ScalarFunction { + fun: BuiltinScalarFunction::Concat, + args: new_args, + }) +} + +fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { + match delimiter { + Expr::Literal( + ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter), + ) => { + match delimiter { + // when the delimiter is an empty string, + // we can use `concat` to replace `concat_ws` + Some(delimiter) if delimiter.is_empty() => simpl_concat(args.to_vec()), + Some(delimiter) => { + let mut new_args = Vec::with_capacity(args.len()); + new_args.push(lit(delimiter)); + let mut contiguous_scalar = None; + for arg in args { + match arg { + // filter out null args + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} + Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => { + match contiguous_scalar { + None => contiguous_scalar = Some(v.to_string()), + Some(mut pre) => { + pre += delimiter; + pre += v; + contiguous_scalar = Some(pre) + } + } + } + Expr::Literal(s) => return Err(DataFusionError::Internal(format!("The scalar {} should be casted to string type during the type coercion.", s))), + // If the arg is not a literal, we should first push the current `contiguous_scalar` + // to the `new_args` and reset it to None. + // Then pushing this arg to the `new_args`. + arg => { + if let Some(val) = contiguous_scalar { + new_args.push(lit(val)); + } + new_args.push(arg.clone()); + contiguous_scalar = None; + } + } + } + if let Some(val) = contiguous_scalar { + new_args.push(lit(val)); + } + Ok(Expr::ScalarFunction { + fun: BuiltinScalarFunction::ConcatWithSeparator, + args: new_args, + }) + } + // if the delimiter is null, then the value of the whole expression is null. + None => Ok(Expr::Literal(ScalarValue::Utf8(None))), + } + } + Expr::Literal(d) => Err(DataFusionError::Internal(format!( + "The scalar {} should be casted to string type during the type coercion.", + d + ))), + d => Ok(concat_ws( + d.clone(), + args.iter() + .cloned() + .filter(|x| !is_null(x)) + .collect::>(), + )), + } +} + impl OptimizerRule for SimplifyExpressions { fn name(&self) -> &str { "simplify_expressions" @@ -880,62 +992,19 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { Expr::ScalarFunction { fun: BuiltinScalarFunction::Concat, args, - } => { - let mut new_args = Vec::with_capacity(args.len()); - let mut contiguous_scalar = "".to_string(); - for e in args { - match e { - // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. - // Concatenate it with `contiguous_scalar`. - Expr::Literal( - ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x), - ) => { - if let Some(s) = x { - contiguous_scalar += &s; - } - } - // If the arg is not a literal, we should first push the current `contiguous_scalar` - // to the `new_args` (if it is not empty) and reset it to empty string. - // Then pushing this arg to the `new_args`. - e => { - if !contiguous_scalar.is_empty() { - new_args.push(Expr::Literal(ScalarValue::Utf8(Some( - contiguous_scalar.clone(), - )))); - contiguous_scalar = "".to_string(); - } - new_args.push(e); - } - } - } - if !contiguous_scalar.is_empty() { - new_args - .push(Expr::Literal(ScalarValue::Utf8(Some(contiguous_scalar)))); - } - - Expr::ScalarFunction { - fun: BuiltinScalarFunction::Concat, - args: new_args, - } - } + } => simpl_concat(args)?, // concat_ws Expr::ScalarFunction { fun: BuiltinScalarFunction::ConcatWithSeparator, args, - } => { - match &args[..] { - // concat_ws(null, ..) --> null - [Expr::Literal(sp), ..] if sp.is_null() => { - Expr::Literal(ScalarValue::Utf8(None)) - } - // TODO https://github.com/apache/arrow-datafusion/issues/3599 - _ => Expr::ScalarFunction { - fun: BuiltinScalarFunction::ConcatWithSeparator, - args, - }, - } - } + } => match &args[..] { + [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, + _ => Expr::ScalarFunction { + fun: BuiltinScalarFunction::ConcatWithSeparator, + args, + }, + }, // // Rules for Between @@ -994,6 +1063,7 @@ mod tests { use chrono::{DateTime, TimeZone, Utc}; use datafusion_common::{DFField, DFSchemaRef}; use datafusion_expr::expr::Case; + use datafusion_expr::expr_fn::{concat, concat_ws}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ and, binary_expr, call_fn, col, create_udf, lit, lit_timestamp_nano, @@ -1379,56 +1449,81 @@ mod tests { } #[test] - fn test_simplify_concat_ws_null_separator() { - fn build_concat_ws_expr(args: &[Expr]) -> Expr { - Expr::ScalarFunction { - fun: BuiltinScalarFunction::ConcatWithSeparator, - args: args.to_vec(), - } + fn test_simplify_concat_ws() { + let null = Expr::Literal(ScalarValue::Utf8(None)); + // the delimiter is not a literal + { + let expr = concat_ws(col("c"), vec![lit("a"), null.clone(), lit("b")]); + let expected = concat_ws(col("c"), vec![lit("a"), lit("b")]); + assert_eq!(simplify(expr), expected); + } + + // the delimiter is an empty string + { + let expr = concat_ws(lit(""), vec![col("a"), lit("c"), lit("b")]); + let expected = concat(&[col("a"), lit("cb")]); + assert_eq!(simplify(expr), expected); + } + + // the delimiter is a not-empty string + { + let expr = concat_ws( + lit("-"), + vec![ + null.clone(), + col("c0"), + lit("hello"), + null.clone(), + lit("rust"), + col("c1"), + lit(""), + lit(""), + null, + ], + ); + let expected = concat_ws( + lit("-"), + vec![col("c0"), lit("hello-rust"), col("c1"), lit("-")], + ); + assert_eq!(simplify(expr), expected) } + } + #[test] + fn test_simplify_concat_ws_with_null() { let null = Expr::Literal(ScalarValue::Utf8(None)); - // simple test + // null delimiter -> null { - let expr = build_concat_ws_expr(&[null.clone(), col("c1"), col("c2")]); + let expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); assert_eq!(simplify(expr), null); } - // NULLs in other positions are not simplified. + // filter out null args { - let expr = build_concat_ws_expr(&[lit("|"), null.clone(), col("c2")]); - assert_eq!(simplify(expr.clone()), expr); + let expr = concat_ws(lit("|"), vec![col("c1"), null.clone(), col("c2")]); + let expected = concat_ws(lit("|"), vec![col("c1"), col("c2")]); + assert_eq!(simplify(expr), expected); } // nested test { - let sub_expr = build_concat_ws_expr(&[null.clone(), col("c1"), col("c2")]); - let expr = build_concat_ws_expr(&[lit("|"), sub_expr, col("c3")]); - assert_eq!( - simplify(expr), - build_concat_ws_expr(&[lit("|"), null.clone(), col("c3")]) - ); + let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); + let expr = concat_ws(lit("|"), vec![sub_expr, col("c3")]); + assert_eq!(simplify(expr), concat_ws(lit("|"), vec![col("c3")])); } - // nested test -- separator + // null delimiter (nested) { - let sub_expr = build_concat_ws_expr(&[null.clone(), col("c1"), col("c2")]); - let expr = build_concat_ws_expr(&[sub_expr, col("c3"), col("c4")]); + let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); + let expr = concat_ws(sub_expr, vec![col("c3"), col("c4")]); assert_eq!(simplify(expr), null); } } #[test] fn test_simplify_concat() { - fn build_concat_expr(args: &[Expr]) -> Expr { - Expr::ScalarFunction { - fun: BuiltinScalarFunction::Concat, - args: args.to_vec(), - } - } - let null = Expr::Literal(ScalarValue::Utf8(None)); - let expr = build_concat_expr(&[ + let expr = concat(&[ null.clone(), col("c0"), lit("hello "), @@ -1438,7 +1533,7 @@ mod tests { lit(""), null, ]); - let expected = build_concat_expr(&[col("c0"), lit("hello rust"), col("c1")]); + let expected = concat(&[col("c0"), lit("hello rust"), col("c1")]); assert_eq!(simplify(expr), expected) } diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 6c80a4df50b12..2833eee048b46 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -978,7 +978,7 @@ mod test { // concat_ws { - let expr = concat_ws("-", &args); + let expr = concat_ws(lit("-"), args.to_vec()); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?); diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index af7bc635aad7a..a4ee36925af94 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -212,6 +212,19 @@ fn concat_literals() -> Result<()> { Ok(()) } +#[test] +fn concat_ws_literals() -> Result<()> { + let sql = "SELECT concat_ws('-', true, col_int32, false, null, 'hello', col_utf8, 12, '', 3.4) \ + AS col + FROM test"; + let plan = test_sql(sql)?; + let expected = + "Projection: concatwithseparator(Utf8(\"-\"), Utf8(\"1\"), CAST(test.col_int32 AS Utf8), Utf8(\"0-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... From bb2bdad2a778988be8b818372d5759b751a10682 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Tue, 18 Oct 2022 16:15:40 +0800 Subject: [PATCH 2/2] fmt and docs Signed-off-by: remzi <13716567376yh@gmail.com> --- datafusion/optimizer/src/simplify_expressions.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index 541ab600c8065..7e54266405c69 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -276,10 +276,12 @@ fn simpl_concat(args: Vec) -> Result { Expr::Literal( ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), ) => contiguous_scalar += &v, - Expr::Literal(x) => return Err(DataFusionError::Internal(format!( + Expr::Literal(x) => { + return Err(DataFusionError::Internal(format!( "The scalar {} should be casted to string type during the type coercion.", x - ))), + ))) + } // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` (if it is not empty) and reset it to empty string. // Then pushing this arg to the `new_args`. @@ -302,6 +304,11 @@ fn simpl_concat(args: Vec) -> Result { }) } +/// Simply the `concat_ws` function by +/// 1. folding to `null` if the delimiter is null +/// 2. filtering out `null` arguments +/// 3. using `concat` to replace `concat_ws` if the delimiter is an empty string +/// 4. concatenating contiguous literals if the delimiter is a literal. fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { Expr::Literal(