diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 0e307153341bf..711a521da14c8 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -346,7 +346,10 @@ impl CaseExpr { .downcast_ref::() .expect("predicate should evaluate to a boolean array"); // invert the bitmask - let bit_mask = not(bit_mask)?; + let bit_mask = match bit_mask.null_count() { + 0 => not(bit_mask)?, + _ => not(&prep_null_mask_filter(bit_mask))?, + }; match then_expr.evaluate(batch)? { ColumnarValue::Array(array) => { Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?)) @@ -885,6 +888,32 @@ mod tests { Ok(()) } + #[test] + fn test_when_null_and_some_cond_else_null() -> Result<()> { + let batch = case_test_batch()?; + let schema = batch.schema(); + + let when = binary( + Arc::new(Literal::new(ScalarValue::Boolean(None))), + Operator::And, + binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?, + &schema, + )?; + let then = col("a", &schema)?; + + // SELECT CASE WHEN (NULL AND a = 'foo') THEN a ELSE NULL END + let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?); + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_string_array(&result); + + // all result values should be null + assert_eq!(result.logical_null_count(), batch.num_rows()); + Ok(()) + } + fn case_test_batch() -> Result { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 3c967eed219a9..4f3320931d2c5 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -50,7 +50,7 @@ NULL 6 NULL NULL -7 +NULL # column or implicit null query I @@ -61,8 +61,30 @@ NULL 6 NULL NULL +NULL + +# column or implicit null (no nulls) +query I +SELECT CASE WHEN NULLIF(NVL(a, 0) >= 0, FALSE) THEN b END FROM foo +---- +2 +4 +6 +NULL +NULL 7 +# column or implicit null (all nulls) +query I +SELECT CASE WHEN NULLIF(NVL(a, 0) >= 0, TRUE) THEN b END FROM foo +---- +NULL +NULL +NULL +NULL +NULL +NULL + # scalar or scalar (string) query T SELECT CASE WHEN a > 2 THEN 'even' ELSE 'odd' END FROM foo