Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions vortex-array/src/expr/analysis/annotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ pub fn descendent_annotations<A: AnnotationFn>(
let mut visitor = AnnotationVisitor {
annotations: Default::default(),
annotate,
propagate_up: true,
};
expr.accept(&mut visitor).vortex_expect("Infallible");
visitor.annotations
}

/// Walk the expression tree and annotate each expression with zero or more
/// annotations.
///
/// Returns a map of each expression to all annotations. Annotations of
/// children are not propagated to parents.
pub fn direct_annotations<A: AnnotationFn>(
expr: &Expression,
annotate: A,
) -> Annotations<'_, A::Annotation> {
let mut visitor = AnnotationVisitor {
annotations: Default::default(),
annotate,
propagate_up: false,
};
expr.accept(&mut visitor).vortex_expect("Infallible");
visitor.annotations
Expand All @@ -50,6 +69,7 @@ pub fn descendent_annotations<A: AnnotationFn>(
struct AnnotationVisitor<'a, A: AnnotationFn> {
annotations: Annotations<'a, A::Annotation>,
annotate: A,
propagate_up: bool,
}

impl<'a, A: AnnotationFn> NodeVisitor<'a> for AnnotationVisitor<'a, A> {
Expand All @@ -70,6 +90,9 @@ impl<'a, A: AnnotationFn> NodeVisitor<'a> for AnnotationVisitor<'a, A> {
}

fn visit_up(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> {
if !self.propagate_up {
return Ok(TraversalOrder::Continue);
}
let child_annotations = node
.children()
.iter()
Expand Down
16 changes: 14 additions & 2 deletions vortex-array/src/expr/transform/partition.rs
Comment thread
myrrc marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use std::fmt::Display;
use std::fmt::Formatter;
use std::hash::Hash;

use itertools::Itertools;
use vortex_error::VortexExpect;
Expand Down Expand Up @@ -49,11 +50,22 @@ where
{
// Annotate each expression with the annotations that any of its descendent expressions have.
let annotations = descendent_annotations(&expr, annotate_fn);
partition_annotations(expr.clone(), scope, annotations)
}

pub fn partition_annotations<A>(
expr: Expression,
scope: &DType,
annotations: Annotations<A>,
) -> VortexResult<PartitionedExpr<A>>
where
A: Display + Clone + Eq + Hash,
FieldName: From<A>,
{
// Now we split the original expression into sub-expressions based on the annotations, and
// generate a root expression to re-assemble the results.
let mut splitter = StructFieldExpressionSplitter::<A::Annotation>::new(&annotations);
let root = expr.clone().rewrite(&mut splitter)?.value;
let mut splitter = StructFieldExpressionSplitter::<A>::new(&annotations);
let root = expr.rewrite(&mut splitter)?.value;

let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
let mut partition_annotations = Vec::with_capacity(splitter.sub_expressions.len());
Expand Down
18 changes: 18 additions & 0 deletions vortex-array/src/scalar_fn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

use vortex_session::registry::Id;

use crate::scalar_fn::fns::byte_length::ByteLength;
use crate::scalar_fn::fns::get_item::GetItem;
use crate::scalar_fn::fns::literal::Literal;

mod vtable;
pub use vtable::*;

Expand Down Expand Up @@ -48,3 +52,17 @@ mod sealed {
/// This can be the **only** implementor for [`super::typed::DynScalarFn`].
impl<V: ScalarFnVTable> Sealed for TypedScalarFnInstance<V> {}
}

/// A scalar function has a negative cost if applying it to an array and
/// canonicalizing is cheaper than canonicalizing an array and applying it.
///
/// Example of negative cost expressions are byte_length() and get_item() since
/// they don't depend on input size.
///
/// Example of non-negative cost expression is like() as it's linear over
/// individual input.
pub fn is_negative_cost(id: ScalarFnId) -> bool {
id == ScalarFnVTable::id(&ByteLength)
|| id == ScalarFnVTable::id(&GetItem)
|| id == ScalarFnVTable::id(&Literal)
}
Loading
Loading