-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Optimize the concat_ws function
#3869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,7 @@ use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; | |
| use datafusion_expr::expr::BinaryExpr; | ||
| use datafusion_expr::{ | ||
| expr::Between, | ||
| expr_fn::{and, or}, | ||
| expr_fn::{and, concat_ws, or}, | ||
| expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, | ||
| lit, | ||
| logical_plan::LogicalPlan, | ||
|
|
@@ -256,6 +256,125 @@ fn negate_clause(expr: Expr) -> Expr { | |
| } | ||
| } | ||
|
|
||
| /// Simplify the `concat` function by | ||
| /// 1. filtering out all `null` literals | ||
| /// 2. concatenating contiguous literal arguments | ||
| /// | ||
| /// For example: | ||
| /// `concat(col(a), 'hello ', 'world', col(b), null)` | ||
| /// will be optimized to | ||
| /// `concat(col(a), 'hello world', col(b))` | ||
| fn simpl_concat(args: Vec<Expr>) -> Result<Expr> { | ||
| let mut new_args = Vec::with_capacity(args.len()); | ||
| let mut contiguous_scalar = "".to_string(); | ||
| for arg in args { | ||
| match arg { | ||
| // filter out `null` args | ||
| Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} | ||
| // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. | ||
| // Concatenate it with the `contiguous_scalar`. | ||
| Expr::Literal( | ||
| ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), | ||
| ) => contiguous_scalar += &v, | ||
| Expr::Literal(x) => { | ||
| return Err(DataFusionError::Internal(format!( | ||
| "The scalar {} should be casted to string type during the type coercion.", | ||
| x | ||
| ))) | ||
| } | ||
| // If the arg is not a literal, we should first push the current `contiguous_scalar` | ||
| // to the `new_args` (if it is not empty) and reset it to empty string. | ||
| // Then pushing this arg to the `new_args`. | ||
| arg => { | ||
| if !contiguous_scalar.is_empty() { | ||
| new_args.push(lit(contiguous_scalar)); | ||
| contiguous_scalar = "".to_string(); | ||
| } | ||
| new_args.push(arg); | ||
| } | ||
| } | ||
| } | ||
| if !contiguous_scalar.is_empty() { | ||
| new_args.push(lit(contiguous_scalar)); | ||
| } | ||
|
|
||
| Ok(Expr::ScalarFunction { | ||
| fun: BuiltinScalarFunction::Concat, | ||
| args: new_args, | ||
| }) | ||
| } | ||
|
|
||
| /// Simply the `concat_ws` function by | ||
| /// 1. folding to `null` if the delimiter is null | ||
| /// 2. filtering out `null` arguments | ||
| /// 3. using `concat` to replace `concat_ws` if the delimiter is an empty string | ||
| /// 4. concatenating contiguous literals if the delimiter is a literal. | ||
| fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result<Expr> { | ||
| match delimiter { | ||
| Expr::Literal( | ||
| ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter), | ||
| ) => { | ||
| match delimiter { | ||
| // when the delimiter is an empty string, | ||
| // we can use `concat` to replace `concat_ws` | ||
| Some(delimiter) if delimiter.is_empty() => simpl_concat(args.to_vec()), | ||
| Some(delimiter) => { | ||
| let mut new_args = Vec::with_capacity(args.len()); | ||
| new_args.push(lit(delimiter)); | ||
| let mut contiguous_scalar = None; | ||
| for arg in args { | ||
| match arg { | ||
| // filter out null args | ||
| Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} | ||
| Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => { | ||
| match contiguous_scalar { | ||
| None => contiguous_scalar = Some(v.to_string()), | ||
| Some(mut pre) => { | ||
| pre += delimiter; | ||
| pre += v; | ||
| contiguous_scalar = Some(pre) | ||
| } | ||
| } | ||
| } | ||
| Expr::Literal(s) => return Err(DataFusionError::Internal(format!("The scalar {} should be casted to string type during the type coercion.", s))), | ||
| // If the arg is not a literal, we should first push the current `contiguous_scalar` | ||
| // to the `new_args` and reset it to None. | ||
| // Then pushing this arg to the `new_args`. | ||
| arg => { | ||
| if let Some(val) = contiguous_scalar { | ||
| new_args.push(lit(val)); | ||
| } | ||
| new_args.push(arg.clone()); | ||
| contiguous_scalar = None; | ||
| } | ||
| } | ||
| } | ||
| if let Some(val) = contiguous_scalar { | ||
| new_args.push(lit(val)); | ||
| } | ||
|
Comment on lines
+325
to
+354
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This pattern of creating the contiguous scalar is so similar -- I wonder if it could be extracted out into a function -- perhaps as a follow on PR
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for reviewing @alamb The logic for |
||
| Ok(Expr::ScalarFunction { | ||
| fun: BuiltinScalarFunction::ConcatWithSeparator, | ||
| args: new_args, | ||
| }) | ||
| } | ||
| // if the delimiter is null, then the value of the whole expression is null. | ||
| None => Ok(Expr::Literal(ScalarValue::Utf8(None))), | ||
| } | ||
| } | ||
| Expr::Literal(d) => Err(DataFusionError::Internal(format!( | ||
| "The scalar {} should be casted to string type during the type coercion.", | ||
| d | ||
| ))), | ||
| d => Ok(concat_ws( | ||
| d.clone(), | ||
| args.iter() | ||
| .cloned() | ||
| .filter(|x| !is_null(x)) | ||
| .collect::<Vec<Expr>>(), | ||
| )), | ||
| } | ||
| } | ||
|
|
||
| impl OptimizerRule for SimplifyExpressions { | ||
| fn name(&self) -> &str { | ||
| "simplify_expressions" | ||
|
|
@@ -880,62 +999,19 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { | |
| Expr::ScalarFunction { | ||
| fun: BuiltinScalarFunction::Concat, | ||
| args, | ||
| } => { | ||
| let mut new_args = Vec::with_capacity(args.len()); | ||
| let mut contiguous_scalar = "".to_string(); | ||
| for e in args { | ||
| match e { | ||
| // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. | ||
| // Concatenate it with `contiguous_scalar`. | ||
| Expr::Literal( | ||
| ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x), | ||
| ) => { | ||
| if let Some(s) = x { | ||
| contiguous_scalar += &s; | ||
| } | ||
| } | ||
| // If the arg is not a literal, we should first push the current `contiguous_scalar` | ||
| // to the `new_args` (if it is not empty) and reset it to empty string. | ||
| // Then pushing this arg to the `new_args`. | ||
| e => { | ||
| if !contiguous_scalar.is_empty() { | ||
| new_args.push(Expr::Literal(ScalarValue::Utf8(Some( | ||
| contiguous_scalar.clone(), | ||
| )))); | ||
| contiguous_scalar = "".to_string(); | ||
| } | ||
| new_args.push(e); | ||
| } | ||
| } | ||
| } | ||
| if !contiguous_scalar.is_empty() { | ||
| new_args | ||
| .push(Expr::Literal(ScalarValue::Utf8(Some(contiguous_scalar)))); | ||
| } | ||
|
|
||
| Expr::ScalarFunction { | ||
| fun: BuiltinScalarFunction::Concat, | ||
| args: new_args, | ||
| } | ||
| } | ||
| } => simpl_concat(args)?, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❤️ |
||
|
|
||
| // concat_ws | ||
| Expr::ScalarFunction { | ||
| fun: BuiltinScalarFunction::ConcatWithSeparator, | ||
| args, | ||
| } => { | ||
| match &args[..] { | ||
| // concat_ws(null, ..) --> null | ||
| [Expr::Literal(sp), ..] if sp.is_null() => { | ||
| Expr::Literal(ScalarValue::Utf8(None)) | ||
| } | ||
| // TODO https://github.com/apache/arrow-datafusion/issues/3599 | ||
| _ => Expr::ScalarFunction { | ||
| fun: BuiltinScalarFunction::ConcatWithSeparator, | ||
| args, | ||
| }, | ||
| } | ||
| } | ||
| } => match &args[..] { | ||
| [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, | ||
| _ => Expr::ScalarFunction { | ||
| fun: BuiltinScalarFunction::ConcatWithSeparator, | ||
| args, | ||
| }, | ||
| }, | ||
|
|
||
| // | ||
| // Rules for Between | ||
|
|
@@ -994,6 +1070,7 @@ mod tests { | |
| use chrono::{DateTime, TimeZone, Utc}; | ||
| use datafusion_common::{DFField, DFSchemaRef}; | ||
| use datafusion_expr::expr::Case; | ||
| use datafusion_expr::expr_fn::{concat, concat_ws}; | ||
| use datafusion_expr::logical_plan::table_scan; | ||
| use datafusion_expr::{ | ||
| and, binary_expr, call_fn, col, create_udf, lit, lit_timestamp_nano, | ||
|
|
@@ -1379,56 +1456,81 @@ mod tests { | |
| } | ||
|
|
||
| #[test] | ||
| fn test_simplify_concat_ws_null_separator() { | ||
| fn build_concat_ws_expr(args: &[Expr]) -> Expr { | ||
| Expr::ScalarFunction { | ||
| fun: BuiltinScalarFunction::ConcatWithSeparator, | ||
| args: args.to_vec(), | ||
| } | ||
| fn test_simplify_concat_ws() { | ||
| let null = Expr::Literal(ScalarValue::Utf8(None)); | ||
| // the delimiter is not a literal | ||
| { | ||
| let expr = concat_ws(col("c"), vec![lit("a"), null.clone(), lit("b")]); | ||
| let expected = concat_ws(col("c"), vec![lit("a"), lit("b")]); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so cool! |
||
| assert_eq!(simplify(expr), expected); | ||
| } | ||
|
|
||
| // the delimiter is an empty string | ||
| { | ||
| let expr = concat_ws(lit(""), vec![col("a"), lit("c"), lit("b")]); | ||
| let expected = concat(&[col("a"), lit("cb")]); | ||
| assert_eq!(simplify(expr), expected); | ||
| } | ||
|
|
||
| // the delimiter is a not-empty string | ||
| { | ||
| let expr = concat_ws( | ||
| lit("-"), | ||
| vec![ | ||
| null.clone(), | ||
| col("c0"), | ||
| lit("hello"), | ||
| null.clone(), | ||
| lit("rust"), | ||
| col("c1"), | ||
| lit(""), | ||
| lit(""), | ||
| null, | ||
| ], | ||
| ); | ||
| let expected = concat_ws( | ||
| lit("-"), | ||
| vec![col("c0"), lit("hello-rust"), col("c1"), lit("-")], | ||
| ); | ||
| assert_eq!(simplify(expr), expected) | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_simplify_concat_ws_with_null() { | ||
| let null = Expr::Literal(ScalarValue::Utf8(None)); | ||
| // simple test | ||
| // null delimiter -> null | ||
| { | ||
| let expr = build_concat_ws_expr(&[null.clone(), col("c1"), col("c2")]); | ||
| let expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); | ||
| assert_eq!(simplify(expr), null); | ||
| } | ||
|
|
||
| // NULLs in other positions are not simplified. | ||
| // filter out null args | ||
| { | ||
| let expr = build_concat_ws_expr(&[lit("|"), null.clone(), col("c2")]); | ||
| assert_eq!(simplify(expr.clone()), expr); | ||
| let expr = concat_ws(lit("|"), vec![col("c1"), null.clone(), col("c2")]); | ||
| let expected = concat_ws(lit("|"), vec![col("c1"), col("c2")]); | ||
| assert_eq!(simplify(expr), expected); | ||
| } | ||
|
|
||
| // nested test | ||
| { | ||
| let sub_expr = build_concat_ws_expr(&[null.clone(), col("c1"), col("c2")]); | ||
| let expr = build_concat_ws_expr(&[lit("|"), sub_expr, col("c3")]); | ||
| assert_eq!( | ||
| simplify(expr), | ||
| build_concat_ws_expr(&[lit("|"), null.clone(), col("c3")]) | ||
| ); | ||
| let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); | ||
| let expr = concat_ws(lit("|"), vec![sub_expr, col("c3")]); | ||
| assert_eq!(simplify(expr), concat_ws(lit("|"), vec![col("c3")])); | ||
| } | ||
|
|
||
| // nested test -- separator | ||
| // null delimiter (nested) | ||
| { | ||
| let sub_expr = build_concat_ws_expr(&[null.clone(), col("c1"), col("c2")]); | ||
| let expr = build_concat_ws_expr(&[sub_expr, col("c3"), col("c4")]); | ||
| let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); | ||
| let expr = concat_ws(sub_expr, vec![col("c3"), col("c4")]); | ||
| assert_eq!(simplify(expr), null); | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_simplify_concat() { | ||
| fn build_concat_expr(args: &[Expr]) -> Expr { | ||
| Expr::ScalarFunction { | ||
| fun: BuiltinScalarFunction::Concat, | ||
| args: args.to_vec(), | ||
| } | ||
| } | ||
|
|
||
| let null = Expr::Literal(ScalarValue::Utf8(None)); | ||
| let expr = build_concat_expr(&[ | ||
| let expr = concat(&[ | ||
| null.clone(), | ||
| col("c0"), | ||
| lit("hello "), | ||
|
|
@@ -1438,7 +1540,7 @@ mod tests { | |
| lit(""), | ||
| null, | ||
| ]); | ||
| let expected = build_concat_expr(&[col("c0"), lit("hello rust"), col("c1")]); | ||
| let expected = concat(&[col("c0"), lit("hello rust"), col("c1")]); | ||
| assert_eq!(simplify(expr), expected) | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"TIL"
lit()👍