Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions vortex-array/src/scalar_fn/fns/case_when.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ use crate::scalar_fn::ChildName;
use crate::scalar_fn::ExecutionArgs;
use crate::scalar_fn::ScalarFnId;
use crate::scalar_fn::ScalarFnVTable;
use crate::scalar_fn::SimplifyCtx;
use crate::scalar_fn::fns::is_not_null::IsNotNull;
use crate::scalar_fn::fns::is_null::IsNull;
use crate::scalar_fn::fns::literal::Literal;
use crate::scalar_fn::fns::zip::zip_impl;

/// Options for the n-ary CaseWhen expression.
Expand Down Expand Up @@ -251,6 +255,51 @@ impl ScalarFnVTable for CaseWhen {
merge_case_branches(branches, else_value, ctx)
}

fn simplify(
&self,
options: &Self::Options,
expr: &Expression,
_ctx: &dyn SimplifyCtx,
) -> VortexResult<Option<Expression>> {
// Rewrite the COALESCE-shaped CASE WHEN into `fill_null`, which references `x`
// once and lowers to a single fill kernel instead of a `zip`/merge that resolves
// `x` twice (once for the `is_null` predicate, once for the value branch).
//
// CASE WHEN is_null(x) THEN c ELSE x END ==> fill_null(x, c)
// CASE WHEN is_not_null(x) THEN x ELSE c END ==> fill_null(x, c)
//
// The fill `c` must be a `Literal`: `fill_null`'s kernel reads the fill value via
// `as_constant()`, so a non-constant fill would produce an unexecutable expression.
if options.num_when_then_pairs != 1 || !options.has_else {
return Ok(None);
}

let when = expr.child(0);
let then = expr.child(1);
let els = expr.child(2);

// `is_null(x) ? c : x` — predicate operand and ELSE are the same `x`, fill is THEN.
let (x, fill) = if when.is::<IsNull>() && when.child(0) == els {
(els, then)
// `is_not_null(x) ? x : c` — predicate operand and THEN are the same `x`, fill is ELSE.
} else if when.is::<IsNotNull>() && when.child(0) == then {
(then, els)
} else {
return Ok(None);
};

let Some(scalar) = fill.as_opt::<Literal>() else {
return Ok(None);
};

if scalar.is_null() {
// Filling the nulls of `x` with NULL is a no-op
return Ok(Some(x.clone()));
}

Ok(Some(crate::expr::fill_null(x.clone(), fill.clone())))
}

fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
true
}
Expand Down Expand Up @@ -410,12 +459,15 @@ mod tests {
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::dtype::StructFields;
use crate::expr::case_when;
use crate::expr::case_when_no_else;
use crate::expr::col;
use crate::expr::eq;
use crate::expr::get_item;
use crate::expr::gt;
use crate::expr::is_not_null;
use crate::expr::is_null;
use crate::expr::lit;
use crate::expr::nested_case_when;
use crate::expr::root;
Expand Down Expand Up @@ -1193,6 +1245,170 @@ mod tests {
assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
}

// ==================== Simplify: COALESCE -> fill_null ====================

/// Builds a non-nullable struct scope whose named fields are all `Nullable(I64)`.
fn nullable_i64_scope(fields: &[&str]) -> DType {
DType::Struct(
StructFields::new(
fields.to_vec().into(),
vec![DType::Primitive(PType::I64, Nullability::Nullable); fields.len()],
),
Nullability::NonNullable,
)
}

#[test]
fn test_simplify_coalesce_is_null_rewrites_to_fill_null() -> VortexResult<()> {
// CASE WHEN is_null(x) THEN 0 ELSE x END ==> fill_null(x, 0)
let expr = case_when(is_null(col("x")), lit(0i64), col("x"));
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
assert!(
optimized.to_string().starts_with("vortex.fill_null"),
"expected fill_null, got {optimized}"
);
Ok(())
}

#[test]
fn test_simplify_coalesce_is_not_null_rewrites_to_fill_null() -> VortexResult<()> {
// CASE WHEN is_not_null(x) THEN x ELSE 0 END ==> fill_null(x, 0)
let expr = case_when(is_not_null(col("x")), col("x"), lit(0i64));
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
assert!(
optimized.to_string().starts_with("vortex.fill_null"),
"expected fill_null, got {optimized}"
);
Ok(())
}

#[test]
fn test_simplify_does_not_fire_when_operands_differ() -> VortexResult<()> {
// The is_null operand (x) and the ELSE (y) are different columns: not a COALESCE.
let expr = case_when(is_null(col("x")), lit(0i64), col("y"));
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "y"]))?;
let s = optimized.to_string();
assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
Ok(())
}

#[test]
fn test_simplify_does_not_fire_for_non_constant_fill() -> VortexResult<()> {
// COALESCE(x, c) with a *column* fill: fill_null cannot consume a non-constant
// fill value, so the rewrite must not fire.
let expr = case_when(is_null(col("x")), col("c"), col("x"));
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x", "c"]))?;
let s = optimized.to_string();
assert!(s.contains("CASE"), "expected CASE WHEN to remain, got {s}");
assert!(!s.contains("fill_null"), "must not rewrite, got {s}");
Ok(())
}

#[test]
fn test_simplify_null_fill_collapses_to_input() -> VortexResult<()> {
// Filling the nulls of x with NULL is a no-op, so both forms collapse to just `x`.
// CASE WHEN is_null(x) THEN null ELSE x END ==> x
// CASE WHEN is_not_null(x) THEN x ELSE null END ==> x
let null_fill = || {
lit(Scalar::null(DType::Primitive(
PType::I64,
Nullability::Nullable,
)))
};

for expr in [
case_when(is_null(col("x")), null_fill(), col("x")),
case_when(is_not_null(col("x")), col("x"), null_fill()),
] {
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
assert_eq!(
optimized.to_string(),
"$.x",
"expected collapse to input column, got {optimized}"
);
}
Ok(())
}

#[test]
fn test_simplify_null_fill_semantic_equivalence() -> VortexResult<()> {
// The collapse-to-input rewrite must preserve values (and `x`'s nullability).
let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
let scope = DType::Primitive(PType::I64, Nullability::Nullable);
let null_fill = lit(Scalar::null(DType::Primitive(
PType::I64,
Nullability::Nullable,
)));

let original = case_when(is_null(root()), null_fill, root());
let optimized = original.optimize_recursive(&scope)?;
assert_eq!(
optimized.to_string(),
"$",
"expected collapse to root, got {optimized}"
);

let expected = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
assert_arrays_eq!(evaluate_expr(&original, &array), expected);
assert_arrays_eq!(evaluate_expr(&optimized, &array), expected);
Ok(())
}

#[test]
fn test_simplify_does_not_fire_without_else() -> VortexResult<()> {
let expr = case_when_no_else(is_null(col("x")), lit(0i64));
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
assert!(
!optimized.to_string().contains("fill_null"),
"must not rewrite a no-ELSE case_when, got {optimized}"
);
Ok(())
}

#[test]
fn test_simplify_does_not_fire_for_multi_pair() -> VortexResult<()> {
let expr = nested_case_when(
vec![
(is_null(col("x")), lit(0i64)),
(gt(col("x"), lit(5i64)), lit(1i64)),
],
Some(col("x")),
);
let optimized = expr.optimize_recursive(&nullable_i64_scope(&["x"]))?;
assert!(
!optimized.to_string().contains("fill_null"),
"must not rewrite a multi-pair case_when, got {optimized}"
);
Ok(())
}

#[test]
fn test_simplify_semantic_equivalence() -> VortexResult<()> {
// The optimized expression must produce the same values as the original CASE WHEN.
let array = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3)]).into_array();
let scope = DType::Primitive(PType::I64, Nullability::Nullable);

let original = case_when(is_null(root()), lit(0i64), root());
let optimized = original.optimize_recursive(&scope)?;
assert!(
optimized.to_string().starts_with("vortex.fill_null"),
"expected fill_null, got {optimized}"
);

// Original keeps CASE WHEN's nullable result dtype; the rewrite tightens it to
// NonNullable because a non-null fill cannot leave any nulls behind. Values match.
assert_arrays_eq!(
evaluate_expr(&original, &array),
PrimitiveArray::from_option_iter([Some(1i64), Some(0), Some(3)]).into_array()
);
assert_arrays_eq!(
evaluate_expr(&optimized, &array),
buffer![1i64, 0, 3].into_array()
);
Ok(())
}

#[test]
fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
// Exercises the scalar path: alternating rows produce one slice per row (no runs),
Expand Down
Loading