Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -361,154 +361,167 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
_.containsAnyPattern(AND, OR, NOT), ruleId) {
case q: LogicalPlan => q.transformExpressionsUpWithPruning(
_.containsAnyPattern(AND, OR, NOT), ruleId) {
case TrueLiteral And e => e
case e And TrueLiteral => e
case FalseLiteral Or e => e
case e Or FalseLiteral => e

case FalseLiteral And _ => FalseLiteral
case _ And FalseLiteral => FalseLiteral
case TrueLiteral Or _ => TrueLiteral
case _ Or TrueLiteral => TrueLiteral

case a And b if Not(a).semanticEquals(b) =>
If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral)
case a And b if a.semanticEquals(Not(b)) =>
If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral)

case a Or b if Not(a).semanticEquals(b) =>
If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral)
case a Or b if a.semanticEquals(Not(b)) =>
If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral)

case a And b if a.semanticEquals(b) => a
case a Or b if a.semanticEquals(b) => a

// The following optimizations are applicable only when the operands are not nullable,
// since the three-value logic of AND and OR are different in NULL handling.
// See the chart:
// +---------+---------+---------+---------+
// | operand | operand | OR | AND |
// +---------+---------+---------+---------+
// | TRUE | TRUE | TRUE | TRUE |
// | TRUE | FALSE | TRUE | FALSE |
// | FALSE | FALSE | FALSE | FALSE |
// | UNKNOWN | TRUE | TRUE | UNKNOWN |
// | UNKNOWN | FALSE | UNKNOWN | FALSE |
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
// +---------+---------+---------+---------+

// (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c)
// (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b)
// ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c)
// ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c)

// (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c)
// (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b)
// ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c)
// ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c)

// Common factor elimination for conjunction
case and @ (left And right) =>
// 1. Split left and right to get the disjunctive predicates,
// i.e. lhs = (a || b), rhs = (a || c)
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
// 4. If common is non-empty, apply the formula to get the optimized predicate:
// common || (ldiff && rdiff)
// 5. Else if common is empty, split left and right to get the conjunctive predicates.
// for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, c), distinct = (a, b, c)
// optimized predicate: (a && b && c)
val lhs = splitDisjunctivePredicates(left)
val rhs = splitDisjunctivePredicates(right)
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
if (common.nonEmpty) {
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
if (ldiff.isEmpty || rdiff.isEmpty) {
// (a || b || c || ...) && (a || b) => (a || b)
common.reduce(Or)
} else {
// (a || b || c || ...) && (a || b || d || ...) =>
// a || b || ((c || ...) && (d || ...))
(common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
}
actualExprTransformer
}
}

val actualExprTransformer: PartialFunction[Expression, Expression] = {
case TrueLiteral And e => e
case e And TrueLiteral => e
case FalseLiteral Or e => e
case e Or FalseLiteral => e

case FalseLiteral And _ => FalseLiteral
case _ And FalseLiteral => FalseLiteral
case TrueLiteral Or _ => TrueLiteral
case _ Or TrueLiteral => TrueLiteral

case a And b if Not(a).semanticEquals(b) =>
If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral)
case a And b if a.semanticEquals(Not(b)) =>
If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral)

case a Or b if Not(a).semanticEquals(b) =>
If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral)
case a Or b if a.semanticEquals(Not(b)) =>
If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral)

case a And b if a.semanticEquals(b) => a
case a Or b if a.semanticEquals(b) => a

// The following optimizations are applicable only when the operands are not nullable,
// since the three-value logic of AND and OR are different in NULL handling.
// See the chart:
// +---------+---------+---------+---------+
// | operand | operand | OR | AND |
// +---------+---------+---------+---------+
// | TRUE | TRUE | TRUE | TRUE |
// | TRUE | FALSE | TRUE | FALSE |
// | FALSE | FALSE | FALSE | FALSE |
// | UNKNOWN | TRUE | TRUE | UNKNOWN |
// | UNKNOWN | FALSE | UNKNOWN | FALSE |
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
// +---------+---------+---------+---------+

// (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c)
// (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b)
// ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c)
// ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c)

// (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c)
// (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b)
// ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c)
// ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c)

// Common factor elimination for conjunction
case and @ (left And right) =>
// 1. Split left and right to get the disjunctive predicates,
// i.e. lhs = (a || b), rhs = (a || c)
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
// 4. If common is non-empty, apply the formula to get the optimized predicate:
// common || (ldiff && rdiff)
// 5. Else if common is empty, split left and right to get the conjunctive predicates.
// for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, c), distinct = (a, b, c)
// optimized predicate: (a && b && c)
val lhs = splitDisjunctivePredicates(left)
val rhs = splitDisjunctivePredicates(right)
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
if (common.nonEmpty) {
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
if (ldiff.isEmpty || rdiff.isEmpty) {
// (a || b || c || ...) && (a || b) => (a || b)
common.reduce(Or)
} else {
// No common factors from disjunctive predicates, reduce common factor from conjunction
val all = splitConjunctivePredicates(left) ++ splitConjunctivePredicates(right)
val distinct = ExpressionSet(all)
if (all.size == distinct.size) {
// No common factors, return the original predicate
and
} else {
// (a && b) && a && (a && c) => a && b && c
buildBalancedPredicate(distinct.toSeq, And)
}
// (a || b || c || ...) && (a || b || d || ...) =>
// a || b || ((c || ...) && (d || ...))
(common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
}
} else {
// No common factors from disjunctive predicates, reduce common factor from conjunction
val all = splitConjunctivePredicates(left) ++ splitConjunctivePredicates(right)
val distinct = ExpressionSet(all)
if (all.size == distinct.size) {
// No common factors, return the original predicate
and
} else {
// (a && b) && a && (a && c) => a && b && c
buildBalancedPredicate(distinct.toSeq, And)
}
}

// Common factor elimination for disjunction
case or @ (left Or right) =>
// 1. Split left and right to get the conjunctive predicates,
// i.e. lhs = (a && b), rhs = (a && c)
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
// 4. If common is non-empty, apply the formula to get the optimized predicate:
// common && (ldiff || rdiff)
// 5. Else if common is empty, split left and right to get the conjunctive predicates.
// for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), distinct = (a, b, c)
// optimized predicate: (a || b || c)
val lhs = splitConjunctivePredicates(left)
val rhs = splitConjunctivePredicates(right)
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
if (common.nonEmpty) {
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
if (ldiff.isEmpty || rdiff.isEmpty) {
// (a && b) || (a && b && c && ...) => a && b
common.reduce(And)
} else {
// (a && b && c && ...) || (a && b && d && ...) =>
// a && b && ((c && ...) || (d && ...))
(common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
}
// Common factor elimination for disjunction
case or @ (left Or right) =>
// 1. Split left and right to get the conjunctive predicates,
// i.e. lhs = (a && b), rhs = (a && c)
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
// 4. If common is non-empty, apply the formula to get the optimized predicate:
// common && (ldiff || rdiff)
// 5. Else if common is empty, split left and right to get the conjunctive predicates.
// for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), distinct = (a, b, c)
// optimized predicate: (a || b || c)
val lhs = splitConjunctivePredicates(left)
val rhs = splitConjunctivePredicates(right)
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
if (common.nonEmpty) {
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
if (ldiff.isEmpty || rdiff.isEmpty) {
// (a && b) || (a && b && c && ...) => a && b
common.reduce(And)
} else {
// No common factors in conjunctive predicates, reduce common factor from disjunction
val all = splitDisjunctivePredicates(left) ++ splitDisjunctivePredicates(right)
val distinct = ExpressionSet(all)
if (all.size == distinct.size) {
// No common factors, return the original predicate
or
} else {
// (a || b) || a || (a || c) => a || b || c
buildBalancedPredicate(distinct.toSeq, Or)
}
// (a && b && c && ...) || (a && b && d && ...) =>
// a && b && ((c && ...) || (d && ...))
(common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
}
} else {
// No common factors in conjunctive predicates, reduce common factor from disjunction
val all = splitDisjunctivePredicates(left) ++ splitDisjunctivePredicates(right)
val distinct = ExpressionSet(all)
if (all.size == distinct.size) {
// No common factors, return the original predicate
or
} else {
// (a || b) || a || (a || c) => a || b || c
buildBalancedPredicate(distinct.toSeq, Or)
}
}

case Not(TrueLiteral) => FalseLiteral
case Not(FalseLiteral) => TrueLiteral
case Not(TrueLiteral) => FalseLiteral
case Not(FalseLiteral) => TrueLiteral

case Not(a GreaterThan b) => LessThanOrEqual(a, b)
case Not(a GreaterThanOrEqual b) => LessThan(a, b)
case Not(a GreaterThan b) => LessThanOrEqual(a, b)
case Not(a GreaterThanOrEqual b) => LessThan(a, b)

case Not(a LessThan b) => GreaterThanOrEqual(a, b)
case Not(a LessThanOrEqual b) => GreaterThan(a, b)
case Not(a LessThan b) => GreaterThanOrEqual(a, b)
case Not(a LessThanOrEqual b) => GreaterThan(a, b)

case Not(a Or b) => And(Not(a), Not(b))
case Not(a And b) => Or(Not(a), Not(b))
// SPARK-54881: push down the NOT operators on children, before attaching the junction Node
// to the main tree. This ensures idempotency in an optimal way and avoids an extra rule
// iteration.
case Not(a Or b) =>
And(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), ruleId) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe? I mean, before this PR the simplification logic of actualExprTransformer was called with transformUp..., but now you call it with transformDown... (please note that a Not node can be deep down in a or b). Is there any reason why we invoke the logic with transformUp or could the whole rule use transformDown on expression trees?

@peter-toth peter-toth Jan 6, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not something like And(actualExprTransformer.applyOrElse(Not(a), identity), actualExprTransformer.applyOrElse(Not(b), identity)) just to be on the safe side?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe? I mean, before this PR the simplification logic of actualExprTransformer was called with transformUp..., but now you call it with transformDown... (please note that a Not node can be deep down in a or b). Is there any reason why we invoke the logic with transformUp or could the whole rule use transformDown on expression trees?

I believe it's safe..
If the original logic is modified such that instead of transform up ,
transform down is used, then this bug would be fixed, but other cases like
that mentioned in Constant folding suite will break in idempotency.
To take care of both the cases, use of transform up and transform down is
needed...as in the pr. This reason is also mentioned in the initial PR details.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would make sense to split the logic into 2 traversals? Keep the current transformExpressionsUpWithPruning() with the current cases excluding these 2 Not "pushdowns" and then a transformExpressionsDownWithPruning() with these 2 cases.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That in my view, would defeat the purpose of achieving idempotency in a minimum possible tree traversal. If we separate it in 2 traversals, then only for a part of subtree , the whole traversal will have to happen again.
As such I do not see any issue with the current code of subtree traversal of the newly added children to cause any issue.. Is there something which is making it suspicious?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides that it is hard to reason about a nested traversals, my problem with the current inner transformDownWithPruning() is that it can call actualExprTransformer top-down way not only on the new And and Not nodes, but also on nodes of a and b subtrees if those contain Not nodes.
The current rule might be safe in top-down manner as well, but I feel it would be a bit cleaner to separate the traversals. But, on the other hand, separating the traversals would require 2 unique rule ids so the current PR has pros as well.

Anyways, I'm ok with this PR.

@cloud-fan, do you have any concerns or comments on this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that rules like BooleanSimplification would work same bottom - up, or top - down in terms of functionality, so long as number of iterations to achieve idempotency is ignored.
If one goes top -down, some cases (Not) become optimal, while if you bottom - up ( other cases like depicted in ConstantFoldinghSuite become optimal).

The point is that the subtrees in NOT (Junction) before being acted upon by top- down rule , have already undergone the traversal of bottom- up, so the top - down would act only for pushing of Not, and moreover the traversal would terminate the moment subtree has no NOT pushed.

In my mind, I am comfortable with the behaviour.

actualExprTransformer
}
case Not(a And b) =>
Or(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), ruleId) {
actualExprTransformer
}

case Not(Not(e)) => e
case Not(Not(e)) => e

case Not(IsNull(e)) => IsNotNull(e)
case Not(IsNotNull(e)) => IsNull(e)
}
case Not(IsNull(e)) => IsNotNull(e)
case Not(IsNotNull(e)) => IsNull(e)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,37 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper {
checkCondition(Not(IsNull($"b")), IsNotNull($"b"))
}

test("SPARK-54881: simplify Not(Expr) in single pass") {
def executeRuleOnce(exprToTest: Expression, optimizedExprExpected: Expression): Unit = {
val planAfterRuleApp = BooleanSimplification.apply(testRelation.where(exprToTest).analyze)
val expectedOptPlan = testRelation.where(optimizedExprExpected).analyze
comparePlans(expectedOptPlan, planAfterRuleApp)
}
// check simplify Not(A <= B OR A >= B) to (a > b AND a < b) in single pass
executeRuleOnce(
Not(($"a" <= $"b") || ($"a" >= $"b")),
$"a" > $"b" && $"a" < $"b"
)

// check simplify Not((expr1 OR expr2) OR (expr3 AND expr4)) in single pass
executeRuleOnce(
Not(($"a" <= $"b" || $"c" > $"a" + 4) || ($"a" >= $"b" && $"c" < $"a")),
And(
And($"a" > $"b", $"c" <= $"a" + 4),
Or($"a" < $"b", $"c" >= $"a")
)
)

// check simplify Not((expr1 OR expr2) AND (expr3 OR expr4)) in single pass
executeRuleOnce(
Not(($"a" <= $"b" || $"c" > $"a" + 4) && ($"a" >= $"b" || $"c" < $"a")),
Or(
And($"a" > $"b", $"c" <= $"a" + 4),
And($"a" < $"b", $"c" >= $"a")
)
)
}

protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
Expand Down