From a7d9ff685c2e45b6448749b096c1f031685d1c4d Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Tue, 23 Sep 2025 12:23:12 +0100 Subject: [PATCH 1/3] Add case expr simplifiers for literal comparisons --- .../simplify_expressions/expr_simplifier.rs | 225 +++++++++++++++++- .../src/simplify_expressions/utils.rs | 27 ++- 2 files changed, 250 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 2d10c7fe22b5c..ec93effe45d91 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1399,6 +1399,39 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // Rules for Case // + // CASE WHEN X THEN "a" WHEN Y THEN "b" ... END = "a" --> CASE WHEN X THEN "a" = "a" WHEN Y THEN "b" = "a" END + Expr::BinaryExpr(BinaryExpr { + left, + op: op @ (Eq | NotEq), + right, + }) if is_case_with_literal_outputs(&left) && is_lit(&right) => { + let case = as_case(&left)?; + Transformed::yes(Expr::Case(Case { + expr: None, + when_then_expr: case + .when_then_expr + .iter() + .map(|(when, then)| { + ( + when.clone(), + Box::new(Expr::BinaryExpr(BinaryExpr { + left: then.clone(), + op, + right: right.clone(), + })), + ) + }) + .collect(), + else_expr: case.else_expr.as_ref().map(|els| { + Box::new(Expr::BinaryExpr(BinaryExpr { + left: els.clone(), + op, + right, + })) + }), + })) + } + // CASE WHEN true THEN A ... END --> A // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END Expr::Case(Case { @@ -1447,7 +1480,11 @@ impl TreeNodeRewriter for Simplifier<'_, S> { when_then_expr, else_expr, }) if !when_then_expr.is_empty() - && when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number + // The rewrite is O(n²) in general so limit to small number of when-thens that can be true + && (when_then_expr.len() < 3 // small number of input whens + // or all thens are literal bools and a small number of them are true + || (when_then_expr.iter().all(|(_, then)| is_bool_lit(then)) + && when_then_expr.iter().filter(|(_, then)| is_true(then)).count() < 3)) && info.is_boolean_type(&when_then_expr[0].1)? => { // String disjunction of all the when predicates encountered so far. Not nullable. @@ -1471,6 +1508,56 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // Do a first pass at simplification out_expr.rewrite(self)? } + // CASE + // WHEN X THEN true + // WHEN Y THEN true + // WHEN Z THEN false + // ... + // ELSE true + // END + // + // ---> + // + // NOT(CASE + // WHEN X THEN false + // WHEN Y THEN false + // WHEN Z THEN true + // ... + // ELSE false + // END) + // + // Note: the rationale for this rewrite is that the case can then be further + // simplified into a small number of ANDs and ORs + Expr::Case(Case { + expr: None, + when_then_expr, + else_expr, + }) if !when_then_expr.is_empty() + && when_then_expr + .iter() + .all(|(_, then)| is_bool_lit(then)) // all thens are literal bools + // This simplification is only helpful if we end up with a small number of true thens + && when_then_expr + .iter() + .filter(|(_, then)| is_false(then)) + .count() + < 3 + && else_expr.as_deref().is_none_or(is_bool_lit) => + { + Transformed::yes( + Expr::Case(Case { + expr: None, + when_then_expr: when_then_expr + .iter() + .map(|(when, then)| { + (when.clone(), Box::new(then.clone().not())) + }) + .collect(), + else_expr: else_expr.map(|else_expr| Box::new(else_expr.not())), + }) + .not(), + ) + } Expr::ScalarFunction(ScalarFunction { func: udf, args }) => { match udf.simplify(args, info)? { ExprSimplifyResult::Original(args) => { @@ -3465,6 +3552,142 @@ mod tests { ); } + #[test] + fn simplify_literal_case_equality() { + // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" + let simple_case = Expr::Case(Case::new( + None, + vec![( + Box::new(col("c2_non_null").not_eq(lit(false))), + Box::new(lit("ok")), + )], + Some(Box::new(lit("not_ok"))), + )); + + // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" == "ok" + // --> + // CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok" + // --> + // CASE WHEN c2 != false THEN true ELSE false + // --> + // c2 + assert_eq!( + simplify(binary_expr(simple_case.clone(), Operator::Eq, lit("ok"),)), + col("c2_non_null"), + ); + + // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" != "ok" + // --> + // NOT(CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok") + // --> + // NOT(CASE WHEN c2 != false THEN true ELSE false) + // --> + // NOT(c2) + assert_eq!( + simplify(binary_expr(simple_case, Operator::NotEq, lit("ok"),)), + not(col("c2_non_null")), + ); + + let complex_case = Expr::Case(Case::new( + None, + vec![ + ( + Box::new(col("c1").eq(lit("inboxed"))), + Box::new(lit("pending")), + ), + ( + Box::new(col("c1").eq(lit("scheduled"))), + Box::new(lit("pending")), + ), + ( + Box::new(col("c1").eq(lit("completed"))), + Box::new(lit("completed")), + ), + ( + Box::new(col("c1").eq(lit("paused"))), + Box::new(lit("paused")), + ), + (Box::new(col("c2")), Box::new(lit("running"))), + ( + Box::new(col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0)))), + Box::new(lit("backing-off")), + ), + ], + Some(Box::new(lit("ready"))), + )); + + assert_eq!( + simplify(binary_expr( + complex_case.clone(), + Operator::Eq, + lit("completed"), + )), + not_distinct_from(col("c1").eq(lit("completed")), lit(true)).and( + distinct_from(col("c1").eq(lit("inboxed")), lit(true)) + .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true))) + ) + ); + + assert_eq!( + simplify(binary_expr( + complex_case.clone(), + Operator::NotEq, + lit("completed"), + )), + distinct_from(col("c1").eq(lit("completed")), lit(true)) + .or(not_distinct_from(col("c1").eq(lit("inboxed")), lit(true)) + .or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true)))) + ); + + assert_eq!( + simplify(binary_expr( + complex_case.clone(), + Operator::Eq, + lit("running"), + )), + not_distinct_from(col("c2"), lit(true)).and( + distinct_from(col("c1").eq(lit("inboxed")), lit(true)) + .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true))) + .and(distinct_from(col("c1").eq(lit("completed")), lit(true))) + .and(distinct_from(col("c1").eq(lit("paused")), lit(true))) + ) + ); + + assert_eq!( + simplify(binary_expr( + complex_case.clone(), + Operator::Eq, + lit("ready"), + )), + distinct_from(col("c1").eq(lit("inboxed")), lit(true)) + .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true))) + .and(distinct_from(col("c1").eq(lit("completed")), lit(true))) + .and(distinct_from(col("c1").eq(lit("paused")), lit(true))) + .and(distinct_from(col("c2"), lit(true))) + .and(distinct_from( + col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))), + lit(true) + )) + ); + + assert_eq!( + simplify(binary_expr( + complex_case.clone(), + Operator::NotEq, + lit("ready"), + )), + not_distinct_from(col("c1").eq(lit("inboxed")), lit(true)) + .or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true))) + .or(not_distinct_from(col("c1").eq(lit("completed")), lit(true))) + .or(not_distinct_from(col("c1").eq(lit("paused")), lit(true))) + .or(not_distinct_from(col("c2"), lit(true))) + .or(not_distinct_from( + col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))), + lit(true) + )) + ); + } + #[test] fn simplify_expr_case_when_then_else() { // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 2f7dadcebaa49..e526df206afe2 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -22,7 +22,7 @@ use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::{ expr::{Between, BinaryExpr, InList}, expr_fn::{and, bitwise_and, bitwise_or, or}, - Expr, Like, Operator, + Case, Expr, Like, Operator, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -265,6 +265,31 @@ pub fn as_bool_lit(expr: &Expr) -> Result> { } } +pub fn is_case_with_literal_outputs(expr: &Expr) -> bool { + match expr { + Expr::Case(Case { + expr: None, + when_then_expr, + else_expr, + }) => { + when_then_expr.iter().all(|(_, then)| is_lit(then)) + && else_expr.as_deref().is_none_or(is_lit) + } + _ => false, + } +} + +pub fn as_case(expr: &Expr) -> Result<&Case> { + match expr { + Expr::Case(case) => Ok(case), + _ => internal_err!("Expected case, got {expr:?}"), + } +} + +pub fn is_lit(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(_, _)) +} + /// negate a Not clause /// input is the clause to be negated.(args of Not clause) /// For BinaryExpr, use the negation of op instead. From f55d53806d841c0e91147141c639b997d7ed44c2 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Tue, 23 Sep 2025 18:34:13 +0100 Subject: [PATCH 2/3] Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs Co-authored-by: Andrew Lamb --- .../optimizer/src/simplify_expressions/expr_simplifier.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index ec93effe45d91..cdce3e1f1b1b7 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1399,6 +1399,8 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // Rules for Case // + // Inline a comparison to a literal with the case statement into the `THEN` clauses. + // which can enable further simplifications // CASE WHEN X THEN "a" WHEN Y THEN "b" ... END = "a" --> CASE WHEN X THEN "a" = "a" WHEN Y THEN "b" = "a" END Expr::BinaryExpr(BinaryExpr { left, From c3c0ef31ba395c6b68aa21e3f4253ca232a8f904 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Tue, 23 Sep 2025 18:44:30 +0100 Subject: [PATCH 3/3] Avoid expr clones --- .../simplify_expressions/expr_simplifier.rs | 21 +++++++++---------- .../src/simplify_expressions/utils.rs | 2 +- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index cdce3e1f1b1b7..b491a3529f353 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1407,26 +1407,26 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: op @ (Eq | NotEq), right, }) if is_case_with_literal_outputs(&left) && is_lit(&right) => { - let case = as_case(&left)?; + let case = into_case(*left)?; Transformed::yes(Expr::Case(Case { expr: None, when_then_expr: case .when_then_expr - .iter() + .into_iter() .map(|(when, then)| { ( - when.clone(), + when, Box::new(Expr::BinaryExpr(BinaryExpr { - left: then.clone(), + left: then, op, right: right.clone(), })), ) }) .collect(), - else_expr: case.else_expr.as_ref().map(|els| { + else_expr: case.else_expr.map(|els| { Box::new(Expr::BinaryExpr(BinaryExpr { - left: els.clone(), + left: els, op, right, })) @@ -1550,12 +1550,11 @@ impl TreeNodeRewriter for Simplifier<'_, S> { Expr::Case(Case { expr: None, when_then_expr: when_then_expr - .iter() - .map(|(when, then)| { - (when.clone(), Box::new(then.clone().not())) - }) + .into_iter() + .map(|(when, then)| (when, Box::new(Expr::Not(then)))) .collect(), - else_expr: else_expr.map(|else_expr| Box::new(else_expr.not())), + else_expr: else_expr + .map(|else_expr| Box::new(Expr::Not(else_expr))), }) .not(), ) diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index e526df206afe2..35e256f3064e3 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -279,7 +279,7 @@ pub fn is_case_with_literal_outputs(expr: &Expr) -> bool { } } -pub fn as_case(expr: &Expr) -> Result<&Case> { +pub fn into_case(expr: Expr) -> Result { match expr { Expr::Case(case) => Ok(case), _ => internal_err!("Expected case, got {expr:?}"),