Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 76 additions & 45 deletions vortex-array/src/stats/rewrite/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ impl StatsRewriteRule for BinaryStatsRewrite {

Ok(match operator {
Operator::Eq => {
let left = min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b));
let right = min(rhs).zip(max(lhs)).map(|(a, b)| gt(a, b));
let left = min(ctx, lhs).zip(max(ctx, rhs)).map(|(a, b)| gt(a, b));
let right = min(ctx, rhs).zip(max(ctx, lhs)).map(|(a, b)| gt(a, b));
or_collect(left.into_iter().chain(right))
.map(|value_predicate| with_nan_predicate(ctx, lhs, rhs, value_predicate))
.transpose()?
}
Operator::NotEq => min(lhs)
.zip(max(rhs))
.zip(max(lhs).zip(min(rhs)))
Operator::NotEq => min(ctx, lhs)
.zip(max(ctx, rhs))
.zip(max(ctx, lhs).zip(min(ctx, rhs)))
.map(|((min_lhs, max_rhs), (max_lhs, min_rhs))| {
with_nan_predicate(
ctx,
Expand All @@ -102,20 +102,20 @@ impl StatsRewriteRule for BinaryStatsRewrite {
)
})
.transpose()?,
Operator::Gt => max(lhs)
.zip(min(rhs))
Operator::Gt => max(ctx, lhs)
.zip(min(ctx, rhs))
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt_eq(a, b)))
.transpose()?,
Operator::Gte => max(lhs)
.zip(min(rhs))
Operator::Gte => max(ctx, lhs)
.zip(min(ctx, rhs))
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt(a, b)))
.transpose()?,
Operator::Lt => min(lhs)
.zip(max(rhs))
Operator::Lt => min(ctx, lhs)
.zip(max(ctx, rhs))
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt_eq(a, b)))
.transpose()?,
Operator::Lte => min(lhs)
.zip(max(rhs))
Operator::Lte => min(ctx, lhs)
.zip(max(ctx, rhs))
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt(a, b)))
.transpose()?,
Operator::And => {
Expand Down Expand Up @@ -167,17 +167,17 @@ impl StatsRewriteRule for IsNullLegacyStatsRewrite {
fn falsify(
&self,
expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
Ok(null_count(expr.child(0)).map(|null_count| eq(null_count, lit(0u64))))
Ok(null_count(ctx, expr.child(0)).map(|null_count| eq(null_count, lit(0u64))))
}

fn satisfy(
&self,
expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
Ok(null_count(expr.child(0))
Ok(null_count(ctx, expr.child(0))
.map(|null_count| eq(null_count, RowCount.new_expr(EmptyOptions, []))))
}
}
Expand Down Expand Up @@ -227,18 +227,18 @@ impl StatsRewriteRule for IsNotNullLegacyStatsRewrite {
fn falsify(
&self,
expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
Ok(null_count(expr.child(0))
Ok(null_count(ctx, expr.child(0))
.map(|null_count| eq(null_count, RowCount.new_expr(EmptyOptions, []))))
}

fn satisfy(
&self,
expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
Ok(null_count(expr.child(0)).map(|null_count| eq(null_count, lit(0u64))))
Ok(null_count(ctx, expr.child(0)).map(|null_count| eq(null_count, lit(0u64))))
}
}

Expand Down Expand Up @@ -287,7 +287,7 @@ impl StatsRewriteRule for LikeStatsRewrite {
fn falsify(
&self,
expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
let like_options = expr.as_::<Like>();
if like_options.negated || like_options.case_insensitive {
Expand All @@ -304,8 +304,8 @@ impl StatsRewriteRule for LikeStatsRewrite {
let source = expr.child(0);
Ok(match LikeVariant::from_str(pattern) {
Some(LikeVariant::Exact(text)) => {
min(source)
.zip(max(source))
min(ctx, source)
.zip(max(ctx, source))
.map(|(source_min, source_max)| {
or(
gt(source_min, lit(text.as_ref())),
Expand All @@ -317,8 +317,8 @@ impl StatsRewriteRule for LikeStatsRewrite {
let Some(successor) = prefix.to_string().increment().ok() else {
return Ok(None);
};
min(source)
.zip(max(source))
min(ctx, source)
.zip(max(ctx, source))
.map(|(source_min, source_max)| {
or(
gt_eq(source_min, lit(successor)),
Expand Down Expand Up @@ -361,10 +361,10 @@ impl StatsRewriteRule for ListContainsStatsRewrite {
return Ok(Some(lit(true)));
}

let Some(value_max) = max(needle) else {
let Some(value_max) = max(ctx, needle) else {
return Ok(None);
};
let Some(value_min) = min(needle) else {
let Some(value_min) = min(ctx, needle) else {
return Ok(None);
};

Expand Down Expand Up @@ -398,10 +398,10 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite {

let Some((operator, lhs_stat)) = (match dynamic.operator {
CompareOperator::Eq | CompareOperator::NotEq => None,
CompareOperator::Gt => max(lhs).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)),
CompareOperator::Gte => max(lhs).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)),
CompareOperator::Lt => min(lhs).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)),
CompareOperator::Lte => min(lhs).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)),
CompareOperator::Gt => max(ctx, lhs).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)),
CompareOperator::Gte => max(ctx, lhs).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)),
CompareOperator::Lt => min(ctx, lhs).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)),
CompareOperator::Lte => min(ctx, lhs).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)),
}) else {
return Ok(None);
};
Expand All @@ -418,16 +418,16 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite {
}
}

fn min(expr: &Expression) -> Option<Expression> {
stat_expr(expr, Stat::Min)
fn min(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> Option<Expression> {
stat_expr(ctx, expr, Stat::Min)
}

fn max(expr: &Expression) -> Option<Expression> {
stat_expr(expr, Stat::Max)
fn max(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> Option<Expression> {
stat_expr(ctx, expr, Stat::Max)
}

fn null_count(expr: &Expression) -> Option<Expression> {
stat_expr(expr, Stat::NullCount)
fn null_count(ctx: &StatsRewriteCtx<'_>, expr: &Expression) -> Option<Expression> {
stat_expr(ctx, expr, Stat::NullCount)
}

fn all_null(expr: &Expression) -> Expression {
Expand Down Expand Up @@ -474,7 +474,7 @@ fn has_nans(dtype: &DType) -> bool {
matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float())
}

fn stat_expr(expr: &Expression, stat: Stat) -> Option<Expression> {
fn stat_expr(ctx: &StatsRewriteCtx<'_>, expr: &Expression, stat: Stat) -> Option<Expression> {
if let Some(literal) = literal_stat(expr, stat) {
return Some(literal);
}
Expand All @@ -487,11 +487,19 @@ fn stat_expr(expr: &Expression, stat: Stat) -> Option<Expression> {
}

if let Some(dtype) = expr.as_opt::<Cast>() {
return cast_stat(expr.child(0), dtype, stat);
return cast_stat(ctx, expr.child(0), dtype, stat);
}

stat.aggregate_fn()
.map(|aggregate_fn| stat_fn(expr.clone(), aggregate_fn))
let aggregate_fn = stat.aggregate_fn()?;

// Only manufacture a `StatFn` when the aggregate can actually produce a value for
// this input dtype. For example, `min`/`max` over a struct created by `pack` is
// undefined; such a predicate cannot be lowered against a zone map and would abort
// the scan, so leave the expression un-rewritten (no pruning) instead.
let input_dtype = ctx.return_dtype(expr).ok()?;
aggregate_fn.return_dtype(&input_dtype)?;

Some(stat_fn(expr.clone(), aggregate_fn))
}

fn with_nan_predicate(
Expand Down Expand Up @@ -545,10 +553,15 @@ fn literal_stat(expr: &Expression, stat: Stat) -> Option<Expression> {
}
}

fn cast_stat(expr: &Expression, dtype: &DType, stat: Stat) -> Option<Expression> {
fn cast_stat(
ctx: &StatsRewriteCtx<'_>,
expr: &Expression,
dtype: &DType,
stat: Stat,
) -> Option<Expression> {
match stat {
Stat::Min | Stat::Max => stat_expr(expr, stat).map(|stat| cast(stat, dtype.clone())),
Stat::NaNCount | Stat::Sum | Stat::UncompressedSizeInBytes => stat_expr(expr, stat),
Stat::Min | Stat::Max => stat_expr(ctx, expr, stat).map(|stat| cast(stat, dtype.clone())),
Stat::NaNCount | Stat::Sum | Stat::UncompressedSizeInBytes => stat_expr(ctx, expr, stat),
Stat::NullCount | Stat::IsConstant | Stat::IsSorted | Stat::IsStrictSorted => None,
}
}
Expand Down Expand Up @@ -594,6 +607,7 @@ mod tests {
use crate::expr::lt;
use crate::expr::lt_eq;
use crate::expr::or;
use crate::expr::pack;
use crate::expr::stats::Stat;
use crate::scalar::Scalar;
use crate::scalar_fn::EmptyOptions;
Expand Down Expand Up @@ -671,6 +685,23 @@ mod tests {
Ok(())
}

#[test]
fn does_not_rewrite_min_max_over_unsupported_input() -> VortexResult<()> {
// Regression test for issue #8249. `min`/`max` are undefined for struct inputs
// such as the one produced by `pack`. A comparison against such an expression
// must not be rewritten into a min/max pruning predicate, because that predicate
// cannot be lowered against a zone map and would otherwise abort the scan.
let packed = pack([("a", col("a"))], Nullability::NonNullable);
let bound = Scalar::struct_(
packed.return_dtype(&test_scope())?,
[Scalar::primitive(5i32, Nullability::NonNullable)],
);

assert_eq!(falsify(&lt(packed.clone(), lit(bound.clone())))?, None);
assert_eq!(falsify(&gt(packed, lit(bound)))?, None);
Ok(())
}

#[test]
fn rewrites_boolean_falsifiers() -> VortexResult<()> {
let expr = and(gt(col("a"), lit(10)), lt(col("a"), lit(50)));
Expand Down
Loading