diff --git a/docs/developer-guide/index.md b/docs/developer-guide/index.md index 37ad276044d..e6eec4bfb70 100644 --- a/docs/developer-guide/index.md +++ b/docs/developer-guide/index.md @@ -24,6 +24,7 @@ internals/session internals/async-runtime internals/vtables internals/execution +internals/stats-pruning internals/io internals/serialization internals/cuda @@ -38,4 +39,4 @@ caption: Integrations integrations/datafusion integrations/duckdb integrations/spark -``` \ No newline at end of file +``` diff --git a/docs/developer-guide/internals/stats-pruning.md b/docs/developer-guide/internals/stats-pruning.md new file mode 100644 index 00000000000..4acaf9efc2a --- /dev/null +++ b/docs/developer-guide/internals/stats-pruning.md @@ -0,0 +1,39 @@ +# Stats Pruning + +Vortex uses statistics to prove when a filter cannot match a row group, zone, or +file. The proof expression returns `true` when the input can be skipped. It +returns `false` or `null` when pruning is not proven. + +Both `false` and `null` are non-pruning outcomes, but they mean different +things. `false` means the available stats disproved the skip proof. `null` means +the proof was unknown, usually because a required stat was missing or inexact. + +The pruning pipeline has two phases: + +1. `Expression::falsify(scope, session)` asks the session's + `StatsRewriteRule`s to rewrite a filter into an abstract proof expression. + Rules describe semantics in terms of `vortex.stat(input, aggregate_fn)` + placeholders. These placeholders name the statistic needed by the proof, but + not where that statistic is stored. +2. `bind_stats` lowers those abstract stat placeholders with a `StatBinder`. + The binder maps stats to the representation used by the caller, such as + zone-map table fields, file-level stat literals, or typed null literals for + missing stats. + +Missing stats lower to typed null literals. This preserves the three-valued +logic used by pruning: only a non-null `true` value proves that the scope can be +skipped. A missing stat therefore cannot accidentally prune data. + +## Binding Targets + +Zone maps bind stats to fields in their per-zone stats table. The lowered +expression is evaluated against that table and produces a mask where `true` +means the zone can be skipped. + +File-level stats bind stats to literal values from the file footer. The lowered +expression is reduced and evaluated once for the full file. If it evaluates to +`true`, the file stats reader can return an all-false pruning mask without +reading child layouts. + +For the layout model around these pruning points, see +[Layouts](../../concepts/layouts.md) and [Scanning](../../concepts/scanning.md). diff --git a/vortex-array/src/expr/expression.rs b/vortex-array/src/expr/expression.rs index cc21fb9a9a6..41cd61a369d 100644 --- a/vortex-array/src/expr/expression.rs +++ b/vortex-array/src/expr/expression.rs @@ -15,9 +15,7 @@ use vortex_error::vortex_ensure; use vortex_session::VortexSession; use crate::dtype::DType; -use crate::expr::StatsCatalog; use crate::expr::display::DisplayTreeExpr; -use crate::expr::stats::Stat; use crate::scalar_fn::ScalarFnRef; use crate::scalar_fn::fns::root::Root; @@ -114,28 +112,6 @@ impl Expression { self.scalar_fn.validity(self) } - /// An expression over zone-statistics which implies all records in the zone evaluate to false. - /// - /// Given an expression, `e`, if `e.stat_falsification(..)` evaluates to true, it is guaranteed - /// that `e` evaluates to false on all records in the zone. However, the inverse is not - /// necessarily true: even if the falsification evaluates to false, `e` need not evaluate to - /// true on all records. - /// - /// The [`StatsCatalog`] can be used to constrain or rename stats used in the final expr. - /// - /// # Examples - /// - /// - An expression over one variable: `x > 0` is false for all records in a zone if the maximum - /// value of the column `x` in that zone is less than or equal to zero: `max(x) <= 0`. - /// - An expression over two variables: `x > y` becomes `max(x) <= min(y)`. - /// - A conjunctive expression: `x > y AND z < x` becomes `max(x) <= min(y) OR min(z) >= max(x). - /// - /// Some expressions, in theory, have falsifications but this function does not support them - /// such as `x < (y < z)` or `x LIKE "needle%"`. - pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option { - self.scalar_fn().stat_falsification(self, catalog) - } - /// Returns an expression that proves this predicate is definitely false from stats. /// /// `scope` is the dtype of the row this expression evaluates over. @@ -164,28 +140,6 @@ impl Expression { crate::stats::rewrite::StatsRewriteCtx::new(session, scope).satisfy(self) } - /// Returns an expression representing the zoned statistic for the given stat, if available. - /// - /// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a - /// scope. Expressions can implement this function to propagate such statistics through the - /// expression tree. For example, the `a + 10` expression could propagate `min: min(a) + 10`. - /// - /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an - /// issue to discuss a solution to this. - pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option { - self.scalar_fn().stat_expression(self, stat, catalog) - } - - /// Returns an expression representing the zoned maximum statistic, if available. - pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option { - self.stat_expression(Stat::Min, catalog) - } - - /// Returns an expression representing the zoned maximum statistic, if available. - pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option { - self.stat_expression(Stat::Max, catalog) - } - /// Format the expression as a compact string. /// /// Since this is a recursive formatter, it is exposed on the public Expression type. diff --git a/vortex-array/src/expr/mod.rs b/vortex-array/src/expr/mod.rs index a5d32510443..c3728155517 100644 --- a/vortex-array/src/expr/mod.rs +++ b/vortex-array/src/expr/mod.rs @@ -34,7 +34,6 @@ pub(crate) mod field; pub mod forms; mod optimize; pub mod proto; -pub mod pruning; pub mod stats; pub mod transform; pub mod traversal; @@ -42,7 +41,6 @@ pub mod traversal; pub use analysis::*; pub use expression::*; pub use exprs::*; -pub use pruning::StatsCatalog; pub trait VortexExprExt { /// Accumulate all field references from this expression and its children in a set diff --git a/vortex-array/src/expr/pruning/mod.rs b/vortex-array/src/expr/pruning/mod.rs deleted file mode 100644 index 7c20508b7a8..00000000000 --- a/vortex-array/src/expr/pruning/mod.rs +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -pub(crate) mod pruning_expr; -mod relation; - -pub use pruning_expr::RequiredStats; -pub use pruning_expr::checked_pruning_expr; -pub use pruning_expr::field_path_stat_field_name; -pub use relation::Relation; - -use crate::dtype::FieldPath; -use crate::expr::Expression; -use crate::expr::stats::Stat; - -/// A catalog of available stats that are associated with field paths. -pub trait StatsCatalog { - /// Given a field path and statistic, return an expression that when evaluated over the catalog - /// will return that stat for the referenced field. - /// - /// This is likely to be a column expression, or a literal. - /// - /// Returns `None` if the stat is not available for the field path. - fn stats_ref(&self, _field_path: &FieldPath, _stat: Stat) -> Option { - None - } -} diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs deleted file mode 100644 index 00d29fbcf99..00000000000 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ /dev/null @@ -1,561 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::cell::RefCell; -use std::iter; - -use itertools::Itertools; -use vortex_utils::aliases::hash_map::HashMap; - -use super::relation::Relation; -use crate::dtype::Field; -use crate::dtype::FieldName; -use crate::dtype::FieldPath; -use crate::dtype::FieldPathSet; -use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::get_item; -use crate::expr::root; -use crate::expr::stats::Stat; - -pub type RequiredStats = Relation; - -// A catalog that return a stat column whenever it is required, tracking all accessed -// stats and returning them later. -#[derive(Default)] -pub(crate) struct TrackingStatsCatalog { - usage: RefCell>, -} - -impl TrackingStatsCatalog { - /// Consume the catalog, yielding a map of field statistics that were required - /// for each expression. - fn into_usages(self) -> HashMap<(FieldPath, Stat), Expression> { - self.usage.into_inner() - } -} - -// A catalog that return a stat column if it exists in the given scope. -struct ScopeStatsCatalog<'a> { - inner: TrackingStatsCatalog, - available_stats: &'a FieldPathSet, -} - -impl StatsCatalog for ScopeStatsCatalog<'_> { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { - let stat_path = field_path.clone().push(stat.name()); - - if self.available_stats.contains(&stat_path) { - self.inner.stats_ref(field_path, stat) - } else { - None - } - } -} - -impl StatsCatalog for TrackingStatsCatalog { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { - let mut expr = root(); - let name = field_path_stat_field_name(field_path, stat); - expr = get_item(name, expr); - self.usage - .borrow_mut() - .insert((field_path.clone(), stat), expr.clone()); - Some(expr) - } -} - -#[doc(hidden)] -pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName { - field_path - .parts() - .iter() - .map(|f| match f { - Field::Name(n) => n.as_ref(), - Field::ElementType => todo!("element type not currently handled"), - }) - .chain(iter::once(stat.name())) - .join("_") - .into() -} - -/// Build a pruning expr mask, using an existing set of stats. -/// The available stats are provided as a set of [`FieldPath`]. -/// -/// A pruning expression is one that returns `true` for all positions where the original expression -/// cannot hold, and false if it cannot be determined from stats alone whether the positions can -/// be pruned. -/// -/// Some rewrites, such as `is_not_null(...)`, emit -/// [`row_count`][crate::scalar_fn::internal::row_count] placeholders. The evaluation layer must -/// replace those placeholders with the row count for its current scope before -/// executing the returned expression. -/// -/// If the falsification logic attempts to access an unknown stat, -/// this function will return `None`. -pub fn checked_pruning_expr( - expr: &Expression, - available_stats: &FieldPathSet, -) -> Option<(Expression, RequiredStats)> { - let catalog = ScopeStatsCatalog { - inner: Default::default(), - available_stats, - }; - - let expr = expr.stat_falsification(&catalog)?; - - // TODO(joe): filter access by used exprs - let mut relation: Relation = Relation::new(); - for ((field_path, stat), _) in catalog.inner.into_usages() { - relation.insert(field_path, stat) - } - - Some((expr, relation)) -} - -#[cfg(test)] -mod tests { - use rstest::fixture; - use rstest::rstest; - use vortex_utils::aliases::hash_set::HashSet; - - use super::HashMap; - use crate::dtype::DType; - use crate::dtype::FieldName; - use crate::dtype::FieldNames; - use crate::dtype::FieldPath; - use crate::dtype::FieldPathSet; - use crate::dtype::Nullability; - use crate::dtype::StructFields; - use crate::expr::and; - use crate::expr::between; - use crate::expr::cast; - use crate::expr::col; - use crate::expr::eq; - use crate::expr::get_item; - use crate::expr::gt; - use crate::expr::gt_eq; - use crate::expr::lit; - use crate::expr::lt; - use crate::expr::lt_eq; - use crate::expr::not_eq; - use crate::expr::or; - use crate::expr::pruning::checked_pruning_expr; - use crate::expr::pruning::field_path_stat_field_name; - use crate::expr::root; - use crate::expr::stats::Stat; - use crate::scalar_fn::fns::between::BetweenOptions; - use crate::scalar_fn::fns::between::StrictComparison; - - // Implement some checked pruning expressions. - #[fixture] - fn available_stats() -> FieldPathSet { - let field_a = FieldPath::from_name("a"); - let field_b = FieldPath::from_name("b"); - - FieldPathSet::from_iter([ - field_a.clone().push(Stat::Min.name()), - field_a.push(Stat::Max.name()), - field_b.clone().push(Stat::Min.name()), - field_b.push(Stat::Max.name()), - ]) - } - - #[rstest] - pub fn pruning_equals(available_stats: FieldPathSet) { - let name = FieldName::from("a"); - let literal_eq = lit(42); - let eq_expr = eq(get_item("a", root()), literal_eq.clone()); - let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap(); - let expected_expr = or( - gt( - get_item( - field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min), - root(), - ), - literal_eq.clone(), - ), - gt( - literal_eq, - col(field_path_stat_field_name( - &FieldPath::from_name(name), - Stat::Max, - )), - ), - ); - assert_eq!(&converted, &expected_expr); - } - - #[rstest] - pub fn pruning_equals_column(available_stats: FieldPathSet) { - let column = FieldName::from("a"); - let other_col = FieldName::from("b"); - let eq_expr = eq(col(column.clone()), col(other_col.clone())); - - let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([ - ( - FieldPath::from_name(column.clone()), - HashSet::from_iter([Stat::Min, Stat::Max]) - ), - ( - FieldPath::from_name(other_col.clone()), - HashSet::from_iter([Stat::Max, Stat::Min]) - ) - ]) - ); - let expected_expr = or( - gt( - col(field_path_stat_field_name( - &FieldPath::from_name(column.clone()), - Stat::Min, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(other_col.clone()), - Stat::Max, - )), - ), - gt( - col(field_path_stat_field_name( - &FieldPath::from_name(other_col), - Stat::Min, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(column), - Stat::Max, - )), - ), - ); - assert_eq!(&converted, &expected_expr); - } - - #[rstest] - pub fn pruning_not_equals_column(available_stats: FieldPathSet) { - let column = FieldName::from("a"); - let other_col = FieldName::from("b"); - let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone())); - - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([ - ( - FieldPath::from_name(column.clone()), - HashSet::from_iter([Stat::Min, Stat::Max]) - ), - ( - FieldPath::from_name(other_col.clone()), - HashSet::from_iter([Stat::Max, Stat::Min]) - ) - ]) - ); - let expected_expr = and( - eq( - col(field_path_stat_field_name( - &FieldPath::from_name(column.clone()), - Stat::Min, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(other_col.clone()), - Stat::Max, - )), - ), - eq( - col(field_path_stat_field_name( - &FieldPath::from_name(column), - Stat::Max, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(other_col), - Stat::Min, - )), - ), - ); - - assert_eq!(&converted, &expected_expr); - } - - #[rstest] - pub fn pruning_gt_column(available_stats: FieldPathSet) { - let column = FieldName::from("a"); - let other_col = FieldName::from("b"); - let other_expr = col(other_col.clone()); - let not_eq_expr = gt(col(column.clone()), other_expr); - - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([ - ( - FieldPath::from_name(column.clone()), - HashSet::from_iter([Stat::Max]) - ), - ( - FieldPath::from_name(other_col.clone()), - HashSet::from_iter([Stat::Min]) - ) - ]) - ); - let expected_expr = lt_eq( - col(field_path_stat_field_name( - &FieldPath::from_name(column), - Stat::Max, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(other_col), - Stat::Min, - )), - ); - assert_eq!(&converted, &expected_expr); - } - - #[rstest] - pub fn pruning_gt_value(available_stats: FieldPathSet) { - let column = FieldName::from("a"); - let other_col = lit(42); - let not_eq_expr = gt(col(column.clone()), other_col.clone()); - - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([( - FieldPath::from_name(column.clone()), - HashSet::from_iter([Stat::Max]) - ),]) - ); - let expected_expr = lt_eq( - col(field_path_stat_field_name( - &FieldPath::from_name(column), - Stat::Max, - )), - other_col, - ); - assert_eq!(&converted, &(expected_expr)); - } - - #[rstest] - pub fn pruning_lt_column(available_stats: FieldPathSet) { - let column = FieldName::from("a"); - let other_col = FieldName::from("b"); - let other_expr = col(other_col.clone()); - let not_eq_expr = lt(col(column.clone()), other_expr); - - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([ - ( - FieldPath::from_name(column.clone()), - HashSet::from_iter([Stat::Min]) - ), - ( - FieldPath::from_name(other_col.clone()), - HashSet::from_iter([Stat::Max]) - ) - ]) - ); - let expected_expr = gt_eq( - col(field_path_stat_field_name( - &FieldPath::from_name(column), - Stat::Min, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(other_col), - Stat::Max, - )), - ); - assert_eq!(&converted, &expected_expr); - } - - #[rstest] - pub fn pruning_lt_value(available_stats: FieldPathSet) { - // expression => a < 42 - // pruning expr => a.min >= 42 - let expr = lt(col("a"), lit(42)); - - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))]) - ); - assert_eq!(&converted, >_eq(col("a_min"), lit(42))); - } - - #[rstest] - fn pruning_identity(available_stats: FieldPathSet) { - let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50))); - - let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap(); - - let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50))); - assert_eq!(&predicate.to_string(), &expected_expr.to_string()); - } - #[rstest] - pub fn pruning_and_or_operators(available_stats: FieldPathSet) { - // Test case: a > 10 AND a < 50 - let column = FieldName::from("a"); - let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50))); - let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap(); - - // Expected: a_max <= 10 OR a_min >= 50 - assert_eq!( - &predicate, - &or( - lt_eq(col(FieldName::from("a_max")), lit(10)), - gt_eq(col(FieldName::from("a_min")), lit(50)), - ), - ); - } - - #[rstest] - fn test_gt_eq_with_booleans(available_stats: FieldPathSet) { - // Consider this unusual, but valid (in Arrow, BooleanArray implements ArrayOrd), filter expression: - // x > (y > z) - // The x column is a Boolean-valued column. The y and z columns are numeric. True > False. - // Suppose we had a Vortex zone whose min/max statistics for each column were: - // x: [True, True] - // y: [1, 2] - // z: [0, 2] - // The pruning predicate will convert the aforementioned expression into: - // x_max <= (y_min > z_min) - // If we evaluate that pruning expression on our zone we get: - // x_max <= (y_min > z_min) - // x_max <= (1 > 0 ) - // x_max <= True - // True <= True - // True - // If a pruning predicate evaluates to true then, as stated in PruningPredicate::evaluate: - // > a true value means the chunk can be pruned. - // But, the following record lies within the above intervals and *passes* the filter expression! We - // cannot prune this zone because we need this record! - // {x: True, y: 1, z: 2} - // x > (y > z) - // True > (1 > 2) - // True > False - // True - let expr = gt_eq(col("x"), gt(col("y"), col("z"))); - assert!(checked_pruning_expr(&expr, &available_stats).is_none()); - // TODO(DK): a sufficiently complex pruner would produce: `x_max <= (y_max > z_min)` - } - - #[fixture] - fn available_stats_with_nans() -> FieldPathSet { - let float_col = FieldPath::from_name("float_col"); - let int_col = FieldPath::from_name("int_col"); - - FieldPathSet::from_iter([ - // Float columns will have a NaNCount. - float_col.clone().push(Stat::Min.name()), - float_col.clone().push(Stat::Max.name()), - float_col.push(Stat::NaNCount.name()), - // int columns will not have a NanCount serialized into the layout - int_col.clone().push(Stat::Min.name()), - int_col.push(Stat::Max.name()), - ]) - } - - #[rstest] - fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) { - let expr = gt_eq(col("float_col"), lit(f32::NAN)); - let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap(); - assert_eq!( - &converted, - &and( - and( - eq(col("float_col_nan_count"), lit(0u64)), - // NaNCount of NaN is 1 - eq(lit(1u64), lit(0u64)), - ), - // This is the standard conversion of the >= operator. Comparing NAN to a max - // stat is nonsensical, as min/max stats ignore NaNs, but this should be short-circuited - // by the previous check for nan_count anyway. - lt(col("float_col_max"), lit(f32::NAN)), - ) - ); - - // One half of the expression requires NAN count check, the other half does not. - let expr = and( - gt(col("float_col"), lit(10f32)), - lt(col("int_col"), lit(10)), - ); - - let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap(); - - assert_eq!( - &converted, - &or( - // NaNCount check is enforced for the float column - and( - and( - eq(col("float_col_nan_count"), lit(0u64)), - // NanCount of a non-NaN float literal is 0 - eq(lit(0u64), lit(0u64)), - ), - // We want the opposite: we can prune IF either one is false. - lt_eq(col("float_col_max"), lit(10f32)), - ), - // NanCount check is skipped for the int column - gt_eq(col("int_col_min"), lit(10)), - ) - ) - } - - #[rstest] - fn pruning_between(available_stats: FieldPathSet) { - let expr = between( - col("a"), - lit(10), - lit(50), - BetweenOptions { - lower_strict: StrictComparison::NonStrict, - upper_strict: StrictComparison::NonStrict, - }, - ); - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([( - FieldPath::from_name("a"), - HashSet::from_iter([Stat::Min, Stat::Max]) - )]) - ); - assert_eq!( - &converted, - &or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50))) - ); - } - - #[rstest] - fn pruning_cast_get_item_eq(available_stats: FieldPathSet) { - // This test verifies that cast properly forwards analysis methods to - // enable pruning. - let struct_dtype = DType::Struct( - StructFields::new( - FieldNames::from([FieldName::from("a"), FieldName::from("b")]), - vec![ - DType::Utf8(Nullability::Nullable), - DType::Utf8(Nullability::Nullable), - ], - ), - Nullability::NonNullable, - ); - let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value")); - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([( - FieldPath::from_name("a"), - HashSet::from_iter([Stat::Min, Stat::Max]) - )]) - ); - assert_eq!( - &converted, - &or( - gt(col("a_min"), lit("value")), - gt(lit("value"), col("a_max")) - ) - ); - } -} diff --git a/vortex-array/src/expr/pruning/relation.rs b/vortex-array/src/expr/pruning/relation.rs deleted file mode 100644 index 3a33771c46e..00000000000 --- a/vortex-array/src/expr/pruning/relation.rs +++ /dev/null @@ -1,50 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::hash::Hash; - -use vortex_utils::aliases::hash_map::HashMap; -use vortex_utils::aliases::hash_map::IntoIter; -use vortex_utils::aliases::hash_set::HashSet; - -#[derive(Debug, Clone)] -pub struct Relation { - map: HashMap>, -} - -impl Default for Relation { - fn default() -> Self { - Self::new() - } -} - -impl Relation { - pub fn new() -> Self { - Relation { - map: HashMap::new(), - } - } - - pub fn insert(&mut self, k: K, v: V) { - self.map.entry(k).or_default().insert(v); - } - - pub fn map(&self) -> &HashMap> { - &self.map - } -} - -impl From>> for Relation { - fn from(value: HashMap>) -> Self { - Self { map: value } - } -} - -impl IntoIterator for Relation { - type Item = (K, HashSet); - type IntoIter = IntoIter>; - - fn into_iter(self) -> Self::IntoIter { - self.map.into_iter() - } -} diff --git a/vortex-array/src/scalar_fn/erased.rs b/vortex-array/src/scalar_fn/erased.rs index 10e82d25455..6e0011c297a 100644 --- a/vortex-array/src/scalar_fn/erased.rs +++ b/vortex-array/src/scalar_fn/erased.rs @@ -20,8 +20,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::scalar_fn::EmptyOptions; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ReduceCtx; @@ -180,25 +178,6 @@ impl ScalarFnRef { pub(crate) fn simplify_untyped(&self, expr: &Expression) -> VortexResult> { self.0.simplify_untyped(expr) } - - /// Compute stat falsification expression. - pub(crate) fn stat_falsification( - &self, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - self.0.stat_falsification(expr, catalog) - } - - /// Compute stat expression. - pub(crate) fn stat_expression( - &self, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - self.0.stat_expression(expr, stat, catalog) - } } impl Debug for ScalarFnRef { diff --git a/vortex-array/src/scalar_fn/fns/between/mod.rs b/vortex-array/src/scalar_fn/fns/between/mod.rs index 013438e23b2..2cbe02183c8 100644 --- a/vortex-array/src/scalar_fn/fns/between/mod.rs +++ b/vortex-array/src/scalar_fn/fns/between/mod.rs @@ -25,7 +25,6 @@ use crate::arrays::Primitive; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::DType::Bool; -use crate::expr::StatsCatalog; use crate::expr::expression::Expression; use crate::scalar::Scalar; use crate::scalar_fn::Arity; @@ -33,8 +32,6 @@ use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::ScalarFnVTableExt; -use crate::scalar_fn::fns::binary::Binary; use crate::scalar_fn::fns::binary::execute_boolean; use crate::scalar_fn::fns::operators::CompareOperator; use crate::scalar_fn::fns::operators::Operator; @@ -298,22 +295,6 @@ impl ScalarFnVTable for Between { between_canonical(&arr, &lower, &upper, options, ctx) } - fn stat_falsification( - &self, - options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let arr = expr.child(0).clone(); - let lower = expr.child(1).clone(); - let upper = expr.child(2).clone(); - - let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]); - let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]); - - and(lhs, rhs).stat_falsification(catalog) - } - fn validity( &self, _options: &Self::Options, diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index bbb392a3752..bda7e61ffd7 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -17,23 +17,17 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; -use crate::expr::StatsCatalog; +use crate::dtype::Nullability; use crate::expr::and; -use crate::expr::and_collect; -use crate::expr::eq; use crate::expr::expression::Expression; -use crate::expr::gt; -use crate::expr::gt_eq; use crate::expr::lit; -use crate::expr::lt; -use crate::expr::lt_eq; -use crate::expr::or_collect; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::SimplifyCtx; +use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::operators::CompareOperator; use crate::scalar_fn::fns::operators::Operator; @@ -49,6 +43,7 @@ mod numeric; pub(crate) use numeric::*; use crate::scalar::NumericOperator; +use crate::scalar::Scalar; #[derive(Clone)] pub struct Binary; @@ -169,108 +164,73 @@ impl ScalarFnVTable for Binary { } } - fn stat_falsification( + fn simplify_untyped( &self, operator: &Operator, expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // Wrap another predicate with an optional NaNCount check, if the stat is available. - // - // For example, regular pruning conversion for `A >= B` would be - // - // A.max < B.min - // - // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting - // in: - // - // (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min - // - // Non-floating point column and literal expressions should be unaffected as they do not - // have a nan_count statistic defined. - fn with_nan_predicate( - lhs: &Expression, - rhs: &Expression, - value_predicate: Expression, - catalog: &dyn StatsCatalog, - ) -> Expression { - let nan_predicate = and_collect( - lhs.stat_expression(Stat::NaNCount, catalog) - .into_iter() - .chain(rhs.stat_expression(Stat::NaNCount, catalog)) - .map(|nans| eq(nans, lit(0u64))), - ); - - if let Some(nan_check) = nan_predicate { - and(nan_check, value_predicate) - } else { - value_predicate - } - } - + ) -> VortexResult> { let lhs = expr.child(0); let rhs = expr.child(1); - match operator { - Operator::Eq => { - let min_lhs = lhs.stat_min(catalog); - let max_lhs = lhs.stat_max(catalog); - let min_rhs = rhs.stat_min(catalog); - let max_rhs = rhs.stat_max(catalog); + let bool_literal = |expr: &Expression| { + expr.as_opt::()? + .as_bool_opt() + .map(|value| value.value()) + }; - let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b)); - let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b)); - - let min_max_check = or_collect(left.into_iter().chain(right))?; - - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::NotEq => { - let min_lhs = lhs.stat_min(catalog)?; - let max_lhs = lhs.stat_max(catalog)?; - - let min_rhs = rhs.stat_min(catalog)?; - let max_rhs = rhs.stat_max(catalog)?; - - let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Gt => { - let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Gte => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Lt => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); + // AND/OR use Kleene three-valued logic. `None` below is a boolean null. + // + // AND: + // - false AND x => false + // - true AND x => x + // - null AND null => null + // + // OR: + // - true OR x => true + // - false OR x => x + // - null OR null => null + // + // Other null cases either fall out of the identity/annihilator rules + // above (`null AND true`, `null OR false`) or cannot be simplified under + // Kleene semantics (`null AND x`, `null OR x` for non-literal `x`). + Ok(match operator { + Operator::And => match (bool_literal(lhs), bool_literal(rhs)) { + (Some(Some(false)), _) | (_, Some(Some(false))) => Some(lit(false)), + (Some(Some(true)), _) => Some(rhs.clone()), + (_, Some(Some(true))) => Some(lhs.clone()), + (Some(None), Some(None)) => Some(lhs.clone()), + _ => None, + }, + Operator::Or => match (bool_literal(lhs), bool_literal(rhs)) { + (Some(Some(true)), _) | (_, Some(Some(true))) => Some(lit(true)), + (Some(Some(false)), _) => Some(rhs.clone()), + (_, Some(Some(false))) => Some(lhs.clone()), + (Some(None), Some(None)) => Some(lhs.clone()), + _ => None, + }, + _ => None, + }) + } - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Lte => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); + fn simplify( + &self, + operator: &Operator, + expr: &Expression, + ctx: &dyn SimplifyCtx, + ) -> VortexResult> { + let is_literal_null = + |expr: &Expression| expr.as_opt::().is_some_and(Scalar::is_null); - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::And => or_collect( - lhs.stat_falsification(catalog) - .into_iter() - .chain(rhs.stat_falsification(catalog)), - ), - Operator::Or => Some(and( - lhs.stat_falsification(catalog)?, - rhs.stat_falsification(catalog)?, - )), - Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, + if operator.is_comparison() + && (is_literal_null(expr.child(0)) || is_literal_null(expr.child(1))) + { + // Validate the comparison before reducing it. This preserves type + // errors for expressions like `int_col = null_utf8`. + ctx.return_dtype(expr)?; + return Ok(Some(lit(Scalar::null(DType::Bool(Nullability::Nullable))))); } + + Ok(None) } fn validity( @@ -318,6 +278,7 @@ impl ScalarFnVTable for Binary { #[cfg(test)] mod tests { use vortex_error::VortexExpect; + use vortex_error::VortexResult; use super::*; use crate::LEGACY_SESSION; @@ -326,11 +287,16 @@ mod tests { use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; + use crate::dtype::PType; use crate::expr::Expression; use crate::expr::and_collect; use crate::expr::col; + use crate::expr::eq; + use crate::expr::gt; + use crate::expr::gt_eq; use crate::expr::lit; use crate::expr::lt; + use crate::expr::lt_eq; use crate::expr::not_eq; use crate::expr::or; use crate::expr::or_collect; @@ -446,6 +412,36 @@ mod tests { ); } + #[test] + fn comparison_with_typed_null_simplifies_after_type_check() -> VortexResult<()> { + let dtype = test_harness::struct_dtype(); + + let expr = eq( + col("col1"), + lit(Scalar::null(DType::Primitive( + PType::U16, + Nullability::Nullable, + ))), + ); + + assert_eq!( + expr.optimize_recursive(&dtype)?, + lit(Scalar::null(DType::Bool(Nullability::Nullable))) + ); + Ok(()) + } + + #[test] + fn comparison_with_incompatible_null_still_type_checks() { + let dtype = test_harness::struct_dtype(); + let expr = eq( + col("col1"), + lit(Scalar::null(DType::Utf8(Nullability::Nullable))), + ); + + assert!(expr.optimize_recursive(&dtype).is_err()); + } + #[test] fn test_display_print() { let expr = gt(lit(1), lit(2)); diff --git a/vortex-array/src/scalar_fn/fns/cast/mod.rs b/vortex-array/src/scalar_fn/fns/cast/mod.rs index abc59af2c9a..20852779d42 100644 --- a/vortex-array/src/scalar_fn/fns/cast/mod.rs +++ b/vortex-array/src/scalar_fn/fns/cast/mod.rs @@ -32,11 +32,8 @@ use crate::arrays::VarBinView; use crate::arrays::struct_::compute::cast::struct_cast; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; -use crate::expr::StatsCatalog; -use crate::expr::cast; use crate::expr::expression::Expression; use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; @@ -152,39 +149,6 @@ impl ScalarFnVTable for Cast { Ok(None) } - fn stat_expression( - &self, - dtype: &DType, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - match stat { - Stat::IsConstant - | Stat::IsSorted - | Stat::IsStrictSorted - | Stat::NaNCount - | Stat::Sum - | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog), - Stat::Max | Stat::Min => { - // We cast min/max to the new type - expr.child(0) - .stat_expression(stat, catalog) - .map(|x| cast(x, dtype.clone())) - } - Stat::NullCount => { - // if !expr.data().is_nullable() { - // NOTE(ngates): we should decide on the semantics here. In theory, the null - // count of something cast to non-nullable will be zero. But if we return - // that we know this to be zero, then a pruning predicate may eliminate data - // that would otherwise have caused the cast to error. - // return Some(lit(0u64)); - // } - None - } - } - } - fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult> { Ok(Some(if dtype.is_nullable() { expression.child(0).validity()? diff --git a/vortex-array/src/scalar_fn/fns/dynamic.rs b/vortex-array/src/scalar_fn/fns/dynamic.rs index 7efebf79220..f6e6619282a 100644 --- a/vortex-array/src/scalar_fn/fns/dynamic.rs +++ b/vortex-array/src/scalar_fn/fns/dynamic.rs @@ -20,7 +20,6 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::traversal::NodeExt; use crate::expr::traversal::NodeVisitor; use crate::expr::traversal::TraversalOrder; @@ -120,50 +119,6 @@ impl ScalarFnVTable for DynamicComparison { .into_array()) } - fn stat_falsification( - &self, - dynamic: &DynamicComparisonExpr, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let lhs = expr.child(0); - match dynamic.operator { - CompareOperator::Eq | CompareOperator::NotEq => None, - CompareOperator::Gt => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Lte, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_max(catalog)?], - )), - CompareOperator::Gte => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Lt, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_max(catalog)?], - )), - CompareOperator::Lt => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Gte, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_min(catalog)?], - )), - CompareOperator::Lte => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Gt, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_min(catalog)?], - )), - } - } - // Defer to the child fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false diff --git a/vortex-array/src/scalar_fn/fns/get_item.rs b/vortex-array/src/scalar_fn/fns/get_item.rs index de7e45ca9b0..b9cb9202f38 100644 --- a/vortex-array/src/scalar_fn/fns/get_item.rs +++ b/vortex-array/src/scalar_fn/fns/get_item.rs @@ -18,12 +18,9 @@ use crate::builtins::ArrayBuiltins; use crate::builtins::ExprBuiltins; use crate::dtype::DType; use crate::dtype::FieldName; -use crate::dtype::FieldPath; use crate::dtype::Nullability; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -188,24 +185,6 @@ impl ScalarFnVTable for GetItem { Ok(None) } - fn stat_expression( - &self, - field_name: &FieldName, - _expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - // TODO(ngates): I think we can do better here and support stats over nested fields. - // It would be nice if delegating to our child would return a struct of statistics - // matching the nested DType such that we can write: - // `get_item(expr.child(0).stat_expression(...), expr.data().field_name())` - - // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same - // name as a field in the root struct. This should be resolved with upcoming change to - // falsify expressions, but for now I'm preserving the existing buggy behavior. - catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat) - } - // This will apply struct nullability field. We could add a dtype?? fn is_null_sensitive(&self, _field_name: &FieldName) -> bool { true diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index 589333304e2..f2849f53ccf 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -3,7 +3,6 @@ use std::fmt::Formatter; -use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_error::VortexResult; use vortex_session::VortexSession; use vortex_session::registry::CachedId; @@ -15,16 +14,12 @@ use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::eq; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::ScalarFnVTableExt; use crate::validity::Validity; /// Expression that checks for non-null values. @@ -100,26 +95,16 @@ impl ScalarFnVTable for IsNotNull { fn is_fallible(&self, _instance: &Self::Options) -> bool { false } - - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // is_not_null is falsified when ALL values are null, i.e. null_count == row_count. - let child = expr.child(0); - let null_count_expr = child.stat_expression(Stat::NullCount, catalog)?; - Some(eq(null_count_expr, RowCount.new_expr(EmptyOptions, []))) - } } #[cfg(test)] mod tests { + use std::sync::LazyLock; + use vortex_buffer::buffer; use vortex_error::VortexExpect as _; - use vortex_utils::aliases::hash_map::HashMap; - use vortex_utils::aliases::hash_set::HashSet; + use vortex_error::VortexResult; + use vortex_session::VortexSession; use crate::IntoArray; use crate::LEGACY_SESSION; @@ -127,22 +112,24 @@ mod tests { use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; use crate::dtype::DType; - use crate::dtype::Field; - use crate::dtype::FieldPath; - use crate::dtype::FieldPathSet; use crate::dtype::Nullability; use crate::expr::col; use crate::expr::eq; use crate::expr::get_item; use crate::expr::is_not_null; - use crate::expr::pruning::checked_pruning_expr; + use crate::expr::or; use crate::expr::root; - use crate::expr::stats::Stat; use crate::expr::test_harness; use crate::scalar::Scalar; use crate::scalar_fn::EmptyOptions; + use crate::scalar_fn::ScalarFnVTableExt; use crate::scalar_fn::internal::row_count::RowCount; - use crate::scalar_fn::vtable::ScalarFnVTableExt; + use crate::stats::StatsSession; + use crate::stats::all_null; + use crate::stats::null_count; + + static STATS_SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); #[test] fn dtype() { @@ -262,25 +249,16 @@ mod tests { } #[test] - fn test_is_not_null_falsification() { + fn test_is_not_null_falsification() -> VortexResult<()> { let expr = is_not_null(col("a")); - let (pruning_expr, st) = checked_pruning_expr( - &expr, - &FieldPathSet::from_iter([FieldPath::from_iter([ - Field::Name("a".into()), - Field::Name("null_count".into()), - ])]), - ) - .unwrap(); - - assert_eq!( - &pruning_expr, - &eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, [])) - ); assert_eq!( - st.map(), - &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))]) + expr.falsify(&test_harness::struct_dtype(), &STATS_SESSION)?, + Some(or( + eq(null_count(col("a")), RowCount.new_expr(EmptyOptions, []),), + all_null(col("a")), + )) ); + Ok(()) } } diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index 7315fbe8c07..8df263a4b22 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -12,11 +12,6 @@ use crate::arrays::ConstantArray; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; -use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::eq; -use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -84,16 +79,6 @@ impl ScalarFnVTable for IsNull { } } - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?; - Some(eq(null_count_expr, lit(0u64))) - } - fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } @@ -105,10 +90,12 @@ impl ScalarFnVTable for IsNull { #[cfg(test)] mod tests { + use std::sync::LazyLock; + use vortex_buffer::buffer; use vortex_error::VortexExpect as _; - use vortex_utils::aliases::hash_map::HashMap; - use vortex_utils::aliases::hash_set::HashSet; + use vortex_error::VortexResult; + use vortex_session::VortexSession; use crate::IntoArray; use crate::LEGACY_SESSION; @@ -116,20 +103,22 @@ mod tests { use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; use crate::dtype::DType; - use crate::dtype::Field; - use crate::dtype::FieldPath; - use crate::dtype::FieldPathSet; use crate::dtype::Nullability; use crate::expr::col; use crate::expr::eq; use crate::expr::get_item; use crate::expr::is_null; use crate::expr::lit; - use crate::expr::pruning::checked_pruning_expr; + use crate::expr::or; use crate::expr::root; - use crate::expr::stats::Stat; use crate::expr::test_harness; use crate::scalar::Scalar; + use crate::stats::StatsSession; + use crate::stats::all_non_null; + use crate::stats::null_count; + + static STATS_SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); #[test] fn dtype() { @@ -246,23 +235,17 @@ mod tests { } #[test] - fn test_is_null_falsification() { + fn test_is_null_falsification() -> VortexResult<()> { let expr = is_null(col("a")); - let (pruning_expr, st) = checked_pruning_expr( - &expr, - &FieldPathSet::from_iter([FieldPath::from_iter([ - Field::Name("a".into()), - Field::Name("null_count".into()), - ])]), - ) - .unwrap(); - - assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64))); assert_eq!( - st.map(), - &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))]) + expr.falsify(&test_harness::struct_dtype(), &STATS_SESSION)?, + Some(or( + eq(null_count(col("a")), lit(0u64)), + all_non_null(col("a")), + )) ); + Ok(()) } #[test] diff --git a/vortex-array/src/scalar_fn/fns/like/mod.rs b/vortex-array/src/scalar_fn/fns/like/mod.rs index b7f357020f1..850484a1450 100644 --- a/vortex-array/src/scalar_fn/fns/like/mod.rs +++ b/vortex-array/src/scalar_fn/fns/like/mod.rs @@ -21,20 +21,12 @@ use crate::arrow::Datum; use crate::arrow::from_arrow_columnar; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::and; -use crate::expr::gt; -use crate::expr::gt_eq; -use crate::expr::lit; -use crate::expr::lt; -use crate::expr::or; -use crate::scalar::StringLike; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::fns::literal::Literal; /// Options for SQL LIKE function #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -165,49 +157,6 @@ impl ScalarFnVTable for Like { fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false } - - fn stat_falsification( - &self, - like_opts: &LikeOptions, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // Attempt to do min/max pruning for LIKE 'exact' or LIKE 'prefix%' - - // Don't attempt to handle ilike or negated like - if like_opts.negated || like_opts.case_insensitive { - return None; - } - - // Extract the pattern out - let pat = expr.child(1).as_::(); - - // LIKE NULL is nonsensical, don't try to handle it - let pat_str = pat.as_utf8().value()?; - - let src = expr.child(0).clone(); - let src_min = src.stat_min(catalog)?; - let src_max = src.stat_max(catalog)?; - - match LikeVariant::from_str(pat_str)? { - LikeVariant::Exact(text) => { - // col LIKE 'exact' ==> col.min > 'exact' || col.max < 'exact' - Some(or( - gt(src_min, lit(text.as_ref())), - lt(src_max, lit(text.as_ref())), - )) - } - LikeVariant::Prefix(prefix) => { - // col LIKE 'prefix%' ==> col.max < 'prefix' || col.min >= 'prefiy' - let succ = prefix.to_string().increment().ok()?; - - Some(or( - gt_eq(src_min, lit(succ)), - lt(src_max, lit(prefix.as_ref())), - )) - } - } - } } /// Implementation of LIKE using the Arrow crate. @@ -295,15 +244,11 @@ mod tests { use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; - use crate::expr::col; use crate::expr::get_item; - use crate::expr::ilike; use crate::expr::like; use crate::expr::lit; use crate::expr::not; use crate::expr::not_ilike; - use crate::expr::not_like; - use crate::expr::pruning::pruning_expr::TrackingStatsCatalog; use crate::expr::root; use crate::scalar_fn::fns::like::LikeVariant; @@ -390,50 +335,4 @@ mod tests { assert_eq!(LikeVariant::from_str(r"%\%%"), None); assert_eq!(LikeVariant::from_str("_pattern"), None); } - - #[test] - fn test_like_pushdown() { - // Test that LIKE prefix and exactness filters can be pushed down into stats filtering - // at scan time. - let catalog = TrackingStatsCatalog::default(); - - let pruning_expr = like(col("a"), lit("prefix%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "prefiy") or ($.a_max < "prefix"))"#); - - let pruning_expr = like(col("a"), lit(r"\%%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "&") or ($.a_max < "%"))"#); - - // Multiple wildcards - let pruning_expr = like(col("a"), lit("pref%ix%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "preg") or ($.a_max < "pref"))"#); - - let pruning_expr = like(col("a"), lit("pref_ix_")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "preg") or ($.a_max < "pref"))"#); - - // Exact match - let pruning_expr = like(col("a"), lit("exactly")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min > "exactly") or ($.a_max < "exactly"))"#); - - // Suffix search skips pushdown - let pruning_expr = like(col("a"), lit("%suffix")).stat_falsification(&catalog); - assert_eq!(pruning_expr, None); - - // NOT LIKE, ILIKE not supported currently - assert_eq!( - None, - not_like(col("a"), lit("a")).stat_falsification(&catalog) - ); - assert_eq!(None, ilike(col("a"), lit("a")).stat_falsification(&catalog)); - } } diff --git a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs index e16991763ed..4b39f51a7f2 100644 --- a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs +++ b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs @@ -33,13 +33,6 @@ use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::IntegerPType; use crate::dtype::Nullability; -use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::and_collect; -use crate::expr::gt; -use crate::expr::lit; -use crate::expr::lt; -use crate::expr::or; use crate::match_each_integer_ptype; use crate::match_each_unsigned_integer_ptype; use crate::scalar::ListScalar; @@ -51,7 +44,6 @@ use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; use crate::scalar_fn::fns::binary::Binary; -use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::operators::Operator; use crate::validity::Validity; @@ -129,43 +121,6 @@ impl ScalarFnVTable for ListContains { compute_list_contains(&list_array, &value_array, ctx) } - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let list = expr.child(0); - let needle = expr.child(1); - - // falsification(contains([1,2,5], x)) => - // falsification(x != 1) and falsification(x != 2) and falsification(x != 5) - let min = list.stat_min(catalog)?; - let max = list.stat_max(catalog)?; - // If the list is constant when we can compare each element to the value - if min == max { - let list_ = min - .as_opt::() - .and_then(|l| l.as_list_opt()) - .and_then(|l| l.elements())?; - if list_.is_empty() { - // contains([], x) is always false. - return Some(lit(true)); - } - let value_max = needle.stat_max(catalog)?; - let value_min = needle.stat_min(catalog)?; - - return and_collect(list_.iter().map(move |v| { - or( - lt(value_max.clone(), lit(v.clone())), - gt(value_min.clone(), lit(v.clone())), - ) - })); - } - - None - } - // Nullability matters for contains([], x) where x is false. fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true @@ -439,13 +394,14 @@ fn list_is_not_empty( #[cfg(test)] mod tests { use std::sync::Arc; + use std::sync::LazyLock; use itertools::Itertools; use rstest::rstest; use vortex_buffer::BitBuffer; use vortex_buffer::Buffer; - use vortex_utils::aliases::hash_map::HashMap; - use vortex_utils::aliases::hash_set::HashSet; + use vortex_error::VortexResult; + use vortex_session::VortexSession; use crate::ArrayRef; use crate::IntoArray; @@ -457,12 +413,10 @@ mod tests { #[expect(deprecated)] use crate::canonical::ToCanonical as _; use crate::dtype::DType; - use crate::dtype::Field; - use crate::dtype::FieldPath; - use crate::dtype::FieldPathSet; use crate::dtype::Nullability; use crate::dtype::PType::I32; use crate::dtype::StructFields; + use crate::expr::Expression; use crate::expr::and; use crate::expr::col; use crate::expr::get_item; @@ -471,7 +425,6 @@ mod tests { use crate::expr::lit; use crate::expr::lt; use crate::expr::or; - use crate::expr::pruning::checked_pruning_expr; use crate::expr::root; use crate::expr::stats::Stat; use crate::scalar::Scalar; @@ -479,8 +432,17 @@ mod tests { use crate::scalar_fn::fns::list_contains::ConstantArray; use crate::scalar_fn::fns::list_contains::ListViewArray; use crate::scalar_fn::fns::list_contains::PrimitiveArray; + use crate::stats::StatsSession; + use crate::stats::stat as stat_expr; use crate::validity::Validity; + static STATS_SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + fn stat(expr: Expression, stat: Stat) -> Expression { + stat_expr(expr, stat.aggregate_fn().unwrap()) + } + fn test_array() -> ArrayRef { ListArray::try_new( PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(), @@ -621,7 +583,7 @@ mod tests { } #[test] - pub fn list_falsification() { + pub fn list_falsification() -> VortexResult<()> { let expr = list_contains( lit(Scalar::list( Arc::new(DType::Primitive(I32, Nullability::NonNullable)), @@ -630,34 +592,34 @@ mod tests { )), col("a"), ); - - let (expr, st) = checked_pruning_expr( - &expr, - &FieldPathSet::from_iter([ - FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]), - FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]), - ]), - ) - .unwrap(); + let scope = DType::Struct( + StructFields::new( + ["a"].into(), + vec![DType::Primitive(I32, Nullability::NonNullable)], + ), + Nullability::NonNullable, + ); assert_eq!( - &expr, - &and( + expr.falsify(&scope, &STATS_SESSION)?, + Some(and( and( - or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),), - or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),) + or( + lt(stat(col("a"), Stat::Max), lit(1i32)), + gt(stat(col("a"), Stat::Min), lit(1i32)), + ), + or( + lt(stat(col("a"), Stat::Max), lit(2i32)), + gt(stat(col("a"), Stat::Min), lit(2i32)), + ) ), - or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),) - ) - ); - - assert_eq!( - st.map(), - &HashMap::from_iter([( - FieldPath::from_name("a"), - HashSet::from([Stat::Min, Stat::Max]) - )]) + or( + lt(stat(col("a"), Stat::Max), lit(3i32)), + gt(stat(col("a"), Stat::Min), lit(3i32)), + ) + )) ); + Ok(()) } #[test] diff --git a/vortex-array/src/scalar_fn/fns/literal.rs b/vortex-array/src/scalar_fn/fns/literal.rs index 16b112e5a78..5181a5250dd 100644 --- a/vortex-array/src/scalar_fn/fns/literal.rs +++ b/vortex-array/src/scalar_fn/fns/literal.rs @@ -16,9 +16,6 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; -use crate::match_each_float_ptype; use crate::scalar::Scalar; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; @@ -96,50 +93,6 @@ impl ScalarFnVTable for Literal { Ok(ConstantArray::new(scalar.clone(), args.row_count()).into_array()) } - fn stat_expression( - &self, - scalar: &Scalar, - _expr: &Expression, - stat: Stat, - _catalog: &dyn StatsCatalog, - ) -> Option { - // NOTE(ngates): we return incorrect `1` values for counts here since we don't have - // row-count information. We could resolve this in the future by introducing a `count()` - // expression that evaluates to the row count of the provided scope. But since this is - // only currently used for pruning, it doesn't change the outcome. - - match stat { - Stat::Min | Stat::Max => Some(lit(scalar.clone())), - Stat::IsConstant => Some(lit(true)), - Stat::NaNCount => { - // The NaNCount for a non-float literal is not defined. - // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise. - let value = scalar.as_primitive_opt()?; - if !value.ptype().is_float() { - return None; - } - - match_each_float_ptype!(value.ptype(), |T| { - if value.typed_value::().is_some_and(|v| v.is_nan()) { - Some(lit(1u64)) - } else { - Some(lit(0u64)) - } - }) - } - Stat::NullCount => { - if scalar.is_null() { - Some(lit(1u64)) - } else { - Some(lit(0u64)) - } - } - Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => { - None - } - } - } - fn validity( &self, scalar: &Scalar, diff --git a/vortex-array/src/scalar_fn/fns/root.rs b/vortex-array/src/scalar_fn/fns/root.rs index 87b8b62ccf4..7bd5b758796 100644 --- a/vortex-array/src/scalar_fn/fns/root.rs +++ b/vortex-array/src/scalar_fn/fns/root.rs @@ -11,10 +11,7 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; -use crate::dtype::FieldPath; -use crate::expr::StatsCatalog; use crate::expr::expression::Expression; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -80,16 +77,6 @@ impl ScalarFnVTable for Root { vortex_bail!("Root expression is not executable") } - fn stat_expression( - &self, - _options: &Self::Options, - _expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - catalog.stats_ref(&FieldPath::root(), stat) - } - fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } diff --git a/vortex-array/src/scalar_fn/typed.rs b/vortex-array/src/scalar_fn/typed.rs index fbfc8caace4..1b14a9d613d 100644 --- a/vortex-array/src/scalar_fn/typed.rs +++ b/vortex-array/src/scalar_fn/typed.rs @@ -24,8 +24,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; @@ -103,17 +101,6 @@ pub(super) trait DynScalarFn: 'static + Send + Sync + super::sealed::Sealed { ) -> VortexResult>; fn simplify_untyped(&self, expression: &Expression) -> VortexResult>; fn validity(&self, expression: &Expression) -> VortexResult>; - fn stat_falsification( - &self, - expression: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option; - fn stat_expression( - &self, - expression: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option; // Options operations — self-contained fn options_serialize(&self) -> VortexResult>>; @@ -225,23 +212,6 @@ impl DynScalarFn for TypedScalarFnInstance { V::validity(&self.vtable, &self.options, expression) } - fn stat_falsification( - &self, - expression: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - V::stat_falsification(&self.vtable, &self.options, expression, catalog) - } - - fn stat_expression( - &self, - expression: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - V::stat_expression(&self.vtable, &self.options, expression, stat, catalog) - } - fn options_serialize(&self) -> VortexResult>> { V::serialize(&self.vtable, &self.options) } diff --git a/vortex-array/src/scalar_fn/vtable.rs b/vortex-array/src/scalar_fn/vtable.rs index f4862f6876a..c66afc34932 100644 --- a/vortex-array/src/scalar_fn/vtable.rs +++ b/vortex-array/src/scalar_fn/vtable.rs @@ -20,8 +20,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::expr::traversal::Node; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnRef; @@ -179,34 +177,6 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync { Ok(None) } - /// See [`Expression::stat_falsification`]. - fn stat_falsification( - &self, - options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - _ = options; - _ = expr; - _ = catalog; - None - } - - /// See [`Expression::stat_expression`]. - fn stat_expression( - &self, - options: &Self::Options, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - _ = options; - _ = expr; - _ = stat; - _ = catalog; - None - } - /// Returns an expression that evaluates to the validity of the result of this expression. /// /// If a validity expression cannot be constructed, returns `None` and the expression will diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs new file mode 100644 index 00000000000..6d921fd0ab1 --- /dev/null +++ b/vortex-array/src/stats/bind.rs @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Bind abstract `vortex.stat` expressions to a concrete stats representation. +//! +//! Stats rewrite rules describe pruning in terms of `vortex.stat(input, aggregate_fn)` placeholders +//! so the rewrite is independent of where statistics are stored. These stat placeholders are +//! abstract because they name the statistic needed for a proof, but not how that statistic is +//! represented by a specific layout or reader. +//! +//! Binding is the later pass that replaces each abstract placeholder with the representation used +//! by a caller: zone-map field references, file-level stat literals, or typed nulls for missing +//! stats. This lets all callers share the same falsification rules while keeping layout-specific +//! stat storage behind [`StatBinder`]. + +use vortex_error::VortexResult; + +use crate::aggregate_fn::AggregateFnRef; +use crate::dtype::DType; +use crate::expr::Expression; +use crate::expr::lit; +use crate::expr::traversal::NodeExt; +use crate::expr::traversal::Transformed; +use crate::scalar::Scalar; +use crate::scalar_fn::fns::stat::StatFn; + +/// A target that can bind abstract statistics to concrete expressions. +/// +/// Implementations define how a pruning proof should read stats from a specific backing +/// representation. For example, a zone-map binder can translate a `max(col)` placeholder into a +/// field reference in the per-zone stats table, while a file-stats binder can translate the same +/// placeholder into a literal value from the file footer. +pub trait StatBinder { + /// The dtype scope used to type-check expressions before stats are bound. + fn scope(&self) -> &DType; + + /// Bind `aggregate_fn(input)` to a concrete expression. + /// + /// Implementations should return `Ok(None)` when the requested aggregate + /// statistic is unavailable in their backing representation. + fn bind_aggregate( + &self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult>; + + /// Expression to use when a stat is unavailable. + /// + /// The default is a nullable null literal, which preserves three-valued + /// pruning semantics for stats-table execution. + fn missing_stat(&self, dtype: DType) -> VortexResult { + Ok(null_expr(dtype)) + } +} + +/// Bind all `vortex.stat` expressions in `predicate`. +/// +/// The predicate is usually the output of a stats rewrite rule. Rewrite rules +/// are responsible for expressing stat semantics; binding maps aggregate-backed +/// stat requests to the concrete stats representation supported by the binder. +pub fn bind_stats( + predicate: Expression, + binder: &B, +) -> VortexResult { + let scope = binder.scope().clone(); + Ok(predicate + .transform_down(|expr| { + if !expr.is::() { + return Ok(Transformed::no(expr)); + } + + match bind_stat_fn(&expr, &scope, binder)? { + Some(bound) => Ok(Transformed::yes(bound)), + None => { + let dtype = expr.return_dtype(&scope)?; + Ok(Transformed::yes(binder.missing_stat(dtype)?)) + } + } + })? + .into_inner()) +} + +fn bind_stat_fn( + expr: &Expression, + scope: &DType, + binder: &(impl StatBinder + ?Sized), +) -> VortexResult> { + let options = expr.as_::(); + let aggregate_fn = options.aggregate_fn(); + // `StatFn` has exactly one child: the expression the aggregate statistic is computed over. + let input = expr.child(0); + + let stat_dtype = expr.return_dtype(scope)?; + binder.bind_aggregate(input, aggregate_fn, &stat_dtype) +} + +fn null_expr(dtype: DType) -> Expression { + lit(Scalar::null(dtype.as_nullable())) +} + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use super::*; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::dtype::StructFields; + use crate::expr::and; + use crate::expr::col; + use crate::expr::get_item; + use crate::expr::is_null; + use crate::expr::or; + use crate::expr::root; + use crate::expr::stats::Stat; + use crate::stats::all_non_nan; + use crate::stats::nan_count; + + struct TestBinder { + input_scope: DType, + bind_nan_count: bool, + } + + impl TestBinder { + fn new(bind_nan_count: bool) -> Self { + Self { + input_scope: DType::Struct( + StructFields::from_iter([( + "f", + DType::Primitive(PType::F32, Nullability::NonNullable), + )]), + Nullability::NonNullable, + ), + bind_nan_count, + } + } + } + + impl StatBinder for TestBinder { + fn scope(&self) -> &DType { + &self.input_scope + } + + fn bind_aggregate( + &self, + _input: &Expression, + aggregate_fn: &AggregateFnRef, + _stat_dtype: &DType, + ) -> VortexResult> { + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; + + if stat == Stat::NaNCount && self.bind_nan_count { + Ok(Some(get_item("f_nan_count", root()))) + } else { + Ok(None) + } + } + } + + #[test] + fn nan_count_binds_to_direct_stat_slot() -> VortexResult<()> { + let binder = TestBinder::new(true); + + let bound = bind_stats(nan_count(col("f")), &binder)?; + + assert_eq!(bound, col("f_nan_count")); + Ok(()) + } + + #[test] + fn all_non_nan_does_not_derive_from_nan_count() -> VortexResult<()> { + let binder = TestBinder::new(true); + + let bound = bind_stats(all_non_nan(col("f")), &binder)?; + + assert_eq!(bound, lit(Scalar::null(DType::Bool(Nullability::Nullable)))); + Ok(()) + } + + #[test] + fn missing_stats_bind_to_null_without_reducing() -> VortexResult<()> { + let binder = TestBinder::new(false); + let null_bool = lit(Scalar::null(DType::Bool(Nullability::Nullable))); + + let bound = bind_stats(and(lit(false), all_non_nan(col("f"))), &binder)?; + + assert_eq!(bound, and(lit(false), null_bool.clone())); + + let bound = bind_stats(or(lit(true), all_non_nan(col("f"))), &binder)?; + + assert_eq!(bound, or(lit(true), null_bool)); + Ok(()) + } + + #[test] + fn unrelated_expressions_do_not_request_nan_count() -> VortexResult<()> { + let binder = TestBinder::new(false); + + let bound = bind_stats(is_null(col("f")), &binder)?; + + assert_eq!(bound, is_null(col("f"))); + Ok(()) + } +} diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index ceb085e0815..5f5684dbde2 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -19,9 +19,10 @@ pub use expr::sum; pub use stats_set::*; mod array; +pub mod bind; pub mod expr; pub mod flatbuffers; -pub(crate) mod rewrite; +pub mod rewrite; pub mod session; mod stats_set; diff --git a/vortex-array/src/stats/rewrite.rs b/vortex-array/src/stats/rewrite.rs index 7d723155fad..2bbeeb00022 100644 --- a/vortex-array/src/stats/rewrite.rs +++ b/vortex-array/src/stats/rewrite.rs @@ -21,7 +21,7 @@ mod builtins; pub(crate) use builtins::register_builtins; /// Shared reference to a stats rewrite rule. -pub(crate) type StatsRewriteRuleRef = Arc; +pub type StatsRewriteRuleRef = Arc; /// A plugin-provided rule for predicates whose root scalar function matches this rule. /// @@ -40,7 +40,7 @@ pub(crate) type StatsRewriteRuleRef = Arc; /// `expr` is the full predicate expression whose root scalar function id is /// [`Self::scalar_fn_id`]. Use [`StatsRewriteCtx`] to resolve dtypes and recursively rewrite child /// predicates. -pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { +pub trait StatsRewriteRule: Debug + Send + Sync + 'static { /// Returns the scalar function id handled by this rule. fn scalar_fn_id(&self) -> ScalarFnId; @@ -83,35 +83,35 @@ pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { } /// Context passed to stats rewrite rules. -pub(crate) struct StatsRewriteCtx<'a> { +pub struct StatsRewriteCtx<'a> { session: &'a VortexSession, scope: &'a DType, } impl<'a> StatsRewriteCtx<'a> { /// Create a rewrite context for `session`. - pub(crate) fn new(session: &'a VortexSession, scope: &'a DType) -> Self { + pub fn new(session: &'a VortexSession, scope: &'a DType) -> Self { Self { session, scope } } /// Returns the session that owns the rewrite registry. - pub(crate) fn session(&self) -> &'a VortexSession { + pub fn session(&self) -> &'a VortexSession { self.session } /// Return the dtype of `expr` within this rewrite scope. - pub(crate) fn return_dtype(&self, expr: &Expression) -> VortexResult { + pub fn return_dtype(&self, expr: &Expression) -> VortexResult { expr.return_dtype(self.scope) } /// Rewrite `expr` into a stats-backed falsifier. - pub(crate) fn falsify(&self, expr: &Expression) -> VortexResult> { + pub fn falsify(&self, expr: &Expression) -> VortexResult> { self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::falsify) } /// Rewrite `expr` into a stats-backed satisfier. - pub(crate) fn satisfy(&self, expr: &Expression) -> VortexResult> { + pub fn satisfy(&self, expr: &Expression) -> VortexResult> { self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::satisfy) } diff --git a/vortex-array/src/stats/rewrite/builtins.rs b/vortex-array/src/stats/rewrite/builtins.rs index 4311b7295f3..ea656f24546 100644 --- a/vortex-array/src/stats/rewrite/builtins.rs +++ b/vortex-array/src/stats/rewrite/builtins.rs @@ -52,23 +52,26 @@ use crate::stats::session::StatsSession; /// Register built-in stats rewrite rules. pub(crate) fn register_builtins(session: &StatsSession) { - session.register_rewrite(BinaryStatsRewrite); + session.register_rewrite(BinaryNanCountStatsRewrite); + session.register_rewrite(BinaryAllNonNanStatsRewrite); session.register_rewrite(BetweenStatsRewrite); - session.register_rewrite(IsNullLegacyStatsRewrite); + session.register_rewrite(IsNullNullCountStatsRewrite); session.register_rewrite(IsNullAllNonNullStatsRewrite); session.register_rewrite(IsNullAllNullStatsRewrite); - session.register_rewrite(IsNotNullLegacyStatsRewrite); + session.register_rewrite(IsNotNullNullCountStatsRewrite); session.register_rewrite(IsNotNullAllNullStatsRewrite); session.register_rewrite(IsNotNullAllNonNullStatsRewrite); session.register_rewrite(LikeStatsRewrite); - session.register_rewrite(ListContainsStatsRewrite); - session.register_rewrite(DynamicComparisonStatsRewrite); + session.register_rewrite(ListContainsNanCountStatsRewrite); + session.register_rewrite(ListContainsAllNonNanStatsRewrite); + session.register_rewrite(DynamicComparisonNanCountStatsRewrite); + session.register_rewrite(DynamicComparisonAllNonNanStatsRewrite); } #[derive(Debug)] -struct BinaryStatsRewrite; +struct BinaryNanCountStatsRewrite; -impl StatsRewriteRule for BinaryStatsRewrite { +impl StatsRewriteRule for BinaryNanCountStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { Binary.id() } @@ -78,60 +81,93 @@ impl StatsRewriteRule for BinaryStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - let operator = expr.as_::(); - let lhs = expr.child(0); - let rhs = expr.child(1); - - Ok(match operator { - Operator::Eq => { - let left = min(lhs, ctx).zip(max(rhs, ctx)).map(|(a, b)| gt(a, b)); - let right = min(rhs, ctx).zip(max(lhs, ctx)).map(|(a, b)| gt(a, b)); - or_collect(left.into_iter().chain(right)) - .map(|value_predicate| with_nan_predicate(ctx, lhs, rhs, value_predicate)) - .transpose()? - } - Operator::NotEq => min(lhs, ctx) - .zip(max(rhs, ctx)) - .zip(max(lhs, ctx).zip(min(rhs, ctx))) - .map(|((min_lhs, max_rhs), (max_lhs, min_rhs))| { - with_nan_predicate( - ctx, - lhs, - rhs, - and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)), - ) - }) - .transpose()?, - Operator::Gt => max(lhs, ctx) - .zip(min(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt_eq(a, b))) - .transpose()?, - Operator::Gte => max(lhs, ctx) - .zip(min(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt(a, b))) - .transpose()?, - Operator::Lt => min(lhs, ctx) - .zip(max(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt_eq(a, b))) - .transpose()?, - Operator::Lte => min(lhs, ctx) - .zip(max(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt(a, b))) - .transpose()?, - Operator::And => { - let lhs_falsifier = ctx.falsify(lhs)?; - let rhs_falsifier = ctx.falsify(rhs)?; - or_collect(lhs_falsifier.into_iter().chain(rhs_falsifier)) - } - Operator::Or => match (ctx.falsify(lhs)?, ctx.falsify(rhs)?) { - (Some(lhs), Some(rhs)) => Some(and(lhs, rhs)), - _ => None, - }, - Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, - }) + binary_falsify::(expr, ctx) + } +} + +#[derive(Debug)] +struct BinaryAllNonNanStatsRewrite; + +impl StatsRewriteRule for BinaryAllNonNanStatsRewrite { + fn scalar_fn_id(&self) -> ScalarFnId { + Binary.id() + } + + fn falsify( + &self, + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, + ) -> VortexResult> { + binary_falsify::(expr, ctx) } } +fn binary_falsify( + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, +) -> VortexResult> { + let operator = expr.as_::(); + let lhs = expr.child(0); + let rhs = expr.child(1); + + Ok(match operator { + Operator::Eq => { + let left = min(lhs, ctx).zip(max(rhs, ctx)).map(|(a, b)| gt(a, b)); + let right = min(rhs, ctx).zip(max(lhs, ctx)).map(|(a, b)| gt(a, b)); + or_collect(left.into_iter().chain(right)) + .map(|value_predicate| with_non_nan_guards::

(ctx, [lhs, rhs], value_predicate)) + .transpose()? + .flatten() + } + Operator::NotEq => min(lhs, ctx) + .zip(max(rhs, ctx)) + .zip(max(lhs, ctx).zip(min(rhs, ctx))) + .map(|((min_lhs, max_rhs), (max_lhs, min_rhs))| { + with_non_nan_guards::

( + ctx, + [lhs, rhs], + and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)), + ) + }) + .transpose()? + .flatten(), + Operator::Gt => max(lhs, ctx) + .zip(min(rhs, ctx)) + .map(|(a, b)| with_non_nan_guards::

(ctx, [lhs, rhs], lt_eq(a, b))) + .transpose()? + .flatten(), + Operator::Gte => max(lhs, ctx) + .zip(min(rhs, ctx)) + .map(|(a, b)| with_non_nan_guards::

(ctx, [lhs, rhs], lt(a, b))) + .transpose()? + .flatten(), + Operator::Lt => min(lhs, ctx) + .zip(max(rhs, ctx)) + .map(|(a, b)| with_non_nan_guards::

(ctx, [lhs, rhs], gt_eq(a, b))) + .transpose()? + .flatten(), + Operator::Lte => min(lhs, ctx) + .zip(max(rhs, ctx)) + .map(|(a, b)| with_non_nan_guards::

(ctx, [lhs, rhs], gt(a, b))) + .transpose()? + .flatten(), + Operator::And => { + if !P::EMIT_UNGUARDED_REWRITES { + return Ok(None); + } + + let lhs_falsifier = ctx.falsify(lhs)?; + let rhs_falsifier = ctx.falsify(rhs)?; + or_collect(lhs_falsifier.into_iter().chain(rhs_falsifier)) + } + Operator::Or => match (ctx.falsify(lhs)?, ctx.falsify(rhs)?) { + (Some(lhs), Some(rhs)) if P::EMIT_UNGUARDED_REWRITES => Some(and(lhs, rhs)), + _ => None, + }, + Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, + }) +} + #[derive(Debug)] struct BetweenStatsRewrite; @@ -157,9 +193,9 @@ impl StatsRewriteRule for BetweenStatsRewrite { } #[derive(Debug)] -struct IsNullLegacyStatsRewrite; +struct IsNullNullCountStatsRewrite; -impl StatsRewriteRule for IsNullLegacyStatsRewrite { +impl StatsRewriteRule for IsNullNullCountStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { IsNull.id() } @@ -217,9 +253,9 @@ impl StatsRewriteRule for IsNullAllNullStatsRewrite { } #[derive(Debug)] -struct IsNotNullLegacyStatsRewrite; +struct IsNotNullNullCountStatsRewrite; -impl StatsRewriteRule for IsNotNullLegacyStatsRewrite { +impl StatsRewriteRule for IsNotNullNullCountStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { IsNotNull.id() } @@ -332,9 +368,9 @@ impl StatsRewriteRule for LikeStatsRewrite { } #[derive(Debug)] -struct ListContainsStatsRewrite; +struct ListContainsNanCountStatsRewrite; -impl StatsRewriteRule for ListContainsStatsRewrite { +impl StatsRewriteRule for ListContainsNanCountStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { ListContains.id() } @@ -344,46 +380,71 @@ impl StatsRewriteRule for ListContainsStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - let list = expr.child(0); - let needle = expr.child(1); + list_contains_falsify::(expr, ctx) + } +} - let Some(list_scalar) = literal_stat(list, Stat::Min) else { - return Ok(None); - }; - let elements = list_scalar - .as_opt::() - .and_then(|literal| literal.as_list_opt()) - .and_then(|list| list.elements()); - let Some(elements) = elements else { - return Ok(None); - }; - if elements.is_empty() { - return Ok(Some(lit(true))); - } +#[derive(Debug)] +struct ListContainsAllNonNanStatsRewrite; - let Some(value_max) = max(needle, ctx) else { - return Ok(None); - }; - let Some(value_min) = min(needle, ctx) else { - return Ok(None); - }; +impl StatsRewriteRule for ListContainsAllNonNanStatsRewrite { + fn scalar_fn_id(&self) -> ScalarFnId { + ListContains.id() + } - let value_predicate = and_collect(elements.iter().map(|value| { - or( - lt(value_max.clone(), lit(value.clone())), - gt(value_min.clone(), lit(value.clone())), - ) - })); - value_predicate - .map(|value_predicate| with_all_non_nan_predicate(ctx, [needle], value_predicate)) - .transpose() + fn falsify( + &self, + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, + ) -> VortexResult> { + list_contains_falsify::(expr, ctx) } } +fn list_contains_falsify( + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, +) -> VortexResult> { + let list = expr.child(0); + let needle = expr.child(1); + + let Some(list_scalar) = literal_stat(list, Stat::Min) else { + return Ok(None); + }; + let elements = list_scalar + .as_opt::() + .and_then(|literal| literal.as_list_opt()) + .and_then(|list| list.elements()); + let Some(elements) = elements else { + return Ok(None); + }; + if elements.is_empty() { + return Ok(P::EMIT_UNGUARDED_REWRITES.then(|| lit(true))); + } + + let Some(value_max) = max(needle, ctx) else { + return Ok(None); + }; + let Some(value_min) = min(needle, ctx) else { + return Ok(None); + }; + + let value_predicate = and_collect(elements.iter().map(|value| { + or( + lt(value_max.clone(), lit(value.clone())), + gt(value_min.clone(), lit(value.clone())), + ) + })); + value_predicate + .map(|value_predicate| with_non_nan_guards::

(ctx, [needle], value_predicate)) + .transpose() + .map(Option::flatten) +} + #[derive(Debug)] -struct DynamicComparisonStatsRewrite; +struct DynamicComparisonNanCountStatsRewrite; -impl StatsRewriteRule for DynamicComparisonStatsRewrite { +impl StatsRewriteRule for DynamicComparisonNanCountStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { DynamicComparison.id() } @@ -393,31 +454,55 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - let dynamic = expr.as_::(); - let lhs = expr.child(0); - - let Some((operator, lhs_stat)) = (match dynamic.operator { - CompareOperator::Eq | CompareOperator::NotEq => None, - CompareOperator::Gt => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)), - CompareOperator::Gte => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)), - CompareOperator::Lt => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)), - CompareOperator::Lte => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)), - }) else { - return Ok(None); - }; + dynamic_comparison_falsify::(expr, ctx) + } +} - let value_predicate = DynamicComparison.new_expr( - DynamicComparisonExpr { - operator, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - [lhs_stat], - ); - with_all_non_nan_predicate(ctx, [lhs], value_predicate).map(Some) +#[derive(Debug)] +struct DynamicComparisonAllNonNanStatsRewrite; + +impl StatsRewriteRule for DynamicComparisonAllNonNanStatsRewrite { + fn scalar_fn_id(&self) -> ScalarFnId { + DynamicComparison.id() + } + + fn falsify( + &self, + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, + ) -> VortexResult> { + dynamic_comparison_falsify::(expr, ctx) } } +fn dynamic_comparison_falsify( + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, +) -> VortexResult> { + let dynamic = expr.as_::(); + let lhs = expr.child(0); + + let Some((operator, lhs_stat)) = (match dynamic.operator { + CompareOperator::Eq | CompareOperator::NotEq => None, + CompareOperator::Gt => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)), + CompareOperator::Gte => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)), + CompareOperator::Lt => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)), + CompareOperator::Lte => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)), + }) else { + return Ok(None); + }; + + let value_predicate = DynamicComparison.new_expr( + DynamicComparisonExpr { + operator, + rhs: Arc::clone(&dynamic.rhs), + default: !dynamic.default, + }, + [lhs_stat], + ); + with_non_nan_guards::

(ctx, [lhs], value_predicate) +} + fn min(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option { stat_expr(expr, Stat::Min, ctx) } @@ -438,36 +523,77 @@ fn all_non_null(expr: &Expression) -> Expression { stat_fn(expr.clone(), AllNonNull.bind(AggregateEmptyOptions)) } +enum NanCheck { + NotNeeded, + Check(Expression), + Unavailable, +} + +trait NonNanProof { + const EMIT_UNGUARDED_REWRITES: bool; + + fn check(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> VortexResult; +} + +struct NanCountProof; + +impl NonNanProof for NanCountProof { + const EMIT_UNGUARDED_REWRITES: bool = true; + + fn check(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> VortexResult { + non_nan_check(ctx, expr, |expr| { + match stat_expr(expr, Stat::NaNCount, ctx) { + Some(nan_count) => NanCheck::Check(eq(nan_count, lit(0u64))), + None => NanCheck::Unavailable, + } + }) + } +} + +struct AllNonNanProof; + +impl NonNanProof for AllNonNanProof { + const EMIT_UNGUARDED_REWRITES: bool = false; + + fn check(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> VortexResult { + non_nan_check(ctx, expr, |expr| { + NanCheck::Check(stat_fn(expr.clone(), AllNonNan.bind(AggregateEmptyOptions))) + }) + } +} + // Min/max do not order NaN values, so comparison rewrites are only sound when every // candidate value is known to be non-NaN. Cast result dtypes are not enough: a cast // from float to non-float still needs a proof about the float source values. -fn all_non_nan_stat( +fn non_nan_check( ctx: &StatsRewriteCtx<'_>, expr: &Expression, -) -> VortexResult> { + proof: impl FnOnce(&Expression) -> NanCheck, +) -> VortexResult { if let Some(scalar) = expr.as_opt::() { let Some(value) = scalar.as_primitive_opt() else { - return Ok(None); + return Ok(NanCheck::NotNeeded); }; - return Ok(value.is_nan().then(|| lit(false))); + return Ok(if value.is_nan() { + NanCheck::Check(lit(false)) + } else { + NanCheck::NotNeeded + }); } if expr.is::() { if !has_nans(&ctx.return_dtype(expr.child(0))?) { - return Ok(None); + return Ok(NanCheck::NotNeeded); } - return all_non_nan_stat(ctx, expr.child(0)); + return non_nan_check(ctx, expr.child(0), proof); } if !has_nans(&ctx.return_dtype(expr)?) { - return Ok(None); + return Ok(NanCheck::NotNeeded); } - Ok(Some(stat_fn( - expr.clone(), - AllNonNan.bind(AggregateEmptyOptions), - ))) + Ok(proof(expr)) } fn has_nans(dtype: &DType) -> bool { @@ -501,33 +627,28 @@ fn stat_expr(expr: &Expression, stat: Stat, ctx: &StatsRewriteCtx<'_>) -> Option .then(|| stat_fn(expr.clone(), aggregate_fn)) } -fn with_nan_predicate( - ctx: &StatsRewriteCtx<'_>, - lhs: &Expression, - rhs: &Expression, - value_predicate: Expression, -) -> VortexResult { - with_all_non_nan_predicate(ctx, [lhs, rhs], value_predicate) -} - -fn with_all_non_nan_predicate<'a>( +fn with_non_nan_guards<'a, P: NonNanProof>( ctx: &StatsRewriteCtx<'_>, exprs: impl IntoIterator, value_predicate: Expression, -) -> VortexResult { +) -> VortexResult> { let mut nan_checks = Vec::new(); for expr in exprs { - if let Some(check) = all_non_nan_stat(ctx, expr)? { - nan_checks.push(check); + match P::check(ctx, expr)? { + NanCheck::NotNeeded => {} + NanCheck::Check(check) => nan_checks.push(check), + NanCheck::Unavailable => return Ok(None), } } let nan_predicate = and_collect(nan_checks); Ok(match nan_predicate { - Some(nan_check) => and(nan_check, value_predicate), + Some(nan_check) => Some(and(nan_check, value_predicate)), // No possible NaN-bearing expression remains, so the value predicate is - // already guarded. - None => value_predicate, + // already guarded. Only one registered rule emits this unguarded + // rewrite so non-float comparisons are not duplicated. + None if P::EMIT_UNGUARDED_REWRITES => Some(value_predicate), + None => None, }) } @@ -657,8 +778,17 @@ mod tests { expr.satisfy(&test_scope(), &SESSION) } - fn nan_free(expr: Expression) -> Expression { - stat_fn(expr, AllNonNan.bind(AggregateEmptyOptions)) + fn nan_guarded(expr: Expression, value_predicate: Expression) -> Expression { + or( + and( + eq(stat(expr.clone(), Stat::NaNCount), lit(0u64)), + value_predicate.clone(), + ), + and( + stat_fn(expr, AllNonNan.bind(AggregateEmptyOptions)), + value_predicate, + ), + ) } #[test] @@ -812,6 +942,33 @@ mod tests { )) ); + let expr = like(col("s"), lit(r"\%%")); + assert_eq!( + falsify(&expr)?, + Some(or( + gt_eq(stat(col("s"), Stat::Min), lit("&")), + lt(stat(col("s"), Stat::Max), lit("%")), + )) + ); + + let expr = like(col("s"), lit("pref%ix%")); + assert_eq!( + falsify(&expr)?, + Some(or( + gt_eq(stat(col("s"), Stat::Min), lit("preg")), + lt(stat(col("s"), Stat::Max), lit("pref")), + )) + ); + + let expr = like(col("s"), lit("pref_ix_")); + assert_eq!( + falsify(&expr)?, + Some(or( + gt_eq(stat(col("s"), Stat::Min), lit("preg")), + lt(stat(col("s"), Stat::Max), lit("pref")), + )) + ); + let expr = like(col("s"), lit("exact")); assert_eq!( falsify(&expr)?, @@ -858,8 +1015,8 @@ mod tests { assert_eq!( falsify(&expr)?, - Some(and( - nan_free(col("f")), + Some(nan_guarded( + col("f"), lt_eq(cast(stat(col("f"), Stat::Max), dtype), lit(5i32)), )) ); diff --git a/vortex-array/src/stats/session.rs b/vortex-array/src/stats/session.rs index 867ec4f8fb0..b18ca111b19 100644 --- a/vortex-array/src/stats/session.rs +++ b/vortex-array/src/stats/session.rs @@ -36,14 +36,12 @@ impl Default for StatsSession { impl StatsSession { /// Register a stats rewrite rule. - #[allow(dead_code)] - pub(crate) fn register_rewrite(&self, rule: R) { + pub fn register_rewrite(&self, rule: R) { self.register_rewrite_ref(Arc::new(rule)); } /// Register a shared stats rewrite rule. - #[allow(dead_code)] - pub(crate) fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) { + pub fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) { let mut rules = self.rewrite_rules.write(); let rule_id = rule.scalar_fn_id(); let mut updated_rules = rules @@ -74,7 +72,7 @@ impl SessionVar for StatsSession { } /// Extension trait for accessing stats session data. -pub(crate) trait StatsSessionExt: SessionExt { +pub trait StatsSessionExt: SessionExt { /// Returns the stats session state. fn stats(&self) -> &StatsSession { self.get::() diff --git a/vortex-duckdb/src/projection.rs b/vortex-duckdb/src/projection.rs index 5b94306de7d..4521115666c 100644 --- a/vortex-duckdb/src/projection.rs +++ b/vortex-duckdb/src/projection.rs @@ -17,7 +17,6 @@ use vortex::expr::root; use vortex::expr::select; use vortex::layout::layouts::row_idx::row_idx; use vortex::scan::selection::Selection; -use vortex_utils::aliases::hash_set::HashSet; use crate::convert::try_from_table_filter; use crate::convert::try_from_virtual_column_filter; @@ -199,6 +198,12 @@ pub struct Filter { pub has_non_optional_filter: bool, } +fn push_filter_expr(filter_exprs: &mut Vec, expr: &Expression) { + if !filter_exprs.iter().any(|existing| existing == expr) { + filter_exprs.push(expr.clone()); + } +} + impl Filter { /// Creates a table filter expression, row selection, and row range from the table filter set, /// column metadata, additional filter expressions, and the top-level DType. @@ -211,29 +216,26 @@ impl Filter { ) -> VortexResult { let mut has_non_optional_filter = false; - let mut table_filter_exprs: HashSet = if let Some(filter) = table_filter_set { - filter - .into_iter() - .filter(|(idx, _)| { - let idx_u: usize = idx.as_(); - !is_virtual_column(column_ids[idx_u]) - }) - .map(|(idx, ex)| { - has_non_optional_filter |= - !matches!(ex.as_class(), TableFilterClass::Optional(_)); - - let idx_u: usize = idx.as_(); - let col_idx: usize = column_ids[idx_u].as_(); - let name = &column_fields.get(col_idx).vortex_expect("exists").name; - try_from_table_filter(ex, &col(name.as_str()), dtype) - }) - .collect::>>>()? - .unwrap_or_else(HashSet::new) - } else { - HashSet::new() - }; + let mut table_filter_exprs = Vec::new(); + if let Some(filter) = table_filter_set { + for (idx, ex) in filter.into_iter().filter(|(idx, _)| { + let idx_u: usize = idx.as_(); + !is_virtual_column(column_ids[idx_u]) + }) { + has_non_optional_filter |= !matches!(ex.as_class(), TableFilterClass::Optional(_)); + + let idx_u: usize = idx.as_(); + let col_idx: usize = column_ids[idx_u].as_(); + let name = &column_fields.get(col_idx).vortex_expect("exists").name; + if let Some(expr) = try_from_table_filter(ex, &col(name.as_str()), dtype)? { + push_filter_expr(&mut table_filter_exprs, &expr); + } + } + } - table_filter_exprs.extend(additional_filters.iter().cloned()); + for expr in additional_filters { + push_filter_expr(&mut table_filter_exprs, expr); + } let mut file_selection = Selection::All; let mut row_selection = Selection::All; @@ -349,4 +351,17 @@ mod tests { let ids = [0, 1, 2]; assert_ne!(Projection::new(None, &ids, &fields).projection, root()); } + + #[test] + fn test_push_filter_expr_preserves_order() { + let first = col("first"); + let second = col("second"); + + let mut filter_exprs = Vec::new(); + push_filter_expr(&mut filter_exprs, &first); + push_filter_expr(&mut filter_exprs, &second); + push_filter_expr(&mut filter_exprs, &first); + + assert_eq!(filter_exprs, vec![first, second]); + } } diff --git a/vortex-file/src/file.rs b/vortex-file/src/file.rs index 23e2114b1c3..162f0e347e3 100644 --- a/vortex-file/src/file.rs +++ b/vortex-file/src/file.rs @@ -12,18 +12,9 @@ use std::sync::OnceLock; use itertools::Itertools; use vortex_array::ArrayRef; -use vortex_array::Columnar; -use vortex_array::IntoArray; -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::ConstantArray; use vortex_array::dtype::DType; -use vortex_array::dtype::Field; use vortex_array::dtype::FieldMask; -use vortex_array::dtype::FieldPath; -use vortex_array::dtype::FieldPathSet; use vortex_array::expr::Expression; -use vortex_array::expr::pruning::checked_pruning_expr; -use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_error::VortexResult; use vortex_layout::LayoutReader; use vortex_layout::scan::layout::LayoutReaderDataSource; @@ -32,11 +23,10 @@ use vortex_layout::scan::split_by::SplitBy; use vortex_layout::segments::SegmentSource; use vortex_scan::DataSourceRef; use vortex_session::VortexSession; -use vortex_utils::aliases::hash_map::HashMap; use crate::FileStatistics; use crate::footer::Footer; -use crate::pruning::extract_relevant_file_stats_as_struct_row; +use crate::pruning::can_prune_file_stats; use crate::v2::FileStatsLayoutReader; /// Represents a Vortex file, providing access to its metadata and content. @@ -202,59 +192,14 @@ impl VortexFile { return Ok(false); }; - let set = FieldPathSet::from_iter( - fields - .names() - .iter() - .zip(stats.stats_sets().iter()) - .flat_map(|(name, stats)| { - stats.iter().map(|(stat, _)| { - FieldPath::from_iter([ - Field::Name(name.clone()), - Field::Name(stat.name().into()), - ]) - }) - }), - ); - - let Some((predicate, required_stats)) = checked_pruning_expr(filter, &set) else { - return Ok(false); - }; - - let required_file_stats = HashMap::from_iter( - required_stats - .map() - .iter() - .map(|(path, stats)| (path.clone(), stats.clone())), - ); - - let Some(file_stats) = extract_relevant_file_stats_as_struct_row( - &required_file_stats, - stats.stats_sets(), + can_prune_file_stats( + filter, + self.footer.dtype(), + self.footer.row_count(), + stats, fields, - )? - else { - return Ok(false); - }; - - // Apply the predicate, then substitute any row_count placeholders in the resulting array - // tree with a ConstantArray carrying the file-level row count. - let applied = file_stats.apply(&predicate)?; - let row_count_replacement = - ConstantArray::new(self.footer.row_count(), applied.len()).into_array(); - let applied = substitute_row_count(applied, &row_count_replacement)?; - - let mut ctx = self.session.create_execution_ctx(); - Ok(match applied.execute::(&mut ctx)? { - Columnar::Constant(s) => s.scalar().as_bool().value() == Some(true), - Columnar::Canonical(c) => { - c.into_array() - .execute_scalar(0, &mut ctx)? - .as_bool() - .value() - == Some(true) - } - }) + &self.session, + ) } pub fn splits(&self) -> VortexResult>> { diff --git a/vortex-file/src/pruning.rs b/vortex-file/src/pruning.rs index 74ef90a034e..559df97d2d0 100644 --- a/vortex-file/src/pruning.rs +++ b/vortex-file/src/pruning.rs @@ -1,78 +1,125 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::Arc; - -use vortex_array::ArrayRef; +use vortex_array::Canonical; use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::AggregateFnRef; use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::StructArray; -use vortex_array::dtype::Field; -use vortex_array::dtype::FieldName; -use vortex_array::dtype::FieldNames; +use vortex_array::arrays::NullArray; +use vortex_array::dtype::DType; use vortex_array::dtype::FieldPath; use vortex_array::dtype::StructFields; -use vortex_array::expr::pruning::field_path_stat_field_name; +use vortex_array::expr::Expression; +use vortex_array::expr::is_root; +use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; -use vortex_array::expr::stats::StatsProvider; -use vortex_array::stats::StatsSet; -use vortex_array::validity::Validity; -use vortex_error::VortexExpect; +use vortex_array::scalar::Scalar; +use vortex_array::scalar_fn::fns::cast::Cast; +use vortex_array::scalar_fn::fns::get_item::GetItem; +use vortex_array::scalar_fn::fns::literal::Literal; +use vortex_array::scalar_fn::internal::row_count::substitute_row_count; +use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_stats; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use vortex_utils::aliases::hash_map::HashMap; -use vortex_utils::aliases::hash_set::HashSet; - -pub fn extract_relevant_file_stats_as_struct_row( - access: &HashMap>, - stats_sets: &Arc<[StatsSet]>, - struct_dtype: &StructFields, -) -> VortexResult> { - if access.is_empty() { - return StructArray::try_new(FieldNames::default(), vec![], 1, Validity::NonNullable) - .map(|s| Some(s.into_array())); +use vortex_session::VortexSession; + +use crate::FileStatistics; + +pub(crate) fn can_prune_file_stats( + expr: &Expression, + dtype: &DType, + row_count: u64, + file_stats: &FileStatistics, + struct_fields: &StructFields, + session: &VortexSession, +) -> VortexResult { + let Some(pruning_expr) = expr.falsify(dtype, session)? else { + return Ok(false); + }; + + let binder = FileStatsBinder { + dtype, + file_stats, + struct_fields, + }; + let pruning_expr = bind_stats(pruning_expr, &binder)?; + + let simplified = pruning_expr.optimize_recursive(&DType::Null)?; + if let Some(result) = simplified.as_opt::() { + return Ok(result.as_bool().value() == Some(true)); } - let mut columns: Vec<(FieldName, ArrayRef)> = Vec::with_capacity(access.len() * 2); - for (field_path, stats) in access.into_iter() { - if field_path.parts().len() != 1 { + let pruning = NullArray::new(1).into_array().apply(&pruning_expr)?; + let row_count_replacement = ConstantArray::new(row_count, pruning.len()).into_array(); + let pruning = substitute_row_count(pruning, &row_count_replacement)?; + + let mut ctx = session.create_execution_ctx(); + let result = pruning + .execute::(&mut ctx)? + .into_bool() + .into_array() + .execute_scalar(0, &mut ctx)?; + + Ok(result.as_bool().value() == Some(true)) +} + +struct FileStatsBinder<'a> { + dtype: &'a DType, + file_stats: &'a FileStatistics, + struct_fields: &'a StructFields, +} + +impl StatBinder for FileStatsBinder<'_> { + fn scope(&self) -> &DType { + self.dtype + } + + fn bind_aggregate( + &self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + _stat_dtype: &DType, + ) -> VortexResult> { + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { return Ok(None); - } - let Field::Name(field) = &field_path.parts()[0] else { + }; + let Some(field_path) = direct_field_path(input) else { return Ok(None); }; + Ok(self.stat_ref(&field_path, stat)) + } +} - let field_idx = struct_dtype - .find(field) - .ok_or_else(|| vortex_err!("Missing field: {field}"))?; - let field_dtype = struct_dtype - .field_by_index(field_idx) - .vortex_expect("Field must exist"); - - let Some(stat_set) = stats_sets.get(field_idx) else { - vortex_bail!("missing stat field {} from stats set", field) - }; - let typed_stats = stat_set.as_typed_ref(&field_dtype); - - for stat in stats { - if matches!( - stat, - Stat::Max | Stat::Min | Stat::NaNCount | Stat::NullCount - ) { - let Some(stat_value) = typed_stats.get(*stat).as_exact() else { - vortex_bail!("missing stat {}, {} from stats set", field, stat) - }; - columns.push(( - field_path_stat_field_name(field_path, *stat), - ConstantArray::new(stat_value, 1).into_array(), - )); - } else { - todo!("unsupported file prune stat {stat}") - } +impl FileStatsBinder<'_> { + fn stat_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { + // FileStats currently only holds top-level field statistics. + if field_path.parts().len() != 1 { + return None; } + + let field_name = field_path.parts()[0].as_name()?; + let field_idx = self.struct_fields.find(field_name)?; + let field_stats = self.file_stats.stats_sets().get(field_idx)?; + + let stat_value = field_stats.get(stat).as_exact()?; + let field_dtype = self.struct_fields.field_by_index(field_idx)?; + let stat_dtype = stat.dtype(&field_dtype)?; + let stat_scalar = Scalar::try_new(stat_dtype, Some(stat_value)).ok()?; + + Some(lit(stat_scalar)) + } +} + +fn direct_field_path(expr: &Expression) -> Option { + if is_root(expr) { + return Some(FieldPath::root()); } - Ok(Some( - StructArray::from_fields(columns.as_slice())?.into_array(), - )) + + if expr.is::() { + return direct_field_path(expr.child(0)); + } + + let field_name = expr.as_opt::()?; + direct_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) } diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index b686905d0b6..324ce76f8b7 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -10,23 +10,11 @@ use std::ops::Range; use std::sync::Arc; -use vortex_array::Canonical; -use vortex_array::IntoArray; use vortex_array::MaskFuture; -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::NullArray; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; -use vortex_array::dtype::FieldPath; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; -use vortex_array::expr::StatsCatalog; -use vortex_array::expr::lit; -use vortex_array::expr::stats::Stat; -use vortex_array::scalar::Scalar; -use vortex_array::scalar_fn::fns::literal::Literal; -use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_error::VortexResult; use vortex_layout::ArrayFuture; use vortex_layout::LayoutReader; @@ -38,6 +26,7 @@ use vortex_session::VortexSession; use vortex_utils::aliases::dash_map::DashMap; use crate::FileStatistics; +use crate::pruning::can_prune_file_stats; /// A [`LayoutReader`] decorator that prunes entire files based on file-level statistics. /// @@ -83,36 +72,14 @@ impl FileStatsLayoutReader { /// Row-count placeholders are resolved against the full file row count, /// independent of the requested row range. fn evaluate_file_stats(&self, expr: &Expression) -> VortexResult { - let Some(pruning_expr) = expr.stat_falsification(self) else { - // If there is no pruning expression, we can't prune. - return Ok(false); - }; - - // Given how we implemented the StatsCatalog, we know the expression must be all literals - // or row_count placeholders. We can therefore optimize with a null scope since there are - // no field references that need to be resolved. - let simplified = pruning_expr.optimize_recursive(&DType::Null)?; - if let Some(result) = simplified.as_opt::() { - // Can prune if the result is non-nullable and true - return Ok(result.as_bool().value() == Some(true)); - } - - // Sometimes expressions don't implement constant folding to literals... In this case, - // we apply the expression over a null array and substitute any row_count placeholders - // in the resulting array tree with the file's row count. - let pruning = NullArray::new(1).into_array().apply(&pruning_expr)?; - let row_count_replacement = - ConstantArray::new(self.child.row_count(), pruning.len()).into_array(); - let pruning = substitute_row_count(pruning, &row_count_replacement)?; - - let mut ctx = self.session.create_execution_ctx(); - let result = pruning - .execute::(&mut ctx)? - .into_bool() - .into_array() - .execute_scalar(0, &mut ctx)?; - - Ok(result.as_bool().value() == Some(true)) + can_prune_file_stats( + expr, + self.child.dtype(), + self.child.row_count(), + &self.file_stats, + &self.struct_fields, + &self.session, + ) } pub fn file_stats(&self) -> &FileStatistics { @@ -120,27 +87,6 @@ impl FileStatsLayoutReader { } } -/// Implements [`StatsCatalog`] to provide file-level stats to expressions during pruning evaluation. -impl StatsCatalog for FileStatsLayoutReader { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { - // FileStats currently only holds top-level field statistics. - if field_path.parts().len() != 1 { - return None; - } - - let field_name = field_path.parts()[0].as_name()?; - let field_idx = self.struct_fields.find(field_name)?; - let field_stats = self.file_stats.stats_sets().get(field_idx)?; - - let stat_value = field_stats.get(stat).as_exact()?; - let field_dtype = self.struct_fields.field_by_index(field_idx)?; - let stat_dtype = stat.dtype(&field_dtype)?; - let stat_scalar = Scalar::try_new(stat_dtype, Some(stat_value)).ok()?; - - Some(lit(stat_scalar)) - } -} - impl LayoutReader for FileStatsLayoutReader { fn name(&self) -> &Arc { self.child.name() @@ -224,6 +170,7 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; + use vortex_array::expr::checked_add; use vortex_array::expr::get_item; use vortex_array::expr::gt; use vortex_array::expr::is_not_null; @@ -361,6 +308,43 @@ mod tests { }) } + #[test] + fn no_pruning_for_computed_expression_stats() -> VortexResult<()> { + block_on(|handle| async { + let session = SESSION.clone().with_handle(handle); + let ctx = ArrayContext::empty(); + let segments = Arc::new(TestSegments::default()); + let (ptr, eof) = SequenceId::root().split(); + let struct_array = + StructArray::from_fields([("col", buffer![0i32, 100].into_array())].as_slice())?; + let strategy = TableStrategy::new( + Arc::new(FlatLayoutStrategy::default()), + Arc::new(FlatLayoutStrategy::default()), + ); + let layout = strategy + .write_stream( + ctx, + Arc::::clone(&segments), + struct_array.into_array().to_array_stream().sequenced(ptr), + eof, + &session, + ) + .await?; + + let child = layout.new_reader("".into(), segments, &SESSION, &Default::default())?; + let reader = + FileStatsLayoutReader::new(child, test_file_stats(0, 100), SESSION.clone()); + + let expr = gt(checked_add(get_item("col", root()), lit(5i32)), lit(102i32)); + let mask = Mask::new_true(2); + let result = reader.pruning_evaluation(&(0..2), &expr, mask)?.await?; + + assert_eq!(result, Mask::new_true(2)); + + Ok(()) + }) + } + /// Regression test: `IS NULL` on a nullable timestamp column must not fail with a /// dtype mismatch. The bug was that `stats_ref` used the *field* dtype (timestamp) /// for the `NullCount` stat scalar instead of the stat's own dtype (u64). diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 789a6a74e8d..dbfaab93910 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -8,33 +8,21 @@ use std::sync::Arc; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; -use vortex_array::aggregate_fn::fns::all_nan::AllNan; -use vortex_array::aggregate_fn::fns::all_non_nan::AllNonNan; -use vortex_array::aggregate_fn::fns::all_non_null::AllNonNull; -use vortex_array::aggregate_fn::fns::all_null::AllNull; -use vortex_array::aggregate_fn::fns::nan_count::NanCount; +use vortex_array::aggregate_fn::AggregateFnRef; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; use vortex_array::arrays::struct_::StructArrayExt; use vortex_array::dtype::DType; -use vortex_array::dtype::Nullability; use vortex_array::expr::Expression; -use vortex_array::expr::eq; use vortex_array::expr::get_item; use vortex_array::expr::is_root; -use vortex_array::expr::lit; use vortex_array::expr::root; use vortex_array::expr::stats::Stat; -use vortex_array::expr::traversal::NodeExt; -use vortex_array::expr::traversal::Transformed; -use vortex_array::scalar::Scalar; -use vortex_array::scalar_fn::EmptyOptions; -use vortex_array::scalar_fn::ScalarFnVTableExt; -use vortex_array::scalar_fn::fns::stat::StatFn; -use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_array::scalar_fn::internal::row_count::contains_row_count; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; +use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_stats; use vortex_array::validity::Validity; use vortex_buffer::buffer; use vortex_error::VortexResult; @@ -132,109 +120,44 @@ impl ZoneMap { } fn lower_stats(&self, predicate: Expression) -> VortexResult { - // Rewritten predicates are evaluated against the stats table, not the data - // column. Lower each StatFn before execution so unavailable stats become - // nullable "unknown" constants rather than prune signals. - predicate - .transform_down(|expr| { - if expr.is::() { - return self.lower_stat_fn(expr).map(Transformed::yes); - } - - Ok(Transformed::no(expr)) - }) - .map(Transformed::into_inner) + let binder = ZoneMapStatsBinder { zone_map: self }; + bind_stats(predicate, &binder)?.optimize_recursive(self.array.dtype()) } +} - fn lower_stat_fn(&self, expr: Expression) -> VortexResult { - // This is the bridge from aggregate-backed bound expressions to the legacy - // zoned stats columns. Exact NullCount and NanCount can prove richer - // all-* aggregates; non-root or missing stats lower to nullable unknowns. - let options = expr.as_::(); - let input = expr.child(0); - let input_dtype = input.return_dtype(&self.column_dtype)?; - let input_is_root = is_root(input); - - if options.aggregate_fn().is::() { - if !has_nans(&input_dtype) { - return Ok(lit(false)); - } - if !input_is_root { - return Ok(null_expr(DType::Bool(Nullability::NonNullable))); - } - return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, row_count_expr())); - } - - if options.aggregate_fn().is::() { - if !has_nans(&input_dtype) { - return Ok(lit(true)); - } - if !input_is_root { - return Ok(null_expr(DType::Bool(Nullability::NonNullable))); - } - return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, lit(0u64))); - } +struct ZoneMapStatsBinder<'a> { + zone_map: &'a ZoneMap, +} - if options.aggregate_fn().is::() && !has_nans(&input_dtype) { - return Ok(lit(0u64)); - } +impl StatBinder for ZoneMapStatsBinder<'_> { + fn scope(&self) -> &DType { + &self.zone_map.column_dtype + } - let return_dtype = match options.aggregate_fn().return_dtype(&input_dtype) { - Some(return_dtype) => return_dtype, - None => vortex_bail!( - "Aggregate function {} does not support input dtype {}", - options.aggregate_fn(), - input_dtype - ), + fn bind_aggregate( + &self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + _stat_dtype: &DType, + ) -> VortexResult> { + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); }; - - if !input_is_root { - return Ok(null_expr(return_dtype)); - } - - if options.aggregate_fn().is::() { - return Ok(eq(self.stat_field_expr(Stat::NullCount)?, row_count_expr())); - } - - if options.aggregate_fn().is::() { - return Ok(eq(self.stat_field_expr(Stat::NullCount)?, lit(0u64))); + if !is_root(input) { + return Ok(None); } - - let Some(stat) = Stat::from_aggregate_fn(options.aggregate_fn()) else { - return Ok(null_expr(return_dtype)); - }; - - self.stat_field_expr(stat) - } - - fn stat_field_expr(&self, stat: Stat) -> VortexResult { - if self.array.unmasked_field_by_name_opt(stat.name()).is_some() { - return Ok(get_item(stat.name(), root())); + if self + .zone_map + .array + .unmasked_field_by_name_opt(stat.name()) + .is_none() + { + return Ok(None); } - - let Some(dtype) = stat.dtype(&self.column_dtype) else { - vortex_bail!( - "Stat {} does not support column dtype {}", - stat, - self.column_dtype - ); - }; - Ok(null_expr(dtype)) + Ok(Some(get_item(stat.name(), root()))) } } -fn row_count_expr() -> Expression { - RowCount.new_expr(EmptyOptions, []) -} - -fn null_expr(dtype: DType) -> Expression { - lit(Scalar::null(dtype.as_nullable())) -} - -fn has_nans(dtype: &DType) -> bool { - matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) -} - /// Build per-zone row counts for a zone map. /// /// `zone_len` is the nominal zone size; only the final zone may be shorter. The @@ -422,7 +345,7 @@ mod tests { } #[test] - fn all_null_stat_fn_lowers_to_null_count_and_row_count() { + fn is_null_falsification_uses_null_count() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -436,12 +359,18 @@ mod tests { ) .unwrap(); - let mask = zone_map.prune(&all_null(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true])); + let expr = is_null(root()); + let pruning_expr = falsify(&expr, PType::U64.into()); + + let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([true, false, false]) + ); } #[test] - fn all_non_null_stat_fn_lowers_to_null_count() { + fn abstract_null_stats_do_not_derive_from_null_count() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -455,15 +384,21 @@ mod tests { ) .unwrap(); + let mask = zone_map.prune(&all_null(root()), &SESSION).unwrap(); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([false, false, false]) + ); + let mask = zone_map.prune(&all_non_null(root()), &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), - BoolArray::from_iter([true, false, false]) + BoolArray::from_iter([false, false, false]) ); } #[test] - fn non_float_nan_stat_fns_lower_to_constants() { + fn non_float_nan_stat_fns_error() { let zone_map = ZoneMap::try_new( PType::I32.into(), StructArray::try_new(FieldNames::empty(), vec![], 2, Validity::NonNullable).unwrap(), @@ -473,11 +408,8 @@ mod tests { ) .unwrap(); - let mask = zone_map.prune(&all_nan(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, false])); - - let mask = zone_map.prune(&all_non_nan(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([true, true])); + assert!(zone_map.prune(&all_nan(root()), &SESSION).is_err()); + assert!(zone_map.prune(&all_non_nan(root()), &SESSION).is_err()); } #[test] @@ -507,7 +439,7 @@ mod tests { } #[test] - fn float_min_max_stat_fn_requires_nan_count() { + fn float_min_max_prunes_only_with_all_non_nan_proof() { let zone_map = ZoneMap::try_new( PType::F32.into(), StructArray::from_fields(&[ @@ -566,7 +498,7 @@ mod tests { } #[test] - fn float_cast_min_max_stat_fn_uses_source_nan_count() { + fn float_cast_min_max_stat_fn_requires_all_non_nan() { let zone_map = ZoneMap::try_new( PType::F32.into(), StructArray::from_fields(&[