From 1eba34e02a2927c44283db3259fb14f28686c52f Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Mon, 3 Oct 2022 17:37:25 +0800 Subject: [PATCH 1/2] simpl concat Signed-off-by: remzi <13716567376yh@gmail.com> --- .../optimizer/src/simplify_expressions.rs | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index 969fa01693291..30b3a6e2cb315 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -849,12 +849,61 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { out_expr.rewrite(self)? } + // concat + ScalarFunction { + fun: BuiltinScalarFunction::Concat, + args, + } => { + let mut new_args = Vec::with_capacity(args.len()); + let mut contiguous_scalar = "".to_string(); + for e in args { + match e { + // ignore `null` scalar and concatenate it with `contiguous scalar`. + Expr::Literal(x) => { + match x { + // true --> '1', false --> '0' + ScalarValue::Boolean(b) => { + contiguous_scalar += b + .map(|b| if b { "1" } else { "0" }) + .unwrap_or(""); + } + x if !x.is_null() => { + contiguous_scalar += &x.to_string(); + } + _ => {} + } + } + e => { + // push the last `contiguous_scalar` and reset it. + if !contiguous_scalar.is_empty() { + new_args.push(Expr::Literal(ScalarValue::Utf8(Some( + contiguous_scalar.clone(), + )))); + contiguous_scalar = "".to_string(); + } + // push `e` directly because `e` is not a scalar and we cannot simplify it + new_args.push(e); + } + } + } + if !contiguous_scalar.is_empty() { + new_args + .push(Expr::Literal(ScalarValue::Utf8(Some(contiguous_scalar)))); + } + + ScalarFunction { + fun: BuiltinScalarFunction::Concat, + args: new_args, + } + } + // concat_ws ScalarFunction { fun: BuiltinScalarFunction::ConcatWithSeparator, args, } => { match &args[..] { + // concat_ws(null, ..) --> null [Expr::Literal(sp), ..] if sp.is_null() => { Expr::Literal(ScalarValue::Utf8(None)) } @@ -1263,6 +1312,37 @@ mod tests { } } + #[test] + fn test_simplify_concat() { + fn build_concat_expr(args: &[Expr]) -> Expr { + Expr::ScalarFunction { + fun: BuiltinScalarFunction::Concat, + args: args.to_vec(), + } + } + + let null = Expr::Literal(ScalarValue::Utf8(None)); + // concat(true, c0, false, null, 'hello', c1, 12, 3.4) --> concat('1', c0, '0hello', c1, '123.4') + let expr = build_concat_expr(&[ + lit(true), + col("c0"), + lit(false), + null, + lit("hello"), + col("c1"), + lit(12), + lit(3.4), + ]); + let expected = build_concat_expr(&[ + lit("1"), + col("c0"), + lit("0hello"), + col("c1"), + lit("123.4"), + ]); + assert_eq!(simplify(expr), expected) + } + // ------------------------------ // --- ConstEvaluator tests ----- // ------------------------------ From 10fb95c18cac2c76b2ad97b33f68b9a37ef15c89 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Sun, 9 Oct 2022 17:57:43 +0800 Subject: [PATCH 2/2] update after type coercion Signed-off-by: remzi <13716567376yh@gmail.com> --- .../optimizer/src/simplify_expressions.rs | 46 +++++++------------ .../optimizer/tests/integration-test.rs | 13 ++++++ 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index d0736c8238f79..95c640b2a7088 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -887,30 +887,25 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { let mut contiguous_scalar = "".to_string(); for e in args { match e { - // ignore `null` scalar and concatenate it with `contiguous scalar`. - Expr::Literal(x) => { - match x { - // true --> '1', false --> '0' - ScalarValue::Boolean(b) => { - contiguous_scalar += b - .map(|b| if b { "1" } else { "0" }) - .unwrap_or(""); - } - x if !x.is_null() => { - contiguous_scalar += &x.to_string(); - } - _ => {} + // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. + // Concatenate it with `contiguous_scalar`. + Expr::Literal( + ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x), + ) => { + if let Some(s) = x { + contiguous_scalar += &s; } } + // If the arg is not a literal, we should first push the current `contiguous_scalar` + // to the `new_args` (if it is not empty) and reset it to empty string. + // Then pushing this arg to the `new_args`. e => { - // push the last `contiguous_scalar` and reset it. if !contiguous_scalar.is_empty() { new_args.push(Expr::Literal(ScalarValue::Utf8(Some( contiguous_scalar.clone(), )))); contiguous_scalar = "".to_string(); } - // push `e` directly because `e` is not a scalar and we cannot simplify it new_args.push(e); } } @@ -1411,24 +1406,17 @@ mod tests { } let null = Expr::Literal(ScalarValue::Utf8(None)); - // concat(true, c0, false, null, 'hello', c1, 12, 3.4) --> concat('1', c0, '0hello', c1, '123.4') let expr = build_concat_expr(&[ - lit(true), + null.clone(), col("c0"), - lit(false), - null, - lit("hello"), + lit("hello "), + null.clone(), + lit("rust"), col("c1"), - lit(12), - lit(3.4), - ]); - let expected = build_concat_expr(&[ - lit("1"), - col("c0"), - lit("0hello"), - col("c1"), - lit("123.4"), + lit(""), + null, ]); + let expected = build_concat_expr(&[col("c0"), lit("hello rust"), col("c1")]); assert_eq!(simplify(expr), expected) } diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index f6fe685ee2820..3071b880e6687 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -177,6 +177,19 @@ fn between_date64_plus_interval() -> Result<()> { Ok(()) } +#[test] +fn concat_literals() -> Result<()> { + let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \ + AS col + FROM test"; + let plan = test_sql(sql)?; + let expected = + "Projection: concat(Utf8(\"1\"), CAST(test.col_int32 AS Utf8), Utf8(\"0hello\"), test.col_utf8, Utf8(\"123.4\")) AS col\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...