diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index c96f6eea71510..95c640b2a7088 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -878,12 +878,56 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { out_expr.rewrite(self)? } + // concat + ScalarFunction { + fun: BuiltinScalarFunction::Concat, + args, + } => { + let mut new_args = Vec::with_capacity(args.len()); + let mut contiguous_scalar = "".to_string(); + for e in args { + match e { + // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. + // Concatenate it with `contiguous_scalar`. + Expr::Literal( + ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x), + ) => { + if let Some(s) = x { + contiguous_scalar += &s; + } + } + // If the arg is not a literal, we should first push the current `contiguous_scalar` + // to the `new_args` (if it is not empty) and reset it to empty string. + // Then pushing this arg to the `new_args`. + e => { + if !contiguous_scalar.is_empty() { + new_args.push(Expr::Literal(ScalarValue::Utf8(Some( + contiguous_scalar.clone(), + )))); + contiguous_scalar = "".to_string(); + } + new_args.push(e); + } + } + } + if !contiguous_scalar.is_empty() { + new_args + .push(Expr::Literal(ScalarValue::Utf8(Some(contiguous_scalar)))); + } + + ScalarFunction { + fun: BuiltinScalarFunction::Concat, + args: new_args, + } + } + // concat_ws ScalarFunction { fun: BuiltinScalarFunction::ConcatWithSeparator, args, } => { match &args[..] { + // concat_ws(null, ..) --> null [Expr::Literal(sp), ..] if sp.is_null() => { Expr::Literal(ScalarValue::Utf8(None)) } @@ -1352,6 +1396,30 @@ mod tests { } } + #[test] + fn test_simplify_concat() { + fn build_concat_expr(args: &[Expr]) -> Expr { + Expr::ScalarFunction { + fun: BuiltinScalarFunction::Concat, + args: args.to_vec(), + } + } + + let null = Expr::Literal(ScalarValue::Utf8(None)); + let expr = build_concat_expr(&[ + null.clone(), + col("c0"), + lit("hello "), + null.clone(), + lit("rust"), + col("c1"), + lit(""), + null, + ]); + let expected = build_concat_expr(&[col("c0"), lit("hello rust"), col("c1")]); + assert_eq!(simplify(expr), expected) + } + // ------------------------------ // --- ConstEvaluator tests ----- // ------------------------------ diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index f6fe685ee2820..3071b880e6687 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -177,6 +177,19 @@ fn between_date64_plus_interval() -> Result<()> { Ok(()) } +#[test] +fn concat_literals() -> Result<()> { + let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \ + AS col + FROM test"; + let plan = test_sql(sql)?; + let expected = + "Projection: concat(Utf8(\"1\"), CAST(test.col_int32 AS Utf8), Utf8(\"0hello\"), test.col_utf8, Utf8(\"123.4\")) AS col\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...