From f5bf52b800bfa9e855fddff76dba1069d2248424 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 15 Oct 2025 16:34:50 -0700 Subject: [PATCH 1/3] chore: use `NullBuffer::union` for Spark `concat` --- .../spark/src/function/string/concat.rs | 115 +++++------------- 1 file changed, 32 insertions(+), 83 deletions(-) diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index 0e981e7c37224..bc620c394c1f4 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayBuilder}; +use arrow::array::Array; +use arrow::buffer::NullBuffer; use arrow::datatypes::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ @@ -122,13 +123,13 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { apply_null_mask(result, null_mask) } -/// Compute NULL mask for the arguments -/// Returns None if all scalars and any is NULL, or a Vector of -/// boolean representing the null mask for incoming arrays +/// Compute NULL mask for the arguments using NullBuffer::union +/// Returns None if all scalars and any is NULL, or an Option +/// representing the combined null mask for incoming arrays fn compute_null_mask( args: &[ColumnarValue], number_rows: usize, -) -> Result>> { +) -> Result>> { // Check if all arguments are scalars let all_scalars = args .iter() @@ -145,9 +146,9 @@ fn compute_null_mask( } } // No NULLs in scalars - Ok(Some(vec![])) + Ok(Some(None)) } else { - // For arrays, compute NULL mask for each row + // For arrays, compute NULL mask for each row using NullBuffer::union let array_len = args .iter() .find_map(|arg| match arg { @@ -166,24 +167,20 @@ fn compute_null_mask( .collect(); let arrays = arrays?; - // Compute NULL mask - let mut null_mask = vec![false; array_len]; - for array in &arrays { - for (i, null_flag) in null_mask.iter_mut().enumerate().take(array_len) { - if array.is_null(i) { - *null_flag = true; - } - } - } + // Use NullBuffer::union to combine all null buffers + let combined_nulls = arrays + .iter() + .map(|arr| arr.nulls()) + .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); - Ok(Some(null_mask)) + Ok(Some(combined_nulls)) } } -/// Apply NULL mask to the result +/// Apply NULL mask to the result using NullBuffer::union fn apply_null_mask( result: ColumnarValue, - null_mask: Option>, + null_mask: Option>, ) -> Result { match (result, null_mask) { // Scalar with NULL mask means return NULL @@ -191,70 +188,22 @@ fn apply_null_mask( Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) } // Scalar without NULL mask, return as-is - (scalar @ ColumnarValue::Scalar(_), Some(mask)) if mask.is_empty() => Ok(scalar), - // Array with NULL mask - (ColumnarValue::Array(array), Some(null_mask)) if !null_mask.is_empty() => { - let array_len = array.len(); - let return_type = array.data_type(); - - let mut builder: Box = match return_type { - DataType::Utf8 => { - let string_array = array - .as_any() - .downcast_ref::() - .unwrap(); - let mut builder = - arrow::array::StringBuilder::with_capacity(array_len, 0); - for (i, &is_null) in null_mask.iter().enumerate().take(array_len) { - if is_null || string_array.is_null(i) { - builder.append_null(); - } else { - builder.append_value(string_array.value(i)); - } - } - Box::new(builder) - } - DataType::LargeUtf8 => { - let string_array = array - .as_any() - .downcast_ref::() - .unwrap(); - let mut builder = - arrow::array::LargeStringBuilder::with_capacity(array_len, 0); - for (i, &is_null) in null_mask.iter().enumerate().take(array_len) { - if is_null || string_array.is_null(i) { - builder.append_null(); - } else { - builder.append_value(string_array.value(i)); - } - } - Box::new(builder) - } - DataType::Utf8View => { - let string_array = array - .as_any() - .downcast_ref::() - .unwrap(); - let mut builder = - arrow::array::StringViewBuilder::with_capacity(array_len); - for (i, &is_null) in null_mask.iter().enumerate().take(array_len) { - if is_null || string_array.is_null(i) { - builder.append_null(); - } else { - builder.append_value(string_array.value(i)); - } - } - Box::new(builder) - } - _ => { - return datafusion_common::exec_err!( - "Unsupported return type for concat: {:?}", - return_type - ); - } - }; - - Ok(ColumnarValue::Array(builder.finish())) + (scalar @ ColumnarValue::Scalar(_), Some(None)) => Ok(scalar), + // Array with NULL mask - use NullBuffer::union to combine nulls + (ColumnarValue::Array(array), Some(Some(null_mask))) => { + // Combine the result's existing nulls with our computed null mask + let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask)); + + // Create new array with combined nulls + let new_array = array + .into_data() + .into_builder() + .nulls(combined_nulls) + .build()?; + + Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array( + new_array, + )))) } // Array without NULL mask, return as-is (array @ ColumnarValue::Array(_), _) => Ok(array), From e76800d837a9434edf784163bc0d4867645fba9d Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 15 Oct 2025 16:56:00 -0700 Subject: [PATCH 2/3] chore: use `NullBuffer::union` for Spark `concat` --- datafusion/spark/src/function/string/concat.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index bc620c394c1f4..218fcfc335356 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -32,6 +32,10 @@ use std::sync::Arc; /// /// Concatenates multiple input strings into a single string. /// Returns NULL if any input is NULL. +/// +/// Differences with DataFusion concat: +/// - Support 0 arguments +/// - Return NULL if any input is NULL #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkConcat { signature: Signature, @@ -124,7 +128,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { } /// Compute NULL mask for the arguments using NullBuffer::union -/// Returns None if all scalars and any is NULL, or an Option +/// Returns None if all scalars and any is NULL, or an Option of NullBuffer /// representing the combined null mask for incoming arrays fn compute_null_mask( args: &[ColumnarValue], From 07b81fe213ac8f037bf73846530b3f39009f416f Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 16 Oct 2025 09:06:08 -0700 Subject: [PATCH 3/3] chore: use `NullBuffer::union` for Spark `concat` --- .../spark/src/function/string/concat.rs | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index 218fcfc335356..0dcc58d5bb8ed 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -85,6 +85,16 @@ impl ScalarUDFImpl for SparkConcat { } } +/// Represents the null state for Spark concat +enum NullMaskResolution { + /// Return NULL as the result (e.g., scalar inputs with at least one NULL) + ReturnNull, + /// No null mask needed (e.g., all scalar inputs are non-NULL) + NoMask, + /// Null mask to apply for arrays + Apply(NullBuffer), +} + /// Concatenates strings, returning NULL if any input is NULL /// This is a Spark-specific wrapper around DataFusion's concat that returns NULL /// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs. @@ -108,7 +118,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { let null_mask = compute_null_mask(&arg_values, number_rows)?; // If all scalars and any is NULL, return NULL immediately - if null_mask.is_none() { + if matches!(null_mask, NullMaskResolution::ReturnNull) { return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); } @@ -128,12 +138,10 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { } /// Compute NULL mask for the arguments using NullBuffer::union -/// Returns None if all scalars and any is NULL, or an Option of NullBuffer -/// representing the combined null mask for incoming arrays fn compute_null_mask( args: &[ColumnarValue], number_rows: usize, -) -> Result>> { +) -> Result { // Check if all arguments are scalars let all_scalars = args .iter() @@ -144,13 +152,12 @@ fn compute_null_mask( for arg in args { if let ColumnarValue::Scalar(scalar) = arg { if scalar.is_null() { - // Return None to indicate all values should be NULL - return Ok(None); + return Ok(NullMaskResolution::ReturnNull); } } } // No NULLs in scalars - Ok(Some(None)) + Ok(NullMaskResolution::NoMask) } else { // For arrays, compute NULL mask for each row using NullBuffer::union let array_len = args @@ -177,24 +184,27 @@ fn compute_null_mask( .map(|arr| arr.nulls()) .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); - Ok(Some(combined_nulls)) + match combined_nulls { + Some(nulls) => Ok(NullMaskResolution::Apply(nulls)), + None => Ok(NullMaskResolution::NoMask), + } } } /// Apply NULL mask to the result using NullBuffer::union fn apply_null_mask( result: ColumnarValue, - null_mask: Option>, + null_mask: NullMaskResolution, ) -> Result { match (result, null_mask) { - // Scalar with NULL mask means return NULL - (ColumnarValue::Scalar(_), None) => { + // Scalar with ReturnNull mask means return NULL + (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => { Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) } - // Scalar without NULL mask, return as-is - (scalar @ ColumnarValue::Scalar(_), Some(None)) => Ok(scalar), + // Scalar without mask, return as-is + (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar), // Array with NULL mask - use NullBuffer::union to combine nulls - (ColumnarValue::Array(array), Some(Some(null_mask))) => { + (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => { // Combine the result's existing nulls with our computed null mask let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask)); @@ -210,8 +220,8 @@ fn apply_null_mask( )))) } // Array without NULL mask, return as-is - (array @ ColumnarValue::Array(_), _) => Ok(array), - // Shouldn't happen + (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array), + // Edge cases that shouldn't happen in practice (scalar, _) => Ok(scalar), } }