diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index ca1f824ac36d0..e6b819517a753 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -1254,3 +1254,51 @@ async fn comparisons_with_null_lt() { assert!(batch.columns()[0].is_null(1)); } } + +#[tokio::test] +async fn binary_mathematical_operator_with_null_lt() { + let ctx = SessionContext::new(); + + let cases = vec![ + // 1. Integer and NULL + "select column1 + NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column1 - NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column1 * NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column1 / NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column1 % NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + // 2. Float and NULL + "select column2 + NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column2 - NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column2 * NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column2 / NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + "select column2 % NULL from (VALUES (1, 2.3), (2, 5.4)) as t", + // ---- + // ---- same queries, reversed argument order + // ---- + // 3. NULL and Integer + "select NULL + column1 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL - column1 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL * column1 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL / column1 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL % column1 from (VALUES (1, 2.3), (2, 5.4)) as t", + // 4. NULL and Float + "select NULL + column2 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL - column2 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL * column2 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL / column2 from (VALUES (1, 2.3), (2, 5.4)) as t", + "select NULL % column2 from (VALUES (1, 2.3), (2, 5.4)) as t", + ]; + + for sql in cases { + println!("Computing: {}", sql); + + let mut actual = execute_to_batches(&ctx, sql).await; + assert_eq!(actual.len(), 1); + + let batch = actual.pop().unwrap(); + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 1); + assert!(batch.columns()[0].is_null(0)); + assert!(batch.columns()[0].is_null(1)); + } +} diff --git a/datafusion/expr/src/binary_rule.rs b/datafusion/expr/src/binary_rule.rs index f3d7534674ebd..c9ef1e4963c74 100644 --- a/datafusion/expr/src/binary_rule.rs +++ b/datafusion/expr/src/binary_rule.rs @@ -282,7 +282,7 @@ fn mathematics_numerical_coercion( use arrow::datatypes::DataType::*; // error on any non-numeric type - if !is_numeric(lhs_type) || !is_numeric(rhs_type) { + if !both_numeric_or_null_and_numeric(lhs_type, rhs_type) { return None; }; @@ -412,6 +412,15 @@ pub fn is_numeric(dt: &DataType) -> bool { } } +/// Determine if at least of one of lhs and rhs is numeric, and the other must be NULL or numeric +fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool { + match (lhs_type, rhs_type) { + (_, DataType::Null) => is_numeric(lhs_type), + (DataType::Null, _) => is_numeric(rhs_type), + _ => is_numeric(lhs_type) && is_numeric(rhs_type), + } +} + /// Coercion rules for dictionary values (aka the type of the dictionary itself) fn dictionary_value_coercion( lhs_type: &DataType, diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 9c6eedad1ac76..fb6f34c50c41e 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -701,16 +701,18 @@ macro_rules! compute_bool_op { /// LEFT is array, RIGHT is scalar value macro_rules! compute_op_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast array"); - // generate the scalar function name, such as lt_scalar, from the $OP parameter - // (which could have a value of lt) and the suffix _scalar - Ok(Arc::new(paste::expr! {[<$OP _scalar>]}( - &ll, - $RIGHT.try_into()?, - )?)) + if $RIGHT.is_null() { + Ok(Arc::new(new_null_array($LEFT.data_type(), $LEFT.len()))) + } else { + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + Ok(Arc::new(paste::expr! {[<$OP _scalar>]}( + &ll, + $RIGHT.try_into()?, + )?)) + } }}; }