diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 0b14a3eafb7..c178e536b38 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -41,6 +41,10 @@ use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::SimplifyCtx; +use crate::scalar_fn::fns::is_not_null::IsNotNull; +use crate::scalar_fn::fns::is_null::IsNull; +use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::zip::zip_impl; /// Options for the n-ary CaseWhen expression. @@ -251,6 +255,51 @@ impl ScalarFnVTable for CaseWhen { merge_case_branches(branches, else_value, ctx) } + fn simplify( + &self, + options: &Self::Options, + expr: &Expression, + _ctx: &dyn SimplifyCtx, + ) -> VortexResult> { + // Rewrite the COALESCE-shaped CASE WHEN into `fill_null`, which references `x` + // once and lowers to a single fill kernel instead of a `zip`/merge that resolves + // `x` twice (once for the `is_null` predicate, once for the value branch). + // + // CASE WHEN is_null(x) THEN c ELSE x END ==> fill_null(x, c) + // CASE WHEN is_not_null(x) THEN x ELSE c END ==> fill_null(x, c) + // + // The fill `c` must be a `Literal`: `fill_null`'s kernel reads the fill value via + // `as_constant()`, so a non-constant fill would produce an unexecutable expression. + if options.num_when_then_pairs != 1 || !options.has_else { + return Ok(None); + } + + let when = expr.child(0); + let then = expr.child(1); + let els = expr.child(2); + + // `is_null(x) ? c : x` — predicate operand and ELSE are the same `x`, fill is THEN. + let (x, fill) = if when.is::() && when.child(0) == els { + (els, then) + // `is_not_null(x) ? x : c` — predicate operand and THEN are the same `x`, fill is ELSE. + } else if when.is::() && when.child(0) == then { + (then, els) + } else { + return Ok(None); + }; + + let Some(scalar) = fill.as_opt::() else { + return Ok(None); + }; + + if scalar.is_null() { + // Filling the nulls of `x` with NULL is a no-op + return Ok(Some(x.clone())); + } + + Ok(Some(crate::expr::fill_null(x.clone(), fill.clone()))) + } + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { true } @@ -410,12 +459,15 @@ mod tests { use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; + use crate::dtype::StructFields; use crate::expr::case_when; use crate::expr::case_when_no_else; use crate::expr::col; use crate::expr::eq; use crate::expr::get_item; use crate::expr::gt; + use crate::expr::is_not_null; + use crate::expr::is_null; use crate::expr::lit; use crate::expr::nested_case_when; use crate::expr::root; @@ -1193,6 +1245,170 @@ mod tests { assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array()); } + // ==================== Simplify: COALESCE -> fill_null ==================== + + /// Builds a non-nullable struct scope whose named fields are all `Nullable(I64)`. + fn nullable_i64_scope(fields: &[&str]) -> DType { + DType::Struct( + StructFields::new( + fields.to_vec().into(), + vec![DType::Primitive(PType::I64, Nullability::Nullable); fields.len()], + ), + Nullability::NonNullable, + ) + } + + #[test] + fn test_simplify_coalesce_is_null_rewrites_to_fill_null() -> VortexResult<()> { + // CASE WHEN is_null(x) THEN 0 ELSE x END ==> fill_null(x, 0) + let expr = case_when(is_null(col("x")), lit(0i64), col("x")); + let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?; + assert!( + optimized.to_string().starts_with("vortex.fill_null"), + "expected fill_null, got {optimized}" + ); + Ok(()) + } + + #[test] + fn test_simplify_coalesce_is_not_null_rewrites_to_fill_null() -> VortexResult<()> { + // CASE WHEN is_not_null(x) THEN x ELSE 0 END ==> fill_null(x, 0) + let expr = case_when(is_not_null(col("x")), col("x"), lit(0i64)); + let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?; + assert!( + optimized.to_string().starts_with("vortex.fill_null"), + "expected fill_null, got {optimized}" + ); + Ok(()) + } + + #[test] + fn test_simplify_does_not_fire_when_operands_differ() -> VortexResult<()> { + // The is_null operand (x) and the ELSE (y) are different columns: not a COALESCE. + let expr = case_when(is_null(col("x")), lit(0i64), col("y")); + let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "y"]))?; + let s = optimized.to_string(); + assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}"); + assert!(!s.contains("fill_null"), "must not rewrite, got {s}"); + Ok(()) + } + + #[test] + fn test_simplify_does_not_fire_for_non_constant_fill() -> VortexResult<()> { + // COALESCE(x, c) with a *column* fill: fill_null cannot consume a non-constant + // fill value, so the rewrite must not fire. + let expr = case_when(is_null(col("x")), col("c"), col("x")); + let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "c"]))?; + let s = optimized.to_string(); + assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}"); + assert!(!s.contains("fill_null"), "must not rewrite, got {s}"); + Ok(()) + } + + #[test] + fn test_simplify_null_fill_collapses_to_input() -> VortexResult<()> { + // Filling the nulls of x with NULL is a no-op, so both forms collapse to just `x`. + // CASE WHEN is_null(x) THEN null ELSE x END ==> x + // CASE WHEN is_not_null(x) THEN x ELSE null END ==> x + let null_fill = || { + lit(Scalar::null(DType::Primitive( + PType::I64, + Nullability::Nullable, + ))) + }; + + for expr in [ + case_when(is_null(col("x")), null_fill(), col("x")), + case_when(is_not_null(col("x")), col("x"), null_fill()), + ] { + let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?; + assert_eq!( + optimized.to_string(), + "$.x", + "expected collapse to input column, got {optimized}" + ); + } + Ok(()) + } + + #[test] + fn test_simplify_null_fill_semantic_equivalence() -> VortexResult<()> { + // The collapse-to-input rewrite must preserve values (and `x`'s nullability). + let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array(); + let scope = DType::Primitive(PType::I64, Nullability::Nullable); + let null_fill = lit(Scalar::null(DType::Primitive( + PType::I64, + Nullability::Nullable, + ))); + + let original = case_when(is_null(root()), null_fill, root()); + let optimized = original.optimize_recursive(&scope)?; + assert_eq!( + optimized.to_string(), + "$", + "expected collapse to root, got {optimized}" + ); + + let expected = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array(); + assert_arrays_eq!(evaluate_expr(&original, &array), expected); + assert_arrays_eq!(evaluate_expr(&optimized, &array), expected); + Ok(()) + } + + #[test] + fn test_simplify_does_not_fire_without_else() -> VortexResult<()> { + let expr = case_when_no_else(is_null(col("x")), lit(0i64)); + let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?; + assert!( + !optimized.to_string().contains("fill_null"), + "must not rewrite a no-ELSE case_when, got {optimized}" + ); + Ok(()) + } + + #[test] + fn test_simplify_does_not_fire_for_multi_pair() -> VortexResult<()> { + let expr = nested_case_when( + vec![ + (is_null(col("x")), lit(0i64)), + (gt(col("x"), lit(5i64)), lit(1i64)), + ], + Some(col("x")), + ); + let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?; + assert!( + !optimized.to_string().contains("fill_null"), + "must not rewrite a multi-pair case_when, got {optimized}" + ); + Ok(()) + } + + #[test] + fn test_simplify_semantic_equivalence() -> VortexResult<()> { + // The optimized expression must produce the same values as the original CASE WHEN. + let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array(); + let scope = DType::Primitive(PType::I64, Nullability::Nullable); + + let original = case_when(is_null(root()), lit(0i64), root()); + let optimized = original.optimize_recursive(&scope)?; + assert!( + optimized.to_string().starts_with("vortex.fill_null"), + "expected fill_null, got {optimized}" + ); + + // Original keeps CASE WHEN's nullable result dtype; the rewrite tightens it to + // NonNullable because a non-null fill cannot leave any nulls behind. Values match. + assert_arrays_eq!( + evaluate_expr(&original, &array), + PrimitiveArray::from_option_iter([Some(1i64), Some(0), Some(3)]).into_array() + ); + assert_arrays_eq!( + evaluate_expr(&optimized, &array), + buffer![1i64, 0, 3].into_array() + ); + Ok(()) + } + #[test] fn test_merge_case_branches_alternating_mask() -> VortexResult<()> { // Exercises the scalar path: alternating rows produce one slice per row (no runs),