diff --git a/vortex-array/src/expr/analysis/annotation.rs b/vortex-array/src/expr/analysis/annotation.rs index be842742a5a..32608f38d1e 100644 --- a/vortex-array/src/expr/analysis/annotation.rs +++ b/vortex-array/src/expr/analysis/annotation.rs @@ -42,6 +42,25 @@ pub fn descendent_annotations( let mut visitor = AnnotationVisitor { annotations: Default::default(), annotate, + propagate_up: true, + }; + expr.accept(&mut visitor).vortex_expect("Infallible"); + visitor.annotations +} + +/// Walk the expression tree and annotate each expression with zero or more +/// annotations. +/// +/// Returns a map of each expression to all annotations. Annotations of +/// children are not propagated to parents. +pub fn direct_annotations( + expr: &Expression, + annotate: A, +) -> Annotations<'_, A::Annotation> { + let mut visitor = AnnotationVisitor { + annotations: Default::default(), + annotate, + propagate_up: false, }; expr.accept(&mut visitor).vortex_expect("Infallible"); visitor.annotations @@ -50,6 +69,7 @@ pub fn descendent_annotations( struct AnnotationVisitor<'a, A: AnnotationFn> { annotations: Annotations<'a, A::Annotation>, annotate: A, + propagate_up: bool, } impl<'a, A: AnnotationFn> NodeVisitor<'a> for AnnotationVisitor<'a, A> { @@ -70,6 +90,9 @@ impl<'a, A: AnnotationFn> NodeVisitor<'a> for AnnotationVisitor<'a, A> { } fn visit_up(&mut self, node: &'a Expression) -> VortexResult { + if !self.propagate_up { + return Ok(TraversalOrder::Continue); + } let child_annotations = node .children() .iter() diff --git a/vortex-array/src/expr/transform/partition.rs b/vortex-array/src/expr/transform/partition.rs index ef852a6cd7f..92dca145d47 100644 --- a/vortex-array/src/expr/transform/partition.rs +++ b/vortex-array/src/expr/transform/partition.rs @@ -3,6 +3,7 @@ use std::fmt::Display; use std::fmt::Formatter; +use std::hash::Hash; use itertools::Itertools; use vortex_error::VortexExpect; @@ -49,11 +50,22 @@ where { // Annotate each expression with the annotations that any of its descendent expressions have. let annotations = descendent_annotations(&expr, annotate_fn); + partition_annotations(expr.clone(), scope, annotations) +} +pub fn partition_annotations( + expr: Expression, + scope: &DType, + annotations: Annotations, +) -> VortexResult> +where + A: Display + Clone + Eq + Hash, + FieldName: From, +{ // Now we split the original expression into sub-expressions based on the annotations, and // generate a root expression to re-assemble the results. - let mut splitter = StructFieldExpressionSplitter::::new(&annotations); - let root = expr.clone().rewrite(&mut splitter)?.value; + let mut splitter = StructFieldExpressionSplitter::::new(&annotations); + let root = expr.rewrite(&mut splitter)?.value; let mut partitions = Vec::with_capacity(splitter.sub_expressions.len()); let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len()); diff --git a/vortex-array/src/scalar_fn/mod.rs b/vortex-array/src/scalar_fn/mod.rs index 590ccb44224..9f5bd5c4628 100644 --- a/vortex-array/src/scalar_fn/mod.rs +++ b/vortex-array/src/scalar_fn/mod.rs @@ -9,6 +9,10 @@ use vortex_session::registry::Id; +use crate::scalar_fn::fns::byte_length::ByteLength; +use crate::scalar_fn::fns::get_item::GetItem; +use crate::scalar_fn::fns::literal::Literal; + mod vtable; pub use vtable::*; @@ -48,3 +52,17 @@ mod sealed { /// This can be the **only** implementor for [`super::typed::DynScalarFn`]. impl Sealed for TypedScalarFnInstance {} } + +/// A scalar function has a negative cost if applying it to an array and +/// canonicalizing is cheaper than canonicalizing an array and applying it. +/// +/// Example of negative cost expressions are byte_length() and get_item() since +/// they don't depend on input size. +/// +/// Example of non-negative cost expression is like() as it's linear over +/// individual input. +pub fn is_negative_cost(id: ScalarFnId) -> bool { + id == ScalarFnVTable::id(&ByteLength) + || id == ScalarFnVTable::id(&GetItem) + || id == ScalarFnVTable::id(&Literal) +} diff --git a/vortex-layout/src/layouts/dict/reader.rs b/vortex-layout/src/layouts/dict/reader.rs index 6c8d91706db..b4a61b76e2f 100644 --- a/vortex-layout/src/layouts/dict/reader.rs +++ b/vortex-layout/src/layouts/dict/reader.rs @@ -19,9 +19,16 @@ use vortex_array::arrays::SharedArray; use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; +use vortex_array::dtype::Nullability; use vortex_array::expr::Expression; +use vortex_array::expr::direct_annotations; +use vortex_array::expr::is_root; +use vortex_array::expr::label_tree; +use vortex_array::expr::pack; use vortex_array::expr::root; +use vortex_array::expr::transform::partition_annotations; use vortex_array::optimizer::ArrayOptimizer; +use vortex_array::scalar_fn::is_negative_cost; use vortex_error::VortexError; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -101,10 +108,7 @@ impl DictReader { ) .vortex_expect("must construct dict values array evaluation") .map_err(Arc::new) - .map(move |array| { - let array = array?; - Ok(SharedArray::new(array).into_array()) - }) + .map(move |array| Ok(SharedArray::new(array?).into_array())) .boxed() .shared() }) @@ -156,6 +160,44 @@ impl DictReader { } } +// On expression pushdown, "inner" is packed as field with this name. +// "outer" expects this name as input. +const PUSHDOWN_ANNOTATION: &str = ""; + +/// Split expression into two parts: +/// +/// left is the outer part that we want to apply to array after canonicalizing. +/// right is the optional inner part that we want to apply to array before +/// canonicalizing. +/// +/// We want to push to array only if expression has a negative cost, is +/// infallible and null-insensitive. +fn split_expression_for_pushdown( + expr: Expression, + dtype: &DType, +) -> VortexResult<(Expression, Option)> { + let references_root = label_tree(&expr, is_root, |acc, &child| acc | child); + let annotations = direct_annotations(&expr, |expr| { + let signature = expr.signature(); + if !signature.is_fallible() + && !signature.is_null_sensitive() + && is_negative_cost(expr.id()) + && references_root.get(&expr).copied().unwrap_or(true) + { + vec![PUSHDOWN_ANNOTATION] + } else { + vec![] + } + }); + let partition = partition_annotations(expr.clone(), dtype, annotations)?; + if partition.partitions.is_empty() { + Ok((partition.root, None)) + } else { + debug_assert_eq!(1, partition.partitions.len()); + Ok((partition.root, Some(partition.partitions[0].clone()))) + } +} + impl LayoutReader for DictReader { fn name(&self) -> &Arc { &self.name @@ -233,13 +275,37 @@ impl LayoutReader for DictReader { mask: MaskFuture, ) -> VortexResult>> { // TODO: fix up expr partitioning with fallible & null sensitive annotations - let values_eval = self.values_array(); let codes_eval = self .codes .projection_evaluation(row_range, &root(), mask) .map_err(|err| err.with_context("While evaluating projection on codes"))?; - let expr = expr.clone(); + let (expr_outer, expr_inner) = split_expression_for_pushdown(expr.clone(), self.dtype())?; + + let values_eval = if let Some(inner) = expr_inner { + // "outer" takes a struct field with PUSHDOWN_ANNOTATION name, so + // pack inner with this name as well + let inner = pack([(PUSHDOWN_ANNOTATION, inner)], Nullability::NonNullable); + + // We can't use values_eval as it uses values_array_uncanonical + // which in turn gets populated from self.values. If + // self.values_array() is called first, it will populate + // self.values with uncompressed data. Supply uncached data + let values_len = self.values_len; + self.values + .projection_evaluation( + &(0..values_len as u64), + &root(), + MaskFuture::new_true(values_len), + ) + .vortex_expect("must construct dict values array evaluation") + .map_err(Arc::new) + .map(move |array| Ok(SharedArray::new(array?.apply(&inner)?).into_array())) + .boxed() + .shared() + } else { + self.values_array() + }; let all_values_referenced = self.layout.has_all_values_referenced(); Ok(async move { let (values, codes) = try_join!(values_eval.map_err(VortexError::from), codes_eval)?; @@ -256,7 +322,7 @@ impl LayoutReader for DictReader { .into_array() .optimize()?; - array.apply(&expr) + array.apply(&expr_outer) } .boxed()) } @@ -272,6 +338,7 @@ mod tests { use rstest::rstest; use vortex_array::ArrayContext; + use vortex_array::ArrayRef; use vortex_array::Canonical; use vortex_array::IntoArray as _; use vortex_array::LEGACY_SESSION; @@ -285,8 +352,14 @@ mod tests { use vortex_array::dtype::FieldName; use vortex_array::dtype::FieldNames; use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::expr::Expression; + use vortex_array::expr::byte_length; + use vortex_array::expr::cast; use vortex_array::expr::eq; + use vortex_array::expr::get_item; use vortex_array::expr::is_not_null; + use vortex_array::expr::like; use vortex_array::expr::lit; use vortex_array::expr::pack; use vortex_array::expr::root; @@ -294,6 +367,7 @@ mod tests { use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_error::VortexExpect; + use vortex_error::VortexResult; use vortex_io::runtime::Handle; use vortex_io::runtime::single::block_on; use vortex_io::session::RuntimeSession; @@ -303,6 +377,8 @@ mod tests { use crate::LayoutId; use crate::LayoutRef; use crate::LayoutStrategy; + use crate::layouts::dict::reader::PUSHDOWN_ANNOTATION; + use crate::layouts::dict::reader::split_expression_for_pushdown; use crate::layouts::dict::writer::DictLayoutOptions; use crate::layouts::dict::writer::DictStrategy; use crate::layouts::flat::writer::FlatLayoutStrategy; @@ -324,6 +400,33 @@ mod tests { .with_handle(handle) } + async fn write_dict_layout( + array: ArrayRef, + session: &VortexSession, + ) -> (LayoutRef, Arc) { + let strategy = DictStrategy::new( + FlatLayoutStrategy::default(), + FlatLayoutStrategy::default(), + FlatLayoutStrategy::default(), + DictLayoutOptions::default(), + ); + let segments = Arc::new(TestSegments::default()); + let (ptr, eof) = SequenceId::root().split(); + let dtype = array.dtype().clone(); + let layout = strategy + .write_stream( + ArrayContext::empty(), + Arc::::clone(&segments), + SequentialStreamAdapter::new(dtype, array.to_array_stream().sequenced(ptr)) + .sendable(), + eof, + session, + ) + .await + .unwrap(); + (layout, segments) + } + #[test] fn reading_nested_packs_works() { block_on(|handle| async move { @@ -546,4 +649,122 @@ mod tests { assert_arrays_eq!(actual_canonical, expected); }) } + + #[test] + fn reading_byte_length_pushdown_works() { + let array = VarBinArray::from_iter( + [ + Some("abc"), + Some("defg"), + None, + Some("abc"), + Some("defg"), + None, + Some("abc"), + Some("defg"), + None, + ], + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + + let expected = array + .clone() + .apply(&byte_length(root())) + .unwrap() + .into_array(); + + block_on(|handle| async move { + let session = session_with_handle(handle); + let (layout, segments) = write_dict_layout(array, &session).await; + assert_eq!(layout.encoding_id(), LayoutId::new("vortex.dict")); + let actual = layout + .new_reader("".into(), segments, &session, &Default::default()) + .unwrap() + .projection_evaluation( + &(0..layout.row_count()), + &byte_length(root()), + MaskFuture::new_true(layout.row_count().try_into().unwrap()), + ) + .unwrap() + .await + .unwrap() + .into_array(); + assert_arrays_eq!(actual, expected); + }) + } + + fn pushed_inner(exprs: impl IntoIterator) -> Expression { + pack( + exprs + .into_iter() + .enumerate() + .map(|(idx, e)| (format!("_{idx}"), e)), + Nullability::NonNullable, + ) + } + + fn pushed_ref(idx: usize) -> Expression { + get_item(format!("_{idx}"), get_item("", root())) + } + + fn test_apply(original: Expression, outer: Expression, inner: Expression) -> VortexResult<()> { + let array = VarBinArray::from_iter( + [Some("abc"), Some("def"), None], + DType::Utf8(Nullability::Nullable), + ) + .into_array(); + + let pushed = array.clone().apply(&pack( + [(PUSHDOWN_ANNOTATION, inner)], + Nullability::NonNullable, + ))?; + let actual = pushed.apply(&outer)?; + let expected = array.apply(&original)?; + assert_arrays_eq!(actual, expected); + Ok(()) + } + + #[test] + fn split_expr_root() { + let (outer, inner) = split_expression_for_pushdown(root(), &DType::Null).unwrap(); + assert_eq!(outer, root()); + assert_eq!(inner, None); + } + + #[test] + fn split_expr_partial_pushdown() -> VortexResult<()> { + // cast is fallible, thus not pushed + let target = DType::Primitive(PType::I64, Nullability::Nullable); + let expr = cast(byte_length(root()), target.clone()); + let (outer, inner) = + split_expression_for_pushdown(expr.clone(), &DType::Utf8(false.into()))?; + let inner = inner.unwrap(); + // [0] = cast([1], dtype) + // [1] = byte_length(root) + assert_eq!(outer, cast(pushed_ref(0), target)); + assert_eq!(inner, pushed_inner([byte_length(root())])); + test_apply(expr, outer, inner) + } + + #[test] + fn split_expr_full_pushdown() -> VortexResult<()> { + let expr = byte_length(root()); + let (outer, inner) = + split_expression_for_pushdown(expr.clone(), &DType::Utf8(false.into()))?; + let inner = inner.unwrap(); + assert_eq!(outer, pushed_ref(0)); + assert_eq!(inner, pushed_inner([byte_length(root())])); + test_apply(expr, outer, inner) + } + + #[test] + fn split_expr_no_pushdown() { + // like is fallible, thus not pushed. lit() does not reference root() + let expr = like(root(), lit(1u64)); + let (outer, inner) = + split_expression_for_pushdown(expr.clone(), &DType::Utf8(true.into())).unwrap(); + assert_eq!(outer, expr); + assert_eq!(inner, None); + } }