From 5b6a86224d49e0378484d32588a27cf9d19e7a36 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 10 Jun 2026 14:49:56 -0700 Subject: [PATCH 01/28] Make stats rewrite rules public Port file pruning to session stats rewrites Signed-off-by: "Nicholas Gates" Signed-off-by: Nicholas Gates --- vortex-array/src/expr/pruning/mod.rs | 1 + vortex-array/src/expr/pruning/pruning_expr.rs | 175 ++++++++++++++++++ vortex-array/src/stats/mod.rs | 2 +- vortex-array/src/stats/rewrite.rs | 16 +- vortex-array/src/stats/session.rs | 8 +- vortex-file/src/file.rs | 6 +- vortex-file/src/v2/file_stats_reader.rs | 108 ++++++++++- 7 files changed, 299 insertions(+), 17 deletions(-) diff --git a/vortex-array/src/expr/pruning/mod.rs b/vortex-array/src/expr/pruning/mod.rs index 7c20508b7a8..bbcfa5942a0 100644 --- a/vortex-array/src/expr/pruning/mod.rs +++ b/vortex-array/src/expr/pruning/mod.rs @@ -6,6 +6,7 @@ mod relation; pub use pruning_expr::RequiredStats; pub use pruning_expr::checked_pruning_expr; +pub use pruning_expr::checked_pruning_expr_with_session; pub use pruning_expr::field_path_stat_field_name; pub use relation::Relation; diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 00d29fbcf99..54c9666c283 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -5,18 +5,36 @@ use std::cell::RefCell; use std::iter; use itertools::Itertools; +use vortex_error::VortexResult; +use vortex_session::VortexSession; use vortex_utils::aliases::hash_map::HashMap; use super::relation::Relation; +use crate::aggregate_fn::fns::all_nan::AllNan; +use crate::aggregate_fn::fns::all_non_nan::AllNonNan; +use crate::aggregate_fn::fns::all_non_null::AllNonNull; +use crate::aggregate_fn::fns::all_null::AllNull; +use crate::aggregate_fn::fns::nan_count::NanCount; +use crate::dtype::DType; use crate::dtype::Field; use crate::dtype::FieldName; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; use crate::expr::Expression; use crate::expr::StatsCatalog; +use crate::expr::analysis::referenced_field_paths; +use crate::expr::eq; use crate::expr::get_item; +use crate::expr::lit; use crate::expr::root; use crate::expr::stats::Stat; +use crate::expr::traversal::NodeExt; +use crate::expr::traversal::Transformed; +use crate::scalar::Scalar; +use crate::scalar_fn::EmptyOptions; +use crate::scalar_fn::ScalarFnVTableExt; +use crate::scalar_fn::fns::stat::StatFn; +use crate::scalar_fn::internal::row_count::RowCount; pub type RequiredStats = Relation; @@ -113,6 +131,163 @@ pub fn checked_pruning_expr( Some((expr, relation)) } +/// Build a pruning expression using session-registered stats rewrite rules. +/// +/// The returned expression is lowered to the same stats-table field references as +/// [`checked_pruning_expr`]. If a rewrite asks for a stat that is not present in +/// `available_stats`, this returns `Ok(None)`. +pub fn checked_pruning_expr_with_session( + expr: &Expression, + scope: &DType, + available_stats: &FieldPathSet, + session: &VortexSession, +) -> VortexResult> { + let Some(predicate) = expr.falsify(scope, session)? else { + return Ok(None); + }; + + lower_stat_fns(predicate, scope, available_stats) +} + +fn lower_stat_fns( + predicate: Expression, + scope: &DType, + available_stats: &FieldPathSet, +) -> VortexResult> { + let mut required_stats = Relation::new(); + let mut missing_stat = false; + let lowered = predicate + .transform_down(|expr| { + if !expr.is::() { + return Ok(Transformed::no(expr)); + } + + if let Some(lowered) = + lower_stat_fn(&expr, scope, available_stats, &mut required_stats)? + { + return Ok(Transformed::yes(lowered)); + } + + missing_stat = true; + let dtype = expr.return_dtype(scope)?; + Ok(Transformed::yes(null_expr(dtype))) + })? + .into_inner(); + + if missing_stat { + return Ok(None); + } + + Ok(Some((lowered, required_stats))) +} + +fn lower_stat_fn( + expr: &Expression, + scope: &DType, + available_stats: &FieldPathSet, + required_stats: &mut RequiredStats, +) -> VortexResult> { + let options = expr.as_::(); + let aggregate_fn = options.aggregate_fn(); + let input = expr.child(0); + let input_dtype = input.return_dtype(scope)?; + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(Some(lit(false))); + } + return lower_stat_ref( + input, + Stat::NaNCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, row_count_expr()))); + } + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(Some(lit(true))); + } + return lower_stat_ref( + input, + Stat::NaNCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, lit(0u64)))); + } + + if aggregate_fn.is::() && !has_nans(&input_dtype) { + return Ok(Some(lit(0u64))); + } + + if aggregate_fn.is::() { + return lower_stat_ref( + input, + Stat::NullCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, row_count_expr()))); + } + + if aggregate_fn.is::() { + return lower_stat_ref( + input, + Stat::NullCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, lit(0u64)))); + } + + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; + + lower_stat_ref(input, stat, scope, available_stats, required_stats) +} + +fn lower_stat_ref( + input: &Expression, + stat: Stat, + scope: &DType, + available_stats: &FieldPathSet, + required_stats: &mut RequiredStats, +) -> VortexResult> { + let field_paths = referenced_field_paths(input, scope)?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + let stat_path = field_path.clone().push(stat.name()); + if !available_stats.contains(&stat_path) { + return Ok(None); + } + + required_stats.insert(field_path.clone(), stat); + Ok(Some(get_item( + field_path_stat_field_name(field_path, stat), + root(), + ))) +} + +fn row_count_expr() -> Expression { + RowCount.new_expr(EmptyOptions, []) +} + +fn null_expr(dtype: DType) -> Expression { + lit(Scalar::null(dtype.as_nullable())) +} + +fn has_nans(dtype: &DType) -> bool { + matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) +} + #[cfg(test)] mod tests { use rstest::fixture; diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index ceb085e0815..3d4cfeb6111 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -21,7 +21,7 @@ pub use stats_set::*; mod array; pub mod expr; pub mod flatbuffers; -pub(crate) mod rewrite; +pub mod rewrite; pub mod session; mod stats_set; diff --git a/vortex-array/src/stats/rewrite.rs b/vortex-array/src/stats/rewrite.rs index 52d354df1a0..dfba62ded40 100644 --- a/vortex-array/src/stats/rewrite.rs +++ b/vortex-array/src/stats/rewrite.rs @@ -21,7 +21,7 @@ mod builtins; pub(crate) use builtins::register_builtins; /// Shared reference to a stats rewrite rule. -pub(crate) type StatsRewriteRuleRef = Arc; +pub type StatsRewriteRuleRef = Arc; /// A plugin-provided rule for predicates whose root scalar function matches this rule. /// @@ -40,7 +40,7 @@ pub(crate) type StatsRewriteRuleRef = Arc; /// `expr` is the full predicate expression whose root scalar function id is /// [`Self::scalar_fn_id`]. Use [`StatsRewriteCtx`] to resolve dtypes and recursively rewrite child /// predicates. -pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { +pub trait StatsRewriteRule: Debug + Send + Sync + 'static { /// Returns the scalar function id handled by this rule. fn scalar_fn_id(&self) -> ScalarFnId; @@ -83,35 +83,35 @@ pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { } /// Context passed to stats rewrite rules. -pub(crate) struct StatsRewriteCtx<'a> { +pub struct StatsRewriteCtx<'a> { session: &'a VortexSession, scope: &'a DType, } impl<'a> StatsRewriteCtx<'a> { /// Create a rewrite context for `session`. - pub(crate) fn new(session: &'a VortexSession, scope: &'a DType) -> Self { + pub fn new(session: &'a VortexSession, scope: &'a DType) -> Self { Self { session, scope } } /// Returns the session that owns the rewrite registry. - pub(crate) fn session(&self) -> &'a VortexSession { + pub fn session(&self) -> &'a VortexSession { self.session } /// Return the dtype of `expr` within this rewrite scope. - pub(crate) fn return_dtype(&self, expr: &Expression) -> VortexResult { + pub fn return_dtype(&self, expr: &Expression) -> VortexResult { expr.return_dtype(self.scope) } /// Rewrite `expr` into a stats-backed falsifier. - pub(crate) fn falsify(&self, expr: &Expression) -> VortexResult> { + pub fn falsify(&self, expr: &Expression) -> VortexResult> { self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::falsify) } /// Rewrite `expr` into a stats-backed satisfier. - pub(crate) fn satisfy(&self, expr: &Expression) -> VortexResult> { + pub fn satisfy(&self, expr: &Expression) -> VortexResult> { self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::satisfy) } diff --git a/vortex-array/src/stats/session.rs b/vortex-array/src/stats/session.rs index 2d4325b2cd7..91eae4a4fa9 100644 --- a/vortex-array/src/stats/session.rs +++ b/vortex-array/src/stats/session.rs @@ -37,14 +37,12 @@ impl Default for StatsSession { impl StatsSession { /// Register a stats rewrite rule. - #[allow(dead_code)] - pub(crate) fn register_rewrite(&self, rule: R) { + pub fn register_rewrite(&self, rule: R) { self.register_rewrite_ref(Arc::new(rule)); } /// Register a shared stats rewrite rule. - #[allow(dead_code)] - pub(crate) fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) { + pub fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) { let mut rules = self.rewrite_rules.write(); let rule_id = rule.scalar_fn_id(); let mut updated_rules = rules @@ -75,7 +73,7 @@ impl SessionVar for StatsSession { } /// Extension trait for accessing stats session data. -pub(crate) trait StatsSessionExt: SessionExt { +pub trait StatsSessionExt: SessionExt { /// Returns the stats session state. fn stats(&self) -> Ref<'_, StatsSession> { self.get::() diff --git a/vortex-file/src/file.rs b/vortex-file/src/file.rs index ded986f6210..225d18b561a 100644 --- a/vortex-file/src/file.rs +++ b/vortex-file/src/file.rs @@ -22,7 +22,7 @@ use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; use vortex_array::dtype::FieldPathSet; use vortex_array::expr::Expression; -use vortex_array::expr::pruning::checked_pruning_expr; +use vortex_array::expr::pruning::checked_pruning_expr_with_session; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_error::VortexResult; use vortex_layout::LayoutReader; @@ -217,7 +217,9 @@ impl VortexFile { }), ); - let Some((predicate, required_stats)) = checked_pruning_expr(filter, &set) else { + let Some((predicate, required_stats)) = + checked_pruning_expr_with_session(filter, self.footer.dtype(), &set, &self.session)? + else { return Ok(false); }; diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index 0121c12b07d..22c92e817d6 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -10,22 +10,37 @@ use std::ops::Range; use std::sync::Arc; +use itertools::Itertools; use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::MaskFuture; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::fns::all_nan::AllNan; +use vortex_array::aggregate_fn::fns::all_non_nan::AllNonNan; +use vortex_array::aggregate_fn::fns::all_non_null::AllNonNull; +use vortex_array::aggregate_fn::fns::all_null::AllNull; +use vortex_array::aggregate_fn::fns::nan_count::NanCount; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::NullArray; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; +use vortex_array::dtype::Nullability; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; use vortex_array::expr::StatsCatalog; +use vortex_array::expr::analysis::referenced_field_paths; +use vortex_array::expr::eq; use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; +use vortex_array::expr::traversal::NodeExt; +use vortex_array::expr::traversal::Transformed; use vortex_array::scalar::Scalar; +use vortex_array::scalar_fn::EmptyOptions; +use vortex_array::scalar_fn::ScalarFnVTableExt; use vortex_array::scalar_fn::fns::literal::Literal; +use vortex_array::scalar_fn::fns::stat::StatFn; +use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_error::VortexResult; use vortex_layout::ArrayFuture; @@ -83,10 +98,11 @@ impl FileStatsLayoutReader { /// Row-count placeholders are resolved against the full file row count, /// independent of the requested row range. fn evaluate_file_stats(&self, expr: &Expression) -> VortexResult { - let Some(pruning_expr) = expr.stat_falsification(self) else { + let Some(pruning_expr) = expr.falsify(self.child.dtype(), &self.session)? else { // If there is no pruning expression, we can't prune. return Ok(false); }; + let pruning_expr = self.lower_stats(pruning_expr)?; // Given how we implemented the StatsCatalog, we know the expression must be all literals // or row_count placeholders. We can therefore optimize with a null scope since there are @@ -115,11 +131,101 @@ impl FileStatsLayoutReader { Ok(result.as_bool().value() == Some(true)) } + fn lower_stats(&self, predicate: Expression) -> VortexResult { + predicate + .transform_down(|expr| { + if expr.is::() { + return self.lower_stat_fn(expr).map(Transformed::yes); + } + + Ok(Transformed::no(expr)) + }) + .map(Transformed::into_inner) + } + + fn lower_stat_fn(&self, expr: Expression) -> VortexResult { + let options = expr.as_::(); + let aggregate_fn = options.aggregate_fn(); + let input = expr.child(0); + let input_dtype = input.return_dtype(self.child.dtype())?; + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(lit(false)); + } + return Ok(self + .stat_ref(input, Stat::NaNCount)? + .map(|stat| eq(stat, row_count_expr())) + .unwrap_or_else(null_bool_expr)); + } + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(lit(true)); + } + return Ok(self + .stat_ref(input, Stat::NaNCount)? + .map(|stat| eq(stat, lit(0u64))) + .unwrap_or_else(null_bool_expr)); + } + + if aggregate_fn.is::() && !has_nans(&input_dtype) { + return Ok(lit(0u64)); + } + + if aggregate_fn.is::() { + return Ok(self + .stat_ref(input, Stat::NullCount)? + .map(|stat| eq(stat, row_count_expr())) + .unwrap_or_else(null_bool_expr)); + } + + if aggregate_fn.is::() { + return Ok(self + .stat_ref(input, Stat::NullCount)? + .map(|stat| eq(stat, lit(0u64))) + .unwrap_or_else(null_bool_expr)); + } + + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(null_expr(expr.return_dtype(self.child.dtype())?)); + }; + + let return_dtype = expr.return_dtype(self.child.dtype())?; + Ok(self + .stat_ref(input, stat)? + .unwrap_or_else(|| null_expr(return_dtype))) + } + + fn stat_ref(&self, input: &Expression, stat: Stat) -> VortexResult> { + let field_paths = referenced_field_paths(input, self.child.dtype())?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + Ok(self.stats_ref(field_path, stat)) + } + pub fn file_stats(&self) -> &FileStatistics { &self.file_stats } } +fn row_count_expr() -> Expression { + RowCount.new_expr(EmptyOptions, []) +} + +fn null_expr(dtype: DType) -> Expression { + lit(Scalar::null(dtype.as_nullable())) +} + +fn null_bool_expr() -> Expression { + null_expr(DType::Bool(Nullability::NonNullable)) +} + +fn has_nans(dtype: &DType) -> bool { + matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) +} + /// Implements [`StatsCatalog`] to provide file-level stats to expressions during pruning evaluation. impl StatsCatalog for FileStatsLayoutReader { fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { From e8dd011748a1138750b3946a0af0d30a92fdcadb Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 11:18:48 -0400 Subject: [PATCH 02/28] Centralize stat expression binding Signed-off-by: "Nicholas Gates" Signed-off-by: Nicholas Gates --- vortex-array/src/expr/pruning/pruning_expr.rs | 182 ++++-------------- vortex-array/src/stats/bind.rs | 160 +++++++++++++++ vortex-array/src/stats/mod.rs | 1 + vortex-file/src/v2/file_stats_reader.rs | 119 +++--------- vortex-layout/src/layouts/zoned/zone_map.rs | 140 +++----------- 5 files changed, 256 insertions(+), 346 deletions(-) create mode 100644 vortex-array/src/stats/bind.rs diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 54c9666c283..5208b2efada 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -10,11 +10,6 @@ use vortex_session::VortexSession; use vortex_utils::aliases::hash_map::HashMap; use super::relation::Relation; -use crate::aggregate_fn::fns::all_nan::AllNan; -use crate::aggregate_fn::fns::all_non_nan::AllNonNan; -use crate::aggregate_fn::fns::all_non_null::AllNonNull; -use crate::aggregate_fn::fns::all_null::AllNull; -use crate::aggregate_fn::fns::nan_count::NanCount; use crate::dtype::DType; use crate::dtype::Field; use crate::dtype::FieldName; @@ -23,18 +18,11 @@ use crate::dtype::FieldPathSet; use crate::expr::Expression; use crate::expr::StatsCatalog; use crate::expr::analysis::referenced_field_paths; -use crate::expr::eq; use crate::expr::get_item; -use crate::expr::lit; use crate::expr::root; use crate::expr::stats::Stat; -use crate::expr::traversal::NodeExt; -use crate::expr::traversal::Transformed; -use crate::scalar::Scalar; -use crate::scalar_fn::EmptyOptions; -use crate::scalar_fn::ScalarFnVTableExt; -use crate::scalar_fn::fns::stat::StatFn; -use crate::scalar_fn::internal::row_count::RowCount; +use crate::stats::bind::StatBinder; +use crate::stats::bind::bind_stats; pub type RequiredStats = Relation; @@ -146,146 +134,54 @@ pub fn checked_pruning_expr_with_session( return Ok(None); }; - lower_stat_fns(predicate, scope, available_stats) -} - -fn lower_stat_fns( - predicate: Expression, - scope: &DType, - available_stats: &FieldPathSet, -) -> VortexResult> { - let mut required_stats = Relation::new(); - let mut missing_stat = false; - let lowered = predicate - .transform_down(|expr| { - if !expr.is::() { - return Ok(Transformed::no(expr)); - } - - if let Some(lowered) = - lower_stat_fn(&expr, scope, available_stats, &mut required_stats)? - { - return Ok(Transformed::yes(lowered)); - } - - missing_stat = true; - let dtype = expr.return_dtype(scope)?; - Ok(Transformed::yes(null_expr(dtype))) - })? - .into_inner(); - - if missing_stat { + let mut binder = RequiredStatsBinder { + scope, + available_stats, + required_stats: Relation::new(), + }; + let Some(lowered) = bind_stats(predicate, &mut binder)? else { return Ok(None); - } + }; - Ok(Some((lowered, required_stats))) + Ok(Some((lowered, binder.required_stats))) } -fn lower_stat_fn( - expr: &Expression, - scope: &DType, - available_stats: &FieldPathSet, - required_stats: &mut RequiredStats, -) -> VortexResult> { - let options = expr.as_::(); - let aggregate_fn = options.aggregate_fn(); - let input = expr.child(0); - let input_dtype = input.return_dtype(scope)?; - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(Some(lit(false))); - } - return lower_stat_ref( - input, - Stat::NaNCount, - scope, - available_stats, - required_stats, - ) - .map(|stat| stat.map(|stat| eq(stat, row_count_expr()))); - } - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(Some(lit(true))); - } - return lower_stat_ref( - input, - Stat::NaNCount, - scope, - available_stats, - required_stats, - ) - .map(|stat| stat.map(|stat| eq(stat, lit(0u64)))); - } +struct RequiredStatsBinder<'a> { + scope: &'a DType, + available_stats: &'a FieldPathSet, + required_stats: RequiredStats, +} - if aggregate_fn.is::() && !has_nans(&input_dtype) { - return Ok(Some(lit(0u64))); +impl StatBinder for RequiredStatsBinder<'_> { + fn scope(&self) -> &DType { + self.scope } - if aggregate_fn.is::() { - return lower_stat_ref( - input, - Stat::NullCount, - scope, - available_stats, - required_stats, - ) - .map(|stat| stat.map(|stat| eq(stat, row_count_expr()))); - } + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + _stat_dtype: &DType, + ) -> VortexResult> { + let field_paths = referenced_field_paths(input, self.scope)?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + let stat_path = field_path.clone().push(stat.name()); + if !self.available_stats.contains(&stat_path) { + return Ok(None); + } - if aggregate_fn.is::() { - return lower_stat_ref( - input, - Stat::NullCount, - scope, - available_stats, - required_stats, - ) - .map(|stat| stat.map(|stat| eq(stat, lit(0u64)))); + self.required_stats.insert(field_path.clone(), stat); + Ok(Some(get_item( + field_path_stat_field_name(field_path, stat), + root(), + ))) } - let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { - return Ok(None); - }; - - lower_stat_ref(input, stat, scope, available_stats, required_stats) -} - -fn lower_stat_ref( - input: &Expression, - stat: Stat, - scope: &DType, - available_stats: &FieldPathSet, - required_stats: &mut RequiredStats, -) -> VortexResult> { - let field_paths = referenced_field_paths(input, scope)?; - let Some(field_path) = field_paths.iter().exactly_one().ok() else { - return Ok(None); - }; - let stat_path = field_path.clone().push(stat.name()); - if !available_stats.contains(&stat_path) { - return Ok(None); + fn missing_stat(&mut self, _dtype: DType) -> VortexResult> { + Ok(None) } - - required_stats.insert(field_path.clone(), stat); - Ok(Some(get_item( - field_path_stat_field_name(field_path, stat), - root(), - ))) -} - -fn row_count_expr() -> Expression { - RowCount.new_expr(EmptyOptions, []) -} - -fn null_expr(dtype: DType) -> Expression { - lit(Scalar::null(dtype.as_nullable())) -} - -fn has_nans(dtype: &DType) -> bool { - matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) } #[cfg(test)] diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs new file mode 100644 index 00000000000..714404fdceb --- /dev/null +++ b/vortex-array/src/stats/bind.rs @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Bind abstract `vortex.stat` expressions to a concrete stats representation. + +use vortex_error::VortexResult; + +use crate::aggregate_fn::fns::all_nan::AllNan; +use crate::aggregate_fn::fns::all_non_nan::AllNonNan; +use crate::aggregate_fn::fns::all_non_null::AllNonNull; +use crate::aggregate_fn::fns::all_null::AllNull; +use crate::aggregate_fn::fns::nan_count::NanCount; +use crate::dtype::DType; +use crate::expr::Expression; +use crate::expr::eq; +use crate::expr::lit; +use crate::expr::stats::Stat; +use crate::expr::traversal::NodeExt; +use crate::expr::traversal::Transformed; +use crate::scalar::Scalar; +use crate::scalar_fn::EmptyOptions; +use crate::scalar_fn::ScalarFnVTableExt; +use crate::scalar_fn::fns::stat::StatFn; +use crate::scalar_fn::internal::row_count::RowCount; + +/// A target that can bind abstract statistics to concrete expressions. +pub trait StatBinder { + /// The dtype scope used to type-check expressions before stats are bound. + fn scope(&self) -> &DType; + + /// Bind `stat(input)` to a concrete expression. + /// + /// Returning `Ok(None)` marks the stat as unavailable. [`bind_stats`] will + /// then call [`Self::missing_stat`] with the dtype expected from the + /// original `vortex.stat` expression. + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + stat_dtype: &DType, + ) -> VortexResult>; + + /// Expression to use when a stat is unavailable. + /// + /// The default is a nullable null literal, which preserves three-valued + /// pruning semantics for stats-table execution. Catalog-like binders can + /// return `Ok(None)` to reject expressions that require unavailable stats. + fn missing_stat(&mut self, dtype: DType) -> VortexResult> { + Ok(Some(null_expr(dtype))) + } +} + +/// Bind all `vortex.stat` expressions in `predicate`. +/// +/// The predicate is usually the output of a stats rewrite rule. This function +/// centralizes the legacy aggregate/stat mapping: `all_null` and `all_nan` +/// style aggregate expressions are expanded through exact count stats, while +/// direct aggregate stats are delegated to the supplied binder. +pub fn bind_stats( + predicate: Expression, + binder: &mut impl StatBinder, +) -> VortexResult> { + let scope = binder.scope().clone(); + let mut missing_stat = false; + let lowered = predicate + .transform_down(|expr| { + if !expr.is::() { + return Ok(Transformed::no(expr)); + } + + match bind_stat_fn(&expr, &scope, binder)? { + Some(bound) => Ok(Transformed::yes(bound)), + None => { + let dtype = expr.return_dtype(&scope)?; + match binder.missing_stat(dtype.clone())? { + Some(missing) => Ok(Transformed::yes(missing)), + None => { + missing_stat = true; + Ok(Transformed::yes(null_expr(dtype))) + } + } + } + } + })? + .into_inner(); + + if missing_stat { + return Ok(None); + } + + Ok(Some(lowered)) +} + +fn bind_stat_fn( + expr: &Expression, + scope: &DType, + binder: &mut impl StatBinder, +) -> VortexResult> { + let options = expr.as_::(); + let aggregate_fn = options.aggregate_fn(); + let input = expr.child(0); + let input_dtype = input.return_dtype(scope)?; + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(Some(lit(false))); + } + let stat_dtype = expr.return_dtype(scope)?; + return Ok(binder + .bind_stat(input, Stat::NaNCount, &stat_dtype)? + .map(|stat| eq(stat, row_count_expr()))); + } + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(Some(lit(true))); + } + let stat_dtype = expr.return_dtype(scope)?; + return Ok(binder + .bind_stat(input, Stat::NaNCount, &stat_dtype)? + .map(|stat| eq(stat, lit(0u64)))); + } + + if aggregate_fn.is::() && !has_nans(&input_dtype) { + return Ok(Some(lit(0u64))); + } + + if aggregate_fn.is::() { + let stat_dtype = expr.return_dtype(scope)?; + return Ok(binder + .bind_stat(input, Stat::NullCount, &stat_dtype)? + .map(|stat| eq(stat, row_count_expr()))); + } + + if aggregate_fn.is::() { + let stat_dtype = expr.return_dtype(scope)?; + return Ok(binder + .bind_stat(input, Stat::NullCount, &stat_dtype)? + .map(|stat| eq(stat, lit(0u64)))); + } + + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; + + let stat_dtype = expr.return_dtype(scope)?; + binder.bind_stat(input, stat, &stat_dtype) +} + +fn row_count_expr() -> Expression { + RowCount.new_expr(EmptyOptions, []) +} + +fn null_expr(dtype: DType) -> Expression { + lit(Scalar::null(dtype.as_nullable())) +} + +fn has_nans(dtype: &DType) -> bool { + matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) +} diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index 3d4cfeb6111..5f5684dbde2 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -19,6 +19,7 @@ pub use expr::sum; pub use stats_set::*; mod array; +pub mod bind; pub mod expr; pub mod flatbuffers; pub mod rewrite; diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index 22c92e817d6..be2efce92a2 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -15,34 +15,24 @@ use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::MaskFuture; use vortex_array::VortexSessionExecute; -use vortex_array::aggregate_fn::fns::all_nan::AllNan; -use vortex_array::aggregate_fn::fns::all_non_nan::AllNonNan; -use vortex_array::aggregate_fn::fns::all_non_null::AllNonNull; -use vortex_array::aggregate_fn::fns::all_null::AllNull; -use vortex_array::aggregate_fn::fns::nan_count::NanCount; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::NullArray; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; -use vortex_array::dtype::Nullability; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; use vortex_array::expr::StatsCatalog; use vortex_array::expr::analysis::referenced_field_paths; -use vortex_array::expr::eq; use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; -use vortex_array::expr::traversal::NodeExt; -use vortex_array::expr::traversal::Transformed; use vortex_array::scalar::Scalar; -use vortex_array::scalar_fn::EmptyOptions; -use vortex_array::scalar_fn::ScalarFnVTableExt; use vortex_array::scalar_fn::fns::literal::Literal; -use vortex_array::scalar_fn::fns::stat::StatFn; -use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; +use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_stats; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_layout::ArrayFuture; use vortex_layout::LayoutReader; use vortex_layout::LayoutReaderRef; @@ -132,77 +122,11 @@ impl FileStatsLayoutReader { } fn lower_stats(&self, predicate: Expression) -> VortexResult { - predicate - .transform_down(|expr| { - if expr.is::() { - return self.lower_stat_fn(expr).map(Transformed::yes); - } - - Ok(Transformed::no(expr)) - }) - .map(Transformed::into_inner) - } - - fn lower_stat_fn(&self, expr: Expression) -> VortexResult { - let options = expr.as_::(); - let aggregate_fn = options.aggregate_fn(); - let input = expr.child(0); - let input_dtype = input.return_dtype(self.child.dtype())?; - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(lit(false)); - } - return Ok(self - .stat_ref(input, Stat::NaNCount)? - .map(|stat| eq(stat, row_count_expr())) - .unwrap_or_else(null_bool_expr)); - } - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(lit(true)); - } - return Ok(self - .stat_ref(input, Stat::NaNCount)? - .map(|stat| eq(stat, lit(0u64))) - .unwrap_or_else(null_bool_expr)); - } - - if aggregate_fn.is::() && !has_nans(&input_dtype) { - return Ok(lit(0u64)); - } - - if aggregate_fn.is::() { - return Ok(self - .stat_ref(input, Stat::NullCount)? - .map(|stat| eq(stat, row_count_expr())) - .unwrap_or_else(null_bool_expr)); - } - - if aggregate_fn.is::() { - return Ok(self - .stat_ref(input, Stat::NullCount)? - .map(|stat| eq(stat, lit(0u64))) - .unwrap_or_else(null_bool_expr)); - } - - let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { - return Ok(null_expr(expr.return_dtype(self.child.dtype())?)); + let mut binder = FileStatsBinder { reader: self }; + let Some(predicate) = bind_stats(predicate, &mut binder)? else { + vortex_bail!("missing stats should lower to null literals"); }; - - let return_dtype = expr.return_dtype(self.child.dtype())?; - Ok(self - .stat_ref(input, stat)? - .unwrap_or_else(|| null_expr(return_dtype))) - } - - fn stat_ref(&self, input: &Expression, stat: Stat) -> VortexResult> { - let field_paths = referenced_field_paths(input, self.child.dtype())?; - let Some(field_path) = field_paths.iter().exactly_one().ok() else { - return Ok(None); - }; - Ok(self.stats_ref(field_path, stat)) + Ok(predicate) } pub fn file_stats(&self) -> &FileStatistics { @@ -210,20 +134,27 @@ impl FileStatsLayoutReader { } } -fn row_count_expr() -> Expression { - RowCount.new_expr(EmptyOptions, []) -} - -fn null_expr(dtype: DType) -> Expression { - lit(Scalar::null(dtype.as_nullable())) +struct FileStatsBinder<'a> { + reader: &'a FileStatsLayoutReader, } -fn null_bool_expr() -> Expression { - null_expr(DType::Bool(Nullability::NonNullable)) -} +impl StatBinder for FileStatsBinder<'_> { + fn scope(&self) -> &DType { + self.reader.child.dtype() + } -fn has_nans(dtype: &DType) -> bool { - matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + _stat_dtype: &DType, + ) -> VortexResult> { + let field_paths = referenced_field_paths(input, self.scope())?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + Ok(self.reader.stats_ref(field_path, stat)) + } } /// Implements [`StatsCatalog`] to provide file-level stats to expressions during pruning evaluation. diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 96154e69571..5c9fdcdad32 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -8,33 +8,20 @@ use std::sync::Arc; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; -use vortex_array::aggregate_fn::fns::all_nan::AllNan; -use vortex_array::aggregate_fn::fns::all_non_nan::AllNonNan; -use vortex_array::aggregate_fn::fns::all_non_null::AllNonNull; -use vortex_array::aggregate_fn::fns::all_null::AllNull; -use vortex_array::aggregate_fn::fns::nan_count::NanCount; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; use vortex_array::arrays::struct_::StructArrayExt; use vortex_array::dtype::DType; -use vortex_array::dtype::Nullability; use vortex_array::expr::Expression; -use vortex_array::expr::eq; use vortex_array::expr::get_item; use vortex_array::expr::is_root; -use vortex_array::expr::lit; use vortex_array::expr::root; use vortex_array::expr::stats::Stat; -use vortex_array::expr::traversal::NodeExt; -use vortex_array::expr::traversal::Transformed; -use vortex_array::scalar::Scalar; -use vortex_array::scalar_fn::EmptyOptions; -use vortex_array::scalar_fn::ScalarFnVTableExt; -use vortex_array::scalar_fn::fns::stat::StatFn; -use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_array::scalar_fn::internal::row_count::contains_row_count; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; +use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_stats; use vortex_array::validity::Validity; use vortex_buffer::buffer; use vortex_error::VortexResult; @@ -132,107 +119,42 @@ impl ZoneMap { } fn lower_stats(&self, predicate: Expression) -> VortexResult { - // Rewritten predicates are evaluated against the stats table, not the data - // column. Lower each StatFn before execution so unavailable stats become - // nullable "unknown" constants rather than prune signals. - predicate - .transform_down(|expr| { - if expr.is::() { - return self.lower_stat_fn(expr).map(Transformed::yes); - } - - Ok(Transformed::no(expr)) - }) - .map(Transformed::into_inner) - } - - fn lower_stat_fn(&self, expr: Expression) -> VortexResult { - // This is the bridge from aggregate-backed bound expressions to the legacy - // zoned stats columns. Exact NullCount and NanCount can prove richer - // all-* aggregates; non-root or missing stats lower to nullable unknowns. - let options = expr.as_::(); - let input = expr.child(0); - let input_dtype = input.return_dtype(&self.column_dtype)?; - let input_is_root = is_root(input); - - if options.aggregate_fn().is::() { - if !has_nans(&input_dtype) { - return Ok(lit(false)); - } - if !input_is_root { - return Ok(null_expr(DType::Bool(Nullability::NonNullable))); - } - return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, row_count_expr())); - } - - if options.aggregate_fn().is::() { - if !has_nans(&input_dtype) { - return Ok(lit(true)); - } - if !input_is_root { - return Ok(null_expr(DType::Bool(Nullability::NonNullable))); - } - return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, lit(0u64))); - } - - if options.aggregate_fn().is::() && !has_nans(&input_dtype) { - return Ok(lit(0u64)); - } - - let return_dtype = match options.aggregate_fn().return_dtype(&input_dtype) { - Some(return_dtype) => return_dtype, - None => vortex_bail!( - "Aggregate function {} does not support input dtype {}", - options.aggregate_fn(), - input_dtype - ), + let mut binder = ZoneMapStatsBinder { zone_map: self }; + let Some(predicate) = bind_stats(predicate, &mut binder)? else { + vortex_bail!("missing stats should lower to null literals"); }; - - if !input_is_root { - return Ok(null_expr(return_dtype)); - } - - if options.aggregate_fn().is::() { - return Ok(eq(self.stat_field_expr(Stat::NullCount)?, row_count_expr())); - } - - if options.aggregate_fn().is::() { - return Ok(eq(self.stat_field_expr(Stat::NullCount)?, lit(0u64))); - } - - let Some(stat) = Stat::from_aggregate_fn(options.aggregate_fn()) else { - return Ok(null_expr(return_dtype)); - }; - - self.stat_field_expr(stat) - } - - fn stat_field_expr(&self, stat: Stat) -> VortexResult { - if self.array.unmasked_field_by_name_opt(stat.name()).is_some() { - return Ok(get_item(stat.name(), root())); - } - - let Some(dtype) = stat.dtype(&self.column_dtype) else { - vortex_bail!( - "Stat {} does not support column dtype {}", - stat, - self.column_dtype - ); - }; - Ok(null_expr(dtype)) + Ok(predicate) } } -fn row_count_expr() -> Expression { - RowCount.new_expr(EmptyOptions, []) +struct ZoneMapStatsBinder<'a> { + zone_map: &'a ZoneMap, } -fn null_expr(dtype: DType) -> Expression { - lit(Scalar::null(dtype.as_nullable())) -} +impl StatBinder for ZoneMapStatsBinder<'_> { + fn scope(&self) -> &DType { + &self.zone_map.column_dtype + } -fn has_nans(dtype: &DType) -> bool { - matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + _stat_dtype: &DType, + ) -> VortexResult> { + if !is_root(input) { + return Ok(None); + } + if self + .zone_map + .array + .unmasked_field_by_name_opt(stat.name()) + .is_none() + { + return Ok(None); + } + Ok(Some(get_item(stat.name(), root()))) + } } /// Build per-zone row counts for a zone map. From c1d6d945d39ed9bafbb60456b84f313e76304d1a Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 14:29:34 -0400 Subject: [PATCH 03/28] Fix file stats binding for computed expressions Signed-off-by: "Nicholas Gates" Signed-off-by: Nicholas Gates --- vortex-file/src/v2/file_stats_reader.rs | 56 ++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index be2efce92a2..7d9bd0d66cf 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -10,7 +10,6 @@ use std::ops::Range; use std::sync::Arc; -use itertools::Itertools; use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::MaskFuture; @@ -23,10 +22,11 @@ use vortex_array::dtype::FieldPath; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; use vortex_array::expr::StatsCatalog; -use vortex_array::expr::analysis::referenced_field_paths; +use vortex_array::expr::is_root; use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; use vortex_array::scalar::Scalar; +use vortex_array::scalar_fn::fns::get_item::GetItem; use vortex_array::scalar_fn::fns::literal::Literal; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::stats::bind::StatBinder; @@ -149,14 +149,22 @@ impl StatBinder for FileStatsBinder<'_> { stat: Stat, _stat_dtype: &DType, ) -> VortexResult> { - let field_paths = referenced_field_paths(input, self.scope())?; - let Some(field_path) = field_paths.iter().exactly_one().ok() else { + let Some(field_path) = direct_field_path(input) else { return Ok(None); }; - Ok(self.reader.stats_ref(field_path, stat)) + Ok(self.reader.stats_ref(&field_path, stat)) } } +fn direct_field_path(expr: &Expression) -> Option { + if is_root(expr) { + return Some(FieldPath::root()); + } + + let field_name = expr.as_opt::()?; + direct_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) +} + /// Implements [`StatsCatalog`] to provide file-level stats to expressions during pruning evaluation. impl StatsCatalog for FileStatsLayoutReader { fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { @@ -261,6 +269,7 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; + use vortex_array::expr::checked_add; use vortex_array::expr::get_item; use vortex_array::expr::gt; use vortex_array::expr::is_not_null; @@ -402,6 +411,43 @@ mod tests { }) } + #[test] + fn no_pruning_for_computed_expression_stats() -> VortexResult<()> { + block_on(|handle| async { + let session = SESSION.clone().with_handle(handle); + let ctx = ArrayContext::empty(); + let segments = Arc::new(TestSegments::default()); + let (ptr, eof) = SequenceId::root().split(); + let struct_array = + StructArray::from_fields([("col", buffer![0i32, 100].into_array())].as_slice())?; + let strategy = TableStrategy::new( + Arc::new(FlatLayoutStrategy::default()), + Arc::new(FlatLayoutStrategy::default()), + ); + let layout = strategy + .write_stream( + ctx, + Arc::::clone(&segments), + struct_array.into_array().to_array_stream().sequenced(ptr), + eof, + &session, + ) + .await?; + + let child = layout.new_reader("".into(), segments, &SESSION, &Default::default())?; + let reader = + FileStatsLayoutReader::new(child, test_file_stats(0, 100), SESSION.clone()); + + let expr = gt(checked_add(get_item("col", root()), lit(5i32)), lit(102i32)); + let mask = Mask::new_true(2); + let result = reader.pruning_evaluation(&(0..2), &expr, mask)?.await?; + + assert_eq!(result, Mask::new_true(2)); + + Ok(()) + }) + } + /// Regression test: `IS NULL` on a nullable timestamp column must not fail with a /// dtype mismatch. The bug was that `stats_ref` used the *field* dtype (timestamp) /// for the `NullCount` stat scalar instead of the stat's own dtype (u64). From 468ebfda10c2cfccb3cacb96aac55439b00ef5d4 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 15:34:21 -0400 Subject: [PATCH 04/28] Fuse checked pruning stats rewrites Signed-off-by: "Nicholas Gates" Signed-off-by: Nicholas Gates --- vortex-array/src/expr/pruning/mod.rs | 1 - vortex-array/src/expr/pruning/pruning_expr.rs | 218 +++++++++--------- vortex-array/src/scalar_fn/fns/is_not_null.rs | 3 + vortex-array/src/scalar_fn/fns/is_null.rs | 3 + .../src/scalar_fn/fns/list_contains/mod.rs | 10 + vortex-array/src/stats/bind.rs | 192 +++++++-------- vortex-file/src/file.rs | 4 +- vortex-layout/src/layouts/zoned/zone_map.rs | 28 +-- 8 files changed, 247 insertions(+), 212 deletions(-) diff --git a/vortex-array/src/expr/pruning/mod.rs b/vortex-array/src/expr/pruning/mod.rs index bbcfa5942a0..7c20508b7a8 100644 --- a/vortex-array/src/expr/pruning/mod.rs +++ b/vortex-array/src/expr/pruning/mod.rs @@ -6,7 +6,6 @@ mod relation; pub use pruning_expr::RequiredStats; pub use pruning_expr::checked_pruning_expr; -pub use pruning_expr::checked_pruning_expr_with_session; pub use pruning_expr::field_path_stat_field_name; pub use relation::Relation; diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 5208b2efada..724dd6a1408 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -1,12 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +#[cfg(test)] use std::cell::RefCell; use std::iter; use itertools::Itertools; use vortex_error::VortexResult; use vortex_session::VortexSession; +#[cfg(test)] use vortex_utils::aliases::hash_map::HashMap; use super::relation::Relation; @@ -16,49 +18,29 @@ use crate::dtype::FieldName; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; use crate::expr::Expression; +#[cfg(test)] use crate::expr::StatsCatalog; use crate::expr::analysis::referenced_field_paths; use crate::expr::get_item; +use crate::expr::is_root; use crate::expr::root; use crate::expr::stats::Stat; +use crate::scalar_fn::fns::cast::Cast; +use crate::scalar_fn::fns::get_item::GetItem; use crate::stats::bind::StatBinder; use crate::stats::bind::bind_stats; pub type RequiredStats = Relation; -// A catalog that return a stat column whenever it is required, tracking all accessed +// A catalog that returns a stat column whenever it is required, tracking all accessed // stats and returning them later. +#[cfg(test)] #[derive(Default)] pub(crate) struct TrackingStatsCatalog { usage: RefCell>, } -impl TrackingStatsCatalog { - /// Consume the catalog, yielding a map of field statistics that were required - /// for each expression. - fn into_usages(self) -> HashMap<(FieldPath, Stat), Expression> { - self.usage.into_inner() - } -} - -// A catalog that return a stat column if it exists in the given scope. -struct ScopeStatsCatalog<'a> { - inner: TrackingStatsCatalog, - available_stats: &'a FieldPathSet, -} - -impl StatsCatalog for ScopeStatsCatalog<'_> { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { - let stat_path = field_path.clone().push(stat.name()); - - if self.available_stats.contains(&stat_path) { - self.inner.stats_ref(field_path, stat) - } else { - None - } - } -} - +#[cfg(test)] impl StatsCatalog for TrackingStatsCatalog { fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { let mut expr = root(); @@ -85,8 +67,7 @@ pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldNa .into() } -/// Build a pruning expr mask, using an existing set of stats. -/// The available stats are provided as a set of [`FieldPath`]. +/// Build a pruning expression using session-registered stats rewrite rules. /// /// A pruning expression is one that returns `true` for all positions where the original expression /// cannot hold, and false if it cannot be determined from stats alone whether the positions can @@ -97,34 +78,10 @@ pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldNa /// replace those placeholders with the row count for its current scope before /// executing the returned expression. /// -/// If the falsification logic attempts to access an unknown stat, -/// this function will return `None`. +/// The returned expression is lowered to stats-table field references. Proof branches that require +/// stats not present in `available_stats` are discarded; this returns `Ok(None)` if no usable proof +/// remains. pub fn checked_pruning_expr( - expr: &Expression, - available_stats: &FieldPathSet, -) -> Option<(Expression, RequiredStats)> { - let catalog = ScopeStatsCatalog { - inner: Default::default(), - available_stats, - }; - - let expr = expr.stat_falsification(&catalog)?; - - // TODO(joe): filter access by used exprs - let mut relation: Relation = Relation::new(); - for ((field_path, stat), _) in catalog.inner.into_usages() { - relation.insert(field_path, stat) - } - - Some((expr, relation)) -} - -/// Build a pruning expression using session-registered stats rewrite rules. -/// -/// The returned expression is lowered to the same stats-table field references as -/// [`checked_pruning_expr`]. If a rewrite asks for a stat that is not present in -/// `available_stats`, this returns `Ok(None)`. -pub fn checked_pruning_expr_with_session( expr: &Expression, scope: &DType, available_stats: &FieldPathSet, @@ -163,9 +120,15 @@ impl StatBinder for RequiredStatsBinder<'_> { stat: Stat, _stat_dtype: &DType, ) -> VortexResult> { - let field_paths = referenced_field_paths(input, self.scope)?; - let Some(field_path) = field_paths.iter().exactly_one().ok() else { - return Ok(None); + let field_path = match direct_stat_field_path(input) { + Some(field_path) => field_path, + None => { + let field_paths = referenced_field_paths(input, self.scope)?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + field_path.clone() + } }; let stat_path = field_path.clone().push(stat.name()); if !self.available_stats.contains(&stat_path) { @@ -174,7 +137,7 @@ impl StatBinder for RequiredStatsBinder<'_> { self.required_stats.insert(field_path.clone(), stat); Ok(Some(get_item( - field_path_stat_field_name(field_path, stat), + field_path_stat_field_name(&field_path, stat), root(), ))) } @@ -182,22 +145,54 @@ impl StatBinder for RequiredStatsBinder<'_> { fn missing_stat(&mut self, _dtype: DType) -> VortexResult> { Ok(None) } + + fn bind_branch(&mut self, bind: F) -> VortexResult> + where + Self: Sized, + F: FnOnce(&mut Self) -> VortexResult>, + { + let required_stats = self.required_stats.clone(); + let bound = bind(self)?; + if bound.is_none() { + self.required_stats = required_stats; + } + Ok(bound) + } +} + +fn direct_stat_field_path(expr: &Expression) -> Option { + if is_root(expr) { + return Some(FieldPath::root()); + } + + if expr.is::() { + return direct_stat_field_path(expr.child(0)); + } + + let field_name = expr.as_opt::()?; + direct_stat_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) } #[cfg(test)] mod tests { + use std::sync::LazyLock; + use rstest::fixture; use rstest::rstest; + use vortex_session::VortexSession; + use vortex_utils::aliases::hash_map::HashMap; use vortex_utils::aliases::hash_set::HashSet; - use super::HashMap; + use super::RequiredStats; use crate::dtype::DType; use crate::dtype::FieldName; use crate::dtype::FieldNames; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; use crate::dtype::Nullability; + use crate::dtype::PType; use crate::dtype::StructFields; + use crate::expr::Expression; use crate::expr::and; use crate::expr::between; use crate::expr::cast; @@ -217,6 +212,38 @@ mod tests { use crate::expr::stats::Stat; use crate::scalar_fn::fns::between::BetweenOptions; use crate::scalar_fn::fns::between::StrictComparison; + use crate::stats::session::StatsSession; + + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + fn test_scope() -> DType { + DType::Struct( + StructFields::from_iter([ + ("a", DType::Primitive(PType::I32, Nullability::NonNullable)), + ("b", DType::Primitive(PType::I32, Nullability::NonNullable)), + ("x", DType::Bool(Nullability::NonNullable)), + ("y", DType::Primitive(PType::I32, Nullability::NonNullable)), + ("z", DType::Primitive(PType::I32, Nullability::NonNullable)), + ( + "float_col", + DType::Primitive(PType::F32, Nullability::NonNullable), + ), + ( + "int_col", + DType::Primitive(PType::I32, Nullability::NonNullable), + ), + ]), + Nullability::NonNullable, + ) + } + + fn checked( + expr: &Expression, + available_stats: &FieldPathSet, + ) -> Option<(Expression, RequiredStats)> { + checked_pruning_expr(expr, &test_scope(), available_stats, &SESSION).unwrap() + } // Implement some checked pruning expressions. #[fixture] @@ -237,7 +264,7 @@ mod tests { let name = FieldName::from("a"); let literal_eq = lit(42); let eq_expr = eq(get_item("a", root()), literal_eq.clone()); - let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap(); + let (converted, _refs) = checked(&eq_expr, &available_stats).unwrap(); let expected_expr = or( gt( get_item( @@ -263,7 +290,7 @@ mod tests { let other_col = FieldName::from("b"); let eq_expr = eq(col(column.clone()), col(other_col.clone())); - let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(&eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -308,7 +335,7 @@ mod tests { let other_col = FieldName::from("b"); let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone())); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -355,7 +382,7 @@ mod tests { let other_expr = col(other_col.clone()); let not_eq_expr = gt(col(column.clone()), other_expr); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -388,7 +415,7 @@ mod tests { let other_col = lit(42); let not_eq_expr = gt(col(column.clone()), other_col.clone()); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([( @@ -413,7 +440,7 @@ mod tests { let other_expr = col(other_col.clone()); let not_eq_expr = lt(col(column.clone()), other_expr); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -446,7 +473,7 @@ mod tests { // pruning expr => a.min >= 42 let expr = lt(col("a"), lit(42)); - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (converted, refs) = checked(&expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))]) @@ -458,7 +485,7 @@ mod tests { fn pruning_identity(available_stats: FieldPathSet) { let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50))); - let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (predicate, _) = checked(&expr, &available_stats).unwrap(); let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50))); assert_eq!(&predicate.to_string(), &expected_expr.to_string()); @@ -468,7 +495,7 @@ mod tests { // Test case: a > 10 AND a < 50 let column = FieldName::from("a"); let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50))); - let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap(); + let (predicate, _) = checked(&and_expr, &available_stats).unwrap(); // Expected: a_max <= 10 OR a_min >= 50 assert_eq!( @@ -507,7 +534,7 @@ mod tests { // True > False // True let expr = gt_eq(col("x"), gt(col("y"), col("z"))); - assert!(checked_pruning_expr(&expr, &available_stats).is_none()); + assert!(checked(&expr, &available_stats).is_none()); // TODO(DK): a sufficiently complex pruner would produce: `x_max <= (y_max > z_min)` } @@ -530,46 +557,27 @@ mod tests { #[rstest] fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) { let expr = gt_eq(col("float_col"), lit(f32::NAN)); - let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap(); - assert_eq!( - &converted, - &and( - and( - eq(col("float_col_nan_count"), lit(0u64)), - // NaNCount of NaN is 1 - eq(lit(1u64), lit(0u64)), - ), - // This is the standard conversion of the >= operator. Comparing NAN to a max - // stat is nonsensical, as min/max stats ignore NaNs, but this should be short-circuited - // by the previous check for nan_count anyway. - lt(col("float_col_max"), lit(f32::NAN)), - ) - ); + assert!(checked(&expr, &available_stats_with_nans).is_none()); - // One half of the expression requires NAN count check, the other half does not. + // One half of the expression requires an all-non-NaN proof, the other half does not. let expr = and( gt(col("float_col"), lit(10f32)), lt(col("int_col"), lit(10)), ); - let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap(); - + let (converted, refs) = checked(&expr, &available_stats_with_nans).unwrap(); + assert_eq!( + refs.map(), + &HashMap::from_iter([( + FieldPath::from_name("int_col"), + HashSet::from_iter([Stat::Min]) + )]) + ); assert_eq!( &converted, - &or( - // NaNCount check is enforced for the float column - and( - and( - eq(col("float_col_nan_count"), lit(0u64)), - // NanCount of a non-NaN float literal is 0 - eq(lit(0u64), lit(0u64)), - ), - // We want the opposite: we can prune IF either one is false. - lt_eq(col("float_col_max"), lit(10f32)), - ), - // NanCount check is skipped for the int column - gt_eq(col("int_col_min"), lit(10)), - ) + // The float branch cannot be proven without AllNonNan stats, so the + // remaining proof is the int branch. + >_eq(col("int_col_min"), lit(10)) ) } @@ -584,7 +592,7 @@ mod tests { upper_strict: StrictComparison::NonStrict, }, ); - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (converted, refs) = checked(&expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([( @@ -613,7 +621,7 @@ mod tests { Nullability::NonNullable, ); let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value")); - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (converted, refs) = checked(&expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([( diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index 589333304e2..59986171ed1 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -267,11 +267,14 @@ mod tests { let (pruning_expr, st) = checked_pruning_expr( &expr, + &test_harness::struct_dtype(), &FieldPathSet::from_iter([FieldPath::from_iter([ Field::Name("a".into()), Field::Name("null_count".into()), ])]), + &LEGACY_SESSION, ) + .unwrap() .unwrap(); assert_eq!( diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index 7315fbe8c07..6a971e7ecd4 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -251,11 +251,14 @@ mod tests { let (pruning_expr, st) = checked_pruning_expr( &expr, + &test_harness::struct_dtype(), &FieldPathSet::from_iter([FieldPath::from_iter([ Field::Name("a".into()), Field::Name("null_count".into()), ])]), + &LEGACY_SESSION, ) + .unwrap() .unwrap(); assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64))); diff --git a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs index 978a1da1caf..bbaed489fd1 100644 --- a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs +++ b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs @@ -630,14 +630,24 @@ mod tests { )), col("a"), ); + let scope = DType::Struct( + StructFields::new( + ["a"].into(), + vec![DType::Primitive(I32, Nullability::NonNullable)], + ), + Nullability::NonNullable, + ); let (expr, st) = checked_pruning_expr( &expr, + &scope, &FieldPathSet::from_iter([ FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]), FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]), ]), + &LEGACY_SESSION, ) + .unwrap() .unwrap(); assert_eq!( diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 714404fdceb..82166e0b1bd 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -5,23 +5,15 @@ use vortex_error::VortexResult; -use crate::aggregate_fn::fns::all_nan::AllNan; -use crate::aggregate_fn::fns::all_non_nan::AllNonNan; -use crate::aggregate_fn::fns::all_non_null::AllNonNull; -use crate::aggregate_fn::fns::all_null::AllNull; -use crate::aggregate_fn::fns::nan_count::NanCount; +use crate::aggregate_fn::AggregateFnRef; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::eq; use crate::expr::lit; use crate::expr::stats::Stat; -use crate::expr::traversal::NodeExt; -use crate::expr::traversal::Transformed; use crate::scalar::Scalar; -use crate::scalar_fn::EmptyOptions; -use crate::scalar_fn::ScalarFnVTableExt; +use crate::scalar_fn::fns::binary::Binary; +use crate::scalar_fn::fns::operators::Operator; use crate::scalar_fn::fns::stat::StatFn; -use crate::scalar_fn::internal::row_count::RowCount; /// A target that can bind abstract statistics to concrete expressions. pub trait StatBinder { @@ -40,6 +32,23 @@ pub trait StatBinder { stat_dtype: &DType, ) -> VortexResult>; + /// Bind `aggregate_fn(input)` to a concrete expression. + /// + /// The default implementation supports aggregate functions with legacy + /// [`Stat`] slots. Binders that store richer aggregate stats can override + /// this method without extending the generic stats binding walker. + fn bind_aggregate( + &mut self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult> { + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; + self.bind_stat(input, stat, stat_dtype) + } + /// Expression to use when a stat is unavailable. /// /// The default is a nullable null literal, which preserves three-valued @@ -48,47 +57,101 @@ pub trait StatBinder { fn missing_stat(&mut self, dtype: DType) -> VortexResult> { Ok(Some(null_expr(dtype))) } + + /// Bind a proof branch, rolling back any binder-local bookkeeping when the + /// branch cannot be bound. + /// + /// Binders that only substitute expressions can use the default + /// implementation. Binders that track required stats should override this + /// so discarded proof branches do not leak requirements. + fn bind_branch(&mut self, bind: F) -> VortexResult> + where + Self: Sized, + F: FnOnce(&mut Self) -> VortexResult>, + { + bind(self) + } } /// Bind all `vortex.stat` expressions in `predicate`. /// -/// The predicate is usually the output of a stats rewrite rule. This function -/// centralizes the legacy aggregate/stat mapping: `all_null` and `all_nan` -/// style aggregate expressions are expanded through exact count stats, while -/// direct aggregate stats are delegated to the supplied binder. +/// The predicate is usually the output of a stats rewrite rule. Rewrite rules +/// are responsible for expressing stat semantics; binding maps aggregate-backed +/// stat requests to the concrete stats representation supported by the binder. pub fn bind_stats( predicate: Expression, binder: &mut impl StatBinder, ) -> VortexResult> { let scope = binder.scope().clone(); - let mut missing_stat = false; - let lowered = predicate - .transform_down(|expr| { - if !expr.is::() { - return Ok(Transformed::no(expr)); - } + bind_stats_expr(predicate, &scope, binder) +} - match bind_stat_fn(&expr, &scope, binder)? { - Some(bound) => Ok(Transformed::yes(bound)), - None => { - let dtype = expr.return_dtype(&scope)?; - match binder.missing_stat(dtype.clone())? { - Some(missing) => Ok(Transformed::yes(missing)), - None => { - missing_stat = true; - Ok(Transformed::yes(null_expr(dtype))) - } - } - } +fn bind_stats_expr( + expr: Expression, + scope: &DType, + binder: &mut impl StatBinder, +) -> VortexResult> { + if expr.is::() { + return match bind_stat_fn(&expr, scope, binder)? { + Some(bound) => Ok(Some(bound)), + None => { + let dtype = expr.return_dtype(scope)?; + binder.missing_stat(dtype) } - })? - .into_inner(); + }; + } + + if expr.is::() { + return bind_binary_expr(expr, scope, binder); + } - if missing_stat { - return Ok(None); + let mut children = Vec::with_capacity(expr.children().len()); + for child in expr.children().iter() { + let Some(child) = bind_stats_expr(child.clone(), scope, binder)? else { + return Ok(None); + }; + children.push(child); } - Ok(Some(lowered)) + Ok(Some(expr.with_children(children)?)) +} + +fn bind_binary_expr( + expr: Expression, + scope: &DType, + binder: &mut impl StatBinder, +) -> VortexResult> { + let operator = expr.as_::(); + + match operator { + Operator::Or => { + let lhs = binder + .bind_branch(|binder| bind_stats_expr(expr.child(0).clone(), scope, binder))?; + let rhs = binder + .bind_branch(|binder| bind_stats_expr(expr.child(1).clone(), scope, binder))?; + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), + (Some(expr), None) | (None, Some(expr)) => Ok(Some(expr)), + (None, None) => Ok(None), + } + } + Operator::And => binder.bind_branch(|binder| { + let lhs = bind_stats_expr(expr.child(0).clone(), scope, binder)?; + let rhs = bind_stats_expr(expr.child(1).clone(), scope, binder)?; + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), + _ => Ok(None), + } + }), + _ => binder.bind_branch(|binder| { + let lhs = bind_stats_expr(expr.child(0).clone(), scope, binder)?; + let rhs = bind_stats_expr(expr.child(1).clone(), scope, binder)?; + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), + _ => Ok(None), + } + }), + } } fn bind_stat_fn( @@ -99,62 +162,11 @@ fn bind_stat_fn( let options = expr.as_::(); let aggregate_fn = options.aggregate_fn(); let input = expr.child(0); - let input_dtype = input.return_dtype(scope)?; - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(Some(lit(false))); - } - let stat_dtype = expr.return_dtype(scope)?; - return Ok(binder - .bind_stat(input, Stat::NaNCount, &stat_dtype)? - .map(|stat| eq(stat, row_count_expr()))); - } - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(Some(lit(true))); - } - let stat_dtype = expr.return_dtype(scope)?; - return Ok(binder - .bind_stat(input, Stat::NaNCount, &stat_dtype)? - .map(|stat| eq(stat, lit(0u64)))); - } - - if aggregate_fn.is::() && !has_nans(&input_dtype) { - return Ok(Some(lit(0u64))); - } - - if aggregate_fn.is::() { - let stat_dtype = expr.return_dtype(scope)?; - return Ok(binder - .bind_stat(input, Stat::NullCount, &stat_dtype)? - .map(|stat| eq(stat, row_count_expr()))); - } - - if aggregate_fn.is::() { - let stat_dtype = expr.return_dtype(scope)?; - return Ok(binder - .bind_stat(input, Stat::NullCount, &stat_dtype)? - .map(|stat| eq(stat, lit(0u64)))); - } - - let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { - return Ok(None); - }; let stat_dtype = expr.return_dtype(scope)?; - binder.bind_stat(input, stat, &stat_dtype) -} - -fn row_count_expr() -> Expression { - RowCount.new_expr(EmptyOptions, []) + binder.bind_aggregate(input, aggregate_fn, &stat_dtype) } fn null_expr(dtype: DType) -> Expression { lit(Scalar::null(dtype.as_nullable())) } - -fn has_nans(dtype: &DType) -> bool { - matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) -} diff --git a/vortex-file/src/file.rs b/vortex-file/src/file.rs index 225d18b561a..9e39a6fab5c 100644 --- a/vortex-file/src/file.rs +++ b/vortex-file/src/file.rs @@ -22,7 +22,7 @@ use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; use vortex_array::dtype::FieldPathSet; use vortex_array::expr::Expression; -use vortex_array::expr::pruning::checked_pruning_expr_with_session; +use vortex_array::expr::pruning::checked_pruning_expr; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_error::VortexResult; use vortex_layout::LayoutReader; @@ -218,7 +218,7 @@ impl VortexFile { ); let Some((predicate, required_stats)) = - checked_pruning_expr_with_session(filter, self.footer.dtype(), &set, &self.session)? + checked_pruning_expr(filter, self.footer.dtype(), &set, &self.session)? else { return Ok(false); }; diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 5c9fdcdad32..f16082fc90e 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -344,7 +344,7 @@ mod tests { } #[test] - fn all_null_stat_fn_lowers_to_null_count_and_row_count() { + fn all_null_stat_fn_lowers_to_unknown_mask() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -359,11 +359,14 @@ mod tests { .unwrap(); let mask = zone_map.prune(&all_null(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true])); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([false, false, false]) + ); } #[test] - fn all_non_null_stat_fn_lowers_to_null_count() { + fn all_non_null_stat_fn_lowers_to_unknown_mask() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -380,12 +383,12 @@ mod tests { let mask = zone_map.prune(&all_non_null(root()), &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), - BoolArray::from_iter([true, false, false]) + BoolArray::from_iter([false, false, false]) ); } #[test] - fn non_float_nan_stat_fns_lower_to_constants() { + fn non_float_nan_stat_fns_error() { let zone_map = ZoneMap::try_new( PType::I32.into(), StructArray::try_new(FieldNames::empty(), vec![], 2, Validity::NonNullable).unwrap(), @@ -395,11 +398,8 @@ mod tests { ) .unwrap(); - let mask = zone_map.prune(&all_nan(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, false])); - - let mask = zone_map.prune(&all_non_nan(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([true, true])); + assert!(zone_map.prune(&all_nan(root()), &SESSION).is_err()); + assert!(zone_map.prune(&all_non_nan(root()), &SESSION).is_err()); } #[test] @@ -429,7 +429,7 @@ mod tests { } #[test] - fn float_min_max_stat_fn_requires_nan_count() { + fn float_min_max_stat_fn_requires_all_non_nan() { let zone_map = ZoneMap::try_new( PType::F32.into(), StructArray::from_fields(&[ @@ -483,12 +483,12 @@ mod tests { let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), - BoolArray::from_iter([true, false, false]) + BoolArray::from_iter([false, false, false]) ); } #[test] - fn float_cast_min_max_stat_fn_uses_source_nan_count() { + fn float_cast_min_max_stat_fn_requires_all_non_nan() { let zone_map = ZoneMap::try_new( PType::F32.into(), StructArray::from_fields(&[ @@ -525,7 +525,7 @@ mod tests { let pruning_expr = falsify(&expr, PType::F32.into()); let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true])); + assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, false])); } #[test] From cd4867432dbcc158de108402a4bd2b91d8caf55b Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 15:46:35 -0400 Subject: [PATCH 05/28] Simplify stats binding null handling Signed-off-by: "Nicholas Gates" Signed-off-by: Nicholas Gates --- vortex-array/src/expr/pruning/pruning_expr.rs | 94 +++++++++++----- vortex-array/src/scalar_fn/fns/binary/mod.rs | 56 ++++++++++ vortex-array/src/scalar_fn/fns/is_not_null.rs | 7 +- vortex-array/src/scalar_fn/fns/is_null.rs | 9 +- vortex-array/src/stats/bind.rs | 103 ++++-------------- 5 files changed, 157 insertions(+), 112 deletions(-) diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 724dd6a1408..61d578787a0 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -10,6 +10,7 @@ use vortex_error::VortexResult; use vortex_session::VortexSession; #[cfg(test)] use vortex_utils::aliases::hash_map::HashMap; +use vortex_utils::aliases::hash_set::HashSet; use super::relation::Relation; use crate::dtype::DType; @@ -27,6 +28,7 @@ use crate::expr::root; use crate::expr::stats::Stat; use crate::scalar_fn::fns::cast::Cast; use crate::scalar_fn::fns::get_item::GetItem; +use crate::scalar_fn::fns::literal::Literal; use crate::stats::bind::StatBinder; use crate::stats::bind::bind_stats; @@ -78,9 +80,9 @@ pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldNa /// replace those placeholders with the row count for its current scope before /// executing the returned expression. /// -/// The returned expression is lowered to stats-table field references. Proof branches that require -/// stats not present in `available_stats` are discarded; this returns `Ok(None)` if no usable proof -/// remains. +/// The returned expression is lowered to stats-table field references. Stats not present in +/// `available_stats` are replaced with typed null literals, preserving three-valued pruning +/// semantics without requiring callers to materialize unavailable stats. pub fn checked_pruning_expr( expr: &Expression, scope: &DType, @@ -99,8 +101,12 @@ pub fn checked_pruning_expr( let Some(lowered) = bind_stats(predicate, &mut binder)? else { return Ok(None); }; + let required_stats = filter_required_stats(&lowered, binder.required_stats); + if required_stats.map().is_empty() && !matches!(bool_literal(&lowered), Some(Some(true))) { + return Ok(None); + } - Ok(Some((lowered, binder.required_stats))) + Ok(Some((lowered, required_stats))) } struct RequiredStatsBinder<'a> { @@ -141,23 +147,6 @@ impl StatBinder for RequiredStatsBinder<'_> { root(), ))) } - - fn missing_stat(&mut self, _dtype: DType) -> VortexResult> { - Ok(None) - } - - fn bind_branch(&mut self, bind: F) -> VortexResult> - where - Self: Sized, - F: FnOnce(&mut Self) -> VortexResult>, - { - let required_stats = self.required_stats.clone(); - let bound = bind(self)?; - if bound.is_none() { - self.required_stats = required_stats; - } - Ok(bound) - } } fn direct_stat_field_path(expr: &Expression) -> Option { @@ -173,6 +162,44 @@ fn direct_stat_field_path(expr: &Expression) -> Option { direct_stat_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) } +fn filter_required_stats(expr: &Expression, required_stats: RequiredStats) -> RequiredStats { + let referenced_names = referenced_stat_field_names(expr); + let mut filtered = Relation::new(); + for (field_path, stats) in required_stats { + for stat in stats { + if referenced_names.contains(&field_path_stat_field_name(&field_path, stat)) { + filtered.insert(field_path.clone(), stat); + } + } + } + filtered +} + +fn referenced_stat_field_names(expr: &Expression) -> HashSet { + let mut refs = HashSet::new(); + collect_referenced_stat_field_names(expr, &mut refs); + refs +} + +fn collect_referenced_stat_field_names(expr: &Expression, refs: &mut HashSet) { + if let Some(field_name) = expr.as_opt::() + && is_root(expr.child(0)) + { + refs.insert(field_name.clone()); + return; + } + + for child in expr.children().iter() { + collect_referenced_stat_field_names(child, refs); + } +} + +fn bool_literal(expr: &Expression) -> Option> { + expr.as_opt::()? + .as_bool_opt() + .map(|value| value.value()) +} + #[cfg(test)] mod tests { use std::sync::LazyLock; @@ -210,6 +237,7 @@ mod tests { use crate::expr::pruning::field_path_stat_field_name; use crate::expr::root; use crate::expr::stats::Stat; + use crate::scalar::Scalar; use crate::scalar_fn::fns::between::BetweenOptions; use crate::scalar_fn::fns::between::StrictComparison; use crate::stats::session::StatsSession; @@ -568,16 +596,26 @@ mod tests { let (converted, refs) = checked(&expr, &available_stats_with_nans).unwrap(); assert_eq!( refs.map(), - &HashMap::from_iter([( - FieldPath::from_name("int_col"), - HashSet::from_iter([Stat::Min]) - )]) + &HashMap::from_iter([ + ( + FieldPath::from_name("float_col"), + HashSet::from_iter([Stat::Max]) + ), + ( + FieldPath::from_name("int_col"), + HashSet::from_iter([Stat::Min]) + ) + ]) ); assert_eq!( &converted, - // The float branch cannot be proven without AllNonNan stats, so the - // remaining proof is the int branch. - >_eq(col("int_col_min"), lit(10)) + &or( + and( + lit(Scalar::null(DType::Bool(Nullability::Nullable))), + lt_eq(col("float_col_max"), lit(10f32)), + ), + gt_eq(col("int_col_min"), lit(10)), + ) ) } diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 1c860cb75b5..6babe5263d1 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -17,6 +17,7 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; +use crate::dtype::Nullability; use crate::expr::StatsCatalog; use crate::expr::and; use crate::expr::and_collect; @@ -34,6 +35,7 @@ use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::operators::CompareOperator; use crate::scalar_fn::fns::operators::Operator; @@ -45,10 +47,45 @@ mod numeric; pub(crate) use numeric::*; use crate::scalar::NumericOperator; +use crate::scalar::Scalar; #[derive(Clone)] pub struct Binary; +fn simplify_and(lhs: &Expression, rhs: &Expression) -> Option { + match (bool_literal(lhs), bool_literal(rhs)) { + (Some(Some(false)), _) | (_, Some(Some(false))) => Some(lit(false)), + (Some(Some(true)), _) => Some(rhs.clone()), + (_, Some(Some(true))) => Some(lhs.clone()), + (Some(None), Some(None)) => Some(lhs.clone()), + _ => None, + } +} + +fn simplify_or(lhs: &Expression, rhs: &Expression) -> Option { + match (bool_literal(lhs), bool_literal(rhs)) { + (Some(Some(true)), _) | (_, Some(Some(true))) => Some(lit(true)), + (Some(Some(false)), _) => Some(rhs.clone()), + (_, Some(Some(false))) => Some(lhs.clone()), + (Some(None), Some(None)) => Some(lhs.clone()), + _ => None, + } +} + +fn bool_literal(expr: &Expression) -> Option> { + expr.as_opt::()? + .as_bool_opt() + .map(|value| value.value()) +} + +fn is_null_literal(expr: &Expression) -> bool { + expr.as_opt::().is_some_and(Scalar::is_null) +} + +fn null_bool() -> Expression { + lit(Scalar::null(DType::Bool(Nullability::Nullable))) +} + impl ScalarFnVTable for Binary { type Options = Operator; @@ -165,6 +202,25 @@ impl ScalarFnVTable for Binary { } } + fn simplify_untyped( + &self, + operator: &Operator, + expr: &Expression, + ) -> VortexResult> { + let lhs = expr.child(0); + let rhs = expr.child(1); + + if operator.is_comparison() && (is_null_literal(lhs) || is_null_literal(rhs)) { + return Ok(Some(null_bool())); + } + + Ok(match operator { + Operator::And => simplify_and(lhs, rhs), + Operator::Or => simplify_or(lhs, rhs), + _ => None, + }) + } + fn stat_falsification( &self, operator: &Operator, diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index 59986171ed1..eb60fc8aec8 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -135,6 +135,8 @@ mod tests { use crate::expr::eq; use crate::expr::get_item; use crate::expr::is_not_null; + use crate::expr::lit; + use crate::expr::or; use crate::expr::pruning::checked_pruning_expr; use crate::expr::root; use crate::expr::stats::Stat; @@ -279,7 +281,10 @@ mod tests { assert_eq!( &pruning_expr, - &eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, [])) + &or( + eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, [])), + lit(Scalar::null(DType::Bool(Nullability::Nullable))), + ) ); assert_eq!( st.map(), diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index 6a971e7ecd4..bbf8a8f2409 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -125,6 +125,7 @@ mod tests { use crate::expr::get_item; use crate::expr::is_null; use crate::expr::lit; + use crate::expr::or; use crate::expr::pruning::checked_pruning_expr; use crate::expr::root; use crate::expr::stats::Stat; @@ -261,7 +262,13 @@ mod tests { .unwrap() .unwrap(); - assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64))); + assert_eq!( + &pruning_expr, + &or( + eq(col("a_null_count"), lit(0u64)), + lit(Scalar::null(DType::Bool(Nullability::Nullable))), + ) + ); assert_eq!( st.map(), &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))]) diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 82166e0b1bd..752664396c6 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -10,9 +10,9 @@ use crate::dtype::DType; use crate::expr::Expression; use crate::expr::lit; use crate::expr::stats::Stat; +use crate::expr::traversal::NodeExt; +use crate::expr::traversal::Transformed; use crate::scalar::Scalar; -use crate::scalar_fn::fns::binary::Binary; -use crate::scalar_fn::fns::operators::Operator; use crate::scalar_fn::fns::stat::StatFn; /// A target that can bind abstract statistics to concrete expressions. @@ -57,20 +57,6 @@ pub trait StatBinder { fn missing_stat(&mut self, dtype: DType) -> VortexResult> { Ok(Some(null_expr(dtype))) } - - /// Bind a proof branch, rolling back any binder-local bookkeeping when the - /// branch cannot be bound. - /// - /// Binders that only substitute expressions can use the default - /// implementation. Binders that track required stats should override this - /// so discarded proof branches do not leak requirements. - fn bind_branch(&mut self, bind: F) -> VortexResult> - where - Self: Sized, - F: FnOnce(&mut Self) -> VortexResult>, - { - bind(self) - } } /// Bind all `vortex.stat` expressions in `predicate`. @@ -83,75 +69,28 @@ pub fn bind_stats( binder: &mut impl StatBinder, ) -> VortexResult> { let scope = binder.scope().clone(); - bind_stats_expr(predicate, &scope, binder) -} - -fn bind_stats_expr( - expr: Expression, - scope: &DType, - binder: &mut impl StatBinder, -) -> VortexResult> { - if expr.is::() { - return match bind_stat_fn(&expr, scope, binder)? { - Some(bound) => Ok(Some(bound)), - None => { - let dtype = expr.return_dtype(scope)?; - binder.missing_stat(dtype) + let lowered = predicate + .transform_down(|expr| { + if !expr.is::() { + return Ok(Transformed::no(expr)); } - }; - } - - if expr.is::() { - return bind_binary_expr(expr, scope, binder); - } - - let mut children = Vec::with_capacity(expr.children().len()); - for child in expr.children().iter() { - let Some(child) = bind_stats_expr(child.clone(), scope, binder)? else { - return Ok(None); - }; - children.push(child); - } - Ok(Some(expr.with_children(children)?)) -} - -fn bind_binary_expr( - expr: Expression, - scope: &DType, - binder: &mut impl StatBinder, -) -> VortexResult> { - let operator = expr.as_::(); - - match operator { - Operator::Or => { - let lhs = binder - .bind_branch(|binder| bind_stats_expr(expr.child(0).clone(), scope, binder))?; - let rhs = binder - .bind_branch(|binder| bind_stats_expr(expr.child(1).clone(), scope, binder))?; - match (lhs, rhs) { - (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), - (Some(expr), None) | (None, Some(expr)) => Ok(Some(expr)), - (None, None) => Ok(None), - } - } - Operator::And => binder.bind_branch(|binder| { - let lhs = bind_stats_expr(expr.child(0).clone(), scope, binder)?; - let rhs = bind_stats_expr(expr.child(1).clone(), scope, binder)?; - match (lhs, rhs) { - (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), - _ => Ok(None), + match bind_stat_fn(&expr, &scope, binder)? { + Some(bound) => Ok(Transformed::yes(bound)), + None => { + let dtype = expr.return_dtype(&scope)?; + match binder.missing_stat(dtype.clone())? { + Some(missing) => Ok(Transformed::yes(missing)), + None => Ok(Transformed::yes(null_expr(dtype))), + } + } } - }), - _ => binder.bind_branch(|binder| { - let lhs = bind_stats_expr(expr.child(0).clone(), scope, binder)?; - let rhs = bind_stats_expr(expr.child(1).clone(), scope, binder)?; - match (lhs, rhs) { - (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), - _ => Ok(None), - } - }), - } + })? + .into_inner(); + + #[expect(deprecated)] + let lowered = lowered.simplify_untyped()?; + Ok(Some(lowered)) } fn bind_stat_fn( From 52fd44337e3a471702db98c08c51ae85de9dbdec Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 16:13:25 -0400 Subject: [PATCH 06/28] Install Java toolchain in CI Signed-off-by: Nicholas Gates --- .github/workflows/ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dea9f3b9ef4..ed68d375724 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -376,6 +376,10 @@ jobs: with: sccache: s3 - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + - uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 + with: + distribution: "corretto" + java-version: "17" - uses: ./.github/actions/setup-prebuild - run: ./gradlew javadoc working-directory: ./java From 429383824e9918bf29a7a3b3f1f8d6653b61b990 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 16:39:09 -0400 Subject: [PATCH 07/28] Remove legacy stat falsification hooks Signed-off-by: Nicholas Gates --- vortex-array/src/expr/expression.rs | 22 ---- vortex-array/src/expr/pruning/pruning_expr.rs | 27 ---- vortex-array/src/scalar_fn/erased.rs | 9 -- vortex-array/src/scalar_fn/fns/between/mod.rs | 19 --- vortex-array/src/scalar_fn/fns/binary/mod.rs | 117 +----------------- vortex-array/src/scalar_fn/fns/dynamic.rs | 45 ------- vortex-array/src/scalar_fn/fns/is_not_null.rs | 17 --- vortex-array/src/scalar_fn/fns/is_null.rs | 15 --- vortex-array/src/scalar_fn/fns/like/mod.rs | 101 --------------- .../src/scalar_fn/fns/list_contains/mod.rs | 45 ------- vortex-array/src/scalar_fn/typed.rs | 13 -- vortex-array/src/scalar_fn/vtable.rs | 13 -- vortex-file/src/v2/file_stats_reader.rs | 13 +- 13 files changed, 9 insertions(+), 447 deletions(-) diff --git a/vortex-array/src/expr/expression.rs b/vortex-array/src/expr/expression.rs index cc21fb9a9a6..043b61aaaf4 100644 --- a/vortex-array/src/expr/expression.rs +++ b/vortex-array/src/expr/expression.rs @@ -114,28 +114,6 @@ impl Expression { self.scalar_fn.validity(self) } - /// An expression over zone-statistics which implies all records in the zone evaluate to false. - /// - /// Given an expression, `e`, if `e.stat_falsification(..)` evaluates to true, it is guaranteed - /// that `e` evaluates to false on all records in the zone. However, the inverse is not - /// necessarily true: even if the falsification evaluates to false, `e` need not evaluate to - /// true on all records. - /// - /// The [`StatsCatalog`] can be used to constrain or rename stats used in the final expr. - /// - /// # Examples - /// - /// - An expression over one variable: `x > 0` is false for all records in a zone if the maximum - /// value of the column `x` in that zone is less than or equal to zero: `max(x) <= 0`. - /// - An expression over two variables: `x > y` becomes `max(x) <= min(y)`. - /// - A conjunctive expression: `x > y AND z < x` becomes `max(x) <= min(y) OR min(z) >= max(x). - /// - /// Some expressions, in theory, have falsifications but this function does not support them - /// such as `x < (y < z)` or `x LIKE "needle%"`. - pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option { - self.scalar_fn().stat_falsification(self, catalog) - } - /// Returns an expression that proves this predicate is definitely false from stats. /// /// `scope` is the dtype of the row this expression evaluates over. diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 61d578787a0..ee775b4f13e 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -1,15 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -#[cfg(test)] -use std::cell::RefCell; use std::iter; use itertools::Itertools; use vortex_error::VortexResult; use vortex_session::VortexSession; -#[cfg(test)] -use vortex_utils::aliases::hash_map::HashMap; use vortex_utils::aliases::hash_set::HashSet; use super::relation::Relation; @@ -19,8 +15,6 @@ use crate::dtype::FieldName; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; use crate::expr::Expression; -#[cfg(test)] -use crate::expr::StatsCatalog; use crate::expr::analysis::referenced_field_paths; use crate::expr::get_item; use crate::expr::is_root; @@ -34,27 +28,6 @@ use crate::stats::bind::bind_stats; pub type RequiredStats = Relation; -// A catalog that returns a stat column whenever it is required, tracking all accessed -// stats and returning them later. -#[cfg(test)] -#[derive(Default)] -pub(crate) struct TrackingStatsCatalog { - usage: RefCell>, -} - -#[cfg(test)] -impl StatsCatalog for TrackingStatsCatalog { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { - let mut expr = root(); - let name = field_path_stat_field_name(field_path, stat); - expr = get_item(name, expr); - self.usage - .borrow_mut() - .insert((field_path.clone(), stat), expr.clone()); - Some(expr) - } -} - #[doc(hidden)] pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName { field_path diff --git a/vortex-array/src/scalar_fn/erased.rs b/vortex-array/src/scalar_fn/erased.rs index 10e82d25455..69befb405e0 100644 --- a/vortex-array/src/scalar_fn/erased.rs +++ b/vortex-array/src/scalar_fn/erased.rs @@ -181,15 +181,6 @@ impl ScalarFnRef { self.0.simplify_untyped(expr) } - /// Compute stat falsification expression. - pub(crate) fn stat_falsification( - &self, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - self.0.stat_falsification(expr, catalog) - } - /// Compute stat expression. pub(crate) fn stat_expression( &self, diff --git a/vortex-array/src/scalar_fn/fns/between/mod.rs b/vortex-array/src/scalar_fn/fns/between/mod.rs index 0e0d9949195..bd546bed941 100644 --- a/vortex-array/src/scalar_fn/fns/between/mod.rs +++ b/vortex-array/src/scalar_fn/fns/between/mod.rs @@ -25,7 +25,6 @@ use crate::arrays::Primitive; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::DType::Bool; -use crate::expr::StatsCatalog; use crate::expr::expression::Expression; use crate::scalar::Scalar; use crate::scalar_fn::Arity; @@ -33,8 +32,6 @@ use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::ScalarFnVTableExt; -use crate::scalar_fn::fns::binary::Binary; use crate::scalar_fn::fns::binary::execute_boolean; use crate::scalar_fn::fns::operators::CompareOperator; use crate::scalar_fn::fns::operators::Operator; @@ -298,22 +295,6 @@ impl ScalarFnVTable for Between { between_canonical(&arr, &lower, &upper, options, ctx) } - fn stat_falsification( - &self, - options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let arr = expr.child(0).clone(); - let lower = expr.child(1).clone(); - let upper = expr.child(2).clone(); - - let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]); - let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]); - - and(lhs, rhs).stat_falsification(catalog) - } - fn validity( &self, _options: &Self::Options, diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 6babe5263d1..b51f86b3188 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -18,18 +18,9 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::dtype::Nullability; -use crate::expr::StatsCatalog; use crate::expr::and; -use crate::expr::and_collect; -use crate::expr::eq; use crate::expr::expression::Expression; -use crate::expr::gt; -use crate::expr::gt_eq; use crate::expr::lit; -use crate::expr::lt; -use crate::expr::lt_eq; -use crate::expr::or_collect; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; @@ -221,110 +212,6 @@ impl ScalarFnVTable for Binary { }) } - fn stat_falsification( - &self, - operator: &Operator, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // Wrap another predicate with an optional NaNCount check, if the stat is available. - // - // For example, regular pruning conversion for `A >= B` would be - // - // A.max < B.min - // - // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting - // in: - // - // (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min - // - // Non-floating point column and literal expressions should be unaffected as they do not - // have a nan_count statistic defined. - fn with_nan_predicate( - lhs: &Expression, - rhs: &Expression, - value_predicate: Expression, - catalog: &dyn StatsCatalog, - ) -> Expression { - let nan_predicate = and_collect( - lhs.stat_expression(Stat::NaNCount, catalog) - .into_iter() - .chain(rhs.stat_expression(Stat::NaNCount, catalog)) - .map(|nans| eq(nans, lit(0u64))), - ); - - if let Some(nan_check) = nan_predicate { - and(nan_check, value_predicate) - } else { - value_predicate - } - } - - let lhs = expr.child(0); - let rhs = expr.child(1); - match operator { - Operator::Eq => { - let min_lhs = lhs.stat_min(catalog); - let max_lhs = lhs.stat_max(catalog); - - let min_rhs = rhs.stat_min(catalog); - let max_rhs = rhs.stat_max(catalog); - - let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b)); - let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b)); - - let min_max_check = or_collect(left.into_iter().chain(right))?; - - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::NotEq => { - let min_lhs = lhs.stat_min(catalog)?; - let max_lhs = lhs.stat_max(catalog)?; - - let min_rhs = rhs.stat_min(catalog)?; - let max_rhs = rhs.stat_max(catalog)?; - - let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Gt => { - let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Gte => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Lt => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Lte => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::And => or_collect( - lhs.stat_falsification(catalog) - .into_iter() - .chain(rhs.stat_falsification(catalog)), - ), - Operator::Or => Some(and( - lhs.stat_falsification(catalog)?, - rhs.stat_falsification(catalog)?, - )), - Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, - } - } - fn validity( &self, operator: &Operator, @@ -381,8 +268,12 @@ mod tests { use crate::expr::Expression; use crate::expr::and_collect; use crate::expr::col; + use crate::expr::eq; + use crate::expr::gt; + use crate::expr::gt_eq; use crate::expr::lit; use crate::expr::lt; + use crate::expr::lt_eq; use crate::expr::not_eq; use crate::expr::or; use crate::expr::or_collect; diff --git a/vortex-array/src/scalar_fn/fns/dynamic.rs b/vortex-array/src/scalar_fn/fns/dynamic.rs index 7efebf79220..f6e6619282a 100644 --- a/vortex-array/src/scalar_fn/fns/dynamic.rs +++ b/vortex-array/src/scalar_fn/fns/dynamic.rs @@ -20,7 +20,6 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::traversal::NodeExt; use crate::expr::traversal::NodeVisitor; use crate::expr::traversal::TraversalOrder; @@ -120,50 +119,6 @@ impl ScalarFnVTable for DynamicComparison { .into_array()) } - fn stat_falsification( - &self, - dynamic: &DynamicComparisonExpr, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let lhs = expr.child(0); - match dynamic.operator { - CompareOperator::Eq | CompareOperator::NotEq => None, - CompareOperator::Gt => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Lte, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_max(catalog)?], - )), - CompareOperator::Gte => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Lt, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_max(catalog)?], - )), - CompareOperator::Lt => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Gte, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_min(catalog)?], - )), - CompareOperator::Lte => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Gt, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_min(catalog)?], - )), - } - } - // Defer to the child fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index eb60fc8aec8..a64aab32611 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -3,7 +3,6 @@ use std::fmt::Formatter; -use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_error::VortexResult; use vortex_session::VortexSession; use vortex_session::registry::CachedId; @@ -15,16 +14,12 @@ use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::eq; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::ScalarFnVTableExt; use crate::validity::Validity; /// Expression that checks for non-null values. @@ -100,18 +95,6 @@ impl ScalarFnVTable for IsNotNull { fn is_fallible(&self, _instance: &Self::Options) -> bool { false } - - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // is_not_null is falsified when ALL values are null, i.e. null_count == row_count. - let child = expr.child(0); - let null_count_expr = child.stat_expression(Stat::NullCount, catalog)?; - Some(eq(null_count_expr, RowCount.new_expr(EmptyOptions, []))) - } } #[cfg(test)] diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index bbf8a8f2409..807b9d9a043 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -12,11 +12,6 @@ use crate::arrays::ConstantArray; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; -use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::eq; -use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -84,16 +79,6 @@ impl ScalarFnVTable for IsNull { } } - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?; - Some(eq(null_count_expr, lit(0u64))) - } - fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } diff --git a/vortex-array/src/scalar_fn/fns/like/mod.rs b/vortex-array/src/scalar_fn/fns/like/mod.rs index b7f357020f1..850484a1450 100644 --- a/vortex-array/src/scalar_fn/fns/like/mod.rs +++ b/vortex-array/src/scalar_fn/fns/like/mod.rs @@ -21,20 +21,12 @@ use crate::arrow::Datum; use crate::arrow::from_arrow_columnar; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::and; -use crate::expr::gt; -use crate::expr::gt_eq; -use crate::expr::lit; -use crate::expr::lt; -use crate::expr::or; -use crate::scalar::StringLike; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::fns::literal::Literal; /// Options for SQL LIKE function #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -165,49 +157,6 @@ impl ScalarFnVTable for Like { fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false } - - fn stat_falsification( - &self, - like_opts: &LikeOptions, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // Attempt to do min/max pruning for LIKE 'exact' or LIKE 'prefix%' - - // Don't attempt to handle ilike or negated like - if like_opts.negated || like_opts.case_insensitive { - return None; - } - - // Extract the pattern out - let pat = expr.child(1).as_::(); - - // LIKE NULL is nonsensical, don't try to handle it - let pat_str = pat.as_utf8().value()?; - - let src = expr.child(0).clone(); - let src_min = src.stat_min(catalog)?; - let src_max = src.stat_max(catalog)?; - - match LikeVariant::from_str(pat_str)? { - LikeVariant::Exact(text) => { - // col LIKE 'exact' ==> col.min > 'exact' || col.max < 'exact' - Some(or( - gt(src_min, lit(text.as_ref())), - lt(src_max, lit(text.as_ref())), - )) - } - LikeVariant::Prefix(prefix) => { - // col LIKE 'prefix%' ==> col.max < 'prefix' || col.min >= 'prefiy' - let succ = prefix.to_string().increment().ok()?; - - Some(or( - gt_eq(src_min, lit(succ)), - lt(src_max, lit(prefix.as_ref())), - )) - } - } - } } /// Implementation of LIKE using the Arrow crate. @@ -295,15 +244,11 @@ mod tests { use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; - use crate::expr::col; use crate::expr::get_item; - use crate::expr::ilike; use crate::expr::like; use crate::expr::lit; use crate::expr::not; use crate::expr::not_ilike; - use crate::expr::not_like; - use crate::expr::pruning::pruning_expr::TrackingStatsCatalog; use crate::expr::root; use crate::scalar_fn::fns::like::LikeVariant; @@ -390,50 +335,4 @@ mod tests { assert_eq!(LikeVariant::from_str(r"%\%%"), None); assert_eq!(LikeVariant::from_str("_pattern"), None); } - - #[test] - fn test_like_pushdown() { - // Test that LIKE prefix and exactness filters can be pushed down into stats filtering - // at scan time. - let catalog = TrackingStatsCatalog::default(); - - let pruning_expr = like(col("a"), lit("prefix%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "prefiy") or ($.a_max < "prefix"))"#); - - let pruning_expr = like(col("a"), lit(r"\%%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "&") or ($.a_max < "%"))"#); - - // Multiple wildcards - let pruning_expr = like(col("a"), lit("pref%ix%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "preg") or ($.a_max < "pref"))"#); - - let pruning_expr = like(col("a"), lit("pref_ix_")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "preg") or ($.a_max < "pref"))"#); - - // Exact match - let pruning_expr = like(col("a"), lit("exactly")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min > "exactly") or ($.a_max < "exactly"))"#); - - // Suffix search skips pushdown - let pruning_expr = like(col("a"), lit("%suffix")).stat_falsification(&catalog); - assert_eq!(pruning_expr, None); - - // NOT LIKE, ILIKE not supported currently - assert_eq!( - None, - not_like(col("a"), lit("a")).stat_falsification(&catalog) - ); - assert_eq!(None, ilike(col("a"), lit("a")).stat_falsification(&catalog)); - } } diff --git a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs index bbaed489fd1..236bdd646eb 100644 --- a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs +++ b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs @@ -33,13 +33,6 @@ use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::IntegerPType; use crate::dtype::Nullability; -use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::and_collect; -use crate::expr::gt; -use crate::expr::lit; -use crate::expr::lt; -use crate::expr::or; use crate::match_each_integer_ptype; use crate::match_each_unsigned_integer_ptype; use crate::scalar::ListScalar; @@ -51,7 +44,6 @@ use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; use crate::scalar_fn::fns::binary::Binary; -use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::operators::Operator; use crate::validity::Validity; @@ -129,43 +121,6 @@ impl ScalarFnVTable for ListContains { compute_list_contains(&list_array, &value_array, ctx) } - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let list = expr.child(0); - let needle = expr.child(1); - - // falsification(contains([1,2,5], x)) => - // falsification(x != 1) and falsification(x != 2) and falsification(x != 5) - let min = list.stat_min(catalog)?; - let max = list.stat_max(catalog)?; - // If the list is constant when we can compare each element to the value - if min == max { - let list_ = min - .as_opt::() - .and_then(|l| l.as_list_opt()) - .and_then(|l| l.elements())?; - if list_.is_empty() { - // contains([], x) is always false. - return Some(lit(true)); - } - let value_max = needle.stat_max(catalog)?; - let value_min = needle.stat_min(catalog)?; - - return and_collect(list_.iter().map(move |v| { - or( - lt(value_max.clone(), lit(v.clone())), - gt(value_min.clone(), lit(v.clone())), - ) - })); - } - - None - } - // Nullability matters for contains([], x) where x is false. fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true diff --git a/vortex-array/src/scalar_fn/typed.rs b/vortex-array/src/scalar_fn/typed.rs index a2ef9549bff..83d2bfea496 100644 --- a/vortex-array/src/scalar_fn/typed.rs +++ b/vortex-array/src/scalar_fn/typed.rs @@ -101,11 +101,6 @@ pub(super) trait DynScalarFn: 'static + Send + Sync + super::sealed::Sealed { ) -> VortexResult>; fn simplify_untyped(&self, expression: &Expression) -> VortexResult>; fn validity(&self, expression: &Expression) -> VortexResult>; - fn stat_falsification( - &self, - expression: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option; fn stat_expression( &self, expression: &Expression, @@ -223,14 +218,6 @@ impl DynScalarFn for TypedScalarFnInstance { V::validity(&self.vtable, &self.options, expression) } - fn stat_falsification( - &self, - expression: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - V::stat_falsification(&self.vtable, &self.options, expression, catalog) - } - fn stat_expression( &self, expression: &Expression, diff --git a/vortex-array/src/scalar_fn/vtable.rs b/vortex-array/src/scalar_fn/vtable.rs index f4862f6876a..1556354f9f1 100644 --- a/vortex-array/src/scalar_fn/vtable.rs +++ b/vortex-array/src/scalar_fn/vtable.rs @@ -179,19 +179,6 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync { Ok(None) } - /// See [`Expression::stat_falsification`]. - fn stat_falsification( - &self, - options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - _ = options; - _ = expr; - _ = catalog; - None - } - /// See [`Expression::stat_expression`]. fn stat_expression( &self, diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index 7d9bd0d66cf..b697becbc0f 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -21,7 +21,6 @@ use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; -use vortex_array::expr::StatsCatalog; use vortex_array::expr::is_root; use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; @@ -94,9 +93,8 @@ impl FileStatsLayoutReader { }; let pruning_expr = self.lower_stats(pruning_expr)?; - // Given how we implemented the StatsCatalog, we know the expression must be all literals - // or row_count placeholders. We can therefore optimize with a null scope since there are - // no field references that need to be resolved. + // Stats lowering replaces available stats with literals and unavailable stats with nulls, + // so only row_count placeholders remain unresolved here. let simplified = pruning_expr.optimize_recursive(&DType::Null)?; if let Some(result) = simplified.as_opt::() { // Can prune if the result is non-nullable and true @@ -152,7 +150,7 @@ impl StatBinder for FileStatsBinder<'_> { let Some(field_path) = direct_field_path(input) else { return Ok(None); }; - Ok(self.reader.stats_ref(&field_path, stat)) + Ok(self.reader.stat_ref(&field_path, stat)) } } @@ -165,9 +163,8 @@ fn direct_field_path(expr: &Expression) -> Option { direct_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) } -/// Implements [`StatsCatalog`] to provide file-level stats to expressions during pruning evaluation. -impl StatsCatalog for FileStatsLayoutReader { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { +impl FileStatsLayoutReader { + fn stat_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { // FileStats currently only holds top-level field statistics. if field_path.parts().len() != 1 { return None; From ffbccee5e987e734132ebb8dcebe6b63b9ec3b93 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 16:45:16 -0400 Subject: [PATCH 08/28] Remove legacy stat expression hooks Signed-off-by: Nicholas Gates --- vortex-array/src/expr/expression.rs | 24 ----------- vortex-array/src/expr/mod.rs | 1 - vortex-array/src/expr/pruning/mod.rs | 17 -------- vortex-array/src/scalar_fn/erased.rs | 12 ------ vortex-array/src/scalar_fn/fns/cast/mod.rs | 36 ----------------- vortex-array/src/scalar_fn/fns/get_item.rs | 21 ---------- vortex-array/src/scalar_fn/fns/literal.rs | 47 ---------------------- vortex-array/src/scalar_fn/fns/root.rs | 13 ------ vortex-array/src/scalar_fn/typed.rs | 17 -------- vortex-array/src/scalar_fn/vtable.rs | 17 -------- 10 files changed, 205 deletions(-) diff --git a/vortex-array/src/expr/expression.rs b/vortex-array/src/expr/expression.rs index 043b61aaaf4..41cd61a369d 100644 --- a/vortex-array/src/expr/expression.rs +++ b/vortex-array/src/expr/expression.rs @@ -15,9 +15,7 @@ use vortex_error::vortex_ensure; use vortex_session::VortexSession; use crate::dtype::DType; -use crate::expr::StatsCatalog; use crate::expr::display::DisplayTreeExpr; -use crate::expr::stats::Stat; use crate::scalar_fn::ScalarFnRef; use crate::scalar_fn::fns::root::Root; @@ -142,28 +140,6 @@ impl Expression { crate::stats::rewrite::StatsRewriteCtx::new(session, scope).satisfy(self) } - /// Returns an expression representing the zoned statistic for the given stat, if available. - /// - /// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a - /// scope. Expressions can implement this function to propagate such statistics through the - /// expression tree. For example, the `a + 10` expression could propagate `min: min(a) + 10`. - /// - /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an - /// issue to discuss a solution to this. - pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option { - self.scalar_fn().stat_expression(self, stat, catalog) - } - - /// Returns an expression representing the zoned maximum statistic, if available. - pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option { - self.stat_expression(Stat::Min, catalog) - } - - /// Returns an expression representing the zoned maximum statistic, if available. - pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option { - self.stat_expression(Stat::Max, catalog) - } - /// Format the expression as a compact string. /// /// Since this is a recursive formatter, it is exposed on the public Expression type. diff --git a/vortex-array/src/expr/mod.rs b/vortex-array/src/expr/mod.rs index a5d32510443..72969baf23a 100644 --- a/vortex-array/src/expr/mod.rs +++ b/vortex-array/src/expr/mod.rs @@ -42,7 +42,6 @@ pub mod traversal; pub use analysis::*; pub use expression::*; pub use exprs::*; -pub use pruning::StatsCatalog; pub trait VortexExprExt { /// Accumulate all field references from this expression and its children in a set diff --git a/vortex-array/src/expr/pruning/mod.rs b/vortex-array/src/expr/pruning/mod.rs index 7c20508b7a8..5ce2785f446 100644 --- a/vortex-array/src/expr/pruning/mod.rs +++ b/vortex-array/src/expr/pruning/mod.rs @@ -8,20 +8,3 @@ pub use pruning_expr::RequiredStats; pub use pruning_expr::checked_pruning_expr; pub use pruning_expr::field_path_stat_field_name; pub use relation::Relation; - -use crate::dtype::FieldPath; -use crate::expr::Expression; -use crate::expr::stats::Stat; - -/// A catalog of available stats that are associated with field paths. -pub trait StatsCatalog { - /// Given a field path and statistic, return an expression that when evaluated over the catalog - /// will return that stat for the referenced field. - /// - /// This is likely to be a column expression, or a literal. - /// - /// Returns `None` if the stat is not available for the field path. - fn stats_ref(&self, _field_path: &FieldPath, _stat: Stat) -> Option { - None - } -} diff --git a/vortex-array/src/scalar_fn/erased.rs b/vortex-array/src/scalar_fn/erased.rs index 69befb405e0..6e0011c297a 100644 --- a/vortex-array/src/scalar_fn/erased.rs +++ b/vortex-array/src/scalar_fn/erased.rs @@ -20,8 +20,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::scalar_fn::EmptyOptions; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ReduceCtx; @@ -180,16 +178,6 @@ impl ScalarFnRef { pub(crate) fn simplify_untyped(&self, expr: &Expression) -> VortexResult> { self.0.simplify_untyped(expr) } - - /// Compute stat expression. - pub(crate) fn stat_expression( - &self, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - self.0.stat_expression(expr, stat, catalog) - } } impl Debug for ScalarFnRef { diff --git a/vortex-array/src/scalar_fn/fns/cast/mod.rs b/vortex-array/src/scalar_fn/fns/cast/mod.rs index abc59af2c9a..20852779d42 100644 --- a/vortex-array/src/scalar_fn/fns/cast/mod.rs +++ b/vortex-array/src/scalar_fn/fns/cast/mod.rs @@ -32,11 +32,8 @@ use crate::arrays::VarBinView; use crate::arrays::struct_::compute::cast::struct_cast; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; -use crate::expr::StatsCatalog; -use crate::expr::cast; use crate::expr::expression::Expression; use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; @@ -152,39 +149,6 @@ impl ScalarFnVTable for Cast { Ok(None) } - fn stat_expression( - &self, - dtype: &DType, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - match stat { - Stat::IsConstant - | Stat::IsSorted - | Stat::IsStrictSorted - | Stat::NaNCount - | Stat::Sum - | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog), - Stat::Max | Stat::Min => { - // We cast min/max to the new type - expr.child(0) - .stat_expression(stat, catalog) - .map(|x| cast(x, dtype.clone())) - } - Stat::NullCount => { - // if !expr.data().is_nullable() { - // NOTE(ngates): we should decide on the semantics here. In theory, the null - // count of something cast to non-nullable will be zero. But if we return - // that we know this to be zero, then a pruning predicate may eliminate data - // that would otherwise have caused the cast to error. - // return Some(lit(0u64)); - // } - None - } - } - } - fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult> { Ok(Some(if dtype.is_nullable() { expression.child(0).validity()? diff --git a/vortex-array/src/scalar_fn/fns/get_item.rs b/vortex-array/src/scalar_fn/fns/get_item.rs index de7e45ca9b0..b9cb9202f38 100644 --- a/vortex-array/src/scalar_fn/fns/get_item.rs +++ b/vortex-array/src/scalar_fn/fns/get_item.rs @@ -18,12 +18,9 @@ use crate::builtins::ArrayBuiltins; use crate::builtins::ExprBuiltins; use crate::dtype::DType; use crate::dtype::FieldName; -use crate::dtype::FieldPath; use crate::dtype::Nullability; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -188,24 +185,6 @@ impl ScalarFnVTable for GetItem { Ok(None) } - fn stat_expression( - &self, - field_name: &FieldName, - _expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - // TODO(ngates): I think we can do better here and support stats over nested fields. - // It would be nice if delegating to our child would return a struct of statistics - // matching the nested DType such that we can write: - // `get_item(expr.child(0).stat_expression(...), expr.data().field_name())` - - // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same - // name as a field in the root struct. This should be resolved with upcoming change to - // falsify expressions, but for now I'm preserving the existing buggy behavior. - catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat) - } - // This will apply struct nullability field. We could add a dtype?? fn is_null_sensitive(&self, _field_name: &FieldName) -> bool { true diff --git a/vortex-array/src/scalar_fn/fns/literal.rs b/vortex-array/src/scalar_fn/fns/literal.rs index 16b112e5a78..5181a5250dd 100644 --- a/vortex-array/src/scalar_fn/fns/literal.rs +++ b/vortex-array/src/scalar_fn/fns/literal.rs @@ -16,9 +16,6 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; -use crate::match_each_float_ptype; use crate::scalar::Scalar; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; @@ -96,50 +93,6 @@ impl ScalarFnVTable for Literal { Ok(ConstantArray::new(scalar.clone(), args.row_count()).into_array()) } - fn stat_expression( - &self, - scalar: &Scalar, - _expr: &Expression, - stat: Stat, - _catalog: &dyn StatsCatalog, - ) -> Option { - // NOTE(ngates): we return incorrect `1` values for counts here since we don't have - // row-count information. We could resolve this in the future by introducing a `count()` - // expression that evaluates to the row count of the provided scope. But since this is - // only currently used for pruning, it doesn't change the outcome. - - match stat { - Stat::Min | Stat::Max => Some(lit(scalar.clone())), - Stat::IsConstant => Some(lit(true)), - Stat::NaNCount => { - // The NaNCount for a non-float literal is not defined. - // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise. - let value = scalar.as_primitive_opt()?; - if !value.ptype().is_float() { - return None; - } - - match_each_float_ptype!(value.ptype(), |T| { - if value.typed_value::().is_some_and(|v| v.is_nan()) { - Some(lit(1u64)) - } else { - Some(lit(0u64)) - } - }) - } - Stat::NullCount => { - if scalar.is_null() { - Some(lit(1u64)) - } else { - Some(lit(0u64)) - } - } - Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => { - None - } - } - } - fn validity( &self, scalar: &Scalar, diff --git a/vortex-array/src/scalar_fn/fns/root.rs b/vortex-array/src/scalar_fn/fns/root.rs index 87b8b62ccf4..7bd5b758796 100644 --- a/vortex-array/src/scalar_fn/fns/root.rs +++ b/vortex-array/src/scalar_fn/fns/root.rs @@ -11,10 +11,7 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; -use crate::dtype::FieldPath; -use crate::expr::StatsCatalog; use crate::expr::expression::Expression; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -80,16 +77,6 @@ impl ScalarFnVTable for Root { vortex_bail!("Root expression is not executable") } - fn stat_expression( - &self, - _options: &Self::Options, - _expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - catalog.stats_ref(&FieldPath::root(), stat) - } - fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } diff --git a/vortex-array/src/scalar_fn/typed.rs b/vortex-array/src/scalar_fn/typed.rs index 83d2bfea496..e31cdc79ed0 100644 --- a/vortex-array/src/scalar_fn/typed.rs +++ b/vortex-array/src/scalar_fn/typed.rs @@ -22,8 +22,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; @@ -101,12 +99,6 @@ pub(super) trait DynScalarFn: 'static + Send + Sync + super::sealed::Sealed { ) -> VortexResult>; fn simplify_untyped(&self, expression: &Expression) -> VortexResult>; fn validity(&self, expression: &Expression) -> VortexResult>; - fn stat_expression( - &self, - expression: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option; // Options operations — self-contained fn options_serialize(&self) -> VortexResult>>; @@ -218,15 +210,6 @@ impl DynScalarFn for TypedScalarFnInstance { V::validity(&self.vtable, &self.options, expression) } - fn stat_expression( - &self, - expression: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - V::stat_expression(&self.vtable, &self.options, expression, stat, catalog) - } - fn options_serialize(&self) -> VortexResult>> { V::serialize(&self.vtable, &self.options) } diff --git a/vortex-array/src/scalar_fn/vtable.rs b/vortex-array/src/scalar_fn/vtable.rs index 1556354f9f1..c66afc34932 100644 --- a/vortex-array/src/scalar_fn/vtable.rs +++ b/vortex-array/src/scalar_fn/vtable.rs @@ -20,8 +20,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::expr::traversal::Node; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnRef; @@ -179,21 +177,6 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync { Ok(None) } - /// See [`Expression::stat_expression`]. - fn stat_expression( - &self, - options: &Self::Options, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - _ = options; - _ = expr; - _ = stat; - _ = catalog; - None - } - /// Returns an expression that evaluates to the validity of the result of this expression. /// /// If a validity expression cannot be constructed, returns `None` and the expression will From fe12e71a0b8112be719ef233ac2880f3b2ba17fe Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 16 Jun 2026 16:38:03 -0700 Subject: [PATCH 09/28] Address stats binding review comments Signed-off-by: "Nicholas Gates" --- .github/workflows/ci.yml | 4 -- docs/developer-guide/index.md | 3 +- .../internals/stats-pruning.md | 40 +++++++++++++++++++ vortex-array/src/expr/pruning/pruning_expr.rs | 37 ++++++++++++----- vortex-array/src/stats/bind.rs | 35 +++++++++------- vortex-array/src/stats/rewrite/builtins.rs | 27 +++++++++++++ vortex-file/src/v2/file_stats_reader.rs | 10 ++--- vortex-layout/src/layouts/zoned/zone_map.rs | 9 +++-- 8 files changed, 127 insertions(+), 38 deletions(-) create mode 100644 docs/developer-guide/internals/stats-pruning.md diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 57176093ea2..bf9c0a785b6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -368,10 +368,6 @@ jobs: with: sccache: s3 - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 - with: - distribution: "corretto" - java-version: "17" - uses: ./.github/actions/setup-prebuild - run: ./gradlew javadoc working-directory: ./java diff --git a/docs/developer-guide/index.md b/docs/developer-guide/index.md index 37ad276044d..e6eec4bfb70 100644 --- a/docs/developer-guide/index.md +++ b/docs/developer-guide/index.md @@ -24,6 +24,7 @@ internals/session internals/async-runtime internals/vtables internals/execution +internals/stats-pruning internals/io internals/serialization internals/cuda @@ -38,4 +39,4 @@ caption: Integrations integrations/datafusion integrations/duckdb integrations/spark -``` \ No newline at end of file +``` diff --git a/docs/developer-guide/internals/stats-pruning.md b/docs/developer-guide/internals/stats-pruning.md new file mode 100644 index 00000000000..6049a06455e --- /dev/null +++ b/docs/developer-guide/internals/stats-pruning.md @@ -0,0 +1,40 @@ +# Stats Pruning + +Vortex uses statistics to prove when a filter cannot match a row group, zone, or +file. The proof expression returns `true` when the input can be skipped. It +returns `false` or `null` when pruning is not proven. + +The pruning pipeline has two phases: + +1. `Expression::falsify(scope, session)` asks the session's + `StatsRewriteRule`s to rewrite a filter into an abstract proof expression. + Rules describe semantics in terms of `vortex.stat(input, aggregate_fn)` + placeholders, so the rewrite does not depend on a particular layout. +2. `bind_stats` lowers those abstract stat placeholders with a `StatBinder`. + The binder maps stats to the representation used by the caller, such as + zone-map table fields, file-level stat literals, or typed null literals for + missing stats. + +Missing stats lower to typed null literals. This preserves the three-valued +logic used by pruning: only a non-null `true` value proves that the scope can be +skipped. A missing stat therefore cannot accidentally prune data. + +## Binding Targets + +Zone maps bind stats to fields in their per-zone stats table. The lowered +expression is evaluated against that table and produces a mask where `true` +means the zone can be skipped. + +File-level stats bind stats to literal values from the file footer. The lowered +expression is reduced and evaluated once for the full file. If it evaluates to +`true`, the file stats reader can return an all-false pruning mask without +reading child layouts. + +Scan planning uses `checked_pruning_expr` to lower a falsified expression against +the available stats table schema. It returns the stats-table expression and the +set of stat fields still required after expression reduction. If all required +stats are missing, only a constant `true` proof is useful; all other results are +treated as no pruning expression. + +For the layout model around these pruning points, see +[Layouts](../../concepts/layouts.md) and [Scanning](../../concepts/scanning.md). diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index ee775b4f13e..f3b1aaf6f5d 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -14,6 +14,8 @@ use crate::dtype::Field; use crate::dtype::FieldName; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; +use crate::dtype::Nullability; +use crate::dtype::StructFields; use crate::expr::Expression; use crate::expr::analysis::referenced_field_paths; use crate::expr::get_item; @@ -70,11 +72,13 @@ pub fn checked_pruning_expr( scope, available_stats, required_stats: Relation::new(), + bound_stats: Vec::new(), }; - let Some(lowered) = bind_stats(predicate, &mut binder)? else { - return Ok(None); - }; + let lowered = bind_stats(predicate, &mut binder)?; let required_stats = filter_required_stats(&lowered, binder.required_stats); + // If no stats-table fields remain, only a constant `true` proof can prune. + // `false`, `null`, and non-constant expressions cannot justify building a + // stats-table pruning expression. if required_stats.map().is_empty() && !matches!(bool_literal(&lowered), Some(Some(true))) { return Ok(None); } @@ -86,6 +90,7 @@ struct RequiredStatsBinder<'a> { scope: &'a DType, available_stats: &'a FieldPathSet, required_stats: RequiredStats, + bound_stats: Vec<(FieldName, DType)>, } impl StatBinder for RequiredStatsBinder<'_> { @@ -93,11 +98,18 @@ impl StatBinder for RequiredStatsBinder<'_> { self.scope } + fn bound_scope(&self) -> DType { + DType::Struct( + StructFields::from_iter(self.bound_stats.iter().cloned()), + Nullability::NonNullable, + ) + } + fn bind_stat( &mut self, input: &Expression, stat: Stat, - _stat_dtype: &DType, + stat_dtype: &DType, ) -> VortexResult> { let field_path = match direct_stat_field_path(input) { Some(field_path) => field_path, @@ -114,11 +126,18 @@ impl StatBinder for RequiredStatsBinder<'_> { return Ok(None); } - self.required_stats.insert(field_path.clone(), stat); - Ok(Some(get_item( - field_path_stat_field_name(&field_path, stat), - root(), - ))) + let stat_field_name = field_path_stat_field_name(&field_path, stat); + if self + .bound_stats + .iter() + .all(|(field_name, _)| field_name != stat_field_name) + { + self.bound_stats + .push((stat_field_name.clone(), stat_dtype.clone())); + } + + self.required_stats.insert(field_path, stat); + Ok(Some(get_item(stat_field_name, root()))) } } diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 752664396c6..96bffc9c6c7 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -2,6 +2,11 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors //! Bind abstract `vortex.stat` expressions to a concrete stats representation. +//! +//! Stats rewrite rules describe pruning in terms of `vortex.stat(input, aggregate_fn)` placeholders +//! so the rewrite is independent of where statistics are stored. Binding is the later pass that +//! replaces those placeholders with the representation used by a caller: zone-map field references, +//! file-level stat literals, or typed nulls for missing stats. use vortex_error::VortexResult; @@ -20,6 +25,14 @@ pub trait StatBinder { /// The dtype scope used to type-check expressions before stats are bound. fn scope(&self) -> &DType; + /// The dtype scope used after stats have been bound. + /// + /// Binders that rewrite stats to a different root expression, such as a + /// stats-table row, should return that post-binding root dtype. + fn bound_scope(&self) -> DType { + self.scope().clone() + } + /// Bind `stat(input)` to a concrete expression. /// /// Returning `Ok(None)` marks the stat as unavailable. [`bind_stats`] will @@ -52,10 +65,9 @@ pub trait StatBinder { /// Expression to use when a stat is unavailable. /// /// The default is a nullable null literal, which preserves three-valued - /// pruning semantics for stats-table execution. Catalog-like binders can - /// return `Ok(None)` to reject expressions that require unavailable stats. - fn missing_stat(&mut self, dtype: DType) -> VortexResult> { - Ok(Some(null_expr(dtype))) + /// pruning semantics for stats-table execution. + fn missing_stat(&mut self, dtype: DType) -> VortexResult { + Ok(null_expr(dtype)) } } @@ -64,10 +76,7 @@ pub trait StatBinder { /// The predicate is usually the output of a stats rewrite rule. Rewrite rules /// are responsible for expressing stat semantics; binding maps aggregate-backed /// stat requests to the concrete stats representation supported by the binder. -pub fn bind_stats( - predicate: Expression, - binder: &mut impl StatBinder, -) -> VortexResult> { +pub fn bind_stats(predicate: Expression, binder: &mut impl StatBinder) -> VortexResult { let scope = binder.scope().clone(); let lowered = predicate .transform_down(|expr| { @@ -79,18 +88,13 @@ pub fn bind_stats( Some(bound) => Ok(Transformed::yes(bound)), None => { let dtype = expr.return_dtype(&scope)?; - match binder.missing_stat(dtype.clone())? { - Some(missing) => Ok(Transformed::yes(missing)), - None => Ok(Transformed::yes(null_expr(dtype))), - } + Ok(Transformed::yes(binder.missing_stat(dtype)?)) } } })? .into_inner(); - #[expect(deprecated)] - let lowered = lowered.simplify_untyped()?; - Ok(Some(lowered)) + lowered.optimize_recursive(&binder.bound_scope()) } fn bind_stat_fn( @@ -100,6 +104,7 @@ fn bind_stat_fn( ) -> VortexResult> { let options = expr.as_::(); let aggregate_fn = options.aggregate_fn(); + // `StatFn` has exactly one child: the expression the aggregate statistic is computed over. let input = expr.child(0); let stat_dtype = expr.return_dtype(scope)?; diff --git a/vortex-array/src/stats/rewrite/builtins.rs b/vortex-array/src/stats/rewrite/builtins.rs index 2b7316c7d98..c60f5869055 100644 --- a/vortex-array/src/stats/rewrite/builtins.rs +++ b/vortex-array/src/stats/rewrite/builtins.rs @@ -814,6 +814,33 @@ mod tests { )) ); + let expr = like(col("s"), lit(r"\%%")); + assert_eq!( + falsify(&expr)?, + Some(or( + gt_eq(stat(col("s"), Stat::Min), lit("&")), + lt(stat(col("s"), Stat::Max), lit("%")), + )) + ); + + let expr = like(col("s"), lit("pref%ix%")); + assert_eq!( + falsify(&expr)?, + Some(or( + gt_eq(stat(col("s"), Stat::Min), lit("preg")), + lt(stat(col("s"), Stat::Max), lit("pref")), + )) + ); + + let expr = like(col("s"), lit("pref_ix_")); + assert_eq!( + falsify(&expr)?, + Some(or( + gt_eq(stat(col("s"), Stat::Min), lit("preg")), + lt(stat(col("s"), Stat::Max), lit("pref")), + )) + ); + let expr = like(col("s"), lit("exact")); assert_eq!( falsify(&expr)?, diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index b697becbc0f..527fa46262a 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -31,7 +31,6 @@ use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::stats::bind::StatBinder; use vortex_array::stats::bind::bind_stats; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_layout::ArrayFuture; use vortex_layout::LayoutReader; use vortex_layout::LayoutReaderRef; @@ -121,10 +120,7 @@ impl FileStatsLayoutReader { fn lower_stats(&self, predicate: Expression) -> VortexResult { let mut binder = FileStatsBinder { reader: self }; - let Some(predicate) = bind_stats(predicate, &mut binder)? else { - vortex_bail!("missing stats should lower to null literals"); - }; - Ok(predicate) + bind_stats(predicate, &mut binder) } pub fn file_stats(&self) -> &FileStatistics { @@ -141,6 +137,10 @@ impl StatBinder for FileStatsBinder<'_> { self.reader.child.dtype() } + fn bound_scope(&self) -> DType { + DType::Null + } + fn bind_stat( &mut self, input: &Expression, diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 6309e6d225b..0530dec4d57 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -121,10 +121,7 @@ impl ZoneMap { fn lower_stats(&self, predicate: Expression) -> VortexResult { let mut binder = ZoneMapStatsBinder { zone_map: self }; - let Some(predicate) = bind_stats(predicate, &mut binder)? else { - vortex_bail!("missing stats should lower to null literals"); - }; - Ok(predicate) + bind_stats(predicate, &mut binder) } } @@ -137,6 +134,10 @@ impl StatBinder for ZoneMapStatsBinder<'_> { &self.zone_map.column_dtype } + fn bound_scope(&self) -> DType { + self.zone_map.array.dtype().clone() + } + fn bind_stat( &mut self, input: &Expression, From 0de2c098bbc8d494b469d3d1e41733de301e2ba6 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 16 Jun 2026 23:08:47 -0700 Subject: [PATCH 10/28] Bind abstract aggregate stats to legacy counts Signed-off-by: "Nicholas Gates" --- vortex-array/src/expr/pruning/pruning_expr.rs | 5 +- vortex-array/src/scalar_fn/fns/is_not_null.rs | 3 +- vortex-array/src/scalar_fn/fns/is_null.rs | 2 +- vortex-array/src/stats/bind.rs | 49 +++++++++++++++++++ vortex-layout/src/layouts/zoned/zone_map.rs | 12 ++--- 5 files changed, 59 insertions(+), 12 deletions(-) diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index f3b1aaf6f5d..5dca59d4f94 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -229,7 +229,6 @@ mod tests { use crate::expr::pruning::field_path_stat_field_name; use crate::expr::root; use crate::expr::stats::Stat; - use crate::scalar::Scalar; use crate::scalar_fn::fns::between::BetweenOptions; use crate::scalar_fn::fns::between::StrictComparison; use crate::stats::session::StatsSession; @@ -591,7 +590,7 @@ mod tests { &HashMap::from_iter([ ( FieldPath::from_name("float_col"), - HashSet::from_iter([Stat::Max]) + HashSet::from_iter([Stat::Max, Stat::NaNCount]) ), ( FieldPath::from_name("int_col"), @@ -603,7 +602,7 @@ mod tests { &converted, &or( and( - lit(Scalar::null(DType::Bool(Nullability::Nullable))), + eq(col("float_col_nan_count"), lit(0u64)), lt_eq(col("float_col_max"), lit(10f32)), ), gt_eq(col("int_col_min"), lit(10)), diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index a64aab32611..850b074c9ab 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -118,7 +118,6 @@ mod tests { use crate::expr::eq; use crate::expr::get_item; use crate::expr::is_not_null; - use crate::expr::lit; use crate::expr::or; use crate::expr::pruning::checked_pruning_expr; use crate::expr::root; @@ -266,7 +265,7 @@ mod tests { &pruning_expr, &or( eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, [])), - lit(Scalar::null(DType::Bool(Nullability::Nullable))), + eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, [])), ) ); assert_eq!( diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index 807b9d9a043..b4dc839a6cb 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -251,7 +251,7 @@ mod tests { &pruning_expr, &or( eq(col("a_null_count"), lit(0u64)), - lit(Scalar::null(DType::Bool(Nullability::Nullable))), + eq(col("a_null_count"), lit(0u64)), ) ); assert_eq!( diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 96bffc9c6c7..e4667597e3c 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -11,14 +11,22 @@ use vortex_error::VortexResult; use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::fns::all_nan::AllNan; +use crate::aggregate_fn::fns::all_non_nan::AllNonNan; +use crate::aggregate_fn::fns::all_non_null::AllNonNull; +use crate::aggregate_fn::fns::all_null::AllNull; use crate::dtype::DType; use crate::expr::Expression; +use crate::expr::eq; use crate::expr::lit; use crate::expr::stats::Stat; use crate::expr::traversal::NodeExt; use crate::expr::traversal::Transformed; use crate::scalar::Scalar; +use crate::scalar_fn::EmptyOptions; +use crate::scalar_fn::ScalarFnVTableExt; use crate::scalar_fn::fns::stat::StatFn; +use crate::scalar_fn::internal::row_count::RowCount; /// A target that can bind abstract statistics to concrete expressions. pub trait StatBinder { @@ -56,12 +64,53 @@ pub trait StatBinder { aggregate_fn: &AggregateFnRef, stat_dtype: &DType, ) -> VortexResult> { + if aggregate_fn.is::() { + let Some(nan_count) = self.bind_legacy_stat(input, Stat::NaNCount)? else { + return Ok(None); + }; + return Ok(Some(eq(nan_count, RowCount.new_expr(EmptyOptions, [])))); + } + + if aggregate_fn.is::() { + let Some(nan_count) = self.bind_legacy_stat(input, Stat::NaNCount)? else { + return Ok(None); + }; + return Ok(Some(eq(nan_count, lit(0u64)))); + } + + if aggregate_fn.is::() { + let Some(null_count) = self.bind_legacy_stat(input, Stat::NullCount)? else { + return Ok(None); + }; + return Ok(Some(eq(null_count, RowCount.new_expr(EmptyOptions, [])))); + } + + if aggregate_fn.is::() { + let Some(null_count) = self.bind_legacy_stat(input, Stat::NullCount)? else { + return Ok(None); + }; + return Ok(Some(eq(null_count, lit(0u64)))); + } + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { return Ok(None); }; self.bind_stat(input, stat, stat_dtype) } + /// Bind one of the legacy stat slots for `input`. + fn bind_legacy_stat( + &mut self, + input: &Expression, + stat: Stat, + ) -> VortexResult> { + let input_dtype = input.return_dtype(self.scope())?; + let Some(stat_dtype) = stat.dtype(&input_dtype) else { + return Ok(None); + }; + self.bind_stat(input, stat, &stat_dtype) + } + /// Expression to use when a stat is unavailable. /// /// The default is a nullable null literal, which preserves three-valued diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 0530dec4d57..f5d89007881 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -346,7 +346,7 @@ mod tests { } #[test] - fn all_null_stat_fn_lowers_to_unknown_mask() { + fn all_null_stat_fn_uses_null_count() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -363,12 +363,12 @@ mod tests { let mask = zone_map.prune(&all_null(root()), &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), - BoolArray::from_iter([false, false, false]) + BoolArray::from_iter([false, true, true]) ); } #[test] - fn all_non_null_stat_fn_lowers_to_unknown_mask() { + fn all_non_null_stat_fn_uses_null_count() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -385,7 +385,7 @@ mod tests { let mask = zone_map.prune(&all_non_null(root()), &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), - BoolArray::from_iter([false, false, false]) + BoolArray::from_iter([true, false, false]) ); } @@ -485,7 +485,7 @@ mod tests { let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), - BoolArray::from_iter([false, false, false]) + BoolArray::from_iter([true, false, false]) ); } @@ -527,7 +527,7 @@ mod tests { let pruning_expr = falsify(&expr, PType::F32.into()); let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, false])); + assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true])); } #[test] From 118c995b4fbc6c1f987b68480d05fa63204caa37 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 16 Jun 2026 23:26:30 -0700 Subject: [PATCH 11/28] Document aggregate stat binding coverage Signed-off-by: "Nicholas Gates" --- vortex-array/src/stats/bind.rs | 99 +++++++++++++++++++++ vortex-layout/src/layouts/zoned/zone_map.rs | 7 +- 2 files changed, 101 insertions(+), 5 deletions(-) diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index e4667597e3c..87f7e249c46 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -64,6 +64,9 @@ pub trait StatBinder { aggregate_fn: &AggregateFnRef, stat_dtype: &DType, ) -> VortexResult> { + // These aggregate stats are derived from legacy count slots rather than + // stored directly. Keep that storage mapping in binding so stats rewrite + // rules can continue to ask for the semantic aggregate they need. if aggregate_fn.is::() { let Some(nan_count) = self.bind_legacy_stat(input, Stat::NaNCount)? else { return Ok(None); @@ -163,3 +166,99 @@ fn bind_stat_fn( fn null_expr(dtype: DType) -> Expression { lit(Scalar::null(dtype.as_nullable())) } + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use super::*; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::dtype::StructFields; + use crate::expr::col; + use crate::expr::get_item; + use crate::expr::is_null; + use crate::expr::root; + use crate::stats::all_non_nan; + + struct TestBinder { + input_scope: DType, + bound_scope: DType, + bind_nan_count: bool, + } + + impl TestBinder { + fn new(bind_nan_count: bool) -> Self { + Self { + input_scope: DType::Struct( + StructFields::from_iter([( + "f", + DType::Primitive(PType::F32, Nullability::NonNullable), + )]), + Nullability::NonNullable, + ), + bound_scope: DType::Struct( + StructFields::from_iter([( + "f_nan_count", + DType::Primitive(PType::U64, Nullability::NonNullable), + )]), + Nullability::NonNullable, + ), + bind_nan_count, + } + } + } + + impl StatBinder for TestBinder { + fn scope(&self) -> &DType { + &self.input_scope + } + + fn bound_scope(&self) -> DType { + self.bound_scope.clone() + } + + fn bind_stat( + &mut self, + _input: &Expression, + stat: Stat, + _stat_dtype: &DType, + ) -> VortexResult> { + if stat == Stat::NaNCount && self.bind_nan_count { + Ok(Some(get_item("f_nan_count", root()))) + } else { + Ok(None) + } + } + } + + #[test] + fn all_non_nan_binds_to_nan_count_zero() -> VortexResult<()> { + let mut binder = TestBinder::new(true); + + let bound = bind_stats(all_non_nan(col("f")), &mut binder)?; + + assert_eq!(bound, eq(col("f_nan_count"), lit(0u64))); + Ok(()) + } + + #[test] + fn all_non_nan_lowers_to_null_when_nan_count_is_missing() -> VortexResult<()> { + let mut binder = TestBinder::new(false); + + let bound = bind_stats(all_non_nan(col("f")), &mut binder)?; + + assert_eq!(bound, lit(Scalar::null(DType::Bool(Nullability::Nullable)))); + Ok(()) + } + + #[test] + fn unrelated_expressions_do_not_request_nan_count() -> VortexResult<()> { + let mut binder = TestBinder::new(false); + + let bound = bind_stats(is_null(col("f")), &mut binder)?; + + assert_eq!(bound, is_null(col("f"))); + Ok(()) + } +} diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index f5d89007881..ac7711ea17a 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -361,10 +361,7 @@ mod tests { .unwrap(); let mask = zone_map.prune(&all_null(root()), &SESSION).unwrap(); - assert_arrays_eq!( - mask.into_array(), - BoolArray::from_iter([false, true, true]) - ); + assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true])); } #[test] @@ -431,7 +428,7 @@ mod tests { } #[test] - fn float_min_max_stat_fn_requires_all_non_nan() { + fn float_min_max_prunes_only_with_all_non_nan_proof() { let zone_map = ZoneMap::try_new( PType::F32.into(), StructArray::from_fields(&[ From 32241d2434db132642f1a55bb2bb97fb31284a03 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 16 Jun 2026 23:58:48 -0700 Subject: [PATCH 12/28] Make legacy stat aggregate binding explicit Signed-off-by: "Nicholas Gates" --- vortex-array/src/expr/pruning/pruning_expr.rs | 11 ++ vortex-array/src/stats/bind.rs | 163 ++++++++++++++---- vortex-file/src/v2/file_stats_reader.rs | 11 ++ vortex-layout/src/layouts/zoned/zone_map.rs | 11 ++ 4 files changed, 161 insertions(+), 35 deletions(-) diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 5dca59d4f94..3ca196cfb66 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -9,6 +9,7 @@ use vortex_session::VortexSession; use vortex_utils::aliases::hash_set::HashSet; use super::relation::Relation; +use crate::aggregate_fn::AggregateFnRef; use crate::dtype::DType; use crate::dtype::Field; use crate::dtype::FieldName; @@ -26,6 +27,7 @@ use crate::scalar_fn::fns::cast::Cast; use crate::scalar_fn::fns::get_item::GetItem; use crate::scalar_fn::fns::literal::Literal; use crate::stats::bind::StatBinder; +use crate::stats::bind::bind_legacy_count_or_direct_aggregate; use crate::stats::bind::bind_stats; pub type RequiredStats = Relation; @@ -139,6 +141,15 @@ impl StatBinder for RequiredStatsBinder<'_> { self.required_stats.insert(field_path, stat); Ok(Some(get_item(stat_field_name, root()))) } + + fn bind_aggregate( + &mut self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult> { + bind_legacy_count_or_direct_aggregate(self, input, aggregate_fn, stat_dtype) + } } fn direct_stat_field_path(expr: &Expression) -> Option { diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 87f7e249c46..8eb0e916015 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -64,41 +64,7 @@ pub trait StatBinder { aggregate_fn: &AggregateFnRef, stat_dtype: &DType, ) -> VortexResult> { - // These aggregate stats are derived from legacy count slots rather than - // stored directly. Keep that storage mapping in binding so stats rewrite - // rules can continue to ask for the semantic aggregate they need. - if aggregate_fn.is::() { - let Some(nan_count) = self.bind_legacy_stat(input, Stat::NaNCount)? else { - return Ok(None); - }; - return Ok(Some(eq(nan_count, RowCount.new_expr(EmptyOptions, [])))); - } - - if aggregate_fn.is::() { - let Some(nan_count) = self.bind_legacy_stat(input, Stat::NaNCount)? else { - return Ok(None); - }; - return Ok(Some(eq(nan_count, lit(0u64)))); - } - - if aggregate_fn.is::() { - let Some(null_count) = self.bind_legacy_stat(input, Stat::NullCount)? else { - return Ok(None); - }; - return Ok(Some(eq(null_count, RowCount.new_expr(EmptyOptions, [])))); - } - - if aggregate_fn.is::() { - let Some(null_count) = self.bind_legacy_stat(input, Stat::NullCount)? else { - return Ok(None); - }; - return Ok(Some(eq(null_count, lit(0u64)))); - } - - let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { - return Ok(None); - }; - self.bind_stat(input, stat, stat_dtype) + bind_direct_aggregate_stat(self, input, aggregate_fn, stat_dtype) } /// Bind one of the legacy stat slots for `input`. @@ -149,6 +115,77 @@ pub fn bind_stats(predicate: Expression, binder: &mut impl StatBinder) -> Vortex lowered.optimize_recursive(&binder.bound_scope()) } +/// Bind aggregate stats that can be derived from legacy count stat slots. +/// +/// This is an opt-in helper for stats backends that materialize `NaNCount` and +/// `NullCount`, but do not materialize aggregate boolean stats directly. +pub fn bind_legacy_count_aggregate( + binder: &mut B, + input: &Expression, + aggregate_fn: &AggregateFnRef, +) -> VortexResult> { + if aggregate_fn.is::() { + let Some(nan_count) = binder.bind_legacy_stat(input, Stat::NaNCount)? else { + return Ok(None); + }; + return Ok(Some(eq(nan_count, RowCount.new_expr(EmptyOptions, [])))); + } + + if aggregate_fn.is::() { + let Some(nan_count) = binder.bind_legacy_stat(input, Stat::NaNCount)? else { + return Ok(None); + }; + return Ok(Some(eq(nan_count, lit(0u64)))); + } + + if aggregate_fn.is::() { + let Some(null_count) = binder.bind_legacy_stat(input, Stat::NullCount)? else { + return Ok(None); + }; + return Ok(Some(eq(null_count, RowCount.new_expr(EmptyOptions, [])))); + } + + if aggregate_fn.is::() { + let Some(null_count) = binder.bind_legacy_stat(input, Stat::NullCount)? else { + return Ok(None); + }; + return Ok(Some(eq(null_count, lit(0u64)))); + } + + Ok(None) +} + +/// Bind an aggregate function that has a direct legacy [`Stat`] slot. +pub fn bind_direct_aggregate_stat( + binder: &mut B, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, +) -> VortexResult> { + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; + binder.bind_stat(input, stat, stat_dtype) +} + +/// Bind aggregate stats for backends that expose legacy count-derived stats. +/// +/// Backends using this helper first bind aggregate facts derivable from +/// `NaNCount` and `NullCount`, then fall back to direct aggregate-to-stat +/// mappings. +pub fn bind_legacy_count_or_direct_aggregate( + binder: &mut B, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, +) -> VortexResult> { + if let Some(bound) = bind_legacy_count_aggregate(binder, input, aggregate_fn)? { + return Ok(Some(bound)); + } + + bind_direct_aggregate_stat(binder, input, aggregate_fn, stat_dtype) +} + fn bind_stat_fn( expr: &Expression, scope: &DType, @@ -175,9 +212,11 @@ mod tests { use crate::dtype::Nullability; use crate::dtype::PType; use crate::dtype::StructFields; + use crate::expr::and; use crate::expr::col; use crate::expr::get_item; use crate::expr::is_null; + use crate::expr::or; use crate::expr::root; use crate::stats::all_non_nan; @@ -230,6 +269,15 @@ mod tests { Ok(None) } } + + fn bind_aggregate( + &mut self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult> { + bind_legacy_count_or_direct_aggregate(self, input, aggregate_fn, stat_dtype) + } } #[test] @@ -252,6 +300,51 @@ mod tests { Ok(()) } + #[test] + fn missing_stats_fold_when_kleene_semantics_allow_it() -> VortexResult<()> { + let mut binder = TestBinder::new(false); + + let bound = bind_stats(and(lit(false), all_non_nan(col("f"))), &mut binder)?; + + assert_eq!(bound, lit(false)); + + let bound = bind_stats(or(lit(true), all_non_nan(col("f"))), &mut binder)?; + + assert_eq!(bound, lit(true)); + Ok(()) + } + + #[test] + fn default_binder_does_not_derive_all_non_nan_from_nan_count() -> VortexResult<()> { + struct DefaultBinder(TestBinder); + + impl StatBinder for DefaultBinder { + fn scope(&self) -> &DType { + self.0.scope() + } + + fn bound_scope(&self) -> DType { + self.0.bound_scope() + } + + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + stat_dtype: &DType, + ) -> VortexResult> { + self.0.bind_stat(input, stat, stat_dtype) + } + } + + let mut binder = DefaultBinder(TestBinder::new(true)); + + let bound = bind_stats(all_non_nan(col("f")), &mut binder)?; + + assert_eq!(bound, lit(Scalar::null(DType::Bool(Nullability::Nullable)))); + Ok(()) + } + #[test] fn unrelated_expressions_do_not_request_nan_count() -> VortexResult<()> { let mut binder = TestBinder::new(false); diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index 527fa46262a..16b39d293f2 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -14,6 +14,7 @@ use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::MaskFuture; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::AggregateFnRef; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::NullArray; use vortex_array::dtype::DType; @@ -29,6 +30,7 @@ use vortex_array::scalar_fn::fns::get_item::GetItem; use vortex_array::scalar_fn::fns::literal::Literal; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_legacy_count_or_direct_aggregate; use vortex_array::stats::bind::bind_stats; use vortex_error::VortexResult; use vortex_layout::ArrayFuture; @@ -152,6 +154,15 @@ impl StatBinder for FileStatsBinder<'_> { }; Ok(self.reader.stat_ref(&field_path, stat)) } + + fn bind_aggregate( + &mut self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult> { + bind_legacy_count_or_direct_aggregate(self, input, aggregate_fn, stat_dtype) + } } fn direct_field_path(expr: &Expression) -> Option { diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index ac7711ea17a..a50f7fc006c 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::AggregateFnRef; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; @@ -22,6 +23,7 @@ use vortex_array::expr::stats::Stat; use vortex_array::scalar_fn::internal::row_count::contains_row_count; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_legacy_count_or_direct_aggregate; use vortex_array::stats::bind::bind_stats; use vortex_array::validity::Validity; use vortex_buffer::buffer; @@ -157,6 +159,15 @@ impl StatBinder for ZoneMapStatsBinder<'_> { } Ok(Some(get_item(stat.name(), root()))) } + + fn bind_aggregate( + &mut self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult> { + bind_legacy_count_or_direct_aggregate(self, input, aggregate_fn, stat_dtype) + } } /// Build per-zone row counts for a zone map. From 360aa58c0dad2318979cfed08b861002fa37817a Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 17 Jun 2026 12:33:22 -0400 Subject: [PATCH 13/28] Make stats binders immutable Signed-off-by: "Nicholas Gates" --- vortex-array/src/expr/pruning/pruning_expr.rs | 30 +++++------ vortex-array/src/stats/bind.rs | 53 +++++++++---------- vortex-file/src/v2/file_stats_reader.rs | 8 +-- vortex-layout/src/layouts/zoned/zone_map.rs | 8 +-- 4 files changed, 49 insertions(+), 50 deletions(-) diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 3ca196cfb66..23e5ca63eaf 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::cell::RefCell; use std::iter; use itertools::Itertools; @@ -70,14 +71,14 @@ pub fn checked_pruning_expr( return Ok(None); }; - let mut binder = RequiredStatsBinder { + let binder = RequiredStatsBinder { scope, available_stats, - required_stats: Relation::new(), - bound_stats: Vec::new(), + required_stats: RefCell::new(Relation::new()), + bound_stats: RefCell::new(Vec::new()), }; - let lowered = bind_stats(predicate, &mut binder)?; - let required_stats = filter_required_stats(&lowered, binder.required_stats); + let lowered = bind_stats(predicate, &binder)?; + let required_stats = filter_required_stats(&lowered, binder.required_stats.into_inner()); // If no stats-table fields remain, only a constant `true` proof can prune. // `false`, `null`, and non-constant expressions cannot justify building a // stats-table pruning expression. @@ -91,8 +92,8 @@ pub fn checked_pruning_expr( struct RequiredStatsBinder<'a> { scope: &'a DType, available_stats: &'a FieldPathSet, - required_stats: RequiredStats, - bound_stats: Vec<(FieldName, DType)>, + required_stats: RefCell, + bound_stats: RefCell>, } impl StatBinder for RequiredStatsBinder<'_> { @@ -102,13 +103,13 @@ impl StatBinder for RequiredStatsBinder<'_> { fn bound_scope(&self) -> DType { DType::Struct( - StructFields::from_iter(self.bound_stats.iter().cloned()), + StructFields::from_iter(self.bound_stats.borrow().iter().cloned()), Nullability::NonNullable, ) } fn bind_stat( - &mut self, + &self, input: &Expression, stat: Stat, stat_dtype: &DType, @@ -129,21 +130,20 @@ impl StatBinder for RequiredStatsBinder<'_> { } let stat_field_name = field_path_stat_field_name(&field_path, stat); - if self - .bound_stats + let mut bound_stats = self.bound_stats.borrow_mut(); + if bound_stats .iter() .all(|(field_name, _)| field_name != stat_field_name) { - self.bound_stats - .push((stat_field_name.clone(), stat_dtype.clone())); + bound_stats.push((stat_field_name.clone(), stat_dtype.clone())); } - self.required_stats.insert(field_path, stat); + self.required_stats.borrow_mut().insert(field_path, stat); Ok(Some(get_item(stat_field_name, root()))) } fn bind_aggregate( - &mut self, + &self, input: &Expression, aggregate_fn: &AggregateFnRef, stat_dtype: &DType, diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 8eb0e916015..45381dbdef7 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -47,7 +47,7 @@ pub trait StatBinder { /// then call [`Self::missing_stat`] with the dtype expected from the /// original `vortex.stat` expression. fn bind_stat( - &mut self, + &self, input: &Expression, stat: Stat, stat_dtype: &DType, @@ -59,7 +59,7 @@ pub trait StatBinder { /// [`Stat`] slots. Binders that store richer aggregate stats can override /// this method without extending the generic stats binding walker. fn bind_aggregate( - &mut self, + &self, input: &Expression, aggregate_fn: &AggregateFnRef, stat_dtype: &DType, @@ -68,11 +68,7 @@ pub trait StatBinder { } /// Bind one of the legacy stat slots for `input`. - fn bind_legacy_stat( - &mut self, - input: &Expression, - stat: Stat, - ) -> VortexResult> { + fn bind_legacy_stat(&self, input: &Expression, stat: Stat) -> VortexResult> { let input_dtype = input.return_dtype(self.scope())?; let Some(stat_dtype) = stat.dtype(&input_dtype) else { return Ok(None); @@ -84,7 +80,7 @@ pub trait StatBinder { /// /// The default is a nullable null literal, which preserves three-valued /// pruning semantics for stats-table execution. - fn missing_stat(&mut self, dtype: DType) -> VortexResult { + fn missing_stat(&self, dtype: DType) -> VortexResult { Ok(null_expr(dtype)) } } @@ -94,7 +90,10 @@ pub trait StatBinder { /// The predicate is usually the output of a stats rewrite rule. Rewrite rules /// are responsible for expressing stat semantics; binding maps aggregate-backed /// stat requests to the concrete stats representation supported by the binder. -pub fn bind_stats(predicate: Expression, binder: &mut impl StatBinder) -> VortexResult { +pub fn bind_stats( + predicate: Expression, + binder: &B, +) -> VortexResult { let scope = binder.scope().clone(); let lowered = predicate .transform_down(|expr| { @@ -120,7 +119,7 @@ pub fn bind_stats(predicate: Expression, binder: &mut impl StatBinder) -> Vortex /// This is an opt-in helper for stats backends that materialize `NaNCount` and /// `NullCount`, but do not materialize aggregate boolean stats directly. pub fn bind_legacy_count_aggregate( - binder: &mut B, + binder: &B, input: &Expression, aggregate_fn: &AggregateFnRef, ) -> VortexResult> { @@ -157,7 +156,7 @@ pub fn bind_legacy_count_aggregate( /// Bind an aggregate function that has a direct legacy [`Stat`] slot. pub fn bind_direct_aggregate_stat( - binder: &mut B, + binder: &B, input: &Expression, aggregate_fn: &AggregateFnRef, stat_dtype: &DType, @@ -174,7 +173,7 @@ pub fn bind_direct_aggregate_stat( /// `NaNCount` and `NullCount`, then fall back to direct aggregate-to-stat /// mappings. pub fn bind_legacy_count_or_direct_aggregate( - binder: &mut B, + binder: &B, input: &Expression, aggregate_fn: &AggregateFnRef, stat_dtype: &DType, @@ -189,7 +188,7 @@ pub fn bind_legacy_count_or_direct_aggregate( fn bind_stat_fn( expr: &Expression, scope: &DType, - binder: &mut impl StatBinder, + binder: &(impl StatBinder + ?Sized), ) -> VortexResult> { let options = expr.as_::(); let aggregate_fn = options.aggregate_fn(); @@ -258,7 +257,7 @@ mod tests { } fn bind_stat( - &mut self, + &self, _input: &Expression, stat: Stat, _stat_dtype: &DType, @@ -271,7 +270,7 @@ mod tests { } fn bind_aggregate( - &mut self, + &self, input: &Expression, aggregate_fn: &AggregateFnRef, stat_dtype: &DType, @@ -282,9 +281,9 @@ mod tests { #[test] fn all_non_nan_binds_to_nan_count_zero() -> VortexResult<()> { - let mut binder = TestBinder::new(true); + let binder = TestBinder::new(true); - let bound = bind_stats(all_non_nan(col("f")), &mut binder)?; + let bound = bind_stats(all_non_nan(col("f")), &binder)?; assert_eq!(bound, eq(col("f_nan_count"), lit(0u64))); Ok(()) @@ -292,9 +291,9 @@ mod tests { #[test] fn all_non_nan_lowers_to_null_when_nan_count_is_missing() -> VortexResult<()> { - let mut binder = TestBinder::new(false); + let binder = TestBinder::new(false); - let bound = bind_stats(all_non_nan(col("f")), &mut binder)?; + let bound = bind_stats(all_non_nan(col("f")), &binder)?; assert_eq!(bound, lit(Scalar::null(DType::Bool(Nullability::Nullable)))); Ok(()) @@ -302,13 +301,13 @@ mod tests { #[test] fn missing_stats_fold_when_kleene_semantics_allow_it() -> VortexResult<()> { - let mut binder = TestBinder::new(false); + let binder = TestBinder::new(false); - let bound = bind_stats(and(lit(false), all_non_nan(col("f"))), &mut binder)?; + let bound = bind_stats(and(lit(false), all_non_nan(col("f"))), &binder)?; assert_eq!(bound, lit(false)); - let bound = bind_stats(or(lit(true), all_non_nan(col("f"))), &mut binder)?; + let bound = bind_stats(or(lit(true), all_non_nan(col("f"))), &binder)?; assert_eq!(bound, lit(true)); Ok(()) @@ -328,7 +327,7 @@ mod tests { } fn bind_stat( - &mut self, + &self, input: &Expression, stat: Stat, stat_dtype: &DType, @@ -337,9 +336,9 @@ mod tests { } } - let mut binder = DefaultBinder(TestBinder::new(true)); + let binder = DefaultBinder(TestBinder::new(true)); - let bound = bind_stats(all_non_nan(col("f")), &mut binder)?; + let bound = bind_stats(all_non_nan(col("f")), &binder)?; assert_eq!(bound, lit(Scalar::null(DType::Bool(Nullability::Nullable)))); Ok(()) @@ -347,9 +346,9 @@ mod tests { #[test] fn unrelated_expressions_do_not_request_nan_count() -> VortexResult<()> { - let mut binder = TestBinder::new(false); + let binder = TestBinder::new(false); - let bound = bind_stats(is_null(col("f")), &mut binder)?; + let bound = bind_stats(is_null(col("f")), &binder)?; assert_eq!(bound, is_null(col("f"))); Ok(()) diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index 16b39d293f2..c138c251ab6 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -121,8 +121,8 @@ impl FileStatsLayoutReader { } fn lower_stats(&self, predicate: Expression) -> VortexResult { - let mut binder = FileStatsBinder { reader: self }; - bind_stats(predicate, &mut binder) + let binder = FileStatsBinder { reader: self }; + bind_stats(predicate, &binder) } pub fn file_stats(&self) -> &FileStatistics { @@ -144,7 +144,7 @@ impl StatBinder for FileStatsBinder<'_> { } fn bind_stat( - &mut self, + &self, input: &Expression, stat: Stat, _stat_dtype: &DType, @@ -156,7 +156,7 @@ impl StatBinder for FileStatsBinder<'_> { } fn bind_aggregate( - &mut self, + &self, input: &Expression, aggregate_fn: &AggregateFnRef, stat_dtype: &DType, diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index a50f7fc006c..a14dc71691f 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -122,8 +122,8 @@ impl ZoneMap { } fn lower_stats(&self, predicate: Expression) -> VortexResult { - let mut binder = ZoneMapStatsBinder { zone_map: self }; - bind_stats(predicate, &mut binder) + let binder = ZoneMapStatsBinder { zone_map: self }; + bind_stats(predicate, &binder) } } @@ -141,7 +141,7 @@ impl StatBinder for ZoneMapStatsBinder<'_> { } fn bind_stat( - &mut self, + &self, input: &Expression, stat: Stat, _stat_dtype: &DType, @@ -161,7 +161,7 @@ impl StatBinder for ZoneMapStatsBinder<'_> { } fn bind_aggregate( - &mut self, + &self, input: &Expression, aggregate_fn: &AggregateFnRef, stat_dtype: &DType, From ec94e6995757443ad98b91a4467189c599444ed1 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 17 Jun 2026 12:44:43 -0400 Subject: [PATCH 14/28] Remove required stats pruning binder Signed-off-by: "Nicholas Gates" --- vortex-array/src/expr/mod.rs | 1 - vortex-array/src/expr/pruning/mod.rs | 10 - vortex-array/src/expr/pruning/pruning_expr.rs | 680 ------------------ vortex-array/src/expr/pruning/relation.rs | 50 -- vortex-array/src/scalar_fn/fns/is_not_null.rs | 41 -- vortex-array/src/scalar_fn/fns/is_null.rs | 39 - .../src/scalar_fn/fns/list_contains/mod.rs | 62 -- vortex-file/src/file.rs | 73 +- vortex-file/src/pruning.rs | 178 +++-- vortex-file/src/v2/file_stats_reader.rs | 125 +--- 10 files changed, 135 insertions(+), 1124 deletions(-) delete mode 100644 vortex-array/src/expr/pruning/mod.rs delete mode 100644 vortex-array/src/expr/pruning/pruning_expr.rs delete mode 100644 vortex-array/src/expr/pruning/relation.rs diff --git a/vortex-array/src/expr/mod.rs b/vortex-array/src/expr/mod.rs index 72969baf23a..c3728155517 100644 --- a/vortex-array/src/expr/mod.rs +++ b/vortex-array/src/expr/mod.rs @@ -34,7 +34,6 @@ pub(crate) mod field; pub mod forms; mod optimize; pub mod proto; -pub mod pruning; pub mod stats; pub mod transform; pub mod traversal; diff --git a/vortex-array/src/expr/pruning/mod.rs b/vortex-array/src/expr/pruning/mod.rs deleted file mode 100644 index 5ce2785f446..00000000000 --- a/vortex-array/src/expr/pruning/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -pub(crate) mod pruning_expr; -mod relation; - -pub use pruning_expr::RequiredStats; -pub use pruning_expr::checked_pruning_expr; -pub use pruning_expr::field_path_stat_field_name; -pub use relation::Relation; diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs deleted file mode 100644 index 23e5ca63eaf..00000000000 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ /dev/null @@ -1,680 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::cell::RefCell; -use std::iter; - -use itertools::Itertools; -use vortex_error::VortexResult; -use vortex_session::VortexSession; -use vortex_utils::aliases::hash_set::HashSet; - -use super::relation::Relation; -use crate::aggregate_fn::AggregateFnRef; -use crate::dtype::DType; -use crate::dtype::Field; -use crate::dtype::FieldName; -use crate::dtype::FieldPath; -use crate::dtype::FieldPathSet; -use crate::dtype::Nullability; -use crate::dtype::StructFields; -use crate::expr::Expression; -use crate::expr::analysis::referenced_field_paths; -use crate::expr::get_item; -use crate::expr::is_root; -use crate::expr::root; -use crate::expr::stats::Stat; -use crate::scalar_fn::fns::cast::Cast; -use crate::scalar_fn::fns::get_item::GetItem; -use crate::scalar_fn::fns::literal::Literal; -use crate::stats::bind::StatBinder; -use crate::stats::bind::bind_legacy_count_or_direct_aggregate; -use crate::stats::bind::bind_stats; - -pub type RequiredStats = Relation; - -#[doc(hidden)] -pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName { - field_path - .parts() - .iter() - .map(|f| match f { - Field::Name(n) => n.as_ref(), - Field::ElementType => todo!("element type not currently handled"), - }) - .chain(iter::once(stat.name())) - .join("_") - .into() -} - -/// Build a pruning expression using session-registered stats rewrite rules. -/// -/// A pruning expression is one that returns `true` for all positions where the original expression -/// cannot hold, and false if it cannot be determined from stats alone whether the positions can -/// be pruned. -/// -/// Some rewrites, such as `is_not_null(...)`, emit -/// [`row_count`][crate::scalar_fn::internal::row_count] placeholders. The evaluation layer must -/// replace those placeholders with the row count for its current scope before -/// executing the returned expression. -/// -/// The returned expression is lowered to stats-table field references. Stats not present in -/// `available_stats` are replaced with typed null literals, preserving three-valued pruning -/// semantics without requiring callers to materialize unavailable stats. -pub fn checked_pruning_expr( - expr: &Expression, - scope: &DType, - available_stats: &FieldPathSet, - session: &VortexSession, -) -> VortexResult> { - let Some(predicate) = expr.falsify(scope, session)? else { - return Ok(None); - }; - - let binder = RequiredStatsBinder { - scope, - available_stats, - required_stats: RefCell::new(Relation::new()), - bound_stats: RefCell::new(Vec::new()), - }; - let lowered = bind_stats(predicate, &binder)?; - let required_stats = filter_required_stats(&lowered, binder.required_stats.into_inner()); - // If no stats-table fields remain, only a constant `true` proof can prune. - // `false`, `null`, and non-constant expressions cannot justify building a - // stats-table pruning expression. - if required_stats.map().is_empty() && !matches!(bool_literal(&lowered), Some(Some(true))) { - return Ok(None); - } - - Ok(Some((lowered, required_stats))) -} - -struct RequiredStatsBinder<'a> { - scope: &'a DType, - available_stats: &'a FieldPathSet, - required_stats: RefCell, - bound_stats: RefCell>, -} - -impl StatBinder for RequiredStatsBinder<'_> { - fn scope(&self) -> &DType { - self.scope - } - - fn bound_scope(&self) -> DType { - DType::Struct( - StructFields::from_iter(self.bound_stats.borrow().iter().cloned()), - Nullability::NonNullable, - ) - } - - fn bind_stat( - &self, - input: &Expression, - stat: Stat, - stat_dtype: &DType, - ) -> VortexResult> { - let field_path = match direct_stat_field_path(input) { - Some(field_path) => field_path, - None => { - let field_paths = referenced_field_paths(input, self.scope)?; - let Some(field_path) = field_paths.iter().exactly_one().ok() else { - return Ok(None); - }; - field_path.clone() - } - }; - let stat_path = field_path.clone().push(stat.name()); - if !self.available_stats.contains(&stat_path) { - return Ok(None); - } - - let stat_field_name = field_path_stat_field_name(&field_path, stat); - let mut bound_stats = self.bound_stats.borrow_mut(); - if bound_stats - .iter() - .all(|(field_name, _)| field_name != stat_field_name) - { - bound_stats.push((stat_field_name.clone(), stat_dtype.clone())); - } - - self.required_stats.borrow_mut().insert(field_path, stat); - Ok(Some(get_item(stat_field_name, root()))) - } - - fn bind_aggregate( - &self, - input: &Expression, - aggregate_fn: &AggregateFnRef, - stat_dtype: &DType, - ) -> VortexResult> { - bind_legacy_count_or_direct_aggregate(self, input, aggregate_fn, stat_dtype) - } -} - -fn direct_stat_field_path(expr: &Expression) -> Option { - if is_root(expr) { - return Some(FieldPath::root()); - } - - if expr.is::() { - return direct_stat_field_path(expr.child(0)); - } - - let field_name = expr.as_opt::()?; - direct_stat_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) -} - -fn filter_required_stats(expr: &Expression, required_stats: RequiredStats) -> RequiredStats { - let referenced_names = referenced_stat_field_names(expr); - let mut filtered = Relation::new(); - for (field_path, stats) in required_stats { - for stat in stats { - if referenced_names.contains(&field_path_stat_field_name(&field_path, stat)) { - filtered.insert(field_path.clone(), stat); - } - } - } - filtered -} - -fn referenced_stat_field_names(expr: &Expression) -> HashSet { - let mut refs = HashSet::new(); - collect_referenced_stat_field_names(expr, &mut refs); - refs -} - -fn collect_referenced_stat_field_names(expr: &Expression, refs: &mut HashSet) { - if let Some(field_name) = expr.as_opt::() - && is_root(expr.child(0)) - { - refs.insert(field_name.clone()); - return; - } - - for child in expr.children().iter() { - collect_referenced_stat_field_names(child, refs); - } -} - -fn bool_literal(expr: &Expression) -> Option> { - expr.as_opt::()? - .as_bool_opt() - .map(|value| value.value()) -} - -#[cfg(test)] -mod tests { - use std::sync::LazyLock; - - use rstest::fixture; - use rstest::rstest; - use vortex_session::VortexSession; - use vortex_utils::aliases::hash_map::HashMap; - use vortex_utils::aliases::hash_set::HashSet; - - use super::RequiredStats; - use crate::dtype::DType; - use crate::dtype::FieldName; - use crate::dtype::FieldNames; - use crate::dtype::FieldPath; - use crate::dtype::FieldPathSet; - use crate::dtype::Nullability; - use crate::dtype::PType; - use crate::dtype::StructFields; - use crate::expr::Expression; - use crate::expr::and; - use crate::expr::between; - use crate::expr::cast; - use crate::expr::col; - use crate::expr::eq; - use crate::expr::get_item; - use crate::expr::gt; - use crate::expr::gt_eq; - use crate::expr::lit; - use crate::expr::lt; - use crate::expr::lt_eq; - use crate::expr::not_eq; - use crate::expr::or; - use crate::expr::pruning::checked_pruning_expr; - use crate::expr::pruning::field_path_stat_field_name; - use crate::expr::root; - use crate::expr::stats::Stat; - use crate::scalar_fn::fns::between::BetweenOptions; - use crate::scalar_fn::fns::between::StrictComparison; - use crate::stats::session::StatsSession; - - static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); - - fn test_scope() -> DType { - DType::Struct( - StructFields::from_iter([ - ("a", DType::Primitive(PType::I32, Nullability::NonNullable)), - ("b", DType::Primitive(PType::I32, Nullability::NonNullable)), - ("x", DType::Bool(Nullability::NonNullable)), - ("y", DType::Primitive(PType::I32, Nullability::NonNullable)), - ("z", DType::Primitive(PType::I32, Nullability::NonNullable)), - ( - "float_col", - DType::Primitive(PType::F32, Nullability::NonNullable), - ), - ( - "int_col", - DType::Primitive(PType::I32, Nullability::NonNullable), - ), - ]), - Nullability::NonNullable, - ) - } - - fn checked( - expr: &Expression, - available_stats: &FieldPathSet, - ) -> Option<(Expression, RequiredStats)> { - checked_pruning_expr(expr, &test_scope(), available_stats, &SESSION).unwrap() - } - - // Implement some checked pruning expressions. - #[fixture] - fn available_stats() -> FieldPathSet { - let field_a = FieldPath::from_name("a"); - let field_b = FieldPath::from_name("b"); - - FieldPathSet::from_iter([ - field_a.clone().push(Stat::Min.name()), - field_a.push(Stat::Max.name()), - field_b.clone().push(Stat::Min.name()), - field_b.push(Stat::Max.name()), - ]) - } - - #[rstest] - pub fn pruning_equals(available_stats: FieldPathSet) { - let name = FieldName::from("a"); - let literal_eq = lit(42); - let eq_expr = eq(get_item("a", root()), literal_eq.clone()); - let (converted, _refs) = checked(&eq_expr, &available_stats).unwrap(); - let expected_expr = or( - gt( - get_item( - field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min), - root(), - ), - literal_eq.clone(), - ), - gt( - literal_eq, - col(field_path_stat_field_name( - &FieldPath::from_name(name), - Stat::Max, - )), - ), - ); - assert_eq!(&converted, &expected_expr); - } - - #[rstest] - pub fn pruning_equals_column(available_stats: FieldPathSet) { - let column = FieldName::from("a"); - let other_col = FieldName::from("b"); - let eq_expr = eq(col(column.clone()), col(other_col.clone())); - - let (converted, refs) = checked(&eq_expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([ - ( - FieldPath::from_name(column.clone()), - HashSet::from_iter([Stat::Min, Stat::Max]) - ), - ( - FieldPath::from_name(other_col.clone()), - HashSet::from_iter([Stat::Max, Stat::Min]) - ) - ]) - ); - let expected_expr = or( - gt( - col(field_path_stat_field_name( - &FieldPath::from_name(column.clone()), - Stat::Min, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(other_col.clone()), - Stat::Max, - )), - ), - gt( - col(field_path_stat_field_name( - &FieldPath::from_name(other_col), - Stat::Min, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(column), - Stat::Max, - )), - ), - ); - assert_eq!(&converted, &expected_expr); - } - - #[rstest] - pub fn pruning_not_equals_column(available_stats: FieldPathSet) { - let column = FieldName::from("a"); - let other_col = FieldName::from("b"); - let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone())); - - let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([ - ( - FieldPath::from_name(column.clone()), - HashSet::from_iter([Stat::Min, Stat::Max]) - ), - ( - FieldPath::from_name(other_col.clone()), - HashSet::from_iter([Stat::Max, Stat::Min]) - ) - ]) - ); - let expected_expr = and( - eq( - col(field_path_stat_field_name( - &FieldPath::from_name(column.clone()), - Stat::Min, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(other_col.clone()), - Stat::Max, - )), - ), - eq( - col(field_path_stat_field_name( - &FieldPath::from_name(column), - Stat::Max, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(other_col), - Stat::Min, - )), - ), - ); - - assert_eq!(&converted, &expected_expr); - } - - #[rstest] - pub fn pruning_gt_column(available_stats: FieldPathSet) { - let column = FieldName::from("a"); - let other_col = FieldName::from("b"); - let other_expr = col(other_col.clone()); - let not_eq_expr = gt(col(column.clone()), other_expr); - - let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([ - ( - FieldPath::from_name(column.clone()), - HashSet::from_iter([Stat::Max]) - ), - ( - FieldPath::from_name(other_col.clone()), - HashSet::from_iter([Stat::Min]) - ) - ]) - ); - let expected_expr = lt_eq( - col(field_path_stat_field_name( - &FieldPath::from_name(column), - Stat::Max, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(other_col), - Stat::Min, - )), - ); - assert_eq!(&converted, &expected_expr); - } - - #[rstest] - pub fn pruning_gt_value(available_stats: FieldPathSet) { - let column = FieldName::from("a"); - let other_col = lit(42); - let not_eq_expr = gt(col(column.clone()), other_col.clone()); - - let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([( - FieldPath::from_name(column.clone()), - HashSet::from_iter([Stat::Max]) - ),]) - ); - let expected_expr = lt_eq( - col(field_path_stat_field_name( - &FieldPath::from_name(column), - Stat::Max, - )), - other_col, - ); - assert_eq!(&converted, &(expected_expr)); - } - - #[rstest] - pub fn pruning_lt_column(available_stats: FieldPathSet) { - let column = FieldName::from("a"); - let other_col = FieldName::from("b"); - let other_expr = col(other_col.clone()); - let not_eq_expr = lt(col(column.clone()), other_expr); - - let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([ - ( - FieldPath::from_name(column.clone()), - HashSet::from_iter([Stat::Min]) - ), - ( - FieldPath::from_name(other_col.clone()), - HashSet::from_iter([Stat::Max]) - ) - ]) - ); - let expected_expr = gt_eq( - col(field_path_stat_field_name( - &FieldPath::from_name(column), - Stat::Min, - )), - col(field_path_stat_field_name( - &FieldPath::from_name(other_col), - Stat::Max, - )), - ); - assert_eq!(&converted, &expected_expr); - } - - #[rstest] - pub fn pruning_lt_value(available_stats: FieldPathSet) { - // expression => a < 42 - // pruning expr => a.min >= 42 - let expr = lt(col("a"), lit(42)); - - let (converted, refs) = checked(&expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))]) - ); - assert_eq!(&converted, >_eq(col("a_min"), lit(42))); - } - - #[rstest] - fn pruning_identity(available_stats: FieldPathSet) { - let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50))); - - let (predicate, _) = checked(&expr, &available_stats).unwrap(); - - let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50))); - assert_eq!(&predicate.to_string(), &expected_expr.to_string()); - } - #[rstest] - pub fn pruning_and_or_operators(available_stats: FieldPathSet) { - // Test case: a > 10 AND a < 50 - let column = FieldName::from("a"); - let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50))); - let (predicate, _) = checked(&and_expr, &available_stats).unwrap(); - - // Expected: a_max <= 10 OR a_min >= 50 - assert_eq!( - &predicate, - &or( - lt_eq(col(FieldName::from("a_max")), lit(10)), - gt_eq(col(FieldName::from("a_min")), lit(50)), - ), - ); - } - - #[rstest] - fn test_gt_eq_with_booleans(available_stats: FieldPathSet) { - // Consider this unusual, but valid (in Arrow, BooleanArray implements ArrayOrd), filter expression: - // x > (y > z) - // The x column is a Boolean-valued column. The y and z columns are numeric. True > False. - // Suppose we had a Vortex zone whose min/max statistics for each column were: - // x: [True, True] - // y: [1, 2] - // z: [0, 2] - // The pruning predicate will convert the aforementioned expression into: - // x_max <= (y_min > z_min) - // If we evaluate that pruning expression on our zone we get: - // x_max <= (y_min > z_min) - // x_max <= (1 > 0 ) - // x_max <= True - // True <= True - // True - // If a pruning predicate evaluates to true then, as stated in PruningPredicate::evaluate: - // > a true value means the chunk can be pruned. - // But, the following record lies within the above intervals and *passes* the filter expression! We - // cannot prune this zone because we need this record! - // {x: True, y: 1, z: 2} - // x > (y > z) - // True > (1 > 2) - // True > False - // True - let expr = gt_eq(col("x"), gt(col("y"), col("z"))); - assert!(checked(&expr, &available_stats).is_none()); - // TODO(DK): a sufficiently complex pruner would produce: `x_max <= (y_max > z_min)` - } - - #[fixture] - fn available_stats_with_nans() -> FieldPathSet { - let float_col = FieldPath::from_name("float_col"); - let int_col = FieldPath::from_name("int_col"); - - FieldPathSet::from_iter([ - // Float columns will have a NaNCount. - float_col.clone().push(Stat::Min.name()), - float_col.clone().push(Stat::Max.name()), - float_col.push(Stat::NaNCount.name()), - // int columns will not have a NanCount serialized into the layout - int_col.clone().push(Stat::Min.name()), - int_col.push(Stat::Max.name()), - ]) - } - - #[rstest] - fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) { - let expr = gt_eq(col("float_col"), lit(f32::NAN)); - assert!(checked(&expr, &available_stats_with_nans).is_none()); - - // One half of the expression requires an all-non-NaN proof, the other half does not. - let expr = and( - gt(col("float_col"), lit(10f32)), - lt(col("int_col"), lit(10)), - ); - - let (converted, refs) = checked(&expr, &available_stats_with_nans).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([ - ( - FieldPath::from_name("float_col"), - HashSet::from_iter([Stat::Max, Stat::NaNCount]) - ), - ( - FieldPath::from_name("int_col"), - HashSet::from_iter([Stat::Min]) - ) - ]) - ); - assert_eq!( - &converted, - &or( - and( - eq(col("float_col_nan_count"), lit(0u64)), - lt_eq(col("float_col_max"), lit(10f32)), - ), - gt_eq(col("int_col_min"), lit(10)), - ) - ) - } - - #[rstest] - fn pruning_between(available_stats: FieldPathSet) { - let expr = between( - col("a"), - lit(10), - lit(50), - BetweenOptions { - lower_strict: StrictComparison::NonStrict, - upper_strict: StrictComparison::NonStrict, - }, - ); - let (converted, refs) = checked(&expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([( - FieldPath::from_name("a"), - HashSet::from_iter([Stat::Min, Stat::Max]) - )]) - ); - assert_eq!( - &converted, - &or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50))) - ); - } - - #[rstest] - fn pruning_cast_get_item_eq(available_stats: FieldPathSet) { - // This test verifies that cast properly forwards analysis methods to - // enable pruning. - let struct_dtype = DType::Struct( - StructFields::new( - FieldNames::from([FieldName::from("a"), FieldName::from("b")]), - vec![ - DType::Utf8(Nullability::Nullable), - DType::Utf8(Nullability::Nullable), - ], - ), - Nullability::NonNullable, - ); - let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value")); - let (converted, refs) = checked(&expr, &available_stats).unwrap(); - assert_eq!( - refs.map(), - &HashMap::from_iter([( - FieldPath::from_name("a"), - HashSet::from_iter([Stat::Min, Stat::Max]) - )]) - ); - assert_eq!( - &converted, - &or( - gt(col("a_min"), lit("value")), - gt(lit("value"), col("a_max")) - ) - ); - } -} diff --git a/vortex-array/src/expr/pruning/relation.rs b/vortex-array/src/expr/pruning/relation.rs deleted file mode 100644 index 3a33771c46e..00000000000 --- a/vortex-array/src/expr/pruning/relation.rs +++ /dev/null @@ -1,50 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::hash::Hash; - -use vortex_utils::aliases::hash_map::HashMap; -use vortex_utils::aliases::hash_map::IntoIter; -use vortex_utils::aliases::hash_set::HashSet; - -#[derive(Debug, Clone)] -pub struct Relation { - map: HashMap>, -} - -impl Default for Relation { - fn default() -> Self { - Self::new() - } -} - -impl Relation { - pub fn new() -> Self { - Relation { - map: HashMap::new(), - } - } - - pub fn insert(&mut self, k: K, v: V) { - self.map.entry(k).or_default().insert(v); - } - - pub fn map(&self) -> &HashMap> { - &self.map - } -} - -impl From>> for Relation { - fn from(value: HashMap>) -> Self { - Self { map: value } - } -} - -impl IntoIterator for Relation { - type Item = (K, HashSet); - type IntoIter = IntoIter>; - - fn into_iter(self) -> Self::IntoIter { - self.map.into_iter() - } -} diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index 850b074c9ab..4d9cbecede9 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -101,8 +101,6 @@ impl ScalarFnVTable for IsNotNull { mod tests { use vortex_buffer::buffer; use vortex_error::VortexExpect as _; - use vortex_utils::aliases::hash_map::HashMap; - use vortex_utils::aliases::hash_set::HashSet; use crate::IntoArray; use crate::LEGACY_SESSION; @@ -110,23 +108,13 @@ mod tests { use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; use crate::dtype::DType; - use crate::dtype::Field; - use crate::dtype::FieldPath; - use crate::dtype::FieldPathSet; use crate::dtype::Nullability; use crate::expr::col; - use crate::expr::eq; use crate::expr::get_item; use crate::expr::is_not_null; - use crate::expr::or; - use crate::expr::pruning::checked_pruning_expr; use crate::expr::root; - use crate::expr::stats::Stat; use crate::expr::test_harness; use crate::scalar::Scalar; - use crate::scalar_fn::EmptyOptions; - use crate::scalar_fn::internal::row_count::RowCount; - use crate::scalar_fn::vtable::ScalarFnVTableExt; #[test] fn dtype() { @@ -244,33 +232,4 @@ mod tests { fn test_is_not_null_sensitive() { assert!(is_not_null(col("a")).signature().is_null_sensitive()); } - - #[test] - fn test_is_not_null_falsification() { - let expr = is_not_null(col("a")); - - let (pruning_expr, st) = checked_pruning_expr( - &expr, - &test_harness::struct_dtype(), - &FieldPathSet::from_iter([FieldPath::from_iter([ - Field::Name("a".into()), - Field::Name("null_count".into()), - ])]), - &LEGACY_SESSION, - ) - .unwrap() - .unwrap(); - - assert_eq!( - &pruning_expr, - &or( - eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, [])), - eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, [])), - ) - ); - assert_eq!( - st.map(), - &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))]) - ); - } } diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index b4dc839a6cb..88cf0aa8830 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -92,8 +92,6 @@ impl ScalarFnVTable for IsNull { mod tests { use vortex_buffer::buffer; use vortex_error::VortexExpect as _; - use vortex_utils::aliases::hash_map::HashMap; - use vortex_utils::aliases::hash_set::HashSet; use crate::IntoArray; use crate::LEGACY_SESSION; @@ -101,19 +99,11 @@ mod tests { use crate::arrays::PrimitiveArray; use crate::arrays::StructArray; use crate::dtype::DType; - use crate::dtype::Field; - use crate::dtype::FieldPath; - use crate::dtype::FieldPathSet; use crate::dtype::Nullability; use crate::expr::col; - use crate::expr::eq; use crate::expr::get_item; use crate::expr::is_null; - use crate::expr::lit; - use crate::expr::or; - use crate::expr::pruning::checked_pruning_expr; use crate::expr::root; - use crate::expr::stats::Stat; use crate::expr::test_harness; use crate::scalar::Scalar; @@ -231,35 +221,6 @@ mod tests { assert_eq!(expr2.to_string(), "vortex.is_null($)"); } - #[test] - fn test_is_null_falsification() { - let expr = is_null(col("a")); - - let (pruning_expr, st) = checked_pruning_expr( - &expr, - &test_harness::struct_dtype(), - &FieldPathSet::from_iter([FieldPath::from_iter([ - Field::Name("a".into()), - Field::Name("null_count".into()), - ])]), - &LEGACY_SESSION, - ) - .unwrap() - .unwrap(); - - assert_eq!( - &pruning_expr, - &or( - eq(col("a_null_count"), lit(0u64)), - eq(col("a_null_count"), lit(0u64)), - ) - ); - assert_eq!( - st.map(), - &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))]) - ); - } - #[test] fn test_is_null_sensitive() { // is_null itself is null-sensitive diff --git a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs index 00f0bea10a6..79965f0c52f 100644 --- a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs +++ b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs @@ -399,8 +399,6 @@ mod tests { use rstest::rstest; use vortex_buffer::BitBuffer; use vortex_buffer::Buffer; - use vortex_utils::aliases::hash_map::HashMap; - use vortex_utils::aliases::hash_set::HashSet; use crate::ArrayRef; use crate::IntoArray; @@ -412,23 +410,13 @@ mod tests { #[expect(deprecated)] use crate::canonical::ToCanonical as _; use crate::dtype::DType; - use crate::dtype::Field; - use crate::dtype::FieldPath; - use crate::dtype::FieldPathSet; use crate::dtype::Nullability; use crate::dtype::PType::I32; use crate::dtype::StructFields; - use crate::expr::and; - use crate::expr::col; use crate::expr::get_item; - use crate::expr::gt; use crate::expr::list_contains; use crate::expr::lit; - use crate::expr::lt; - use crate::expr::or; - use crate::expr::pruning::checked_pruning_expr; use crate::expr::root; - use crate::expr::stats::Stat; use crate::scalar::Scalar; use crate::scalar_fn::fns::list_contains::BoolArray; use crate::scalar_fn::fns::list_contains::ConstantArray; @@ -575,56 +563,6 @@ mod tests { ); } - #[test] - pub fn list_falsification() { - let expr = list_contains( - lit(Scalar::list( - Arc::new(DType::Primitive(I32, Nullability::NonNullable)), - vec![1.into(), 2.into(), 3.into()], - Nullability::NonNullable, - )), - col("a"), - ); - let scope = DType::Struct( - StructFields::new( - ["a"].into(), - vec![DType::Primitive(I32, Nullability::NonNullable)], - ), - Nullability::NonNullable, - ); - - let (expr, st) = checked_pruning_expr( - &expr, - &scope, - &FieldPathSet::from_iter([ - FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]), - FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]), - ]), - &LEGACY_SESSION, - ) - .unwrap() - .unwrap(); - - assert_eq!( - &expr, - &and( - and( - or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),), - or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),) - ), - or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),) - ) - ); - - assert_eq!( - st.map(), - &HashMap::from_iter([( - FieldPath::from_name("a"), - HashSet::from([Stat::Min, Stat::Max]) - )]) - ); - } - #[test] pub fn test_display() { let expr = list_contains(get_item("tags", root()), lit("urgent")); diff --git a/vortex-file/src/file.rs b/vortex-file/src/file.rs index f701565c8c6..162f0e347e3 100644 --- a/vortex-file/src/file.rs +++ b/vortex-file/src/file.rs @@ -12,18 +12,9 @@ use std::sync::OnceLock; use itertools::Itertools; use vortex_array::ArrayRef; -use vortex_array::Columnar; -use vortex_array::IntoArray; -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::ConstantArray; use vortex_array::dtype::DType; -use vortex_array::dtype::Field; use vortex_array::dtype::FieldMask; -use vortex_array::dtype::FieldPath; -use vortex_array::dtype::FieldPathSet; use vortex_array::expr::Expression; -use vortex_array::expr::pruning::checked_pruning_expr; -use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_error::VortexResult; use vortex_layout::LayoutReader; use vortex_layout::scan::layout::LayoutReaderDataSource; @@ -32,11 +23,10 @@ use vortex_layout::scan::split_by::SplitBy; use vortex_layout::segments::SegmentSource; use vortex_scan::DataSourceRef; use vortex_session::VortexSession; -use vortex_utils::aliases::hash_map::HashMap; use crate::FileStatistics; use crate::footer::Footer; -use crate::pruning::extract_relevant_file_stats_as_struct_row; +use crate::pruning::can_prune_file_stats; use crate::v2::FileStatsLayoutReader; /// Represents a Vortex file, providing access to its metadata and content. @@ -202,61 +192,14 @@ impl VortexFile { return Ok(false); }; - let set = FieldPathSet::from_iter( - fields - .names() - .iter() - .zip(stats.stats_sets().iter()) - .flat_map(|(name, stats)| { - stats.iter().map(|(stat, _)| { - FieldPath::from_iter([ - Field::Name(name.clone()), - Field::Name(stat.name().into()), - ]) - }) - }), - ); - - let Some((predicate, required_stats)) = - checked_pruning_expr(filter, self.footer.dtype(), &set, &self.session)? - else { - return Ok(false); - }; - - let required_file_stats = HashMap::from_iter( - required_stats - .map() - .iter() - .map(|(path, stats)| (path.clone(), stats.clone())), - ); - - let Some(file_stats) = extract_relevant_file_stats_as_struct_row( - &required_file_stats, - stats.stats_sets(), + can_prune_file_stats( + filter, + self.footer.dtype(), + self.footer.row_count(), + stats, fields, - )? - else { - return Ok(false); - }; - - // Apply the predicate, then substitute any row_count placeholders in the resulting array - // tree with a ConstantArray carrying the file-level row count. - let applied = file_stats.apply(&predicate)?; - let row_count_replacement = - ConstantArray::new(self.footer.row_count(), applied.len()).into_array(); - let applied = substitute_row_count(applied, &row_count_replacement)?; - - let mut ctx = self.session.create_execution_ctx(); - Ok(match applied.execute::(&mut ctx)? { - Columnar::Constant(s) => s.scalar().as_bool().value() == Some(true), - Columnar::Canonical(c) => { - c.into_array() - .execute_scalar(0, &mut ctx)? - .as_bool() - .value() - == Some(true) - } - }) + &self.session, + ) } pub fn splits(&self) -> VortexResult>> { diff --git a/vortex-file/src/pruning.rs b/vortex-file/src/pruning.rs index 74ef90a034e..1fc79447d1d 100644 --- a/vortex-file/src/pruning.rs +++ b/vortex-file/src/pruning.rs @@ -1,78 +1,136 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::Arc; - -use vortex_array::ArrayRef; +use vortex_array::Canonical; use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::AggregateFnRef; use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::StructArray; -use vortex_array::dtype::Field; -use vortex_array::dtype::FieldName; -use vortex_array::dtype::FieldNames; +use vortex_array::arrays::NullArray; +use vortex_array::dtype::DType; use vortex_array::dtype::FieldPath; use vortex_array::dtype::StructFields; -use vortex_array::expr::pruning::field_path_stat_field_name; +use vortex_array::expr::Expression; +use vortex_array::expr::is_root; +use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; -use vortex_array::expr::stats::StatsProvider; -use vortex_array::stats::StatsSet; -use vortex_array::validity::Validity; -use vortex_error::VortexExpect; +use vortex_array::scalar::Scalar; +use vortex_array::scalar_fn::fns::cast::Cast; +use vortex_array::scalar_fn::fns::get_item::GetItem; +use vortex_array::scalar_fn::fns::literal::Literal; +use vortex_array::scalar_fn::internal::row_count::substitute_row_count; +use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_legacy_count_or_direct_aggregate; +use vortex_array::stats::bind::bind_stats; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use vortex_utils::aliases::hash_map::HashMap; -use vortex_utils::aliases::hash_set::HashSet; - -pub fn extract_relevant_file_stats_as_struct_row( - access: &HashMap>, - stats_sets: &Arc<[StatsSet]>, - struct_dtype: &StructFields, -) -> VortexResult> { - if access.is_empty() { - return StructArray::try_new(FieldNames::default(), vec![], 1, Validity::NonNullable) - .map(|s| Some(s.into_array())); +use vortex_session::VortexSession; + +use crate::FileStatistics; + +pub(crate) fn can_prune_file_stats( + expr: &Expression, + dtype: &DType, + row_count: u64, + file_stats: &FileStatistics, + struct_fields: &StructFields, + session: &VortexSession, +) -> VortexResult { + let Some(pruning_expr) = expr.falsify(dtype, session)? else { + return Ok(false); + }; + + let binder = FileStatsBinder { + dtype, + file_stats, + struct_fields, + }; + let pruning_expr = bind_stats(pruning_expr, &binder)?; + + let simplified = pruning_expr.optimize_recursive(&DType::Null)?; + if let Some(result) = simplified.as_opt::() { + return Ok(result.as_bool().value() == Some(true)); } - let mut columns: Vec<(FieldName, ArrayRef)> = Vec::with_capacity(access.len() * 2); - for (field_path, stats) in access.into_iter() { - if field_path.parts().len() != 1 { - return Ok(None); - } - let Field::Name(field) = &field_path.parts()[0] else { + let pruning = NullArray::new(1).into_array().apply(&pruning_expr)?; + let row_count_replacement = ConstantArray::new(row_count, pruning.len()).into_array(); + let pruning = substitute_row_count(pruning, &row_count_replacement)?; + + let mut ctx = session.create_execution_ctx(); + let result = pruning + .execute::(&mut ctx)? + .into_bool() + .into_array() + .execute_scalar(0, &mut ctx)?; + + Ok(result.as_bool().value() == Some(true)) +} + +struct FileStatsBinder<'a> { + dtype: &'a DType, + file_stats: &'a FileStatistics, + struct_fields: &'a StructFields, +} + +impl StatBinder for FileStatsBinder<'_> { + fn scope(&self) -> &DType { + self.dtype + } + + fn bound_scope(&self) -> DType { + DType::Null + } + + fn bind_stat( + &self, + input: &Expression, + stat: Stat, + _stat_dtype: &DType, + ) -> VortexResult> { + let Some(field_path) = direct_field_path(input) else { return Ok(None); }; + Ok(self.stat_ref(&field_path, stat)) + } - let field_idx = struct_dtype - .find(field) - .ok_or_else(|| vortex_err!("Missing field: {field}"))?; - let field_dtype = struct_dtype - .field_by_index(field_idx) - .vortex_expect("Field must exist"); + fn bind_aggregate( + &self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult> { + bind_legacy_count_or_direct_aggregate(self, input, aggregate_fn, stat_dtype) + } +} - let Some(stat_set) = stats_sets.get(field_idx) else { - vortex_bail!("missing stat field {} from stats set", field) - }; - let typed_stats = stat_set.as_typed_ref(&field_dtype); - - for stat in stats { - if matches!( - stat, - Stat::Max | Stat::Min | Stat::NaNCount | Stat::NullCount - ) { - let Some(stat_value) = typed_stats.get(*stat).as_exact() else { - vortex_bail!("missing stat {}, {} from stats set", field, stat) - }; - columns.push(( - field_path_stat_field_name(field_path, *stat), - ConstantArray::new(stat_value, 1).into_array(), - )); - } else { - todo!("unsupported file prune stat {stat}") - } +impl FileStatsBinder<'_> { + fn stat_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { + // FileStats currently only holds top-level field statistics. + if field_path.parts().len() != 1 { + return None; } + + let field_name = field_path.parts()[0].as_name()?; + let field_idx = self.struct_fields.find(field_name)?; + let field_stats = self.file_stats.stats_sets().get(field_idx)?; + + let stat_value = field_stats.get(stat).as_exact()?; + let field_dtype = self.struct_fields.field_by_index(field_idx)?; + let stat_dtype = stat.dtype(&field_dtype)?; + let stat_scalar = Scalar::try_new(stat_dtype, Some(stat_value)).ok()?; + + Some(lit(stat_scalar)) + } +} + +fn direct_field_path(expr: &Expression) -> Option { + if is_root(expr) { + return Some(FieldPath::root()); + } + + if expr.is::() { + return direct_field_path(expr.child(0)); } - Ok(Some( - StructArray::from_fields(columns.as_slice())?.into_array(), - )) + + let field_name = expr.as_opt::()?; + direct_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) } diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index c138c251ab6..5f2fdf4f50f 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -10,28 +10,11 @@ use std::ops::Range; use std::sync::Arc; -use vortex_array::Canonical; -use vortex_array::IntoArray; use vortex_array::MaskFuture; -use vortex_array::VortexSessionExecute; -use vortex_array::aggregate_fn::AggregateFnRef; -use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::NullArray; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; -use vortex_array::dtype::FieldPath; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; -use vortex_array::expr::is_root; -use vortex_array::expr::lit; -use vortex_array::expr::stats::Stat; -use vortex_array::scalar::Scalar; -use vortex_array::scalar_fn::fns::get_item::GetItem; -use vortex_array::scalar_fn::fns::literal::Literal; -use vortex_array::scalar_fn::internal::row_count::substitute_row_count; -use vortex_array::stats::bind::StatBinder; -use vortex_array::stats::bind::bind_legacy_count_or_direct_aggregate; -use vortex_array::stats::bind::bind_stats; use vortex_error::VortexResult; use vortex_layout::ArrayFuture; use vortex_layout::LayoutReader; @@ -43,6 +26,7 @@ use vortex_session::VortexSession; use vortex_utils::aliases::dash_map::DashMap; use crate::FileStatistics; +use crate::pruning::can_prune_file_stats; /// A [`LayoutReader`] decorator that prunes entire files based on file-level statistics. /// @@ -88,41 +72,14 @@ impl FileStatsLayoutReader { /// Row-count placeholders are resolved against the full file row count, /// independent of the requested row range. fn evaluate_file_stats(&self, expr: &Expression) -> VortexResult { - let Some(pruning_expr) = expr.falsify(self.child.dtype(), &self.session)? else { - // If there is no pruning expression, we can't prune. - return Ok(false); - }; - let pruning_expr = self.lower_stats(pruning_expr)?; - - // Stats lowering replaces available stats with literals and unavailable stats with nulls, - // so only row_count placeholders remain unresolved here. - let simplified = pruning_expr.optimize_recursive(&DType::Null)?; - if let Some(result) = simplified.as_opt::() { - // Can prune if the result is non-nullable and true - return Ok(result.as_bool().value() == Some(true)); - } - - // Sometimes expressions don't implement constant folding to literals... In this case, - // we apply the expression over a null array and substitute any row_count placeholders - // in the resulting array tree with the file's row count. - let pruning = NullArray::new(1).into_array().apply(&pruning_expr)?; - let row_count_replacement = - ConstantArray::new(self.child.row_count(), pruning.len()).into_array(); - let pruning = substitute_row_count(pruning, &row_count_replacement)?; - - let mut ctx = self.session.create_execution_ctx(); - let result = pruning - .execute::(&mut ctx)? - .into_bool() - .into_array() - .execute_scalar(0, &mut ctx)?; - - Ok(result.as_bool().value() == Some(true)) - } - - fn lower_stats(&self, predicate: Expression) -> VortexResult { - let binder = FileStatsBinder { reader: self }; - bind_stats(predicate, &binder) + can_prune_file_stats( + expr, + self.child.dtype(), + self.child.row_count(), + &self.file_stats, + &self.struct_fields, + &self.session, + ) } pub fn file_stats(&self) -> &FileStatistics { @@ -130,70 +87,6 @@ impl FileStatsLayoutReader { } } -struct FileStatsBinder<'a> { - reader: &'a FileStatsLayoutReader, -} - -impl StatBinder for FileStatsBinder<'_> { - fn scope(&self) -> &DType { - self.reader.child.dtype() - } - - fn bound_scope(&self) -> DType { - DType::Null - } - - fn bind_stat( - &self, - input: &Expression, - stat: Stat, - _stat_dtype: &DType, - ) -> VortexResult> { - let Some(field_path) = direct_field_path(input) else { - return Ok(None); - }; - Ok(self.reader.stat_ref(&field_path, stat)) - } - - fn bind_aggregate( - &self, - input: &Expression, - aggregate_fn: &AggregateFnRef, - stat_dtype: &DType, - ) -> VortexResult> { - bind_legacy_count_or_direct_aggregate(self, input, aggregate_fn, stat_dtype) - } -} - -fn direct_field_path(expr: &Expression) -> Option { - if is_root(expr) { - return Some(FieldPath::root()); - } - - let field_name = expr.as_opt::()?; - direct_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) -} - -impl FileStatsLayoutReader { - fn stat_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { - // FileStats currently only holds top-level field statistics. - if field_path.parts().len() != 1 { - return None; - } - - let field_name = field_path.parts()[0].as_name()?; - let field_idx = self.struct_fields.find(field_name)?; - let field_stats = self.file_stats.stats_sets().get(field_idx)?; - - let stat_value = field_stats.get(stat).as_exact()?; - let field_dtype = self.struct_fields.field_by_index(field_idx)?; - let stat_dtype = stat.dtype(&field_dtype)?; - let stat_scalar = Scalar::try_new(stat_dtype, Some(stat_value)).ok()?; - - Some(lit(stat_scalar)) - } -} - impl LayoutReader for FileStatsLayoutReader { fn name(&self) -> &Arc { self.child.name() From fff66b6473d057ab698170ac3e05388663a13de6 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 17 Jun 2026 12:48:27 -0400 Subject: [PATCH 15/28] Restore session falsification tests Signed-off-by: "Nicholas Gates" --- vortex-array/src/scalar_fn/fns/is_not_null.rs | 29 +++++++++ vortex-array/src/scalar_fn/fns/is_null.rs | 27 +++++++++ .../src/scalar_fn/fns/list_contains/mod.rs | 59 +++++++++++++++++++ 3 files changed, 115 insertions(+) diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index 4d9cbecede9..f2849f53ccf 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -99,8 +99,12 @@ impl ScalarFnVTable for IsNotNull { #[cfg(test)] mod tests { + use std::sync::LazyLock; + use vortex_buffer::buffer; use vortex_error::VortexExpect as _; + use vortex_error::VortexResult; + use vortex_session::VortexSession; use crate::IntoArray; use crate::LEGACY_SESSION; @@ -110,11 +114,22 @@ mod tests { use crate::dtype::DType; use crate::dtype::Nullability; use crate::expr::col; + use crate::expr::eq; use crate::expr::get_item; use crate::expr::is_not_null; + use crate::expr::or; use crate::expr::root; use crate::expr::test_harness; use crate::scalar::Scalar; + use crate::scalar_fn::EmptyOptions; + use crate::scalar_fn::ScalarFnVTableExt; + use crate::scalar_fn::internal::row_count::RowCount; + use crate::stats::StatsSession; + use crate::stats::all_null; + use crate::stats::null_count; + + static STATS_SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); #[test] fn dtype() { @@ -232,4 +247,18 @@ mod tests { fn test_is_not_null_sensitive() { assert!(is_not_null(col("a")).signature().is_null_sensitive()); } + + #[test] + fn test_is_not_null_falsification() -> VortexResult<()> { + let expr = is_not_null(col("a")); + + assert_eq!( + expr.falsify(&test_harness::struct_dtype(), &STATS_SESSION)?, + Some(or( + eq(null_count(col("a")), RowCount.new_expr(EmptyOptions, []),), + all_null(col("a")), + )) + ); + Ok(()) + } } diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index 88cf0aa8830..8df263a4b22 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -90,8 +90,12 @@ impl ScalarFnVTable for IsNull { #[cfg(test)] mod tests { + use std::sync::LazyLock; + use vortex_buffer::buffer; use vortex_error::VortexExpect as _; + use vortex_error::VortexResult; + use vortex_session::VortexSession; use crate::IntoArray; use crate::LEGACY_SESSION; @@ -101,11 +105,20 @@ mod tests { use crate::dtype::DType; use crate::dtype::Nullability; use crate::expr::col; + use crate::expr::eq; use crate::expr::get_item; use crate::expr::is_null; + use crate::expr::lit; + use crate::expr::or; use crate::expr::root; use crate::expr::test_harness; use crate::scalar::Scalar; + use crate::stats::StatsSession; + use crate::stats::all_non_null; + use crate::stats::null_count; + + static STATS_SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); #[test] fn dtype() { @@ -221,6 +234,20 @@ mod tests { assert_eq!(expr2.to_string(), "vortex.is_null($)"); } + #[test] + fn test_is_null_falsification() -> VortexResult<()> { + let expr = is_null(col("a")); + + assert_eq!( + expr.falsify(&test_harness::struct_dtype(), &STATS_SESSION)?, + Some(or( + eq(null_count(col("a")), lit(0u64)), + all_non_null(col("a")), + )) + ); + Ok(()) + } + #[test] fn test_is_null_sensitive() { // is_null itself is null-sensitive diff --git a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs index 79965f0c52f..4b39f51a7f2 100644 --- a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs +++ b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs @@ -394,11 +394,14 @@ fn list_is_not_empty( #[cfg(test)] mod tests { use std::sync::Arc; + use std::sync::LazyLock; use itertools::Itertools; use rstest::rstest; use vortex_buffer::BitBuffer; use vortex_buffer::Buffer; + use vortex_error::VortexResult; + use vortex_session::VortexSession; use crate::ArrayRef; use crate::IntoArray; @@ -413,17 +416,33 @@ mod tests { use crate::dtype::Nullability; use crate::dtype::PType::I32; use crate::dtype::StructFields; + use crate::expr::Expression; + use crate::expr::and; + use crate::expr::col; use crate::expr::get_item; + use crate::expr::gt; use crate::expr::list_contains; use crate::expr::lit; + use crate::expr::lt; + use crate::expr::or; use crate::expr::root; + use crate::expr::stats::Stat; use crate::scalar::Scalar; use crate::scalar_fn::fns::list_contains::BoolArray; use crate::scalar_fn::fns::list_contains::ConstantArray; use crate::scalar_fn::fns::list_contains::ListViewArray; use crate::scalar_fn::fns::list_contains::PrimitiveArray; + use crate::stats::StatsSession; + use crate::stats::stat as stat_expr; use crate::validity::Validity; + static STATS_SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + fn stat(expr: Expression, stat: Stat) -> Expression { + stat_expr(expr, stat.aggregate_fn().unwrap()) + } + fn test_array() -> ArrayRef { ListArray::try_new( PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(), @@ -563,6 +582,46 @@ mod tests { ); } + #[test] + pub fn list_falsification() -> VortexResult<()> { + let expr = list_contains( + lit(Scalar::list( + Arc::new(DType::Primitive(I32, Nullability::NonNullable)), + vec![1.into(), 2.into(), 3.into()], + Nullability::NonNullable, + )), + col("a"), + ); + let scope = DType::Struct( + StructFields::new( + ["a"].into(), + vec![DType::Primitive(I32, Nullability::NonNullable)], + ), + Nullability::NonNullable, + ); + + assert_eq!( + expr.falsify(&scope, &STATS_SESSION)?, + Some(and( + and( + or( + lt(stat(col("a"), Stat::Max), lit(1i32)), + gt(stat(col("a"), Stat::Min), lit(1i32)), + ), + or( + lt(stat(col("a"), Stat::Max), lit(2i32)), + gt(stat(col("a"), Stat::Min), lit(2i32)), + ) + ), + or( + lt(stat(col("a"), Stat::Max), lit(3i32)), + gt(stat(col("a"), Stat::Min), lit(3i32)), + ) + )) + ); + Ok(()) + } + #[test] pub fn test_display() { let expr = list_contains(get_item("tags", root()), lit("urgent")); From 25e9400366d56b9a466cd10527405527f706660a Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 17 Jun 2026 12:58:18 -0400 Subject: [PATCH 16/28] Bind only direct stat aggregates Signed-off-by: "Nicholas Gates" --- vortex-array/src/stats/bind.rs | 117 +------------------- vortex-array/src/stats/rewrite/builtins.rs | 22 +++- vortex-file/src/pruning.rs | 11 -- vortex-layout/src/layouts/zoned/zone_map.rs | 33 +++--- 4 files changed, 39 insertions(+), 144 deletions(-) diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 45381dbdef7..070097e127f 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -11,22 +11,14 @@ use vortex_error::VortexResult; use crate::aggregate_fn::AggregateFnRef; -use crate::aggregate_fn::fns::all_nan::AllNan; -use crate::aggregate_fn::fns::all_non_nan::AllNonNan; -use crate::aggregate_fn::fns::all_non_null::AllNonNull; -use crate::aggregate_fn::fns::all_null::AllNull; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::eq; use crate::expr::lit; use crate::expr::stats::Stat; use crate::expr::traversal::NodeExt; use crate::expr::traversal::Transformed; use crate::scalar::Scalar; -use crate::scalar_fn::EmptyOptions; -use crate::scalar_fn::ScalarFnVTableExt; use crate::scalar_fn::fns::stat::StatFn; -use crate::scalar_fn::internal::row_count::RowCount; /// A target that can bind abstract statistics to concrete expressions. pub trait StatBinder { @@ -114,46 +106,6 @@ pub fn bind_stats( lowered.optimize_recursive(&binder.bound_scope()) } -/// Bind aggregate stats that can be derived from legacy count stat slots. -/// -/// This is an opt-in helper for stats backends that materialize `NaNCount` and -/// `NullCount`, but do not materialize aggregate boolean stats directly. -pub fn bind_legacy_count_aggregate( - binder: &B, - input: &Expression, - aggregate_fn: &AggregateFnRef, -) -> VortexResult> { - if aggregate_fn.is::() { - let Some(nan_count) = binder.bind_legacy_stat(input, Stat::NaNCount)? else { - return Ok(None); - }; - return Ok(Some(eq(nan_count, RowCount.new_expr(EmptyOptions, [])))); - } - - if aggregate_fn.is::() { - let Some(nan_count) = binder.bind_legacy_stat(input, Stat::NaNCount)? else { - return Ok(None); - }; - return Ok(Some(eq(nan_count, lit(0u64)))); - } - - if aggregate_fn.is::() { - let Some(null_count) = binder.bind_legacy_stat(input, Stat::NullCount)? else { - return Ok(None); - }; - return Ok(Some(eq(null_count, RowCount.new_expr(EmptyOptions, [])))); - } - - if aggregate_fn.is::() { - let Some(null_count) = binder.bind_legacy_stat(input, Stat::NullCount)? else { - return Ok(None); - }; - return Ok(Some(eq(null_count, lit(0u64)))); - } - - Ok(None) -} - /// Bind an aggregate function that has a direct legacy [`Stat`] slot. pub fn bind_direct_aggregate_stat( binder: &B, @@ -167,24 +119,6 @@ pub fn bind_direct_aggregate_stat( binder.bind_stat(input, stat, stat_dtype) } -/// Bind aggregate stats for backends that expose legacy count-derived stats. -/// -/// Backends using this helper first bind aggregate facts derivable from -/// `NaNCount` and `NullCount`, then fall back to direct aggregate-to-stat -/// mappings. -pub fn bind_legacy_count_or_direct_aggregate( - binder: &B, - input: &Expression, - aggregate_fn: &AggregateFnRef, - stat_dtype: &DType, -) -> VortexResult> { - if let Some(bound) = bind_legacy_count_aggregate(binder, input, aggregate_fn)? { - return Ok(Some(bound)); - } - - bind_direct_aggregate_stat(binder, input, aggregate_fn, stat_dtype) -} - fn bind_stat_fn( expr: &Expression, scope: &DType, @@ -218,6 +152,7 @@ mod tests { use crate::expr::or; use crate::expr::root; use crate::stats::all_non_nan; + use crate::stats::nan_count; struct TestBinder { input_scope: DType, @@ -268,30 +203,21 @@ mod tests { Ok(None) } } - - fn bind_aggregate( - &self, - input: &Expression, - aggregate_fn: &AggregateFnRef, - stat_dtype: &DType, - ) -> VortexResult> { - bind_legacy_count_or_direct_aggregate(self, input, aggregate_fn, stat_dtype) - } } #[test] - fn all_non_nan_binds_to_nan_count_zero() -> VortexResult<()> { + fn nan_count_binds_to_legacy_stat_slot() -> VortexResult<()> { let binder = TestBinder::new(true); - let bound = bind_stats(all_non_nan(col("f")), &binder)?; + let bound = bind_stats(nan_count(col("f")), &binder)?; - assert_eq!(bound, eq(col("f_nan_count"), lit(0u64))); + assert_eq!(bound, col("f_nan_count")); Ok(()) } #[test] - fn all_non_nan_lowers_to_null_when_nan_count_is_missing() -> VortexResult<()> { - let binder = TestBinder::new(false); + fn all_non_nan_does_not_derive_from_nan_count() -> VortexResult<()> { + let binder = TestBinder::new(true); let bound = bind_stats(all_non_nan(col("f")), &binder)?; @@ -313,37 +239,6 @@ mod tests { Ok(()) } - #[test] - fn default_binder_does_not_derive_all_non_nan_from_nan_count() -> VortexResult<()> { - struct DefaultBinder(TestBinder); - - impl StatBinder for DefaultBinder { - fn scope(&self) -> &DType { - self.0.scope() - } - - fn bound_scope(&self) -> DType { - self.0.bound_scope() - } - - fn bind_stat( - &self, - input: &Expression, - stat: Stat, - stat_dtype: &DType, - ) -> VortexResult> { - self.0.bind_stat(input, stat, stat_dtype) - } - } - - let binder = DefaultBinder(TestBinder::new(true)); - - let bound = bind_stats(all_non_nan(col("f")), &binder)?; - - assert_eq!(bound, lit(Scalar::null(DType::Bool(Nullability::Nullable)))); - Ok(()) - } - #[test] fn unrelated_expressions_do_not_request_nan_count() -> VortexResult<()> { let binder = TestBinder::new(false); diff --git a/vortex-array/src/stats/rewrite/builtins.rs b/vortex-array/src/stats/rewrite/builtins.rs index c60f5869055..ad6737d67c0 100644 --- a/vortex-array/src/stats/rewrite/builtins.rs +++ b/vortex-array/src/stats/rewrite/builtins.rs @@ -464,9 +464,16 @@ fn all_non_nan_stat( return Ok(None); } - Ok(Some(stat_fn( - expr.clone(), - AllNonNan.bind(AggregateEmptyOptions), + let Some(nan_count) = stat_expr(expr, Stat::NaNCount, ctx) else { + return Ok(Some(stat_fn( + expr.clone(), + AllNonNan.bind(AggregateEmptyOptions), + ))); + }; + + Ok(Some(or( + eq(nan_count, lit(0u64)), + stat_fn(expr.clone(), AllNonNan.bind(AggregateEmptyOptions)), ))) } @@ -659,8 +666,11 @@ mod tests { expr.satisfy(&test_scope(), &SESSION) } - fn nan_free(expr: Expression) -> Expression { - stat_fn(expr, AllNonNan.bind(AggregateEmptyOptions)) + fn nan_guard(expr: Expression) -> Expression { + or( + eq(stat(expr.clone(), Stat::NaNCount), lit(0u64)), + stat_fn(expr, AllNonNan.bind(AggregateEmptyOptions)), + ) } #[test] @@ -888,7 +898,7 @@ mod tests { assert_eq!( falsify(&expr)?, Some(and( - nan_free(col("f")), + nan_guard(col("f")), lt_eq(cast(stat(col("f"), Stat::Max), dtype), lit(5i32)), )) ); diff --git a/vortex-file/src/pruning.rs b/vortex-file/src/pruning.rs index 1fc79447d1d..af08a64b687 100644 --- a/vortex-file/src/pruning.rs +++ b/vortex-file/src/pruning.rs @@ -4,7 +4,6 @@ use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; -use vortex_array::aggregate_fn::AggregateFnRef; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::NullArray; use vortex_array::dtype::DType; @@ -20,7 +19,6 @@ use vortex_array::scalar_fn::fns::get_item::GetItem; use vortex_array::scalar_fn::fns::literal::Literal; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::stats::bind::StatBinder; -use vortex_array::stats::bind::bind_legacy_count_or_direct_aggregate; use vortex_array::stats::bind::bind_stats; use vortex_error::VortexResult; use vortex_session::VortexSession; @@ -91,15 +89,6 @@ impl StatBinder for FileStatsBinder<'_> { }; Ok(self.stat_ref(&field_path, stat)) } - - fn bind_aggregate( - &self, - input: &Expression, - aggregate_fn: &AggregateFnRef, - stat_dtype: &DType, - ) -> VortexResult> { - bind_legacy_count_or_direct_aggregate(self, input, aggregate_fn, stat_dtype) - } } impl FileStatsBinder<'_> { diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index a14dc71691f..efacf84b6bc 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -8,7 +8,6 @@ use std::sync::Arc; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; -use vortex_array::aggregate_fn::AggregateFnRef; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; @@ -23,7 +22,6 @@ use vortex_array::expr::stats::Stat; use vortex_array::scalar_fn::internal::row_count::contains_row_count; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::stats::bind::StatBinder; -use vortex_array::stats::bind::bind_legacy_count_or_direct_aggregate; use vortex_array::stats::bind::bind_stats; use vortex_array::validity::Validity; use vortex_buffer::buffer; @@ -159,15 +157,6 @@ impl StatBinder for ZoneMapStatsBinder<'_> { } Ok(Some(get_item(stat.name(), root()))) } - - fn bind_aggregate( - &self, - input: &Expression, - aggregate_fn: &AggregateFnRef, - stat_dtype: &DType, - ) -> VortexResult> { - bind_legacy_count_or_direct_aggregate(self, input, aggregate_fn, stat_dtype) - } } /// Build per-zone row counts for a zone map. @@ -357,7 +346,7 @@ mod tests { } #[test] - fn all_null_stat_fn_uses_null_count() { + fn is_null_falsification_uses_null_count() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -371,12 +360,18 @@ mod tests { ) .unwrap(); - let mask = zone_map.prune(&all_null(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true])); + let expr = is_null(root()); + let pruning_expr = falsify(&expr, PType::U64.into()); + + let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([true, false, false]) + ); } #[test] - fn all_non_null_stat_fn_uses_null_count() { + fn abstract_null_stats_do_not_derive_from_null_count() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -390,10 +385,16 @@ mod tests { ) .unwrap(); + let mask = zone_map.prune(&all_null(root()), &SESSION).unwrap(); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([false, false, false]) + ); + let mask = zone_map.prune(&all_non_null(root()), &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), - BoolArray::from_iter([true, false, false]) + BoolArray::from_iter([false, false, false]) ); } From b2224548c97172f0dd34ff217aadbcdd55ea5475 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 17 Jun 2026 13:03:49 -0400 Subject: [PATCH 17/28] Split NaN stat rewrite proofs Signed-off-by: "Nicholas Gates" --- vortex-array/src/stats/rewrite/builtins.rs | 200 ++++++++++++++++----- 1 file changed, 151 insertions(+), 49 deletions(-) diff --git a/vortex-array/src/stats/rewrite/builtins.rs b/vortex-array/src/stats/rewrite/builtins.rs index ad6737d67c0..559efd299fc 100644 --- a/vortex-array/src/stats/rewrite/builtins.rs +++ b/vortex-array/src/stats/rewrite/builtins.rs @@ -52,7 +52,8 @@ use crate::stats::session::StatsSession; /// Register built-in stats rewrite rules. pub(crate) fn register_builtins(session: &StatsSession) { - session.register_rewrite(BinaryStatsRewrite); + session.register_rewrite(BinaryStatsRewrite::legacy_nan_count()); + session.register_rewrite(BinaryStatsRewrite::all_non_nan()); session.register_rewrite(BetweenStatsRewrite); session.register_rewrite(IsNullLegacyStatsRewrite); session.register_rewrite(IsNullAllNonNullStatsRewrite); @@ -61,12 +62,30 @@ pub(crate) fn register_builtins(session: &StatsSession) { session.register_rewrite(IsNotNullAllNullStatsRewrite); session.register_rewrite(IsNotNullAllNonNullStatsRewrite); session.register_rewrite(LikeStatsRewrite); - session.register_rewrite(ListContainsStatsRewrite); - session.register_rewrite(DynamicComparisonStatsRewrite); + session.register_rewrite(ListContainsStatsRewrite::legacy_nan_count()); + session.register_rewrite(ListContainsStatsRewrite::all_non_nan()); + session.register_rewrite(DynamicComparisonStatsRewrite::legacy_nan_count()); + session.register_rewrite(DynamicComparisonStatsRewrite::all_non_nan()); } #[derive(Debug)] -struct BinaryStatsRewrite; +struct BinaryStatsRewrite { + nan_proof: NanProof, +} + +impl BinaryStatsRewrite { + fn legacy_nan_count() -> Self { + Self { + nan_proof: NanProof::LegacyNanCount, + } + } + + fn all_non_nan() -> Self { + Self { + nan_proof: NanProof::AllNonNan, + } + } +} impl StatsRewriteRule for BinaryStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { @@ -87,8 +106,11 @@ impl StatsRewriteRule for BinaryStatsRewrite { let left = min(lhs, ctx).zip(max(rhs, ctx)).map(|(a, b)| gt(a, b)); let right = min(rhs, ctx).zip(max(lhs, ctx)).map(|(a, b)| gt(a, b)); or_collect(left.into_iter().chain(right)) - .map(|value_predicate| with_nan_predicate(ctx, lhs, rhs, value_predicate)) + .map(|value_predicate| { + with_nan_predicate(ctx, self.nan_proof, lhs, rhs, value_predicate) + }) .transpose()? + .flatten() } Operator::NotEq => min(lhs, ctx) .zip(max(rhs, ctx)) @@ -96,35 +118,47 @@ impl StatsRewriteRule for BinaryStatsRewrite { .map(|((min_lhs, max_rhs), (max_lhs, min_rhs))| { with_nan_predicate( ctx, + self.nan_proof, lhs, rhs, and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)), ) }) - .transpose()?, + .transpose()? + .flatten(), Operator::Gt => max(lhs, ctx) .zip(min(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt_eq(a, b))) - .transpose()?, + .map(|(a, b)| with_nan_predicate(ctx, self.nan_proof, lhs, rhs, lt_eq(a, b))) + .transpose()? + .flatten(), Operator::Gte => max(lhs, ctx) .zip(min(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt(a, b))) - .transpose()?, + .map(|(a, b)| with_nan_predicate(ctx, self.nan_proof, lhs, rhs, lt(a, b))) + .transpose()? + .flatten(), Operator::Lt => min(lhs, ctx) .zip(max(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt_eq(a, b))) - .transpose()?, + .map(|(a, b)| with_nan_predicate(ctx, self.nan_proof, lhs, rhs, gt_eq(a, b))) + .transpose()? + .flatten(), Operator::Lte => min(lhs, ctx) .zip(max(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt(a, b))) - .transpose()?, + .map(|(a, b)| with_nan_predicate(ctx, self.nan_proof, lhs, rhs, gt(a, b))) + .transpose()? + .flatten(), Operator::And => { + if !self.nan_proof.emits_unguarded_rewrites() { + return Ok(None); + } + let lhs_falsifier = ctx.falsify(lhs)?; let rhs_falsifier = ctx.falsify(rhs)?; or_collect(lhs_falsifier.into_iter().chain(rhs_falsifier)) } Operator::Or => match (ctx.falsify(lhs)?, ctx.falsify(rhs)?) { - (Some(lhs), Some(rhs)) => Some(and(lhs, rhs)), + (Some(lhs), Some(rhs)) if self.nan_proof.emits_unguarded_rewrites() => { + Some(and(lhs, rhs)) + } _ => None, }, Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, @@ -332,7 +366,23 @@ impl StatsRewriteRule for LikeStatsRewrite { } #[derive(Debug)] -struct ListContainsStatsRewrite; +struct ListContainsStatsRewrite { + nan_proof: NanProof, +} + +impl ListContainsStatsRewrite { + fn legacy_nan_count() -> Self { + Self { + nan_proof: NanProof::LegacyNanCount, + } + } + + fn all_non_nan() -> Self { + Self { + nan_proof: NanProof::AllNonNan, + } + } +} impl StatsRewriteRule for ListContainsStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { @@ -375,13 +425,32 @@ impl StatsRewriteRule for ListContainsStatsRewrite { ) })); value_predicate - .map(|value_predicate| with_all_non_nan_predicate(ctx, [needle], value_predicate)) + .map(|value_predicate| { + with_all_non_nan_predicate(ctx, self.nan_proof, [needle], value_predicate) + }) .transpose() + .map(Option::flatten) } } #[derive(Debug)] -struct DynamicComparisonStatsRewrite; +struct DynamicComparisonStatsRewrite { + nan_proof: NanProof, +} + +impl DynamicComparisonStatsRewrite { + fn legacy_nan_count() -> Self { + Self { + nan_proof: NanProof::LegacyNanCount, + } + } + + fn all_non_nan() -> Self { + Self { + nan_proof: NanProof::AllNonNan, + } + } +} impl StatsRewriteRule for DynamicComparisonStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { @@ -414,7 +483,7 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite { }, [lhs_stat], ); - with_all_non_nan_predicate(ctx, [lhs], value_predicate).map(Some) + with_all_non_nan_predicate(ctx, self.nan_proof, [lhs], value_predicate) } } @@ -438,43 +507,64 @@ fn all_non_null(expr: &Expression) -> Expression { stat_fn(expr.clone(), AllNonNull.bind(AggregateEmptyOptions)) } +#[derive(Debug, Clone, Copy)] +enum NanProof { + LegacyNanCount, + AllNonNan, +} + +impl NanProof { + fn emits_unguarded_rewrites(self) -> bool { + matches!(self, Self::LegacyNanCount) + } +} + +enum NanCheck { + NotNeeded, + Check(Expression), + Unavailable, +} + // Min/max do not order NaN values, so comparison rewrites are only sound when every // candidate value is known to be non-NaN. Cast result dtypes are not enough: a cast // from float to non-float still needs a proof about the float source values. fn all_non_nan_stat( ctx: &StatsRewriteCtx<'_>, + nan_proof: NanProof, expr: &Expression, -) -> VortexResult> { +) -> VortexResult { if let Some(scalar) = expr.as_opt::() { let Some(value) = scalar.as_primitive_opt() else { - return Ok(None); + return Ok(NanCheck::NotNeeded); }; - return Ok(value.is_nan().then(|| lit(false))); + return Ok(if value.is_nan() { + NanCheck::Check(lit(false)) + } else { + NanCheck::NotNeeded + }); } if expr.is::() { if !has_nans(&ctx.return_dtype(expr.child(0))?) { - return Ok(None); + return Ok(NanCheck::NotNeeded); } - return all_non_nan_stat(ctx, expr.child(0)); + return all_non_nan_stat(ctx, nan_proof, expr.child(0)); } if !has_nans(&ctx.return_dtype(expr)?) { - return Ok(None); + return Ok(NanCheck::NotNeeded); } - let Some(nan_count) = stat_expr(expr, Stat::NaNCount, ctx) else { - return Ok(Some(stat_fn( - expr.clone(), - AllNonNan.bind(AggregateEmptyOptions), - ))); - }; - - Ok(Some(or( - eq(nan_count, lit(0u64)), - stat_fn(expr.clone(), AllNonNan.bind(AggregateEmptyOptions)), - ))) + Ok(match nan_proof { + NanProof::LegacyNanCount => match stat_expr(expr, Stat::NaNCount, ctx) { + Some(nan_count) => NanCheck::Check(eq(nan_count, lit(0u64))), + None => NanCheck::Unavailable, + }, + NanProof::AllNonNan => { + NanCheck::Check(stat_fn(expr.clone(), AllNonNan.bind(AggregateEmptyOptions))) + } + }) } fn has_nans(dtype: &DType) -> bool { @@ -510,31 +600,37 @@ fn stat_expr(expr: &Expression, stat: Stat, ctx: &StatsRewriteCtx<'_>) -> Option fn with_nan_predicate( ctx: &StatsRewriteCtx<'_>, + nan_proof: NanProof, lhs: &Expression, rhs: &Expression, value_predicate: Expression, -) -> VortexResult { - with_all_non_nan_predicate(ctx, [lhs, rhs], value_predicate) +) -> VortexResult> { + with_all_non_nan_predicate(ctx, nan_proof, [lhs, rhs], value_predicate) } fn with_all_non_nan_predicate<'a>( ctx: &StatsRewriteCtx<'_>, + nan_proof: NanProof, exprs: impl IntoIterator, value_predicate: Expression, -) -> VortexResult { +) -> VortexResult> { let mut nan_checks = Vec::new(); for expr in exprs { - if let Some(check) = all_non_nan_stat(ctx, expr)? { - nan_checks.push(check); + match all_non_nan_stat(ctx, nan_proof, expr)? { + NanCheck::NotNeeded => {} + NanCheck::Check(check) => nan_checks.push(check), + NanCheck::Unavailable => return Ok(None), } } let nan_predicate = and_collect(nan_checks); Ok(match nan_predicate { - Some(nan_check) => and(nan_check, value_predicate), + Some(nan_check) => Some(and(nan_check, value_predicate)), // No possible NaN-bearing expression remains, so the value predicate is - // already guarded. - None => value_predicate, + // already guarded. Only one registered rule emits this unguarded + // rewrite so non-float comparisons are not duplicated. + None if nan_proof.emits_unguarded_rewrites() => Some(value_predicate), + None => None, }) } @@ -666,10 +762,16 @@ mod tests { expr.satisfy(&test_scope(), &SESSION) } - fn nan_guard(expr: Expression) -> Expression { + fn nan_guarded(expr: Expression, value_predicate: Expression) -> Expression { or( - eq(stat(expr.clone(), Stat::NaNCount), lit(0u64)), - stat_fn(expr, AllNonNan.bind(AggregateEmptyOptions)), + and( + eq(stat(expr.clone(), Stat::NaNCount), lit(0u64)), + value_predicate.clone(), + ), + and( + stat_fn(expr, AllNonNan.bind(AggregateEmptyOptions)), + value_predicate, + ), ) } @@ -897,8 +999,8 @@ mod tests { assert_eq!( falsify(&expr)?, - Some(and( - nan_guard(col("f")), + Some(nan_guarded( + col("f"), lt_eq(cast(stat(col("f"), Stat::Max), dtype), lit(5i32)), )) ); From e82764f11d85616a114089ebc92f425705e20827 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 17 Jun 2026 13:09:38 -0400 Subject: [PATCH 18/28] Use concrete stat terminology Signed-off-by: "Nicholas Gates" --- vortex-array/src/stats/bind.rs | 17 +++------- vortex-array/src/stats/rewrite/builtins.rs | 36 +++++++++++----------- 2 files changed, 22 insertions(+), 31 deletions(-) diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 070097e127f..328266c356e 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -47,9 +47,9 @@ pub trait StatBinder { /// Bind `aggregate_fn(input)` to a concrete expression. /// - /// The default implementation supports aggregate functions with legacy - /// [`Stat`] slots. Binders that store richer aggregate stats can override - /// this method without extending the generic stats binding walker. + /// The default implementation supports aggregate functions that map + /// directly to [`Stat`] slots. Binders that store richer aggregate stats can + /// override this method without extending the generic stats binding walker. fn bind_aggregate( &self, input: &Expression, @@ -59,15 +59,6 @@ pub trait StatBinder { bind_direct_aggregate_stat(self, input, aggregate_fn, stat_dtype) } - /// Bind one of the legacy stat slots for `input`. - fn bind_legacy_stat(&self, input: &Expression, stat: Stat) -> VortexResult> { - let input_dtype = input.return_dtype(self.scope())?; - let Some(stat_dtype) = stat.dtype(&input_dtype) else { - return Ok(None); - }; - self.bind_stat(input, stat, &stat_dtype) - } - /// Expression to use when a stat is unavailable. /// /// The default is a nullable null literal, which preserves three-valued @@ -206,7 +197,7 @@ mod tests { } #[test] - fn nan_count_binds_to_legacy_stat_slot() -> VortexResult<()> { + fn nan_count_binds_to_direct_stat_slot() -> VortexResult<()> { let binder = TestBinder::new(true); let bound = bind_stats(nan_count(col("f")), &binder)?; diff --git a/vortex-array/src/stats/rewrite/builtins.rs b/vortex-array/src/stats/rewrite/builtins.rs index 559efd299fc..0717d035afe 100644 --- a/vortex-array/src/stats/rewrite/builtins.rs +++ b/vortex-array/src/stats/rewrite/builtins.rs @@ -52,19 +52,19 @@ use crate::stats::session::StatsSession; /// Register built-in stats rewrite rules. pub(crate) fn register_builtins(session: &StatsSession) { - session.register_rewrite(BinaryStatsRewrite::legacy_nan_count()); + session.register_rewrite(BinaryStatsRewrite::nan_count()); session.register_rewrite(BinaryStatsRewrite::all_non_nan()); session.register_rewrite(BetweenStatsRewrite); - session.register_rewrite(IsNullLegacyStatsRewrite); + session.register_rewrite(IsNullNullCountStatsRewrite); session.register_rewrite(IsNullAllNonNullStatsRewrite); session.register_rewrite(IsNullAllNullStatsRewrite); - session.register_rewrite(IsNotNullLegacyStatsRewrite); + session.register_rewrite(IsNotNullNullCountStatsRewrite); session.register_rewrite(IsNotNullAllNullStatsRewrite); session.register_rewrite(IsNotNullAllNonNullStatsRewrite); session.register_rewrite(LikeStatsRewrite); - session.register_rewrite(ListContainsStatsRewrite::legacy_nan_count()); + session.register_rewrite(ListContainsStatsRewrite::nan_count()); session.register_rewrite(ListContainsStatsRewrite::all_non_nan()); - session.register_rewrite(DynamicComparisonStatsRewrite::legacy_nan_count()); + session.register_rewrite(DynamicComparisonStatsRewrite::nan_count()); session.register_rewrite(DynamicComparisonStatsRewrite::all_non_nan()); } @@ -74,9 +74,9 @@ struct BinaryStatsRewrite { } impl BinaryStatsRewrite { - fn legacy_nan_count() -> Self { + fn nan_count() -> Self { Self { - nan_proof: NanProof::LegacyNanCount, + nan_proof: NanProof::NanCount, } } @@ -191,9 +191,9 @@ impl StatsRewriteRule for BetweenStatsRewrite { } #[derive(Debug)] -struct IsNullLegacyStatsRewrite; +struct IsNullNullCountStatsRewrite; -impl StatsRewriteRule for IsNullLegacyStatsRewrite { +impl StatsRewriteRule for IsNullNullCountStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { IsNull.id() } @@ -251,9 +251,9 @@ impl StatsRewriteRule for IsNullAllNullStatsRewrite { } #[derive(Debug)] -struct IsNotNullLegacyStatsRewrite; +struct IsNotNullNullCountStatsRewrite; -impl StatsRewriteRule for IsNotNullLegacyStatsRewrite { +impl StatsRewriteRule for IsNotNullNullCountStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { IsNotNull.id() } @@ -371,9 +371,9 @@ struct ListContainsStatsRewrite { } impl ListContainsStatsRewrite { - fn legacy_nan_count() -> Self { + fn nan_count() -> Self { Self { - nan_proof: NanProof::LegacyNanCount, + nan_proof: NanProof::NanCount, } } @@ -439,9 +439,9 @@ struct DynamicComparisonStatsRewrite { } impl DynamicComparisonStatsRewrite { - fn legacy_nan_count() -> Self { + fn nan_count() -> Self { Self { - nan_proof: NanProof::LegacyNanCount, + nan_proof: NanProof::NanCount, } } @@ -509,13 +509,13 @@ fn all_non_null(expr: &Expression) -> Expression { #[derive(Debug, Clone, Copy)] enum NanProof { - LegacyNanCount, + NanCount, AllNonNan, } impl NanProof { fn emits_unguarded_rewrites(self) -> bool { - matches!(self, Self::LegacyNanCount) + matches!(self, Self::NanCount) } } @@ -557,7 +557,7 @@ fn all_non_nan_stat( } Ok(match nan_proof { - NanProof::LegacyNanCount => match stat_expr(expr, Stat::NaNCount, ctx) { + NanProof::NanCount => match stat_expr(expr, Stat::NaNCount, ctx) { Some(nan_count) => NanCheck::Check(eq(nan_count, lit(0u64))), None => NanCheck::Unavailable, }, From 619706f1bdaa15ccd1a10a609278ccf1d781546f Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 17 Jun 2026 13:12:16 -0400 Subject: [PATCH 19/28] Inline stats rewrite proof variants Signed-off-by: "Nicholas Gates" --- vortex-array/src/stats/rewrite/builtins.rs | 66 ++++++---------------- 1 file changed, 18 insertions(+), 48 deletions(-) diff --git a/vortex-array/src/stats/rewrite/builtins.rs b/vortex-array/src/stats/rewrite/builtins.rs index 0717d035afe..f5ed2d3a0df 100644 --- a/vortex-array/src/stats/rewrite/builtins.rs +++ b/vortex-array/src/stats/rewrite/builtins.rs @@ -52,8 +52,12 @@ use crate::stats::session::StatsSession; /// Register built-in stats rewrite rules. pub(crate) fn register_builtins(session: &StatsSession) { - session.register_rewrite(BinaryStatsRewrite::nan_count()); - session.register_rewrite(BinaryStatsRewrite::all_non_nan()); + session.register_rewrite(BinaryStatsRewrite { + nan_proof: NanProof::NanCount, + }); + session.register_rewrite(BinaryStatsRewrite { + nan_proof: NanProof::AllNonNan, + }); session.register_rewrite(BetweenStatsRewrite); session.register_rewrite(IsNullNullCountStatsRewrite); session.register_rewrite(IsNullAllNonNullStatsRewrite); @@ -62,10 +66,18 @@ pub(crate) fn register_builtins(session: &StatsSession) { session.register_rewrite(IsNotNullAllNullStatsRewrite); session.register_rewrite(IsNotNullAllNonNullStatsRewrite); session.register_rewrite(LikeStatsRewrite); - session.register_rewrite(ListContainsStatsRewrite::nan_count()); - session.register_rewrite(ListContainsStatsRewrite::all_non_nan()); - session.register_rewrite(DynamicComparisonStatsRewrite::nan_count()); - session.register_rewrite(DynamicComparisonStatsRewrite::all_non_nan()); + session.register_rewrite(ListContainsStatsRewrite { + nan_proof: NanProof::NanCount, + }); + session.register_rewrite(ListContainsStatsRewrite { + nan_proof: NanProof::AllNonNan, + }); + session.register_rewrite(DynamicComparisonStatsRewrite { + nan_proof: NanProof::NanCount, + }); + session.register_rewrite(DynamicComparisonStatsRewrite { + nan_proof: NanProof::AllNonNan, + }); } #[derive(Debug)] @@ -73,20 +85,6 @@ struct BinaryStatsRewrite { nan_proof: NanProof, } -impl BinaryStatsRewrite { - fn nan_count() -> Self { - Self { - nan_proof: NanProof::NanCount, - } - } - - fn all_non_nan() -> Self { - Self { - nan_proof: NanProof::AllNonNan, - } - } -} - impl StatsRewriteRule for BinaryStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { Binary.id() @@ -370,20 +368,6 @@ struct ListContainsStatsRewrite { nan_proof: NanProof, } -impl ListContainsStatsRewrite { - fn nan_count() -> Self { - Self { - nan_proof: NanProof::NanCount, - } - } - - fn all_non_nan() -> Self { - Self { - nan_proof: NanProof::AllNonNan, - } - } -} - impl StatsRewriteRule for ListContainsStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { ListContains.id() @@ -438,20 +422,6 @@ struct DynamicComparisonStatsRewrite { nan_proof: NanProof, } -impl DynamicComparisonStatsRewrite { - fn nan_count() -> Self { - Self { - nan_proof: NanProof::NanCount, - } - } - - fn all_non_nan() -> Self { - Self { - nan_proof: NanProof::AllNonNan, - } - } -} - impl StatsRewriteRule for DynamicComparisonStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { DynamicComparison.id() From 9a7c7663fc6ad694164ae021b65608faf9386151 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 17 Jun 2026 13:25:49 -0400 Subject: [PATCH 20/28] Split stats rewrite rule structs Signed-off-by: "Nicholas Gates" --- vortex-array/src/stats/rewrite/builtins.rs | 443 +++++++++++++-------- 1 file changed, 273 insertions(+), 170 deletions(-) diff --git a/vortex-array/src/stats/rewrite/builtins.rs b/vortex-array/src/stats/rewrite/builtins.rs index f5ed2d3a0df..ab076738210 100644 --- a/vortex-array/src/stats/rewrite/builtins.rs +++ b/vortex-array/src/stats/rewrite/builtins.rs @@ -52,12 +52,8 @@ use crate::stats::session::StatsSession; /// Register built-in stats rewrite rules. pub(crate) fn register_builtins(session: &StatsSession) { - session.register_rewrite(BinaryStatsRewrite { - nan_proof: NanProof::NanCount, - }); - session.register_rewrite(BinaryStatsRewrite { - nan_proof: NanProof::AllNonNan, - }); + session.register_rewrite(BinaryNanCountStatsRewrite); + session.register_rewrite(BinaryAllNonNanStatsRewrite); session.register_rewrite(BetweenStatsRewrite); session.register_rewrite(IsNullNullCountStatsRewrite); session.register_rewrite(IsNullAllNonNullStatsRewrite); @@ -66,26 +62,33 @@ pub(crate) fn register_builtins(session: &StatsSession) { session.register_rewrite(IsNotNullAllNullStatsRewrite); session.register_rewrite(IsNotNullAllNonNullStatsRewrite); session.register_rewrite(LikeStatsRewrite); - session.register_rewrite(ListContainsStatsRewrite { - nan_proof: NanProof::NanCount, - }); - session.register_rewrite(ListContainsStatsRewrite { - nan_proof: NanProof::AllNonNan, - }); - session.register_rewrite(DynamicComparisonStatsRewrite { - nan_proof: NanProof::NanCount, - }); - session.register_rewrite(DynamicComparisonStatsRewrite { - nan_proof: NanProof::AllNonNan, - }); + session.register_rewrite(ListContainsNanCountStatsRewrite); + session.register_rewrite(ListContainsAllNonNanStatsRewrite); + session.register_rewrite(DynamicComparisonNanCountStatsRewrite); + session.register_rewrite(DynamicComparisonAllNonNanStatsRewrite); } #[derive(Debug)] -struct BinaryStatsRewrite { - nan_proof: NanProof, +struct BinaryNanCountStatsRewrite; + +impl StatsRewriteRule for BinaryNanCountStatsRewrite { + fn scalar_fn_id(&self) -> ScalarFnId { + Binary.id() + } + + fn falsify( + &self, + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, + ) -> VortexResult> { + binary_falsify(expr, ctx, nan_count_check, true) + } } -impl StatsRewriteRule for BinaryStatsRewrite { +#[derive(Debug)] +struct BinaryAllNonNanStatsRewrite; + +impl StatsRewriteRule for BinaryAllNonNanStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { Binary.id() } @@ -95,73 +98,110 @@ impl StatsRewriteRule for BinaryStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - let operator = expr.as_::(); - let lhs = expr.child(0); - let rhs = expr.child(1); - - Ok(match operator { - Operator::Eq => { - let left = min(lhs, ctx).zip(max(rhs, ctx)).map(|(a, b)| gt(a, b)); - let right = min(rhs, ctx).zip(max(lhs, ctx)).map(|(a, b)| gt(a, b)); - or_collect(left.into_iter().chain(right)) - .map(|value_predicate| { - with_nan_predicate(ctx, self.nan_proof, lhs, rhs, value_predicate) - }) - .transpose()? - .flatten() - } - Operator::NotEq => min(lhs, ctx) - .zip(max(rhs, ctx)) - .zip(max(lhs, ctx).zip(min(rhs, ctx))) - .map(|((min_lhs, max_rhs), (max_lhs, min_rhs))| { + binary_falsify(expr, ctx, all_non_nan_check, false) + } +} + +fn binary_falsify( + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, + nan_check: fn(&StatsRewriteCtx<'_>, &Expression) -> VortexResult, + emit_unguarded_rewrites: bool, +) -> VortexResult> { + let operator = expr.as_::(); + let lhs = expr.child(0); + let rhs = expr.child(1); + + Ok(match operator { + Operator::Eq => { + let left = min(lhs, ctx).zip(max(rhs, ctx)).map(|(a, b)| gt(a, b)); + let right = min(rhs, ctx).zip(max(lhs, ctx)).map(|(a, b)| gt(a, b)); + or_collect(left.into_iter().chain(right)) + .map(|value_predicate| { with_nan_predicate( ctx, - self.nan_proof, + nan_check, + emit_unguarded_rewrites, lhs, rhs, - and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)), + value_predicate, ) }) .transpose()? - .flatten(), - Operator::Gt => max(lhs, ctx) - .zip(min(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, self.nan_proof, lhs, rhs, lt_eq(a, b))) - .transpose()? - .flatten(), - Operator::Gte => max(lhs, ctx) - .zip(min(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, self.nan_proof, lhs, rhs, lt(a, b))) - .transpose()? - .flatten(), - Operator::Lt => min(lhs, ctx) - .zip(max(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, self.nan_proof, lhs, rhs, gt_eq(a, b))) - .transpose()? - .flatten(), - Operator::Lte => min(lhs, ctx) - .zip(max(rhs, ctx)) - .map(|(a, b)| with_nan_predicate(ctx, self.nan_proof, lhs, rhs, gt(a, b))) - .transpose()? - .flatten(), - Operator::And => { - if !self.nan_proof.emits_unguarded_rewrites() { - return Ok(None); - } - - let lhs_falsifier = ctx.falsify(lhs)?; - let rhs_falsifier = ctx.falsify(rhs)?; - or_collect(lhs_falsifier.into_iter().chain(rhs_falsifier)) + .flatten() + } + Operator::NotEq => min(lhs, ctx) + .zip(max(rhs, ctx)) + .zip(max(lhs, ctx).zip(min(rhs, ctx))) + .map(|((min_lhs, max_rhs), (max_lhs, min_rhs))| { + with_nan_predicate( + ctx, + nan_check, + emit_unguarded_rewrites, + lhs, + rhs, + and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)), + ) + }) + .transpose()? + .flatten(), + Operator::Gt => max(lhs, ctx) + .zip(min(rhs, ctx)) + .map(|(a, b)| { + with_nan_predicate( + ctx, + nan_check, + emit_unguarded_rewrites, + lhs, + rhs, + lt_eq(a, b), + ) + }) + .transpose()? + .flatten(), + Operator::Gte => max(lhs, ctx) + .zip(min(rhs, ctx)) + .map(|(a, b)| { + with_nan_predicate(ctx, nan_check, emit_unguarded_rewrites, lhs, rhs, lt(a, b)) + }) + .transpose()? + .flatten(), + Operator::Lt => min(lhs, ctx) + .zip(max(rhs, ctx)) + .map(|(a, b)| { + with_nan_predicate( + ctx, + nan_check, + emit_unguarded_rewrites, + lhs, + rhs, + gt_eq(a, b), + ) + }) + .transpose()? + .flatten(), + Operator::Lte => min(lhs, ctx) + .zip(max(rhs, ctx)) + .map(|(a, b)| { + with_nan_predicate(ctx, nan_check, emit_unguarded_rewrites, lhs, rhs, gt(a, b)) + }) + .transpose()? + .flatten(), + Operator::And => { + if !emit_unguarded_rewrites { + return Ok(None); } - Operator::Or => match (ctx.falsify(lhs)?, ctx.falsify(rhs)?) { - (Some(lhs), Some(rhs)) if self.nan_proof.emits_unguarded_rewrites() => { - Some(and(lhs, rhs)) - } - _ => None, - }, - Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, - }) - } + + let lhs_falsifier = ctx.falsify(lhs)?; + let rhs_falsifier = ctx.falsify(rhs)?; + or_collect(lhs_falsifier.into_iter().chain(rhs_falsifier)) + } + Operator::Or => match (ctx.falsify(lhs)?, ctx.falsify(rhs)?) { + (Some(lhs), Some(rhs)) if emit_unguarded_rewrites => Some(and(lhs, rhs)), + _ => None, + }, + Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, + }) } #[derive(Debug)] @@ -364,11 +404,26 @@ impl StatsRewriteRule for LikeStatsRewrite { } #[derive(Debug)] -struct ListContainsStatsRewrite { - nan_proof: NanProof, +struct ListContainsNanCountStatsRewrite; + +impl StatsRewriteRule for ListContainsNanCountStatsRewrite { + fn scalar_fn_id(&self) -> ScalarFnId { + ListContains.id() + } + + fn falsify( + &self, + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, + ) -> VortexResult> { + list_contains_falsify(expr, ctx, nan_count_check, true) + } } -impl StatsRewriteRule for ListContainsStatsRewrite { +#[derive(Debug)] +struct ListContainsAllNonNanStatsRewrite; + +impl StatsRewriteRule for ListContainsAllNonNanStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { ListContains.id() } @@ -378,51 +433,64 @@ impl StatsRewriteRule for ListContainsStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - let list = expr.child(0); - let needle = expr.child(1); + list_contains_falsify(expr, ctx, all_non_nan_check, false) + } +} - let Some(list_scalar) = literal_stat(list, Stat::Min) else { - return Ok(None); - }; - let elements = list_scalar - .as_opt::() - .and_then(|literal| literal.as_list_opt()) - .and_then(|list| list.elements()); - let Some(elements) = elements else { - return Ok(None); - }; - if elements.is_empty() { - return Ok(Some(lit(true))); - } +fn list_contains_falsify( + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, + nan_check: fn(&StatsRewriteCtx<'_>, &Expression) -> VortexResult, + emit_unguarded_rewrites: bool, +) -> VortexResult> { + let list = expr.child(0); + let needle = expr.child(1); + + let Some(list_scalar) = literal_stat(list, Stat::Min) else { + return Ok(None); + }; + let elements = list_scalar + .as_opt::() + .and_then(|literal| literal.as_list_opt()) + .and_then(|list| list.elements()); + let Some(elements) = elements else { + return Ok(None); + }; + if elements.is_empty() { + return Ok(emit_unguarded_rewrites.then(|| lit(true))); + } - let Some(value_max) = max(needle, ctx) else { - return Ok(None); - }; - let Some(value_min) = min(needle, ctx) else { - return Ok(None); - }; + let Some(value_max) = max(needle, ctx) else { + return Ok(None); + }; + let Some(value_min) = min(needle, ctx) else { + return Ok(None); + }; - let value_predicate = and_collect(elements.iter().map(|value| { - or( - lt(value_max.clone(), lit(value.clone())), - gt(value_min.clone(), lit(value.clone())), + let value_predicate = and_collect(elements.iter().map(|value| { + or( + lt(value_max.clone(), lit(value.clone())), + gt(value_min.clone(), lit(value.clone())), + ) + })); + value_predicate + .map(|value_predicate| { + with_all_non_nan_predicate( + ctx, + nan_check, + emit_unguarded_rewrites, + [needle], + value_predicate, ) - })); - value_predicate - .map(|value_predicate| { - with_all_non_nan_predicate(ctx, self.nan_proof, [needle], value_predicate) - }) - .transpose() - .map(Option::flatten) - } + }) + .transpose() + .map(Option::flatten) } #[derive(Debug)] -struct DynamicComparisonStatsRewrite { - nan_proof: NanProof, -} +struct DynamicComparisonNanCountStatsRewrite; -impl StatsRewriteRule for DynamicComparisonStatsRewrite { +impl StatsRewriteRule for DynamicComparisonNanCountStatsRewrite { fn scalar_fn_id(&self) -> ScalarFnId { DynamicComparison.id() } @@ -432,31 +500,63 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - let dynamic = expr.as_::(); - let lhs = expr.child(0); - - let Some((operator, lhs_stat)) = (match dynamic.operator { - CompareOperator::Eq | CompareOperator::NotEq => None, - CompareOperator::Gt => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)), - CompareOperator::Gte => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)), - CompareOperator::Lt => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)), - CompareOperator::Lte => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)), - }) else { - return Ok(None); - }; + dynamic_comparison_falsify(expr, ctx, nan_count_check, true) + } +} - let value_predicate = DynamicComparison.new_expr( - DynamicComparisonExpr { - operator, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - [lhs_stat], - ); - with_all_non_nan_predicate(ctx, self.nan_proof, [lhs], value_predicate) +#[derive(Debug)] +struct DynamicComparisonAllNonNanStatsRewrite; + +impl StatsRewriteRule for DynamicComparisonAllNonNanStatsRewrite { + fn scalar_fn_id(&self) -> ScalarFnId { + DynamicComparison.id() + } + + fn falsify( + &self, + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, + ) -> VortexResult> { + dynamic_comparison_falsify(expr, ctx, all_non_nan_check, false) } } +fn dynamic_comparison_falsify( + expr: &Expression, + ctx: &StatsRewriteCtx<'_>, + nan_check: fn(&StatsRewriteCtx<'_>, &Expression) -> VortexResult, + emit_unguarded_rewrites: bool, +) -> VortexResult> { + let dynamic = expr.as_::(); + let lhs = expr.child(0); + + let Some((operator, lhs_stat)) = (match dynamic.operator { + CompareOperator::Eq | CompareOperator::NotEq => None, + CompareOperator::Gt => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)), + CompareOperator::Gte => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)), + CompareOperator::Lt => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)), + CompareOperator::Lte => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)), + }) else { + return Ok(None); + }; + + let value_predicate = DynamicComparison.new_expr( + DynamicComparisonExpr { + operator, + rhs: Arc::clone(&dynamic.rhs), + default: !dynamic.default, + }, + [lhs_stat], + ); + with_all_non_nan_predicate( + ctx, + nan_check, + emit_unguarded_rewrites, + [lhs], + value_predicate, + ) +} + fn min(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option { stat_expr(expr, Stat::Min, ctx) } @@ -477,18 +577,6 @@ fn all_non_null(expr: &Expression) -> Expression { stat_fn(expr.clone(), AllNonNull.bind(AggregateEmptyOptions)) } -#[derive(Debug, Clone, Copy)] -enum NanProof { - NanCount, - AllNonNan, -} - -impl NanProof { - fn emits_unguarded_rewrites(self) -> bool { - matches!(self, Self::NanCount) - } -} - enum NanCheck { NotNeeded, Check(Expression), @@ -498,10 +586,25 @@ enum NanCheck { // Min/max do not order NaN values, so comparison rewrites are only sound when every // candidate value is known to be non-NaN. Cast result dtypes are not enough: a cast // from float to non-float still needs a proof about the float source values. -fn all_non_nan_stat( +fn nan_count_check(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> VortexResult { + nan_check(ctx, expr, |expr| { + match stat_expr(expr, Stat::NaNCount, ctx) { + Some(nan_count) => NanCheck::Check(eq(nan_count, lit(0u64))), + None => NanCheck::Unavailable, + } + }) +} + +fn all_non_nan_check(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> VortexResult { + nan_check(ctx, expr, |expr| { + NanCheck::Check(stat_fn(expr.clone(), AllNonNan.bind(AggregateEmptyOptions))) + }) +} + +fn nan_check( ctx: &StatsRewriteCtx<'_>, - nan_proof: NanProof, expr: &Expression, + proof: impl FnOnce(&Expression) -> NanCheck, ) -> VortexResult { if let Some(scalar) = expr.as_opt::() { let Some(value) = scalar.as_primitive_opt() else { @@ -519,22 +622,14 @@ fn all_non_nan_stat( return Ok(NanCheck::NotNeeded); } - return all_non_nan_stat(ctx, nan_proof, expr.child(0)); + return nan_check(ctx, expr.child(0), proof); } if !has_nans(&ctx.return_dtype(expr)?) { return Ok(NanCheck::NotNeeded); } - Ok(match nan_proof { - NanProof::NanCount => match stat_expr(expr, Stat::NaNCount, ctx) { - Some(nan_count) => NanCheck::Check(eq(nan_count, lit(0u64))), - None => NanCheck::Unavailable, - }, - NanProof::AllNonNan => { - NanCheck::Check(stat_fn(expr.clone(), AllNonNan.bind(AggregateEmptyOptions))) - } - }) + Ok(proof(expr)) } fn has_nans(dtype: &DType) -> bool { @@ -570,23 +665,31 @@ fn stat_expr(expr: &Expression, stat: Stat, ctx: &StatsRewriteCtx<'_>) -> Option fn with_nan_predicate( ctx: &StatsRewriteCtx<'_>, - nan_proof: NanProof, + nan_check: fn(&StatsRewriteCtx<'_>, &Expression) -> VortexResult, + emit_unguarded_rewrites: bool, lhs: &Expression, rhs: &Expression, value_predicate: Expression, ) -> VortexResult> { - with_all_non_nan_predicate(ctx, nan_proof, [lhs, rhs], value_predicate) + with_all_non_nan_predicate( + ctx, + nan_check, + emit_unguarded_rewrites, + [lhs, rhs], + value_predicate, + ) } fn with_all_non_nan_predicate<'a>( ctx: &StatsRewriteCtx<'_>, - nan_proof: NanProof, + nan_check: fn(&StatsRewriteCtx<'_>, &Expression) -> VortexResult, + emit_unguarded_rewrites: bool, exprs: impl IntoIterator, value_predicate: Expression, ) -> VortexResult> { let mut nan_checks = Vec::new(); for expr in exprs { - match all_non_nan_stat(ctx, nan_proof, expr)? { + match nan_check(ctx, expr)? { NanCheck::NotNeeded => {} NanCheck::Check(check) => nan_checks.push(check), NanCheck::Unavailable => return Ok(None), @@ -599,7 +702,7 @@ fn with_all_non_nan_predicate<'a>( // No possible NaN-bearing expression remains, so the value predicate is // already guarded. Only one registered rule emits this unguarded // rewrite so non-float comparisons are not duplicated. - None if nan_proof.emits_unguarded_rewrites() => Some(value_predicate), + None if emit_unguarded_rewrites => Some(value_predicate), None => None, }) } From 91fc94f6888637bc1dfc458b721e08cf4abb5a78 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 17 Jun 2026 13:46:26 -0400 Subject: [PATCH 21/28] Clean up stats rewrite proof helpers Signed-off-by: "Nicholas Gates" --- vortex-array/src/stats/rewrite/builtins.rs | 171 +++++++-------------- 1 file changed, 58 insertions(+), 113 deletions(-) diff --git a/vortex-array/src/stats/rewrite/builtins.rs b/vortex-array/src/stats/rewrite/builtins.rs index ab076738210..f5047d08c8f 100644 --- a/vortex-array/src/stats/rewrite/builtins.rs +++ b/vortex-array/src/stats/rewrite/builtins.rs @@ -81,7 +81,7 @@ impl StatsRewriteRule for BinaryNanCountStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - binary_falsify(expr, ctx, nan_count_check, true) + binary_falsify::(expr, ctx) } } @@ -98,15 +98,13 @@ impl StatsRewriteRule for BinaryAllNonNanStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - binary_falsify(expr, ctx, all_non_nan_check, false) + binary_falsify::(expr, ctx) } } -fn binary_falsify( +fn binary_falsify( expr: &Expression, ctx: &StatsRewriteCtx<'_>, - nan_check: fn(&StatsRewriteCtx<'_>, &Expression) -> VortexResult, - emit_unguarded_rewrites: bool, ) -> VortexResult> { let operator = expr.as_::(); let lhs = expr.child(0); @@ -117,16 +115,7 @@ fn binary_falsify( let left = min(lhs, ctx).zip(max(rhs, ctx)).map(|(a, b)| gt(a, b)); let right = min(rhs, ctx).zip(max(lhs, ctx)).map(|(a, b)| gt(a, b)); or_collect(left.into_iter().chain(right)) - .map(|value_predicate| { - with_nan_predicate( - ctx, - nan_check, - emit_unguarded_rewrites, - lhs, - rhs, - value_predicate, - ) - }) + .map(|value_predicate| with_non_nan_guards::

(ctx, [lhs, rhs], value_predicate)) .transpose()? .flatten() } @@ -134,12 +123,9 @@ fn binary_falsify( .zip(max(rhs, ctx)) .zip(max(lhs, ctx).zip(min(rhs, ctx))) .map(|((min_lhs, max_rhs), (max_lhs, min_rhs))| { - with_nan_predicate( + with_non_nan_guards::

( ctx, - nan_check, - emit_unguarded_rewrites, - lhs, - rhs, + [lhs, rhs], and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)), ) }) @@ -147,48 +133,26 @@ fn binary_falsify( .flatten(), Operator::Gt => max(lhs, ctx) .zip(min(rhs, ctx)) - .map(|(a, b)| { - with_nan_predicate( - ctx, - nan_check, - emit_unguarded_rewrites, - lhs, - rhs, - lt_eq(a, b), - ) - }) + .map(|(a, b)| with_non_nan_guards::

(ctx, [lhs, rhs], lt_eq(a, b))) .transpose()? .flatten(), Operator::Gte => max(lhs, ctx) .zip(min(rhs, ctx)) - .map(|(a, b)| { - with_nan_predicate(ctx, nan_check, emit_unguarded_rewrites, lhs, rhs, lt(a, b)) - }) + .map(|(a, b)| with_non_nan_guards::

(ctx, [lhs, rhs], lt(a, b))) .transpose()? .flatten(), Operator::Lt => min(lhs, ctx) .zip(max(rhs, ctx)) - .map(|(a, b)| { - with_nan_predicate( - ctx, - nan_check, - emit_unguarded_rewrites, - lhs, - rhs, - gt_eq(a, b), - ) - }) + .map(|(a, b)| with_non_nan_guards::

(ctx, [lhs, rhs], gt_eq(a, b))) .transpose()? .flatten(), Operator::Lte => min(lhs, ctx) .zip(max(rhs, ctx)) - .map(|(a, b)| { - with_nan_predicate(ctx, nan_check, emit_unguarded_rewrites, lhs, rhs, gt(a, b)) - }) + .map(|(a, b)| with_non_nan_guards::

(ctx, [lhs, rhs], gt(a, b))) .transpose()? .flatten(), Operator::And => { - if !emit_unguarded_rewrites { + if !P::EMIT_UNGUARDED_REWRITES { return Ok(None); } @@ -197,7 +161,7 @@ fn binary_falsify( or_collect(lhs_falsifier.into_iter().chain(rhs_falsifier)) } Operator::Or => match (ctx.falsify(lhs)?, ctx.falsify(rhs)?) { - (Some(lhs), Some(rhs)) if emit_unguarded_rewrites => Some(and(lhs, rhs)), + (Some(lhs), Some(rhs)) if P::EMIT_UNGUARDED_REWRITES => Some(and(lhs, rhs)), _ => None, }, Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, @@ -416,7 +380,7 @@ impl StatsRewriteRule for ListContainsNanCountStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - list_contains_falsify(expr, ctx, nan_count_check, true) + list_contains_falsify::(expr, ctx) } } @@ -433,15 +397,13 @@ impl StatsRewriteRule for ListContainsAllNonNanStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - list_contains_falsify(expr, ctx, all_non_nan_check, false) + list_contains_falsify::(expr, ctx) } } -fn list_contains_falsify( +fn list_contains_falsify( expr: &Expression, ctx: &StatsRewriteCtx<'_>, - nan_check: fn(&StatsRewriteCtx<'_>, &Expression) -> VortexResult, - emit_unguarded_rewrites: bool, ) -> VortexResult> { let list = expr.child(0); let needle = expr.child(1); @@ -457,7 +419,7 @@ fn list_contains_falsify( return Ok(None); }; if elements.is_empty() { - return Ok(emit_unguarded_rewrites.then(|| lit(true))); + return Ok(P::EMIT_UNGUARDED_REWRITES.then(|| lit(true))); } let Some(value_max) = max(needle, ctx) else { @@ -474,15 +436,7 @@ fn list_contains_falsify( ) })); value_predicate - .map(|value_predicate| { - with_all_non_nan_predicate( - ctx, - nan_check, - emit_unguarded_rewrites, - [needle], - value_predicate, - ) - }) + .map(|value_predicate| with_non_nan_guards::

(ctx, [needle], value_predicate)) .transpose() .map(Option::flatten) } @@ -500,7 +454,7 @@ impl StatsRewriteRule for DynamicComparisonNanCountStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - dynamic_comparison_falsify(expr, ctx, nan_count_check, true) + dynamic_comparison_falsify::(expr, ctx) } } @@ -517,15 +471,13 @@ impl StatsRewriteRule for DynamicComparisonAllNonNanStatsRewrite { expr: &Expression, ctx: &StatsRewriteCtx<'_>, ) -> VortexResult> { - dynamic_comparison_falsify(expr, ctx, all_non_nan_check, false) + dynamic_comparison_falsify::(expr, ctx) } } -fn dynamic_comparison_falsify( +fn dynamic_comparison_falsify( expr: &Expression, ctx: &StatsRewriteCtx<'_>, - nan_check: fn(&StatsRewriteCtx<'_>, &Expression) -> VortexResult, - emit_unguarded_rewrites: bool, ) -> VortexResult> { let dynamic = expr.as_::(); let lhs = expr.child(0); @@ -548,13 +500,7 @@ fn dynamic_comparison_falsify( }, [lhs_stat], ); - with_all_non_nan_predicate( - ctx, - nan_check, - emit_unguarded_rewrites, - [lhs], - value_predicate, - ) + with_non_nan_guards::

(ctx, [lhs], value_predicate) } fn min(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option { @@ -583,25 +529,43 @@ enum NanCheck { Unavailable, } -// Min/max do not order NaN values, so comparison rewrites are only sound when every -// candidate value is known to be non-NaN. Cast result dtypes are not enough: a cast -// from float to non-float still needs a proof about the float source values. -fn nan_count_check(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> VortexResult { - nan_check(ctx, expr, |expr| { - match stat_expr(expr, Stat::NaNCount, ctx) { - Some(nan_count) => NanCheck::Check(eq(nan_count, lit(0u64))), - None => NanCheck::Unavailable, - } - }) +trait NonNanProof { + const EMIT_UNGUARDED_REWRITES: bool; + + fn check(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> VortexResult; } -fn all_non_nan_check(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> VortexResult { - nan_check(ctx, expr, |expr| { - NanCheck::Check(stat_fn(expr.clone(), AllNonNan.bind(AggregateEmptyOptions))) - }) +struct NanCountProof; + +impl NonNanProof for NanCountProof { + const EMIT_UNGUARDED_REWRITES: bool = true; + + fn check(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> VortexResult { + non_nan_check(ctx, expr, |expr| { + match stat_expr(expr, Stat::NaNCount, ctx) { + Some(nan_count) => NanCheck::Check(eq(nan_count, lit(0u64))), + None => NanCheck::Unavailable, + } + }) + } +} + +struct AllNonNanProof; + +impl NonNanProof for AllNonNanProof { + const EMIT_UNGUARDED_REWRITES: bool = false; + + fn check(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> VortexResult { + non_nan_check(ctx, expr, |expr| { + NanCheck::Check(stat_fn(expr.clone(), AllNonNan.bind(AggregateEmptyOptions))) + }) + } } -fn nan_check( +// Min/max do not order NaN values, so comparison rewrites are only sound when every +// candidate value is known to be non-NaN. Cast result dtypes are not enough: a cast +// from float to non-float still needs a proof about the float source values. +fn non_nan_check( ctx: &StatsRewriteCtx<'_>, expr: &Expression, proof: impl FnOnce(&Expression) -> NanCheck, @@ -622,7 +586,7 @@ fn nan_check( return Ok(NanCheck::NotNeeded); } - return nan_check(ctx, expr.child(0), proof); + return non_nan_check(ctx, expr.child(0), proof); } if !has_nans(&ctx.return_dtype(expr)?) { @@ -663,33 +627,14 @@ fn stat_expr(expr: &Expression, stat: Stat, ctx: &StatsRewriteCtx<'_>) -> Option .then(|| stat_fn(expr.clone(), aggregate_fn)) } -fn with_nan_predicate( - ctx: &StatsRewriteCtx<'_>, - nan_check: fn(&StatsRewriteCtx<'_>, &Expression) -> VortexResult, - emit_unguarded_rewrites: bool, - lhs: &Expression, - rhs: &Expression, - value_predicate: Expression, -) -> VortexResult> { - with_all_non_nan_predicate( - ctx, - nan_check, - emit_unguarded_rewrites, - [lhs, rhs], - value_predicate, - ) -} - -fn with_all_non_nan_predicate<'a>( +fn with_non_nan_guards<'a, P: NonNanProof>( ctx: &StatsRewriteCtx<'_>, - nan_check: fn(&StatsRewriteCtx<'_>, &Expression) -> VortexResult, - emit_unguarded_rewrites: bool, exprs: impl IntoIterator, value_predicate: Expression, ) -> VortexResult> { let mut nan_checks = Vec::new(); for expr in exprs { - match nan_check(ctx, expr)? { + match P::check(ctx, expr)? { NanCheck::NotNeeded => {} NanCheck::Check(check) => nan_checks.push(check), NanCheck::Unavailable => return Ok(None), @@ -702,7 +647,7 @@ fn with_all_non_nan_predicate<'a>( // No possible NaN-bearing expression remains, so the value predicate is // already guarded. Only one registered rule emits this unguarded // rewrite so non-float comparisons are not duplicated. - None if emit_unguarded_rewrites => Some(value_predicate), + None if P::EMIT_UNGUARDED_REWRITES => Some(value_predicate), None => None, }) } From 22cd9de13f4b6a6dc94b61e5d417d575c61478b5 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 17 Jun 2026 15:54:34 -0400 Subject: [PATCH 22/28] Preserve DuckDB filter order Signed-off-by: Nicholas Gates --- vortex-duckdb/src/projection.rs | 61 ++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/vortex-duckdb/src/projection.rs b/vortex-duckdb/src/projection.rs index a27056e5f01..593f15946e0 100644 --- a/vortex-duckdb/src/projection.rs +++ b/vortex-duckdb/src/projection.rs @@ -18,7 +18,6 @@ use vortex::expr::root; use vortex::expr::select; use vortex::layout::layouts::row_idx::row_idx; use vortex::scan::selection::Selection; -use vortex_utils::aliases::hash_set::HashSet; use crate::convert::try_from_table_filter; use crate::convert::try_from_virtual_column_filter; @@ -146,6 +145,12 @@ pub struct Filter { pub has_non_optional_filter: bool, } +fn push_filter_expr(filter_exprs: &mut Vec, expr: Expression) { + if !filter_exprs.iter().any(|existing| existing == &expr) { + filter_exprs.push(expr); + } +} + impl Filter { /// Creates a table filter expression, row selection, and row range from the table filter set, /// column metadata, additional filter expressions, and the top-level DType. @@ -158,29 +163,26 @@ impl Filter { ) -> VortexResult { let mut has_non_optional_filter = false; - let mut table_filter_exprs: HashSet = if let Some(filter) = table_filter_set { - filter - .into_iter() - .filter(|(idx, _)| { - let idx_u: usize = idx.as_(); - !is_virtual_column(column_ids[idx_u]) - }) - .map(|(idx, ex)| { - has_non_optional_filter |= - !matches!(ex.as_class(), TableFilterClass::Optional(_)); - - let idx_u: usize = idx.as_(); - let col_idx: usize = column_ids[idx_u].as_(); - let name = &column_fields.get(col_idx).vortex_expect("exists").name; - try_from_table_filter(ex, &col(name.as_str()), dtype) - }) - .collect::>>>()? - .unwrap_or_else(HashSet::new) - } else { - HashSet::new() - }; + let mut table_filter_exprs = Vec::new(); + if let Some(filter) = table_filter_set { + for (idx, ex) in filter.into_iter().filter(|(idx, _)| { + let idx_u: usize = idx.as_(); + !is_virtual_column(column_ids[idx_u]) + }) { + has_non_optional_filter |= !matches!(ex.as_class(), TableFilterClass::Optional(_)); + + let idx_u: usize = idx.as_(); + let col_idx: usize = column_ids[idx_u].as_(); + let name = &column_fields.get(col_idx).vortex_expect("exists").name; + if let Some(expr) = try_from_table_filter(ex, &col(name.as_str()), dtype)? { + push_filter_expr(&mut table_filter_exprs, expr); + } + } + } - table_filter_exprs.extend(additional_filters.iter().cloned()); + for expr in additional_filters.iter().cloned() { + push_filter_expr(&mut table_filter_exprs, expr); + } let mut file_selection = Selection::All; let mut row_selection = Selection::All; @@ -286,4 +288,17 @@ mod tests { let ids = [2, 1, 0]; assert_ne!(Projection::new(None, &ids, &fields).projection, root()); } + + #[test] + fn test_push_filter_expr_preserves_order() { + let first = col("first"); + let second = col("second"); + + let mut filter_exprs = Vec::new(); + push_filter_expr(&mut filter_exprs, first.clone()); + push_filter_expr(&mut filter_exprs, second.clone()); + push_filter_expr(&mut filter_exprs, first.clone()); + + assert_eq!(filter_exprs, vec![first, second]); + } } From d81033e531f0a43b1c32e8938311e2b66d155495 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 17 Jun 2026 17:14:44 -0400 Subject: [PATCH 23/28] Restore integer Delta compression scheme Signed-off-by: "Nicholas Gates" --- vortex-btrblocks/src/builder.rs | 7 + vortex-btrblocks/src/schemes/integer/delta.rs | 199 ++++++++++++++++++ vortex-btrblocks/src/schemes/integer/mod.rs | 4 + .../schemes/integer/scheme_selection_tests.rs | 61 +++++- 4 files changed, 270 insertions(+), 1 deletion(-) create mode 100644 vortex-btrblocks/src/schemes/integer/delta.rs diff --git a/vortex-btrblocks/src/builder.rs b/vortex-btrblocks/src/builder.rs index 43df788473a..523ffa7635b 100644 --- a/vortex-btrblocks/src/builder.rs +++ b/vortex-btrblocks/src/builder.rs @@ -41,6 +41,9 @@ pub const ALL_SCHEMES: &[&dyn Scheme] = &[ &integer::RunEndScheme, &integer::SequenceScheme, &integer::IntRLEScheme, + // Prefer all other schemes above delta, for now (since its slower to decompress). + #[cfg(feature = "unstable_encodings")] + &integer::DeltaScheme::new(1.25), //////////////////////////////////////////////////////////////////////////////////////////////// // Float schemes. //////////////////////////////////////////////////////////////////////////////////////////////// @@ -180,6 +183,10 @@ impl BtrBlocksCompressorBuilder { ]; #[cfg(feature = "unstable_encodings")] excluded.push(string::OnPairScheme.id()); + // Delta has no GPU decode kernel and its prefix-sum decode is inherently sequential, so it + // is incompatible with pure-GPU decompression paths. + #[cfg(feature = "unstable_encodings")] + excluded.push(integer::DeltaScheme::default().id()); let builder = self.exclude_schemes(excluded); #[cfg(all(feature = "zstd", feature = "unstable_encodings"))] diff --git a/vortex-btrblocks/src/schemes/integer/delta.rs b/vortex-btrblocks/src/schemes/integer/delta.rs new file mode 100644 index 00000000000..2abc4d578f5 --- /dev/null +++ b/vortex-btrblocks/src/schemes/integer/delta.rs @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! FastLanes Delta integer encoding. + +use vortex_array::ArrayRef; +use vortex_array::Canonical; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_compressor::builtins::BinaryDictScheme; +use vortex_compressor::builtins::FloatDictScheme; +use vortex_compressor::builtins::IntDictScheme; +use vortex_compressor::builtins::StringDictScheme; +use vortex_compressor::estimate::CompressionEstimate; +use vortex_compressor::estimate::DeferredEstimate; +use vortex_compressor::estimate::EstimateScore; +use vortex_compressor::estimate::EstimateVerdict; +use vortex_compressor::scheme::AncestorExclusion; +use vortex_compressor::scheme::ChildSelection; +use vortex_compressor::scheme::DescendantExclusion; +use vortex_error::VortexResult; +use vortex_fastlanes::Delta; + +use crate::ArrayAndStats; +use crate::CascadingCompressor; +use crate::CompressorContext; +use crate::GenerateStatsOptions; +use crate::Scheme; +use crate::SchemeExt; + +/// FastLanes Delta encoding for smooth / near-monotone integers. +/// +/// Delta replaces each value with its difference from an earlier value (at the FastLanes lane +/// stride), so a later cascade layer (FoR / BitPacking) packs the smaller residuals. It only +/// pays off when those residuals span meaningfully fewer bits than the values themselves. +/// +/// The minimum penalized compression ratio required for Delta to be selected is configurable via +/// [`DeltaScheme::new`]; [`DeltaScheme::default`] uses a ratio of `1.25`. +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct DeltaScheme { + min_ratio: f64, +} + +impl DeltaScheme { + /// Creates a Delta scheme requiring `min_ratio` (after the [`DELTA_PENALTY`]) before it wins. + /// + /// Pass a higher ratio to make Delta more conservative, or a lower one to select it more + /// eagerly. [`DeltaScheme::default`] uses a ratio of `1.25`. + pub const fn new(min_ratio: f64) -> Self { + Self { min_ratio } + } +} + +impl Default for DeltaScheme { + fn default() -> Self { + Self::new(1.25) + } +} + +/// Multiplicative penalty applied to Delta's estimated compression ratio. +/// +/// Unlike FoR/BitPacking, Delta breaks random access and adds a prefix-sum decode pass, and it +/// carries a structural sign bit on its residuals. We therefore require Delta to be meaningfully +/// (~5%) smaller than the best alternative before it wins, rather than picking it for a +/// single-bit gain. This factor encodes that "delta tax". +const DELTA_PENALTY: f64 = 0.95; + +/// Minimum length before Delta is worth considering (one FastLanes chunk). +const MIN_DELTA_LEN: usize = 1024; + +impl Scheme for DeltaScheme { + fn scheme_name(&self) -> &'static str { + "vortex.int.delta" + } + + fn matches(&self, canonical: &Canonical) -> bool { + canonical.dtype().is_int() + } + + fn num_children(&self) -> usize { + 2 + } + + /// Delta-encode the data at most once per path: exclude Delta from the subtrees of both the + /// bases and the deltas children so we never delta-encode data that was already delta-encoded. + fn descendant_exclusions(&self) -> Vec { + vec![DescendantExclusion { + excluded: self.id(), + children: ChildSelection::All, + }] + } + + /// Delta over dictionary codes just adds indirection: codes are compact integers with no + /// monotone structure, so (like FoR/Sequence) skip the codes child. + fn ancestor_exclusions(&self) -> Vec { + vec![ + AncestorExclusion { + ancestor: IntDictScheme.id(), + children: ChildSelection::One(1), + }, + AncestorExclusion { + ancestor: FloatDictScheme.id(), + children: ChildSelection::One(1), + }, + AncestorExclusion { + ancestor: StringDictScheme.id(), + children: ChildSelection::One(1), + }, + AncestorExclusion { + ancestor: BinaryDictScheme.id(), + children: ChildSelection::One(1), + }, + ] + } + + fn expected_compression_ratio( + &self, + data: &ArrayAndStats, + compress_ctx: CompressorContext, + _exec_ctx: &mut ExecutionCtx, + ) -> CompressionEstimate { + // Delta only pays off if a later cascade layer (FoR/BitPacking) packs the residuals. + if compress_ctx.finished_cascading() { + return CompressionEstimate::Verdict(EstimateVerdict::Skip); + } + // Too short to transpose into FastLanes chunks meaningfully. + if data.array_len() < MIN_DELTA_LEN { + return CompressionEstimate::Verdict(EstimateVerdict::Skip); + } + + // Estimating Delta needs the real transposed-delta span, so defer to a callback that + // delta-encodes the array and measures the residual range. + let min_ratio = self.min_ratio; + CompressionEstimate::Deferred(DeferredEstimate::Callback(Box::new( + move |_compressor, data, best_so_far, _ctx, exec_ctx| { + let primitive = data.array().clone().execute::(exec_ctx)?; + let full_width = primitive.ptype().bit_width() as f64; + + // Delta's best case is residuals collapsing to a single bit. If even that, after + // the penalty, can't beat the incumbent, skip before doing the encode work. + let threshold = best_so_far.and_then(EstimateScore::finite_ratio); + if threshold.is_some_and(|t| full_width * DELTA_PENALTY <= t) { + return Ok(EstimateVerdict::Skip); + } + + // Measure the actual FastLanes transposed-delta span. This is the lane-stride + // difference that gets bit-packed, not the lag-1 difference (which the transpose + // makes optimistic), so it is what truly drives the compressed size. + let (_bases, deltas) = vortex_fastlanes::delta_compress(&primitive, exec_ctx)?; + let delta_stats = + ArrayAndStats::new(deltas.into_array(), GenerateStatsOptions::default()); + let span = delta_stats.integer_stats(exec_ctx).erased().max_minus_min(); + + // Bits needed to FoR-pack the residuals. A zero span means constant deltas, which + // SequenceScheme already captures more cheaply, so defer to it. + let delta_bits = match span.checked_ilog2() { + Some(l) => (l + 1) as f64, + None => return Ok(EstimateVerdict::Skip), + }; + + let ratio = full_width / delta_bits * DELTA_PENALTY; + if ratio <= min_ratio { + return Ok(EstimateVerdict::Skip); + } + Ok(EstimateVerdict::Ratio(ratio)) + }, + ))) + } + + fn compress( + &self, + compressor: &CascadingCompressor, + data: &ArrayAndStats, + compress_ctx: CompressorContext, + exec_ctx: &mut ExecutionCtx, + ) -> VortexResult { + let primitive = data.array().clone().execute::(exec_ctx)?; + let len = primitive.len(); + let (bases, deltas) = vortex_fastlanes::delta_compress(&primitive, exec_ctx)?; + + let compressed_bases = compressor.compress_child( + &bases.into_array(), + &compress_ctx, + self.id(), + 0, + exec_ctx, + )?; + let compressed_deltas = compressor.compress_child( + &deltas.into_array(), + &compress_ctx, + self.id(), + 1, + exec_ctx, + )?; + + Delta::try_new(compressed_bases, compressed_deltas, 0, len).map(IntoArray::into_array) + } +} diff --git a/vortex-btrblocks/src/schemes/integer/mod.rs b/vortex-btrblocks/src/schemes/integer/mod.rs index aed29f1ad3d..abe5868f5c8 100644 --- a/vortex-btrblocks/src/schemes/integer/mod.rs +++ b/vortex-btrblocks/src/schemes/integer/mod.rs @@ -4,6 +4,8 @@ //! Integer compression schemes. mod bitpacking; +#[cfg(feature = "unstable_encodings")] +mod delta; mod for_; mod rle; mod runend; @@ -15,6 +17,8 @@ mod zigzag; mod pco; pub use bitpacking::BitPackingScheme; +#[cfg(feature = "unstable_encodings")] +pub use delta::DeltaScheme; pub use for_::FoRScheme; #[cfg(feature = "pco")] pub use pco::PcoScheme; diff --git a/vortex-btrblocks/src/schemes/integer/scheme_selection_tests.rs b/vortex-btrblocks/src/schemes/integer/scheme_selection_tests.rs index 2e0fb269fda..993827d2057 100644 --- a/vortex-btrblocks/src/schemes/integer/scheme_selection_tests.rs +++ b/vortex-btrblocks/src/schemes/integer/scheme_selection_tests.rs @@ -143,7 +143,11 @@ fn test_sequence_compressed() -> VortexResult<()> { fn test_rle_compressed() -> VortexResult<()> { let mut values: Vec = Vec::new(); for i in 0..1024 { - values.extend(iter::repeat_n(i, 10)); + // Scramble the per-run value so the data is run-length-dominant but not monotone: this + // keeps RunEnd the winner instead of Delta (whose residuals would be small on a smooth + // ramp). + let v = (i as u32).wrapping_mul(2_654_435_761) as i32; + values.extend(iter::repeat_n(v, 10)); } let array = PrimitiveArray::new(Buffer::copy_from(&values), Validity::NonNullable); let btr = BtrBlocksCompressor::default(); @@ -152,3 +156,58 @@ fn test_rle_compressed() -> VortexResult<()> { assert!(compressed.is::()); Ok(()) } + +/// A strictly-increasing column with small, irregular steps: not a perfect arithmetic sequence +/// (so Sequence skips), all-unique with no runs (so RunEnd/Dict skip), and a wide absolute range. +/// Delta's residuals are far smaller than the FoR span, so Delta should win and round-trip, and +/// it must appear at most once in the tree. +#[cfg(feature = "unstable_encodings")] +#[test] +fn test_delta_compressed() -> VortexResult<()> { + use vortex_array::assert_arrays_eq; + use vortex_fastlanes::Delta; + + let mut rng = StdRng::seed_from_u64(7u64); + let mut value = 500_000i32; + let values: Vec = (0..4096) + .map(|_| { + value += 1 + (rng.next_u32() % 6) as i32; + value + }) + .collect(); + let array = PrimitiveArray::new(Buffer::copy_from(&values), Validity::NonNullable); + + let btr = BtrBlocksCompressor::default(); + let compressed = btr.compress( + &array.clone().into_array(), + &mut SESSION.create_execution_ctx(), + )?; + assert!( + compressed.is::(), + "expected Delta, got tree:\n{}", + compressed.display_tree() + ); + // Delta must appear at most once per tree: no Delta node may be nested under another. + assert!( + !has_nested_delta(&compressed, false), + "Delta was applied more than once in the tree:\n{}", + compressed.display_tree() + ); + assert_arrays_eq!(compressed, array.into_array()); + Ok(()) +} + +/// Returns true if any `Delta` array appears below an ancestor `Delta` in the tree. +#[cfg(feature = "unstable_encodings")] +fn has_nested_delta(array: &vortex_array::ArrayRef, under_delta: bool) -> bool { + use vortex_fastlanes::Delta; + + let is_delta = array.is::(); + if is_delta && under_delta { + return true; + } + array + .children() + .iter() + .any(|child| has_nested_delta(child, under_delta || is_delta)) +} From baae2dd3b960e912b2272001fb94275114ee09f0 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 18 Jun 2026 09:53:49 -0400 Subject: [PATCH 24/28] Address stats pruning review comments Signed-off-by: "Nicholas Gates" --- .../internals/stats-pruning.md | 7 +- vortex-array/src/scalar_fn/fns/binary/mod.rs | 129 ++++++++++++------ vortex-array/src/stats/bind.rs | 16 ++- 3 files changed, 108 insertions(+), 44 deletions(-) diff --git a/docs/developer-guide/internals/stats-pruning.md b/docs/developer-guide/internals/stats-pruning.md index 6049a06455e..505c5a898cb 100644 --- a/docs/developer-guide/internals/stats-pruning.md +++ b/docs/developer-guide/internals/stats-pruning.md @@ -4,12 +4,17 @@ Vortex uses statistics to prove when a filter cannot match a row group, zone, or file. The proof expression returns `true` when the input can be skipped. It returns `false` or `null` when pruning is not proven. +Both `false` and `null` are non-pruning outcomes, but they mean different +things. `false` means the available stats disproved the skip proof. `null` means +the proof was unknown, usually because a required stat was missing or inexact. + The pruning pipeline has two phases: 1. `Expression::falsify(scope, session)` asks the session's `StatsRewriteRule`s to rewrite a filter into an abstract proof expression. Rules describe semantics in terms of `vortex.stat(input, aggregate_fn)` - placeholders, so the rewrite does not depend on a particular layout. + placeholders. These placeholders name the statistic needed by the proof, but + not where that statistic is stored. 2. `bind_stats` lowers those abstract stat placeholders with a `StatBinder`. The binder maps stats to the representation used by the caller, such as zone-map table fields, file-level stat literals, or typed null literals for diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index b51f86b3188..e09a982ba1f 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -26,6 +26,7 @@ use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::SimplifyCtx; use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::operators::CompareOperator; use crate::scalar_fn::fns::operators::Operator; @@ -43,40 +44,6 @@ use crate::scalar::Scalar; #[derive(Clone)] pub struct Binary; -fn simplify_and(lhs: &Expression, rhs: &Expression) -> Option { - match (bool_literal(lhs), bool_literal(rhs)) { - (Some(Some(false)), _) | (_, Some(Some(false))) => Some(lit(false)), - (Some(Some(true)), _) => Some(rhs.clone()), - (_, Some(Some(true))) => Some(lhs.clone()), - (Some(None), Some(None)) => Some(lhs.clone()), - _ => None, - } -} - -fn simplify_or(lhs: &Expression, rhs: &Expression) -> Option { - match (bool_literal(lhs), bool_literal(rhs)) { - (Some(Some(true)), _) | (_, Some(Some(true))) => Some(lit(true)), - (Some(Some(false)), _) => Some(rhs.clone()), - (_, Some(Some(false))) => Some(lhs.clone()), - (Some(None), Some(None)) => Some(lhs.clone()), - _ => None, - } -} - -fn bool_literal(expr: &Expression) -> Option> { - expr.as_opt::()? - .as_bool_opt() - .map(|value| value.value()) -} - -fn is_null_literal(expr: &Expression) -> bool { - expr.as_opt::().is_some_and(Scalar::is_null) -} - -fn null_bool() -> Expression { - lit(Scalar::null(DType::Bool(Nullability::Nullable))) -} - impl ScalarFnVTable for Binary { type Options = Operator; @@ -201,17 +168,67 @@ impl ScalarFnVTable for Binary { let lhs = expr.child(0); let rhs = expr.child(1); - if operator.is_comparison() && (is_null_literal(lhs) || is_null_literal(rhs)) { - return Ok(Some(null_bool())); - } - + let bool_literal = |expr: &Expression| { + expr.as_opt::()? + .as_bool_opt() + .map(|value| value.value()) + }; + + // AND/OR use Kleene three-valued logic. `None` below is a boolean null. + // + // AND: + // - false AND x => false + // - true AND x => x + // - null AND null => null + // + // OR: + // - true OR x => true + // - false OR x => x + // - null OR null => null + // + // Other null cases either fall out of the identity/annihilator rules + // above (`null AND true`, `null OR false`) or cannot be simplified under + // Kleene semantics (`null AND x`, `null OR x` for non-literal `x`). Ok(match operator { - Operator::And => simplify_and(lhs, rhs), - Operator::Or => simplify_or(lhs, rhs), + Operator::And => match (bool_literal(lhs), bool_literal(rhs)) { + (Some(Some(false)), _) | (_, Some(Some(false))) => Some(lit(false)), + (Some(Some(true)), _) => Some(rhs.clone()), + (_, Some(Some(true))) => Some(lhs.clone()), + (Some(None), Some(None)) => Some(lhs.clone()), + _ => None, + }, + Operator::Or => match (bool_literal(lhs), bool_literal(rhs)) { + (Some(Some(true)), _) | (_, Some(Some(true))) => Some(lit(true)), + (Some(Some(false)), _) => Some(rhs.clone()), + (_, Some(Some(false))) => Some(lhs.clone()), + (Some(None), Some(None)) => Some(lhs.clone()), + _ => None, + }, _ => None, }) } + fn simplify( + &self, + operator: &Operator, + expr: &Expression, + ctx: &dyn SimplifyCtx, + ) -> VortexResult> { + let is_literal_null = + |expr: &Expression| expr.as_opt::().is_some_and(Scalar::is_null); + + if operator.is_comparison() + && (is_literal_null(expr.child(0)) || is_literal_null(expr.child(1))) + { + // Validate the comparison before reducing it. This preserves type + // errors for expressions like `int_col = null_utf8`. + ctx.return_dtype(expr)?; + return Ok(Some(lit(Scalar::null(DType::Bool(Nullability::Nullable))))); + } + + Ok(None) + } + fn validity( &self, operator: &Operator, @@ -257,6 +274,7 @@ impl ScalarFnVTable for Binary { #[cfg(test)] mod tests { use vortex_error::VortexExpect; + use vortex_error::VortexResult; use super::*; use crate::LEGACY_SESSION; @@ -265,6 +283,7 @@ mod tests { use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; + use crate::dtype::PType; use crate::expr::Expression; use crate::expr::and_collect; use crate::expr::col; @@ -389,6 +408,36 @@ mod tests { ); } + #[test] + fn comparison_with_typed_null_simplifies_after_type_check() -> VortexResult<()> { + let dtype = test_harness::struct_dtype(); + + let expr = eq( + col("col1"), + lit(Scalar::null(DType::Primitive( + PType::U16, + Nullability::Nullable, + ))), + ); + + assert_eq!( + expr.optimize_recursive(&dtype)?, + lit(Scalar::null(DType::Bool(Nullability::Nullable))) + ); + Ok(()) + } + + #[test] + fn comparison_with_incompatible_null_still_type_checks() { + let dtype = test_harness::struct_dtype(); + let expr = eq( + col("col1"), + lit(Scalar::null(DType::Utf8(Nullability::Nullable))), + ); + + assert!(expr.optimize_recursive(&dtype).is_err()); + } + #[test] fn test_display_print() { let expr = gt(lit(1), lit(2)); diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 328266c356e..95cfcffc311 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -4,9 +4,14 @@ //! Bind abstract `vortex.stat` expressions to a concrete stats representation. //! //! Stats rewrite rules describe pruning in terms of `vortex.stat(input, aggregate_fn)` placeholders -//! so the rewrite is independent of where statistics are stored. Binding is the later pass that -//! replaces those placeholders with the representation used by a caller: zone-map field references, -//! file-level stat literals, or typed nulls for missing stats. +//! so the rewrite is independent of where statistics are stored. These stat placeholders are +//! abstract because they name the statistic needed for a proof, but not how that statistic is +//! represented by a specific layout or reader. +//! +//! Binding is the later pass that replaces each abstract placeholder with the representation used +//! by a caller: zone-map field references, file-level stat literals, or typed nulls for missing +//! stats. This lets all callers share the same falsification rules while keeping layout-specific +//! stat storage behind [`StatBinder`]. use vortex_error::VortexResult; @@ -21,6 +26,11 @@ use crate::scalar::Scalar; use crate::scalar_fn::fns::stat::StatFn; /// A target that can bind abstract statistics to concrete expressions. +/// +/// Implementations define how a pruning proof should read stats from a specific backing +/// representation. For example, a zone-map binder can translate a `max(col)` placeholder into a +/// field reference in the per-zone stats table, while a file-stats binder can translate the same +/// placeholder into a literal value from the file footer. pub trait StatBinder { /// The dtype scope used to type-check expressions before stats are bound. fn scope(&self) -> &DType; From 9a737c90eb756578187cd194403de667f9fedca0 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 18 Jun 2026 12:37:40 -0400 Subject: [PATCH 25/28] Address follow-up stats rewrite review comments Signed-off-by: "Nicholas Gates" --- .../internals/stats-pruning.md | 6 ------ vortex-array/src/stats/bind.rs | 20 +++++++++++++------ vortex-duckdb/src/projection.rs | 10 +++++----- vortex-file/src/pruning.rs | 11 ++++++++++ vortex-layout/src/layouts/zoned/zone_map.rs | 11 ++++++++++ 5 files changed, 41 insertions(+), 17 deletions(-) diff --git a/docs/developer-guide/internals/stats-pruning.md b/docs/developer-guide/internals/stats-pruning.md index 505c5a898cb..4acaf9efc2a 100644 --- a/docs/developer-guide/internals/stats-pruning.md +++ b/docs/developer-guide/internals/stats-pruning.md @@ -35,11 +35,5 @@ expression is reduced and evaluated once for the full file. If it evaluates to `true`, the file stats reader can return an all-false pruning mask without reading child layouts. -Scan planning uses `checked_pruning_expr` to lower a falsified expression against -the available stats table schema. It returns the stats-table expression and the -set of stat fields still required after expression reduction. If all required -stats are missing, only a constant `true` proof is useful; all other results are -treated as no pruning expression. - For the layout model around these pruning points, see [Layouts](../../concepts/layouts.md) and [Scanning](../../concepts/scanning.md). diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 95cfcffc311..cc4816ab423 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -57,17 +57,16 @@ pub trait StatBinder { /// Bind `aggregate_fn(input)` to a concrete expression. /// - /// The default implementation supports aggregate functions that map - /// directly to [`Stat`] slots. Binders that store richer aggregate stats can - /// override this method without extending the generic stats binding walker. + /// Implementations should return `Ok(None)` when the requested aggregate + /// statistic is unavailable in their backing representation. Binders that + /// support only direct legacy [`Stat`] slots can delegate to + /// [`bind_direct_aggregate_stat`]. fn bind_aggregate( &self, input: &Expression, aggregate_fn: &AggregateFnRef, stat_dtype: &DType, - ) -> VortexResult> { - bind_direct_aggregate_stat(self, input, aggregate_fn, stat_dtype) - } + ) -> VortexResult>; /// Expression to use when a stat is unavailable. /// @@ -204,6 +203,15 @@ mod tests { Ok(None) } } + + fn bind_aggregate( + &self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult> { + bind_direct_aggregate_stat(self, input, aggregate_fn, stat_dtype) + } } #[test] diff --git a/vortex-duckdb/src/projection.rs b/vortex-duckdb/src/projection.rs index 968cf217fe6..4298c3131e4 100644 --- a/vortex-duckdb/src/projection.rs +++ b/vortex-duckdb/src/projection.rs @@ -198,9 +198,9 @@ pub struct Filter { pub has_non_optional_filter: bool, } -fn push_filter_expr(filter_exprs: &mut Vec, expr: Expression) { - if !filter_exprs.iter().any(|existing| existing == &expr) { - filter_exprs.push(expr); +fn push_filter_expr(filter_exprs: &mut Vec, expr: &Expression) { + if !filter_exprs.iter().any(|existing| existing == expr) { + filter_exprs.push(expr.clone()); } } @@ -228,12 +228,12 @@ impl Filter { let col_idx: usize = column_ids[idx_u].as_(); let name = &column_fields.get(col_idx).vortex_expect("exists").name; if let Some(expr) = try_from_table_filter(ex, &col(name.as_str()), dtype)? { - push_filter_expr(&mut table_filter_exprs, expr); + push_filter_expr(&mut table_filter_exprs, &expr); } } } - for expr in additional_filters.iter().cloned() { + for expr in additional_filters { push_filter_expr(&mut table_filter_exprs, expr); } diff --git a/vortex-file/src/pruning.rs b/vortex-file/src/pruning.rs index af08a64b687..9cd2be73f77 100644 --- a/vortex-file/src/pruning.rs +++ b/vortex-file/src/pruning.rs @@ -4,6 +4,7 @@ use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::AggregateFnRef; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::NullArray; use vortex_array::dtype::DType; @@ -19,6 +20,7 @@ use vortex_array::scalar_fn::fns::get_item::GetItem; use vortex_array::scalar_fn::fns::literal::Literal; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_direct_aggregate_stat; use vortex_array::stats::bind::bind_stats; use vortex_error::VortexResult; use vortex_session::VortexSession; @@ -89,6 +91,15 @@ impl StatBinder for FileStatsBinder<'_> { }; Ok(self.stat_ref(&field_path, stat)) } + + fn bind_aggregate( + &self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult> { + bind_direct_aggregate_stat(self, input, aggregate_fn, stat_dtype) + } } impl FileStatsBinder<'_> { diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index ad82fc7fc5b..cd8bbd6f881 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::AggregateFnRef; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; @@ -21,6 +22,7 @@ use vortex_array::expr::stats::Stat; use vortex_array::scalar_fn::internal::row_count::contains_row_count; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_direct_aggregate_stat; use vortex_array::stats::bind::bind_stats; use vortex_array::validity::Validity; use vortex_buffer::buffer; @@ -156,6 +158,15 @@ impl StatBinder for ZoneMapStatsBinder<'_> { } Ok(Some(get_item(stat.name(), root()))) } + + fn bind_aggregate( + &self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult> { + bind_direct_aggregate_stat(self, input, aggregate_fn, stat_dtype) + } } /// Build per-zone row counts for a zone map. From 83b133255f3f2ae8c4f4e34e21241f06fcd54cef Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 18 Jun 2026 12:47:25 -0400 Subject: [PATCH 26/28] Localize legacy stat aggregate binding Signed-off-by: "Nicholas Gates" --- vortex-array/src/stats/bind.rs | 48 ++++----------------- vortex-file/src/pruning.rs | 17 +++----- vortex-layout/src/layouts/zoned/zone_map.rs | 17 +++----- 3 files changed, 18 insertions(+), 64 deletions(-) diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index cc4816ab423..c26cc8a5b40 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -19,7 +19,6 @@ use crate::aggregate_fn::AggregateFnRef; use crate::dtype::DType; use crate::expr::Expression; use crate::expr::lit; -use crate::expr::stats::Stat; use crate::expr::traversal::NodeExt; use crate::expr::traversal::Transformed; use crate::scalar::Scalar; @@ -43,24 +42,10 @@ pub trait StatBinder { self.scope().clone() } - /// Bind `stat(input)` to a concrete expression. - /// - /// Returning `Ok(None)` marks the stat as unavailable. [`bind_stats`] will - /// then call [`Self::missing_stat`] with the dtype expected from the - /// original `vortex.stat` expression. - fn bind_stat( - &self, - input: &Expression, - stat: Stat, - stat_dtype: &DType, - ) -> VortexResult>; - /// Bind `aggregate_fn(input)` to a concrete expression. /// /// Implementations should return `Ok(None)` when the requested aggregate - /// statistic is unavailable in their backing representation. Binders that - /// support only direct legacy [`Stat`] slots can delegate to - /// [`bind_direct_aggregate_stat`]. + /// statistic is unavailable in their backing representation. fn bind_aggregate( &self, input: &Expression, @@ -106,19 +91,6 @@ pub fn bind_stats( lowered.optimize_recursive(&binder.bound_scope()) } -/// Bind an aggregate function that has a direct legacy [`Stat`] slot. -pub fn bind_direct_aggregate_stat( - binder: &B, - input: &Expression, - aggregate_fn: &AggregateFnRef, - stat_dtype: &DType, -) -> VortexResult> { - let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { - return Ok(None); - }; - binder.bind_stat(input, stat, stat_dtype) -} - fn bind_stat_fn( expr: &Expression, scope: &DType, @@ -151,6 +123,7 @@ mod tests { use crate::expr::is_null; use crate::expr::or; use crate::expr::root; + use crate::expr::stats::Stat; use crate::stats::all_non_nan; use crate::stats::nan_count; @@ -191,27 +164,22 @@ mod tests { self.bound_scope.clone() } - fn bind_stat( + fn bind_aggregate( &self, _input: &Expression, - stat: Stat, + aggregate_fn: &AggregateFnRef, _stat_dtype: &DType, ) -> VortexResult> { + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; + if stat == Stat::NaNCount && self.bind_nan_count { Ok(Some(get_item("f_nan_count", root()))) } else { Ok(None) } } - - fn bind_aggregate( - &self, - input: &Expression, - aggregate_fn: &AggregateFnRef, - stat_dtype: &DType, - ) -> VortexResult> { - bind_direct_aggregate_stat(self, input, aggregate_fn, stat_dtype) - } } #[test] diff --git a/vortex-file/src/pruning.rs b/vortex-file/src/pruning.rs index 9cd2be73f77..67267086adf 100644 --- a/vortex-file/src/pruning.rs +++ b/vortex-file/src/pruning.rs @@ -20,7 +20,6 @@ use vortex_array::scalar_fn::fns::get_item::GetItem; use vortex_array::scalar_fn::fns::literal::Literal; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::stats::bind::StatBinder; -use vortex_array::stats::bind::bind_direct_aggregate_stat; use vortex_array::stats::bind::bind_stats; use vortex_error::VortexResult; use vortex_session::VortexSession; @@ -80,26 +79,20 @@ impl StatBinder for FileStatsBinder<'_> { DType::Null } - fn bind_stat( + fn bind_aggregate( &self, input: &Expression, - stat: Stat, + aggregate_fn: &AggregateFnRef, _stat_dtype: &DType, ) -> VortexResult> { + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; let Some(field_path) = direct_field_path(input) else { return Ok(None); }; Ok(self.stat_ref(&field_path, stat)) } - - fn bind_aggregate( - &self, - input: &Expression, - aggregate_fn: &AggregateFnRef, - stat_dtype: &DType, - ) -> VortexResult> { - bind_direct_aggregate_stat(self, input, aggregate_fn, stat_dtype) - } } impl FileStatsBinder<'_> { diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index cd8bbd6f881..f3432d6f93f 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -22,7 +22,6 @@ use vortex_array::expr::stats::Stat; use vortex_array::scalar_fn::internal::row_count::contains_row_count; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::stats::bind::StatBinder; -use vortex_array::stats::bind::bind_direct_aggregate_stat; use vortex_array::stats::bind::bind_stats; use vortex_array::validity::Validity; use vortex_buffer::buffer; @@ -139,12 +138,15 @@ impl StatBinder for ZoneMapStatsBinder<'_> { self.zone_map.array.dtype().clone() } - fn bind_stat( + fn bind_aggregate( &self, input: &Expression, - stat: Stat, + aggregate_fn: &AggregateFnRef, _stat_dtype: &DType, ) -> VortexResult> { + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; if !is_root(input) { return Ok(None); } @@ -158,15 +160,6 @@ impl StatBinder for ZoneMapStatsBinder<'_> { } Ok(Some(get_item(stat.name(), root()))) } - - fn bind_aggregate( - &self, - input: &Expression, - aggregate_fn: &AggregateFnRef, - stat_dtype: &DType, - ) -> VortexResult> { - bind_direct_aggregate_stat(self, input, aggregate_fn, stat_dtype) - } } /// Build per-zone row counts for a zone map. From 4dac5958a9f4d82f348f80521080136f57b359c7 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 18 Jun 2026 12:52:47 -0400 Subject: [PATCH 27/28] Fix DuckDB projection test compile Signed-off-by: "Nicholas Gates" --- vortex-duckdb/src/projection.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vortex-duckdb/src/projection.rs b/vortex-duckdb/src/projection.rs index 4298c3131e4..4521115666c 100644 --- a/vortex-duckdb/src/projection.rs +++ b/vortex-duckdb/src/projection.rs @@ -358,9 +358,9 @@ mod tests { let second = col("second"); let mut filter_exprs = Vec::new(); - push_filter_expr(&mut filter_exprs, first.clone()); - push_filter_expr(&mut filter_exprs, second.clone()); - push_filter_expr(&mut filter_exprs, first.clone()); + push_filter_expr(&mut filter_exprs, &first); + push_filter_expr(&mut filter_exprs, &second); + push_filter_expr(&mut filter_exprs, &first); assert_eq!(filter_exprs, vec![first, second]); } From 3e7c2d6d99a1f27d7b64bcb0bf16caa41c2875d7 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 18 Jun 2026 13:41:56 -0400 Subject: [PATCH 28/28] Keep stats binding free of reduction Signed-off-by: "Nicholas Gates" --- vortex-array/src/stats/bind.rs | 33 ++++----------------- vortex-file/src/pruning.rs | 4 --- vortex-layout/src/layouts/zoned/zone_map.rs | 6 +--- 3 files changed, 7 insertions(+), 36 deletions(-) diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index c26cc8a5b40..6d921fd0ab1 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -34,14 +34,6 @@ pub trait StatBinder { /// The dtype scope used to type-check expressions before stats are bound. fn scope(&self) -> &DType; - /// The dtype scope used after stats have been bound. - /// - /// Binders that rewrite stats to a different root expression, such as a - /// stats-table row, should return that post-binding root dtype. - fn bound_scope(&self) -> DType { - self.scope().clone() - } - /// Bind `aggregate_fn(input)` to a concrete expression. /// /// Implementations should return `Ok(None)` when the requested aggregate @@ -72,7 +64,7 @@ pub fn bind_stats( binder: &B, ) -> VortexResult { let scope = binder.scope().clone(); - let lowered = predicate + Ok(predicate .transform_down(|expr| { if !expr.is::() { return Ok(Transformed::no(expr)); @@ -86,9 +78,7 @@ pub fn bind_stats( } } })? - .into_inner(); - - lowered.optimize_recursive(&binder.bound_scope()) + .into_inner()) } fn bind_stat_fn( @@ -129,7 +119,6 @@ mod tests { struct TestBinder { input_scope: DType, - bound_scope: DType, bind_nan_count: bool, } @@ -143,13 +132,6 @@ mod tests { )]), Nullability::NonNullable, ), - bound_scope: DType::Struct( - StructFields::from_iter([( - "f_nan_count", - DType::Primitive(PType::U64, Nullability::NonNullable), - )]), - Nullability::NonNullable, - ), bind_nan_count, } } @@ -160,10 +142,6 @@ mod tests { &self.input_scope } - fn bound_scope(&self) -> DType { - self.bound_scope.clone() - } - fn bind_aggregate( &self, _input: &Expression, @@ -203,16 +181,17 @@ mod tests { } #[test] - fn missing_stats_fold_when_kleene_semantics_allow_it() -> VortexResult<()> { + fn missing_stats_bind_to_null_without_reducing() -> VortexResult<()> { let binder = TestBinder::new(false); + let null_bool = lit(Scalar::null(DType::Bool(Nullability::Nullable))); let bound = bind_stats(and(lit(false), all_non_nan(col("f"))), &binder)?; - assert_eq!(bound, lit(false)); + assert_eq!(bound, and(lit(false), null_bool.clone())); let bound = bind_stats(or(lit(true), all_non_nan(col("f"))), &binder)?; - assert_eq!(bound, lit(true)); + assert_eq!(bound, or(lit(true), null_bool)); Ok(()) } diff --git a/vortex-file/src/pruning.rs b/vortex-file/src/pruning.rs index 67267086adf..559df97d2d0 100644 --- a/vortex-file/src/pruning.rs +++ b/vortex-file/src/pruning.rs @@ -75,10 +75,6 @@ impl StatBinder for FileStatsBinder<'_> { self.dtype } - fn bound_scope(&self) -> DType { - DType::Null - } - fn bind_aggregate( &self, input: &Expression, diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index f3432d6f93f..dbfaab93910 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -121,7 +121,7 @@ impl ZoneMap { fn lower_stats(&self, predicate: Expression) -> VortexResult { let binder = ZoneMapStatsBinder { zone_map: self }; - bind_stats(predicate, &binder) + bind_stats(predicate, &binder)?.optimize_recursive(self.array.dtype()) } } @@ -134,10 +134,6 @@ impl StatBinder for ZoneMapStatsBinder<'_> { &self.zone_map.column_dtype } - fn bound_scope(&self) -> DType { - self.zone_map.array.dtype().clone() - } - fn bind_aggregate( &self, input: &Expression,