diff --git a/vortex-array/src/array/erased.rs b/vortex-array/src/array/erased.rs index 9b6d92014c9..9dad46fe0ae 100644 --- a/vortex-array/src/array/erased.rs +++ b/vortex-array/src/array/erased.rs @@ -41,6 +41,7 @@ use crate::arrays::DictArray; use crate::arrays::FilterArray; use crate::arrays::Null; use crate::arrays::Primitive; +use crate::arrays::ScalarFn; use crate::arrays::SliceArray; use crate::arrays::VarBin; use crate::arrays::VarBinView; @@ -276,7 +277,7 @@ impl ArrayRef { /// Execute the array to extract a scalar at the given index. pub fn execute_scalar(&self, index: usize, ctx: &mut ExecutionCtx) -> VortexResult { vortex_ensure!(index < self.len(), OutOfBounds: index, 0, self.len()); - if self.dtype().is_nullable() && self.is_invalid(index, ctx)? { + if self.dtype().is_nullable() && !self.is::() && self.is_invalid(index, ctx)? { return Ok(Scalar::null(self.dtype().clone())); } let scalar = self.0.data.execute_scalar(self, index, ctx)?; diff --git a/vortex-array/src/arrays/scalar_fn/vtable/mod.rs b/vortex-array/src/arrays/scalar_fn/vtable/mod.rs index 6833848deae..e350e5cd0c1 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/mod.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/mod.rs @@ -8,6 +8,7 @@ use std::hash::Hash; use std::hash::Hasher; use std::marker::PhantomData; use std::ops::Deref; +use std::sync::Arc; use itertools::Itertools; use vortex_error::VortexResult; @@ -37,13 +38,18 @@ use crate::dtype::DType; use crate::executor::ExecutionCtx; use crate::executor::ExecutionResult; use crate::expr::Expression; +use crate::expr::lit; use crate::matcher::Matcher; use crate::scalar_fn; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; +use crate::scalar_fn::ReduceCtx; +use crate::scalar_fn::ReduceNode; +use crate::scalar_fn::ReduceNodeRef; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTableExt; +use crate::scalar_fn::TypedScalarFnInstance; use crate::scalar_fn::VecExecutionArgs; use crate::serde::ArrayChildren; @@ -175,7 +181,7 @@ pub trait ScalarFnFactoryExt: scalar_fn::ScalarFnVTable { options: Self::Options, children: impl Into>, ) -> VortexResult { - let scalar_fn = scalar_fn::TypedScalarFnInstance::new(self.clone(), options).erased(); + let scalar_fn = TypedScalarFnInstance::new(self.clone(), options).erased(); let children = children.into(); vortex_ensure!( @@ -201,6 +207,24 @@ pub trait ScalarFnFactoryExt: scalar_fn::ScalarFnVTable { } impl ScalarFnFactoryExt for V {} +pub(crate) fn scalar_fn_array_expr(array: ArrayView<'_, ScalarFn>) -> VortexResult { + let inputs: Vec<_> = array + .iter_children() + .map(|child| { + if let Some(scalar) = child.as_constant() { + return Ok(lit(scalar)); + } + + Expression::try_new( + TypedScalarFnInstance::new(ArrayExpr, FakeEq(child.clone())).erased(), + [], + ) + }) + .collect::>()?; + + Expression::try_new(array.scalar_fn().clone(), inputs) +} + /// A matcher that matches any scalar function expression. #[derive(Debug)] pub struct AnyScalarFn; @@ -320,12 +344,21 @@ impl scalar_fn::ScalarFnVTable for ArrayExpr { crate::Executable::execute(options.0.clone(), ctx) } + fn reduce( + &self, + options: &Self::Options, + _node: &dyn ReduceNode, + _ctx: &dyn ReduceCtx, + ) -> VortexResult> { + Ok(Some(Arc::new(options.0.clone()))) + } + fn validity( &self, options: &Self::Options, _expression: &Expression, - ) -> VortexResult> { + ) -> VortexResult { let validity_array = options.0.validity()?.to_array(options.0.len()); - Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), []))) + Ok(ArrayExpr.new_expr(FakeEq(validity_array), [])) } } diff --git a/vortex-array/src/arrays/scalar_fn/vtable/operations.rs b/vortex-array/src/arrays/scalar_fn/vtable/operations.rs index 67de19d936f..782fdfda033 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/operations.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/operations.rs @@ -65,6 +65,7 @@ mod tests { use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; use crate::arrays::ScalarFnArray; + use crate::arrays::StructArray; use crate::arrays::scalar_fn::ScalarFnArrayExt; use crate::assert_arrays_eq; use crate::scalar::Scalar; @@ -172,6 +173,29 @@ mod tests { Ok(()) } + #[test] + fn scalar_fn_scalar_at_handles_value_derived_validity() -> VortexResult<()> { + let child = StructArray::from_fields(&[( + "a", + PrimitiveArray::from_option_iter([Some(1i32), None]).into_array(), + )])? + .into_array(); + let expr = crate::expr::get_item("a", crate::expr::root()); + let array = ScalarFnArray::try_new(expr.scalar_fn().clone(), vec![child])?.into_array(); + + assert_eq!( + array.execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())?, + Scalar::primitive(1i32, true.into()) + ); + assert!( + array + .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())? + .is_null() + ); + + Ok(()) + } + #[test] fn test_scalar_fn_comparison() -> VortexResult<()> { let lhs = buffer![1i32, 5, 3].into_array(); diff --git a/vortex-array/src/arrays/scalar_fn/vtable/validity.rs b/vortex-array/src/arrays/scalar_fn/vtable/validity.rs index 2ac376155e3..5a738b1a688 100644 --- a/vortex-array/src/arrays/scalar_fn/vtable/validity.rs +++ b/vortex-array/src/arrays/scalar_fn/vtable/validity.rs @@ -3,72 +3,57 @@ use vortex_error::VortexResult; -use crate::ArrayRef; use crate::IntoArray; -use crate::LEGACY_SESSION; -use crate::VortexSessionExecute; use crate::array::ArrayView; use crate::array::ValidityVTable; -use crate::arrays::scalar_fn::ScalarFnArrayExt; -use crate::arrays::scalar_fn::vtable::ArrayExpr; -use crate::arrays::scalar_fn::vtable::FakeEq; +use crate::arrays::ConstantArray; use crate::arrays::scalar_fn::vtable::ScalarFn; -use crate::expr::Expression; -use crate::expr::lit; -use crate::scalar_fn::TypedScalarFnInstance; -use crate::scalar_fn::VecExecutionArgs; -use crate::scalar_fn::fns::literal::Literal; -use crate::scalar_fn::fns::root::Root; +use crate::arrays::scalar_fn::vtable::scalar_fn_array_expr; use crate::validity::Validity; -/// Execute an expression tree recursively. -/// -/// This assumes all leaf expressions are either ArrayExpr (wrapping actual arrays) or Literals. -fn execute_expr(expr: &Expression, row_count: usize) -> VortexResult { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - - // Handle Root expression - this should not happen in validity expressions - if expr.is::() { - vortex_error::vortex_bail!("Root expression cannot be executed in validity context"); - } - - // Handle Literal expression - create a constant array - if expr.is::() { - let scalar = expr.as_::(); - return Ok(crate::arrays::ConstantArray::new(scalar.clone(), row_count).into_array()); - } - - // Recursively execute child expressions to get input arrays - let inputs: Vec = expr - .children() - .iter() - .map(|child| execute_expr(child, row_count)) - .collect::>()?; - - let args = VecExecutionArgs::new(inputs, row_count); - - Ok(expr.scalar_fn().execute(&args, &mut ctx)?.into_array()) -} - impl ValidityVTable for ScalarFn { fn validity(array: ArrayView<'_, ScalarFn>) -> VortexResult { - let inputs: Vec<_> = array - .iter_children() - .map(|child| { - if let Some(scalar) = child.as_constant() { - return Ok(lit(scalar)); - } - Expression::try_new( - TypedScalarFnInstance::new(ArrayExpr, FakeEq(child.clone())).erased(), - [], - ) - }) - .collect::>()?; - - let expr = Expression::try_new(array.scalar_fn().clone(), inputs)?; - let validity_expr = array.scalar_fn().validity(&expr)?; + let expr = scalar_fn_array_expr(array)?.validity()?; + let input = ConstantArray::new(true, array.len()).into_array(); + Ok(Validity::Array(input.apply(&expr)?)) + } +} - // Execute the validity expression. All leaves are ArrayExpr nodes. - Ok(Validity::Array(execute_expr(&validity_expr, array.len())?)) +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + use vortex_mask::Mask; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::arrays::BoolArray; + use crate::arrays::ScalarFn; + use crate::arrays::scalar_fn::ScalarFnArrayExt; + use crate::arrays::scalar_fn::vtable::ScalarFnFactoryExt; + use crate::scalar_fn::fns::binary::Binary; + use crate::scalar_fn::fns::operators::Operator; + use crate::validity::Validity; + + #[test] + fn scalar_fn_validity_stays_lazy() -> VortexResult<()> { + let lhs = BoolArray::from_iter([Some(true), None, Some(false)]).into_array(); + let rhs = BoolArray::from_iter([Some(true), Some(false), None]).into_array(); + let predicate = Binary.try_new_array(lhs.len(), Operator::And, [lhs, rhs])?; + + let Validity::Array(validity_array) = predicate.validity()? else { + panic!("scalar function validity should be represented as an array"); + }; + + let validity_scalar_fn = validity_array + .as_opt::() + .expect("validity should remain a lazy scalar function array"); + assert!(validity_scalar_fn.scalar_fn().is::()); + + let validity_mask = + validity_array.execute::(&mut LEGACY_SESSION.create_execution_ctx())?; + assert!(validity_mask.all_true()); + + Ok(()) } } diff --git a/vortex-array/src/scalar_fn/erased.rs b/vortex-array/src/scalar_fn/erased.rs index 6e0011c297a..5f547d641ee 100644 --- a/vortex-array/src/scalar_fn/erased.rs +++ b/vortex-array/src/scalar_fn/erased.rs @@ -20,16 +20,13 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::scalar_fn::EmptyOptions; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ReduceCtx; use crate::scalar_fn::ReduceNode; use crate::scalar_fn::ReduceNodeRef; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::ScalarFnVTableExt; use crate::scalar_fn::SimplifyCtx; -use crate::scalar_fn::fns::is_not_null::IsNotNull; use crate::scalar_fn::options::ScalarFnOptions; use crate::scalar_fn::signature::ScalarFnSignature; use crate::scalar_fn::typed::DynScalarFn; @@ -132,10 +129,7 @@ impl ScalarFnRef { /// Transforms the expression into one representing the validity of this expression. pub fn validity(&self, expr: &Expression) -> VortexResult { - Ok(self.0.validity(expr)?.unwrap_or_else(|| { - // TODO(ngates): make validity a mandatory method on VTable to avoid this fallback. - IsNotNull.new_expr(EmptyOptions, [expr.clone()]) - })) + self.0.validity(expr) } /// Execute the expression given the input arguments. diff --git a/vortex-array/src/scalar_fn/fns/between/mod.rs b/vortex-array/src/scalar_fn/fns/between/mod.rs index 2cbe02183c8..8d37411334e 100644 --- a/vortex-array/src/scalar_fn/fns/between/mod.rs +++ b/vortex-array/src/scalar_fn/fns/between/mod.rs @@ -299,11 +299,11 @@ impl ScalarFnVTable for Between { &self, _options: &Self::Options, expression: &Expression, - ) -> VortexResult> { + ) -> VortexResult { let arr = expression.child(0).validity()?; let lower = expression.child(1).validity()?; let upper = expression.child(2).validity()?; - Ok(Some(and(and(arr, lower), upper))) + Ok(and(and(arr, lower), upper)) } fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index bda7e61ffd7..fe0d1eb085d 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -20,7 +20,10 @@ use crate::dtype::DType; use crate::dtype::Nullability; use crate::expr::and; use crate::expr::expression::Expression; +use crate::expr::fill_null; use crate::expr::lit; +use crate::expr::not; +use crate::expr::or; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; @@ -233,21 +236,30 @@ impl ScalarFnVTable for Binary { Ok(None) } - fn validity( - &self, - operator: &Operator, - expression: &Expression, - ) -> VortexResult> { + fn validity(&self, operator: &Operator, expression: &Expression) -> VortexResult { let lhs = expression.child(0).validity()?; let rhs = expression.child(1).validity()?; Ok(match operator { - // AND and OR are kleene logic. - Operator::And => None, - Operator::Or => None, + // AND and OR are Kleene logic. Their result is valid if both children are valid, + // or if a valid child value alone determines the result. + Operator::And => or( + and(lhs, rhs), + or( + not(fill_null(expression.child(0).clone(), lit(true))), + not(fill_null(expression.child(1).clone(), lit(true))), + ), + ), + Operator::Or => or( + and(lhs, rhs), + or( + fill_null(expression.child(0).clone(), lit(false)), + fill_null(expression.child(1).clone(), lit(false)), + ), + ), _ => { // All other binary operators are null if either side is null. - Some(and(lhs, rhs)) + and(lhs, rhs) } }) } @@ -540,6 +552,60 @@ mod tests { assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array()) } + #[test] + fn test_kleene_boolean_result_validity() -> VortexResult<()> { + use crate::IntoArray; + use crate::arrays::BoolArray; + use crate::validity::Validity; + + let lhs = BoolArray::from_iter([ + Some(true), + Some(true), + Some(false), + Some(false), + None, + None, + Some(true), + Some(false), + None, + ]) + .into_array(); + let rhs = BoolArray::from_iter([ + Some(true), + Some(false), + Some(true), + None, + Some(true), + Some(false), + None, + None, + None, + ]) + .into_array(); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + let and_result = lhs.binary(rhs.clone(), Operator::And)?; + let and_expected = + Validity::from_iter([true, true, true, true, false, true, false, true, false]); + assert!( + and_result + .validity()? + .mask_eq(&and_expected, and_result.len(), &mut ctx)? + ); + + let or_result = lhs.binary(rhs, Operator::Or)?; + let or_expected = + Validity::from_iter([true, true, true, false, true, false, true, false, false]); + assert!( + or_result + .validity()? + .mask_eq(&or_expected, or_result.len(), &mut ctx)? + ); + + Ok(()) + } + #[test] fn test_scalar_subtract_unsigned() { use vortex_buffer::buffer; diff --git a/vortex-array/src/scalar_fn/fns/byte_length.rs b/vortex-array/src/scalar_fn/fns/byte_length.rs index aa9c508ea89..5b2de882a9d 100644 --- a/vortex-array/src/scalar_fn/fns/byte_length.rs +++ b/vortex-array/src/scalar_fn/fns/byte_length.rs @@ -123,12 +123,8 @@ impl ScalarFnVTable for ByteLength { } } - fn validity( - &self, - _: &Self::Options, - expression: &Expression, - ) -> VortexResult> { - Ok(Some(expression.child(0).validity()?)) + fn validity(&self, _: &Self::Options, expression: &Expression) -> VortexResult { + expression.child(0).validity() } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { diff --git a/vortex-array/src/scalar_fn/fns/case_when.rs b/vortex-array/src/scalar_fn/fns/case_when.rs index 0a5f6b4ee14..92c33a4c7b8 100644 --- a/vortex-array/src/scalar_fn/fns/case_when.rs +++ b/vortex-array/src/scalar_fn/fns/case_when.rs @@ -35,6 +35,8 @@ use crate::builders::builder_with_capacity; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::expr::Expression; +use crate::expr::lit; +use crate::expr::zip_expr; use crate::scalar::Scalar; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; @@ -300,6 +302,29 @@ impl ScalarFnVTable for CaseWhen { Ok(Some(crate::expr::fill_null(x.clone(), fill.clone()))) } + fn validity( + &self, + options: &Self::Options, + expression: &Expression, + ) -> VortexResult { + let num_pairs = options.num_when_then_pairs as usize; + let mut validity = if options.has_else { + expression.child(num_pairs * 2).validity()? + } else { + lit(false) + }; + + for pair_idx in (0..num_pairs).rev() { + validity = zip_expr( + expression.child(pair_idx * 2).clone(), + expression.child(pair_idx * 2 + 1).validity()?, + validity, + ); + } + + Ok(validity) + } + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { true } diff --git a/vortex-array/src/scalar_fn/fns/cast/mod.rs b/vortex-array/src/scalar_fn/fns/cast/mod.rs index 20852779d42..c5452402bfd 100644 --- a/vortex-array/src/scalar_fn/fns/cast/mod.rs +++ b/vortex-array/src/scalar_fn/fns/cast/mod.rs @@ -149,12 +149,12 @@ impl ScalarFnVTable for Cast { Ok(None) } - fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult> { - Ok(Some(if dtype.is_nullable() { + fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult { + Ok(if dtype.is_nullable() { expression.child(0).validity()? } else { lit(true) - })) + }) } // This might apply a nullability diff --git a/vortex-array/src/scalar_fn/fns/dynamic.rs b/vortex-array/src/scalar_fn/fns/dynamic.rs index f6e6619282a..7d419d5e11d 100644 --- a/vortex-array/src/scalar_fn/fns/dynamic.rs +++ b/vortex-array/src/scalar_fn/fns/dynamic.rs @@ -20,6 +20,7 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::expr::Expression; +use crate::expr::is_not_null; use crate::expr::traversal::NodeExt; use crate::expr::traversal::NodeVisitor; use crate::expr::traversal::TraversalOrder; @@ -119,6 +120,14 @@ impl ScalarFnVTable for DynamicComparison { .into_array()) } + fn validity( + &self, + _dynamic: &DynamicComparisonExpr, + expression: &Expression, + ) -> VortexResult { + Ok(is_not_null(expression.clone())) + } + // Defer to the child fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false diff --git a/vortex-array/src/scalar_fn/fns/fill_null/mod.rs b/vortex-array/src/scalar_fn/fns/fill_null/mod.rs index b2d07b41255..6b19fef6f51 100644 --- a/vortex-array/src/scalar_fn/fns/fill_null/mod.rs +++ b/vortex-array/src/scalar_fn/fns/fill_null/mod.rs @@ -126,10 +126,10 @@ impl ScalarFnVTable for FillNull { &self, _options: &Self::Options, expression: &Expression, - ) -> VortexResult> { + ) -> VortexResult { // After fill_null, the result validity depends on the fill value's nullability. // If fill_value is non-nullable, the result is always valid. - Ok(Some(expression.child(1).validity()?)) + expression.child(1).validity() } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { diff --git a/vortex-array/src/scalar_fn/fns/get_item.rs b/vortex-array/src/scalar_fn/fns/get_item.rs index b9cb9202f38..4a02b68f06f 100644 --- a/vortex-array/src/scalar_fn/fns/get_item.rs +++ b/vortex-array/src/scalar_fn/fns/get_item.rs @@ -20,6 +20,7 @@ use crate::dtype::DType; use crate::dtype::FieldName; use crate::dtype::Nullability; use crate::expr::Expression; +use crate::expr::is_not_null; use crate::expr::lit; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; @@ -185,6 +186,14 @@ impl ScalarFnVTable for GetItem { Ok(None) } + fn validity( + &self, + _field_name: &FieldName, + expression: &Expression, + ) -> VortexResult { + Ok(is_not_null(expression.clone())) + } + // This will apply struct nullability field. We could add a dtype?? fn is_null_sensitive(&self, _field_name: &FieldName) -> bool { true diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index f2849f53ccf..68cc2e45b5f 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -11,9 +11,11 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::ConstantArray; +use crate::arrays::ScalarFn; use crate::dtype::DType; use crate::dtype::Nullability; use crate::expr::Expression; +use crate::expr::lit; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -76,10 +78,16 @@ impl ScalarFnVTable for IsNotNull { &self, _data: &Self::Options, args: &dyn ExecutionArgs, - _ctx: &mut ExecutionCtx, + ctx: &mut ExecutionCtx, ) -> VortexResult { let child = args.get(0)?; - match child.validity()? { + let validity = if child.is::() { + child.execute::(ctx)?.validity()? + } else { + child.validity()? + }; + + match validity { Validity::NonNullable | Validity::AllValid => { Ok(ConstantArray::new(true, args.row_count()).into_array()) } @@ -88,6 +96,14 @@ impl ScalarFnVTable for IsNotNull { } } + fn validity( + &self, + _options: &Self::Options, + _expression: &Expression, + ) -> VortexResult { + Ok(lit(true)) + } + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index 8df263a4b22..59d099ac522 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -9,9 +9,12 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::ConstantArray; +use crate::arrays::ScalarFn; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; +use crate::expr::Expression; +use crate::expr::lit; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -63,14 +66,20 @@ impl ScalarFnVTable for IsNull { &self, _data: &Self::Options, args: &dyn ExecutionArgs, - _ctx: &mut ExecutionCtx, + ctx: &mut ExecutionCtx, ) -> VortexResult { let child = args.get(0)?; if let Some(scalar) = child.as_constant() { return Ok(ConstantArray::new(scalar.is_null(), args.row_count()).into_array()); } - match child.validity()? { + let validity = if child.is::() { + child.execute::(ctx)?.validity()? + } else { + child.validity()? + }; + + match validity { Validity::NonNullable | Validity::AllValid => { Ok(ConstantArray::new(false, args.row_count()).into_array()) } @@ -79,6 +88,14 @@ impl ScalarFnVTable for IsNull { } } + fn validity( + &self, + _options: &Self::Options, + _expression: &Expression, + ) -> VortexResult { + Ok(lit(true)) + } + fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } diff --git a/vortex-array/src/scalar_fn/fns/like/mod.rs b/vortex-array/src/scalar_fn/fns/like/mod.rs index f1c53b146cb..72a67d51936 100644 --- a/vortex-array/src/scalar_fn/fns/like/mod.rs +++ b/vortex-array/src/scalar_fn/fns/like/mod.rs @@ -147,11 +147,11 @@ impl ScalarFnVTable for Like { &self, _options: &Self::Options, expression: &Expression, - ) -> VortexResult> { + ) -> VortexResult { tracing::warn!("Computing validity for LIKE expression"); let child_validity = expression.child(0).validity()?; let pattern_validity = expression.child(1).validity()?; - Ok(Some(and(child_validity, pattern_validity))) + Ok(and(child_validity, pattern_validity)) } fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { diff --git a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs index 4b39f51a7f2..3039df3bf74 100644 --- a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs +++ b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs @@ -33,6 +33,8 @@ use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::IntegerPType; use crate::dtype::Nullability; +use crate::expr::Expression; +use crate::expr::and; use crate::match_each_integer_ptype; use crate::match_each_unsigned_integer_ptype; use crate::scalar::ListScalar; @@ -121,6 +123,17 @@ impl ScalarFnVTable for ListContains { compute_list_contains(&list_array, &value_array, ctx) } + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult { + Ok(and( + expression.child(0).validity()?, + expression.child(1).validity()?, + )) + } + // Nullability matters for contains([], x) where x is false. fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true diff --git a/vortex-array/src/scalar_fn/fns/literal.rs b/vortex-array/src/scalar_fn/fns/literal.rs index 5181a5250dd..c48b45ce4a5 100644 --- a/vortex-array/src/scalar_fn/fns/literal.rs +++ b/vortex-array/src/scalar_fn/fns/literal.rs @@ -93,12 +93,8 @@ impl ScalarFnVTable for Literal { Ok(ConstantArray::new(scalar.clone(), args.row_count()).into_array()) } - fn validity( - &self, - scalar: &Scalar, - _expression: &Expression, - ) -> VortexResult> { - Ok(Some(lit(scalar.is_valid()))) + fn validity(&self, scalar: &Scalar, _expression: &Expression) -> VortexResult { + Ok(lit(scalar.is_valid())) } fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { diff --git a/vortex-array/src/scalar_fn/fns/mask/mod.rs b/vortex-array/src/scalar_fn/fns/mask/mod.rs index 4dd55948ec8..44baab5f6f1 100644 --- a/vortex-array/src/scalar_fn/fns/mask/mod.rs +++ b/vortex-array/src/scalar_fn/fns/mask/mod.rs @@ -127,11 +127,11 @@ impl ScalarFnVTable for Mask { &self, _options: &Self::Options, expression: &Expression, - ) -> VortexResult> { - Ok(Some(and( + ) -> VortexResult { + Ok(and( expression.child(0).validity()?, expression.child(1).clone(), - ))) + )) } } diff --git a/vortex-array/src/scalar_fn/fns/merge.rs b/vortex-array/src/scalar_fn/fns/merge.rs index 390d3a2213e..94355c24de1 100644 --- a/vortex-array/src/scalar_fn/fns/merge.rs +++ b/vortex-array/src/scalar_fn/fns/merge.rs @@ -237,8 +237,8 @@ impl ScalarFnVTable for Merge { &self, _options: &Self::Options, _expression: &Expression, - ) -> VortexResult> { - Ok(Some(lit(true))) + ) -> VortexResult { + Ok(lit(true)) } fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { diff --git a/vortex-array/src/scalar_fn/fns/not/mod.rs b/vortex-array/src/scalar_fn/fns/not/mod.rs index 156a6568df7..5a0014c7d57 100644 --- a/vortex-array/src/scalar_fn/fns/not/mod.rs +++ b/vortex-array/src/scalar_fn/fns/not/mod.rs @@ -18,6 +18,7 @@ use crate::arrays::ConstantArray; use crate::arrays::bool::BoolArrayExt; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; +use crate::expr::Expression; use crate::scalar::Scalar; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; @@ -98,6 +99,14 @@ impl ScalarFnVTable for Not { child.execute::(ctx)?.not() } + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult { + expression.child(0).validity() + } + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } diff --git a/vortex-array/src/scalar_fn/fns/pack.rs b/vortex-array/src/scalar_fn/fns/pack.rs index 2a87bf069bc..b2f81f84ac3 100644 --- a/vortex-array/src/scalar_fn/fns/pack.rs +++ b/vortex-array/src/scalar_fn/fns/pack.rs @@ -130,8 +130,8 @@ impl ScalarFnVTable for Pack { &self, _options: &Self::Options, _expression: &Expression, - ) -> VortexResult> { - Ok(Some(lit(true))) + ) -> VortexResult { + Ok(lit(true)) } fn execute( diff --git a/vortex-array/src/scalar_fn/fns/root.rs b/vortex-array/src/scalar_fn/fns/root.rs index 7bd5b758796..39c6271d57e 100644 --- a/vortex-array/src/scalar_fn/fns/root.rs +++ b/vortex-array/src/scalar_fn/fns/root.rs @@ -12,6 +12,7 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::expression::Expression; +use crate::expr::is_not_null; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -77,6 +78,14 @@ impl ScalarFnVTable for Root { vortex_bail!("Root expression is not executable") } + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult { + Ok(is_not_null(expression.clone())) + } + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } diff --git a/vortex-array/src/scalar_fn/fns/select.rs b/vortex-array/src/scalar_fn/fns/select.rs index 54b63a00f89..c4c22cfd5cd 100644 --- a/vortex-array/src/scalar_fn/fns/select.rs +++ b/vortex-array/src/scalar_fn/fns/select.rs @@ -232,6 +232,14 @@ impl ScalarFnVTable for Select { Ok(None) } + fn validity( + &self, + _selection: &FieldSelection, + expression: &Expression, + ) -> VortexResult { + expression.child(0).validity() + } + fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool { true } diff --git a/vortex-array/src/scalar_fn/fns/stat.rs b/vortex-array/src/scalar_fn/fns/stat.rs index 84fc5760495..d3f39af7efb 100644 --- a/vortex-array/src/scalar_fn/fns/stat.rs +++ b/vortex-array/src/scalar_fn/fns/stat.rs @@ -21,6 +21,7 @@ use crate::aggregate_fn::fns::all_null::AllNull; use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::expr::Expression; +use crate::expr::is_not_null; use crate::expr::stats::Precision; use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; @@ -123,6 +124,14 @@ impl ScalarFnVTable for StatFn { let dtype = stat_dtype(options.aggregate_fn(), input.dtype())?; stat_array(&input, options.aggregate_fn(), dtype, args.row_count()) } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult { + Ok(is_not_null(expression.clone())) + } } fn stat_dtype(aggregate_fn: &AggregateFnRef, input_dtype: &DType) -> VortexResult { diff --git a/vortex-array/src/scalar_fn/fns/variant_get/mod.rs b/vortex-array/src/scalar_fn/fns/variant_get/mod.rs index ac6aa562aa6..aa292332fdf 100644 --- a/vortex-array/src/scalar_fn/fns/variant_get/mod.rs +++ b/vortex-array/src/scalar_fn/fns/variant_get/mod.rs @@ -26,6 +26,7 @@ use crate::dtype::DType; use crate::dtype::FieldName; use crate::dtype::Nullability; use crate::expr::Expression; +use crate::expr::is_not_null; use crate::scalar::Scalar; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; @@ -163,6 +164,14 @@ impl ScalarFnVTable for VariantGet { let array = ChunkedArray::try_new(chunks, dtype)?.into_array(); VariantArray::try_new(array, None).map(|array| array.into_array()) } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult { + Ok(is_not_null(expression.clone())) + } } fn variant_get_scalar( diff --git a/vortex-array/src/scalar_fn/fns/zip/mod.rs b/vortex-array/src/scalar_fn/fns/zip/mod.rs index 8a3e1b5bbc6..0f5d4bb96e9 100644 --- a/vortex-array/src/scalar_fn/fns/zip/mod.rs +++ b/vortex-array/src/scalar_fn/fns/zip/mod.rs @@ -24,6 +24,7 @@ use crate::builders::builder_with_capacity; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::expr::Expression; +use crate::expr::zip_expr; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -155,6 +156,18 @@ impl ScalarFnVTable for Zip { Ok(None) } + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult { + Ok(zip_expr( + expression.child(2).clone(), + expression.child(0).validity()?, + expression.child(1).validity()?, + )) + } + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { true } diff --git a/vortex-array/src/scalar_fn/foreign.rs b/vortex-array/src/scalar_fn/foreign.rs index d2abd3b9a95..eefe799a279 100644 --- a/vortex-array/src/scalar_fn/foreign.rs +++ b/vortex-array/src/scalar_fn/foreign.rs @@ -118,4 +118,15 @@ impl ScalarFnVTable for ForeignScalarFnVTable { ) -> VortexResult { vortex_bail!("Cannot execute unknown scalar function '{}'", self.id); } + + fn validity( + &self, + _options: &Self::Options, + _expression: &Expression, + ) -> VortexResult { + vortex_bail!( + "Cannot compute validity for unknown scalar function '{}'", + self.id + ); + } } diff --git a/vortex-array/src/scalar_fn/internal/row_count.rs b/vortex-array/src/scalar_fn/internal/row_count.rs index eee838f803d..6429502a5b6 100644 --- a/vortex-array/src/scalar_fn/internal/row_count.rs +++ b/vortex-array/src/scalar_fn/internal/row_count.rs @@ -12,6 +12,7 @@ use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::expr::Expression; +use vortex_array::expr::lit; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; use vortex_array::scalar_fn::EmptyOptions; @@ -83,6 +84,14 @@ impl ScalarFnVTable for RowCount { vortex_bail!("RowCount must be substituted before evaluation") } + fn validity( + &self, + _options: &Self::Options, + _expression: &Expression, + ) -> VortexResult { + Ok(lit(true)) + } + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } diff --git a/vortex-array/src/scalar_fn/typed.rs b/vortex-array/src/scalar_fn/typed.rs index 1b14a9d613d..9756503f7fa 100644 --- a/vortex-array/src/scalar_fn/typed.rs +++ b/vortex-array/src/scalar_fn/typed.rs @@ -100,7 +100,7 @@ pub(super) trait DynScalarFn: 'static + Send + Sync + super::sealed::Sealed { ctx: &dyn SimplifyCtx, ) -> VortexResult>; fn simplify_untyped(&self, expression: &Expression) -> VortexResult>; - fn validity(&self, expression: &Expression) -> VortexResult>; + fn validity(&self, expression: &Expression) -> VortexResult; // Options operations — self-contained fn options_serialize(&self) -> VortexResult>>; @@ -208,7 +208,7 @@ impl DynScalarFn for TypedScalarFnInstance { V::simplify_untyped(&self.vtable, &self.options, expression) } - fn validity(&self, expression: &Expression) -> VortexResult> { + fn validity(&self, expression: &Expression) -> VortexResult { V::validity(&self.vtable, &self.options, expression) } diff --git a/vortex-array/src/scalar_fn/vtable.rs b/vortex-array/src/scalar_fn/vtable.rs index c66afc34932..5bdbfdd2f5b 100644 --- a/vortex-array/src/scalar_fn/vtable.rs +++ b/vortex-array/src/scalar_fn/vtable.rs @@ -179,18 +179,12 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync { /// Returns an expression that evaluates to the validity of the result of this expression. /// - /// If a validity expression cannot be constructed, returns `None` and the expression will - /// be evaluated as normal before extracting the validity mask from the result. - /// - /// This is essentially a specialized form of a `reduce_parent` + /// This is essentially a specialized form of a `reduce_parent`. fn validity( &self, options: &Self::Options, expression: &Expression, - ) -> VortexResult> { - _ = (options, expression); - Ok(None) - } + ) -> VortexResult; /// Returns whether this expression itself is null-sensitive. Conservatively default to *true*. /// diff --git a/vortex-geo/src/scalar_fn/distance.rs b/vortex-geo/src/scalar_fn/distance.rs index 7f222cb763a..cf730a4cfb1 100644 --- a/vortex-geo/src/scalar_fn/distance.rs +++ b/vortex-geo/src/scalar_fn/distance.rs @@ -13,6 +13,8 @@ use vortex_array::arrays::ScalarFnArray; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; +use vortex_array::expr::Expression; +use vortex_array::expr::lit; use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; @@ -125,6 +127,14 @@ impl ScalarFnVTable for GeoDistance { } } } + + fn validity( + &self, + _options: &Self::Options, + _expression: &Expression, + ) -> VortexResult { + Ok(lit(true)) + } } /// Distance from each row of `points` to a constant `query` point, decoded once and broadcast. diff --git a/vortex-layout/src/layouts/row_idx/expr.rs b/vortex-layout/src/layouts/row_idx/expr.rs index 95ddce3d762..122ee3ca9e8 100644 --- a/vortex-layout/src/layouts/row_idx/expr.rs +++ b/vortex-layout/src/layouts/row_idx/expr.rs @@ -8,6 +8,7 @@ use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::expr::Expression; +use vortex_array::expr::lit; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; use vortex_array::scalar_fn::EmptyOptions; @@ -61,6 +62,14 @@ impl ScalarFnVTable for RowIdx { "RowIdxExpr should not be executed directly, use it in the context of a Vortex scan and it will be substituted for a row index array" ); } + + fn validity( + &self, + _options: &Self::Options, + _expression: &Expression, + ) -> VortexResult { + Ok(lit(true)) + } } pub fn row_idx() -> Expression { diff --git a/vortex-row/src/encode.rs b/vortex-row/src/encode.rs index 04feec89415..906d283abcf 100644 --- a/vortex-row/src/encode.rs +++ b/vortex-row/src/encode.rs @@ -18,6 +18,8 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; +use vortex_array::expr::Expression; +use vortex_array::expr::lit; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; use vortex_array::scalar_fn::ExecutionArgs; @@ -90,6 +92,14 @@ impl ScalarFnVTable for RowEncode { execute_row_encode(options, args, ctx) } + fn validity( + &self, + _options: &Self::Options, + _expression: &Expression, + ) -> VortexResult { + Ok(lit(true)) + } + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { true } diff --git a/vortex-row/src/size.rs b/vortex-row/src/size.rs index 6636c4e9f34..b9f932dbf3d 100644 --- a/vortex-row/src/size.rs +++ b/vortex-row/src/size.rs @@ -18,6 +18,8 @@ use vortex_array::dtype::FieldNames; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::dtype::StructFields; +use vortex_array::expr::Expression; +use vortex_array::expr::lit; use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; @@ -242,6 +244,14 @@ impl ScalarFnVTable for RowSize { .into_array()) } + fn validity( + &self, + _options: &Self::Options, + _expression: &Expression, + ) -> VortexResult { + Ok(lit(true)) + } + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { true } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 82328a9b2d9..972f2e13166 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -178,12 +178,12 @@ impl ScalarFnVTable for CosineSimilarity { &self, _options: &Self::Options, expression: &Expression, - ) -> VortexResult> { + ) -> VortexResult { // The result is null if either input tensor is null. let lhs_validity = expression.child(0).validity()?; let rhs_validity = expression.child(1).validity()?; - Ok(Some(and(lhs_validity, rhs_validity))) + Ok(and(lhs_validity, rhs_validity)) } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index cfea956f314..e11ac730c35 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -162,12 +162,12 @@ impl ScalarFnVTable for InnerProduct { &self, _options: &Self::Options, expression: &Expression, - ) -> VortexResult> { + ) -> VortexResult { // The result is null if either input tensor is null. let lhs_validity = expression.child(0).validity()?; let rhs_validity = expression.child(1).validity()?; - Ok(Some(and(lhs_validity, rhs_validity))) + Ok(and(lhs_validity, rhs_validity)) } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 4f4159b4a21..c4ba7a3ba36 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -255,11 +255,11 @@ impl ScalarFnVTable for L2Denorm { &self, _options: &Self::Options, expression: &Expression, - ) -> VortexResult> { + ) -> VortexResult { let normalized_validity = expression.child(0).validity()?; let norms_validity = expression.child(1).validity()?; - Ok(Some(and(normalized_validity, norms_validity))) + Ok(and(normalized_validity, norms_validity)) } fn is_null_sensitive(&self, _options: &Self::Options) -> bool { diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index aa85ab4c78e..ab7c387f595 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -182,9 +182,9 @@ impl ScalarFnVTable for L2Norm { &self, _options: &Self::Options, expression: &Expression, - ) -> VortexResult> { + ) -> VortexResult { // The result is null if the input tensor is null. - Ok(Some(expression.child(0).validity()?)) + expression.child(0).validity() } fn is_null_sensitive(&self, _options: &Self::Options) -> bool {