Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use crate::expr::{BinaryExpr, GroupingSet};
use crate::{
aggregate_function, built_in_function, conditional_expressions::CaseBuilder, lit,
aggregate_function, built_in_function, conditional_expressions::CaseBuilder,
logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF,
BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction,
ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility,
Expand Down Expand Up @@ -134,11 +134,11 @@ pub fn concat(args: &[Expr]) -> Expr {
}

/// Concatenates all but the first argument, with separators.
/// The first argument is used as the separator string, and should not be NULL.
/// Other NULL arguments are ignored.
pub fn concat_ws(sep: impl Into<String>, values: &[Expr]) -> Expr {
let mut args = vec![lit(sep.into())];
args.extend_from_slice(values);
/// The first argument is used as the separator.
/// NULL arguments in `values` are ignored.
pub fn concat_ws(sep: Expr, values: Vec<Expr>) -> Expr {
let mut args = values;
args.insert(0, sep);
Expr::ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::ConcatWithSeparator,
args,
Expand Down Expand Up @@ -524,6 +524,7 @@ pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) -> Result<Expr> {
#[cfg(test)]
mod test {
use super::*;
use crate::lit;

#[test]
fn filter_is_null_and_is_not_null() {
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ impl Literal for String {
}
}

impl Literal for &String {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"TIL" lit() 👍

fn lit(&self) -> Expr {
Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned())))
}
}

impl Literal for Vec<u8> {
fn lit(&self) -> Expr {
Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())))
Expand Down
264 changes: 183 additions & 81 deletions datafusion/optimizer/src/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::{
expr::Between,
expr_fn::{and, or},
expr_fn::{and, concat_ws, or},
expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion},
lit,
logical_plan::LogicalPlan,
Expand Down Expand Up @@ -256,6 +256,125 @@ fn negate_clause(expr: Expr) -> Expr {
}
}

/// Simplify the `concat` function by
/// 1. filtering out all `null` literals
/// 2. concatenating contiguous literal arguments
///
/// For example:
/// `concat(col(a), 'hello ', 'world', col(b), null)`
/// will be optimized to
/// `concat(col(a), 'hello world', col(b))`
fn simpl_concat(args: Vec<Expr>) -> Result<Expr> {
let mut new_args = Vec::with_capacity(args.len());
let mut contiguous_scalar = "".to_string();
for arg in args {
match arg {
// filter out `null` args
Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {}
// All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
// Concatenate it with the `contiguous_scalar`.
Expr::Literal(
ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)),
) => contiguous_scalar += &v,
Expr::Literal(x) => {
return Err(DataFusionError::Internal(format!(
"The scalar {} should be casted to string type during the type coercion.",
x
)))
}
// If the arg is not a literal, we should first push the current `contiguous_scalar`
// to the `new_args` (if it is not empty) and reset it to empty string.
// Then pushing this arg to the `new_args`.
arg => {
if !contiguous_scalar.is_empty() {
new_args.push(lit(contiguous_scalar));
contiguous_scalar = "".to_string();
}
new_args.push(arg);
}
}
}
if !contiguous_scalar.is_empty() {
new_args.push(lit(contiguous_scalar));
}

Ok(Expr::ScalarFunction {
fun: BuiltinScalarFunction::Concat,
args: new_args,
})
}

/// Simply the `concat_ws` function by
/// 1. folding to `null` if the delimiter is null
/// 2. filtering out `null` arguments
/// 3. using `concat` to replace `concat_ws` if the delimiter is an empty string
/// 4. concatenating contiguous literals if the delimiter is a literal.
fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result<Expr> {
match delimiter {
Expr::Literal(
ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter),
) => {
match delimiter {
// when the delimiter is an empty string,
// we can use `concat` to replace `concat_ws`
Some(delimiter) if delimiter.is_empty() => simpl_concat(args.to_vec()),
Some(delimiter) => {
let mut new_args = Vec::with_capacity(args.len());
new_args.push(lit(delimiter));
let mut contiguous_scalar = None;
for arg in args {
match arg {
// filter out null args
Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {}
Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => {
match contiguous_scalar {
None => contiguous_scalar = Some(v.to_string()),
Some(mut pre) => {
pre += delimiter;
pre += v;
contiguous_scalar = Some(pre)
}
}
}
Expr::Literal(s) => return Err(DataFusionError::Internal(format!("The scalar {} should be casted to string type during the type coercion.", s))),
// If the arg is not a literal, we should first push the current `contiguous_scalar`
// to the `new_args` and reset it to None.
// Then pushing this arg to the `new_args`.
arg => {
if let Some(val) = contiguous_scalar {
new_args.push(lit(val));
}
new_args.push(arg.clone());
contiguous_scalar = None;
}
}
}
if let Some(val) = contiguous_scalar {
new_args.push(lit(val));
}
Comment on lines +325 to +354

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern of creating the contiguous scalar is so similar -- I wonder if it could be extracted out into a function -- perhaps as a follow on PR

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing @alamb ♥️

The logic for concat and concat_ws is a little different, because in concat_ws we must consider the delimiter and we can't ignore the empty string literals. I will try to find a way to refactor them.

Ok(Expr::ScalarFunction {
fun: BuiltinScalarFunction::ConcatWithSeparator,
args: new_args,
})
}
// if the delimiter is null, then the value of the whole expression is null.
None => Ok(Expr::Literal(ScalarValue::Utf8(None))),
}
}
Expr::Literal(d) => Err(DataFusionError::Internal(format!(
"The scalar {} should be casted to string type during the type coercion.",
d
))),
d => Ok(concat_ws(
d.clone(),
args.iter()
.cloned()
.filter(|x| !is_null(x))
.collect::<Vec<Expr>>(),
)),
}
}

impl OptimizerRule for SimplifyExpressions {
fn name(&self) -> &str {
"simplify_expressions"
Expand Down Expand Up @@ -880,62 +999,19 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
Expr::ScalarFunction {
fun: BuiltinScalarFunction::Concat,
args,
} => {
let mut new_args = Vec::with_capacity(args.len());
let mut contiguous_scalar = "".to_string();
for e in args {
match e {
// All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
// Concatenate it with `contiguous_scalar`.
Expr::Literal(
ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x),
) => {
if let Some(s) = x {
contiguous_scalar += &s;
}
}
// If the arg is not a literal, we should first push the current `contiguous_scalar`
// to the `new_args` (if it is not empty) and reset it to empty string.
// Then pushing this arg to the `new_args`.
e => {
if !contiguous_scalar.is_empty() {
new_args.push(Expr::Literal(ScalarValue::Utf8(Some(
contiguous_scalar.clone(),
))));
contiguous_scalar = "".to_string();
}
new_args.push(e);
}
}
}
if !contiguous_scalar.is_empty() {
new_args
.push(Expr::Literal(ScalarValue::Utf8(Some(contiguous_scalar))));
}

Expr::ScalarFunction {
fun: BuiltinScalarFunction::Concat,
args: new_args,
}
}
} => simpl_concat(args)?,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️


// concat_ws
Expr::ScalarFunction {
fun: BuiltinScalarFunction::ConcatWithSeparator,
args,
} => {
match &args[..] {
// concat_ws(null, ..) --> null
[Expr::Literal(sp), ..] if sp.is_null() => {
Expr::Literal(ScalarValue::Utf8(None))
}
// TODO https://github.com/apache/arrow-datafusion/issues/3599
_ => Expr::ScalarFunction {
fun: BuiltinScalarFunction::ConcatWithSeparator,
args,
},
}
}
} => match &args[..] {
[delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?,
_ => Expr::ScalarFunction {
fun: BuiltinScalarFunction::ConcatWithSeparator,
args,
},
},

//
// Rules for Between
Expand Down Expand Up @@ -994,6 +1070,7 @@ mod tests {
use chrono::{DateTime, TimeZone, Utc};
use datafusion_common::{DFField, DFSchemaRef};
use datafusion_expr::expr::Case;
use datafusion_expr::expr_fn::{concat, concat_ws};
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
and, binary_expr, call_fn, col, create_udf, lit, lit_timestamp_nano,
Expand Down Expand Up @@ -1379,56 +1456,81 @@ mod tests {
}

#[test]
fn test_simplify_concat_ws_null_separator() {
fn build_concat_ws_expr(args: &[Expr]) -> Expr {
Expr::ScalarFunction {
fun: BuiltinScalarFunction::ConcatWithSeparator,
args: args.to_vec(),
}
fn test_simplify_concat_ws() {
let null = Expr::Literal(ScalarValue::Utf8(None));
// the delimiter is not a literal
{
let expr = concat_ws(col("c"), vec![lit("a"), null.clone(), lit("b")]);
let expected = concat_ws(col("c"), vec![lit("a"), lit("b")]);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so cool!

assert_eq!(simplify(expr), expected);
}

// the delimiter is an empty string
{
let expr = concat_ws(lit(""), vec![col("a"), lit("c"), lit("b")]);
let expected = concat(&[col("a"), lit("cb")]);
assert_eq!(simplify(expr), expected);
}

// the delimiter is a not-empty string
{
let expr = concat_ws(
lit("-"),
vec![
null.clone(),
col("c0"),
lit("hello"),
null.clone(),
lit("rust"),
col("c1"),
lit(""),
lit(""),
null,
],
);
let expected = concat_ws(
lit("-"),
vec![col("c0"), lit("hello-rust"), col("c1"), lit("-")],
);
assert_eq!(simplify(expr), expected)
}
}

#[test]
fn test_simplify_concat_ws_with_null() {
let null = Expr::Literal(ScalarValue::Utf8(None));
// simple test
// null delimiter -> null
{
let expr = build_concat_ws_expr(&[null.clone(), col("c1"), col("c2")]);
let expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
assert_eq!(simplify(expr), null);
}

// NULLs in other positions are not simplified.
// filter out null args
{
let expr = build_concat_ws_expr(&[lit("|"), null.clone(), col("c2")]);
assert_eq!(simplify(expr.clone()), expr);
let expr = concat_ws(lit("|"), vec![col("c1"), null.clone(), col("c2")]);
let expected = concat_ws(lit("|"), vec![col("c1"), col("c2")]);
assert_eq!(simplify(expr), expected);
}

// nested test
{
let sub_expr = build_concat_ws_expr(&[null.clone(), col("c1"), col("c2")]);
let expr = build_concat_ws_expr(&[lit("|"), sub_expr, col("c3")]);
assert_eq!(
simplify(expr),
build_concat_ws_expr(&[lit("|"), null.clone(), col("c3")])
);
let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
let expr = concat_ws(lit("|"), vec![sub_expr, col("c3")]);
assert_eq!(simplify(expr), concat_ws(lit("|"), vec![col("c3")]));
}

// nested test -- separator
// null delimiter (nested)
{
let sub_expr = build_concat_ws_expr(&[null.clone(), col("c1"), col("c2")]);
let expr = build_concat_ws_expr(&[sub_expr, col("c3"), col("c4")]);
let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
let expr = concat_ws(sub_expr, vec![col("c3"), col("c4")]);
assert_eq!(simplify(expr), null);
}
}

#[test]
fn test_simplify_concat() {
fn build_concat_expr(args: &[Expr]) -> Expr {
Expr::ScalarFunction {
fun: BuiltinScalarFunction::Concat,
args: args.to_vec(),
}
}

let null = Expr::Literal(ScalarValue::Utf8(None));
let expr = build_concat_expr(&[
let expr = concat(&[
null.clone(),
col("c0"),
lit("hello "),
Expand All @@ -1438,7 +1540,7 @@ mod tests {
lit(""),
null,
]);
let expected = build_concat_expr(&[col("c0"), lit("hello rust"), col("c1")]);
let expected = concat(&[col("c0"), lit("hello rust"), col("c1")]);
assert_eq!(simplify(expr), expected)
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ mod test {

// concat_ws
{
let expr = concat_ws("-", &args);
let expr = concat_ws(lit("-"), args.to_vec());

let plan =
LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
Expand Down
Loading