diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 97dfc09c4f2a6..3c96f953f0000 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1400,13 +1400,35 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // // CASE WHEN true THEN A ... END --> A + // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END Expr::Case(Case { expr: None, mut when_then_expr, else_expr: _, - }) if !when_then_expr.is_empty() && is_true(when_then_expr[0].0.as_ref()) => { - let (_, then_) = when_then_expr.swap_remove(0); - Transformed::yes(*then_) + // if let guard is not stabilized so we can't use it yet: https://github.com/rust-lang/rust/issues/51114 + // Once it's supported we can avoid searching through when_then_expr twice in the below .any() and .position() calls + // }) if let Some(i) = when_then_expr.iter().position(|(when, _)| is_true(when.as_ref())) => { + }) if when_then_expr + .iter() + .any(|(when, _)| is_true(when.as_ref())) => + { + let i = when_then_expr + .iter() + .position(|(when, _)| is_true(when.as_ref())) + .unwrap(); + let (_, then_) = when_then_expr.swap_remove(i); + // CASE WHEN true THEN A ... END --> A + if i == 0 { + return Ok(Transformed::yes(*then_)); + } + + // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END + when_then_expr.truncate(i); + Transformed::yes(Expr::Case(Case { + expr: None, + when_then_expr, + else_expr: Some(then_), + })) } // CASE @@ -3563,7 +3585,7 @@ mod tests { } #[test] - fn simplify_expr_case_when_true() { + fn simplify_expr_case_when_first_true() { // CASE WHEN true THEN 1 ELSE x END --> 1 assert_eq!( simplify(Expr::Case(Case::new( @@ -3632,6 +3654,82 @@ mod tests { assert_eq!(simplify(expr.clone()), expr); } + #[test] + fn simplify_expr_case_when_any_true() { + // CASE WHEN x > 0 THEN a WHEN true THEN b ELSE c END --> CASE WHEN x > 0 THEN a ELSE b END + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![ + (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), + (Box::new(lit(true)), Box::new(col("b"))), + ], + Some(Box::new(col("c"))), + ))), + Expr::Case(Case::new( + None, + vec![(Box::new(col("x").gt(lit(0))), Box::new(col("a")))], + Some(Box::new(col("b"))), + )) + ); + + // CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c WHEN z = 0 THEN d ELSE e END + // --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![ + (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), + (Box::new(col("y").lt(lit(0))), Box::new(col("b"))), + (Box::new(lit(true)), Box::new(col("c"))), + (Box::new(col("z").eq(lit(0))), Box::new(col("d"))), + ], + Some(Box::new(col("e"))), + ))), + Expr::Case(Case::new( + None, + vec![ + (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), + (Box::new(col("y").lt(lit(0))), Box::new(col("b"))), + ], + Some(Box::new(col("c"))), + )) + ); + + // CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c END (no else) + // --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![ + (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), + (Box::new(col("y").lt(lit(0))), Box::new(col("b"))), + (Box::new(lit(true)), Box::new(col("c"))), + ], + None, + ))), + Expr::Case(Case::new( + None, + vec![ + (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), + (Box::new(col("y").lt(lit(0))), Box::new(col("b"))), + ], + Some(Box::new(col("c"))), + )) + ); + + // Negative test: CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END should not be simplified + let expr = Expr::Case(Case::new( + None, + vec![ + (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), + (Box::new(col("y").lt(lit(0))), Box::new(col("b"))), + ], + Some(Box::new(col("c"))), + )); + assert_eq!(simplify(expr.clone()), expr); + } + fn distinct_from(left: impl Into, right: impl Into) -> Expr { Expr::BinaryExpr(BinaryExpr { left: Box::new(left.into()),