diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 661e43f8548b2..98379241c366a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -361,154 +361,167 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { _.containsAnyPattern(AND, OR, NOT), ruleId) { case q: LogicalPlan => q.transformExpressionsUpWithPruning( _.containsAnyPattern(AND, OR, NOT), ruleId) { - case TrueLiteral And e => e - case e And TrueLiteral => e - case FalseLiteral Or e => e - case e Or FalseLiteral => e - - case FalseLiteral And _ => FalseLiteral - case _ And FalseLiteral => FalseLiteral - case TrueLiteral Or _ => TrueLiteral - case _ Or TrueLiteral => TrueLiteral - - case a And b if Not(a).semanticEquals(b) => - If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral) - case a And b if a.semanticEquals(Not(b)) => - If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral) - - case a Or b if Not(a).semanticEquals(b) => - If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral) - case a Or b if a.semanticEquals(Not(b)) => - If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral) - - case a And b if a.semanticEquals(b) => a - case a Or b if a.semanticEquals(b) => a - - // The following optimizations are applicable only when the operands are not nullable, - // since the three-value logic of AND and OR are different in NULL handling. - // See the chart: - // +---------+---------+---------+---------+ - // | operand | operand | OR | AND | - // +---------+---------+---------+---------+ - // | TRUE | TRUE | TRUE | TRUE | - // | TRUE | FALSE | TRUE | FALSE | - // | FALSE | FALSE | FALSE | FALSE | - // | UNKNOWN | TRUE | TRUE | UNKNOWN | - // | UNKNOWN | FALSE | UNKNOWN | FALSE | - // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN | - // +---------+---------+---------+---------+ - - // (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable. - case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c) - // (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable. - case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b) - // ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable. - case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c) - // ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable. - case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c) - - // (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable. - case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c) - // (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable. - case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b) - // ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable. - case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c) - // ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable. - case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c) - - // Common factor elimination for conjunction - case and @ (left And right) => - // 1. Split left and right to get the disjunctive predicates, - // i.e. lhs = (a || b), rhs = (a || c) - // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) - // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) - // 4. If common is non-empty, apply the formula to get the optimized predicate: - // common || (ldiff && rdiff) - // 5. Else if common is empty, split left and right to get the conjunctive predicates. - // for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, c), distinct = (a, b, c) - // optimized predicate: (a && b && c) - val lhs = splitDisjunctivePredicates(left) - val rhs = splitDisjunctivePredicates(right) - val common = lhs.filter(e => rhs.exists(e.semanticEquals)) - if (common.nonEmpty) { - val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals)) - val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals)) - if (ldiff.isEmpty || rdiff.isEmpty) { - // (a || b || c || ...) && (a || b) => (a || b) - common.reduce(Or) - } else { - // (a || b || c || ...) && (a || b || d || ...) => - // a || b || ((c || ...) && (d || ...)) - (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) - } + actualExprTransformer + } + } + + val actualExprTransformer: PartialFunction[Expression, Expression] = { + case TrueLiteral And e => e + case e And TrueLiteral => e + case FalseLiteral Or e => e + case e Or FalseLiteral => e + + case FalseLiteral And _ => FalseLiteral + case _ And FalseLiteral => FalseLiteral + case TrueLiteral Or _ => TrueLiteral + case _ Or TrueLiteral => TrueLiteral + + case a And b if Not(a).semanticEquals(b) => + If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral) + case a And b if a.semanticEquals(Not(b)) => + If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral) + + case a Or b if Not(a).semanticEquals(b) => + If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral) + case a Or b if a.semanticEquals(Not(b)) => + If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral) + + case a And b if a.semanticEquals(b) => a + case a Or b if a.semanticEquals(b) => a + + // The following optimizations are applicable only when the operands are not nullable, + // since the three-value logic of AND and OR are different in NULL handling. + // See the chart: + // +---------+---------+---------+---------+ + // | operand | operand | OR | AND | + // +---------+---------+---------+---------+ + // | TRUE | TRUE | TRUE | TRUE | + // | TRUE | FALSE | TRUE | FALSE | + // | FALSE | FALSE | FALSE | FALSE | + // | UNKNOWN | TRUE | TRUE | UNKNOWN | + // | UNKNOWN | FALSE | UNKNOWN | FALSE | + // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN | + // +---------+---------+---------+---------+ + + // (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable. + case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c) + // (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable. + case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b) + // ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable. + case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c) + // ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable. + case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c) + + // (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable. + case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c) + // (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable. + case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b) + // ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable. + case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c) + // ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable. + case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c) + + // Common factor elimination for conjunction + case and @ (left And right) => + // 1. Split left and right to get the disjunctive predicates, + // i.e. lhs = (a || b), rhs = (a || c) + // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) + // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) + // 4. If common is non-empty, apply the formula to get the optimized predicate: + // common || (ldiff && rdiff) + // 5. Else if common is empty, split left and right to get the conjunctive predicates. + // for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, c), distinct = (a, b, c) + // optimized predicate: (a && b && c) + val lhs = splitDisjunctivePredicates(left) + val rhs = splitDisjunctivePredicates(right) + val common = lhs.filter(e => rhs.exists(e.semanticEquals)) + if (common.nonEmpty) { + val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals)) + val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals)) + if (ldiff.isEmpty || rdiff.isEmpty) { + // (a || b || c || ...) && (a || b) => (a || b) + common.reduce(Or) } else { - // No common factors from disjunctive predicates, reduce common factor from conjunction - val all = splitConjunctivePredicates(left) ++ splitConjunctivePredicates(right) - val distinct = ExpressionSet(all) - if (all.size == distinct.size) { - // No common factors, return the original predicate - and - } else { - // (a && b) && a && (a && c) => a && b && c - buildBalancedPredicate(distinct.toSeq, And) - } + // (a || b || c || ...) && (a || b || d || ...) => + // a || b || ((c || ...) && (d || ...)) + (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) } + } else { + // No common factors from disjunctive predicates, reduce common factor from conjunction + val all = splitConjunctivePredicates(left) ++ splitConjunctivePredicates(right) + val distinct = ExpressionSet(all) + if (all.size == distinct.size) { + // No common factors, return the original predicate + and + } else { + // (a && b) && a && (a && c) => a && b && c + buildBalancedPredicate(distinct.toSeq, And) + } + } - // Common factor elimination for disjunction - case or @ (left Or right) => - // 1. Split left and right to get the conjunctive predicates, - // i.e. lhs = (a && b), rhs = (a && c) - // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) - // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) - // 4. If common is non-empty, apply the formula to get the optimized predicate: - // common && (ldiff || rdiff) - // 5. Else if common is empty, split left and right to get the conjunctive predicates. - // for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), distinct = (a, b, c) - // optimized predicate: (a || b || c) - val lhs = splitConjunctivePredicates(left) - val rhs = splitConjunctivePredicates(right) - val common = lhs.filter(e => rhs.exists(e.semanticEquals)) - if (common.nonEmpty) { - val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals)) - val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals)) - if (ldiff.isEmpty || rdiff.isEmpty) { - // (a && b) || (a && b && c && ...) => a && b - common.reduce(And) - } else { - // (a && b && c && ...) || (a && b && d && ...) => - // a && b && ((c && ...) || (d && ...)) - (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) - } + // Common factor elimination for disjunction + case or @ (left Or right) => + // 1. Split left and right to get the conjunctive predicates, + // i.e. lhs = (a && b), rhs = (a && c) + // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) + // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) + // 4. If common is non-empty, apply the formula to get the optimized predicate: + // common && (ldiff || rdiff) + // 5. Else if common is empty, split left and right to get the conjunctive predicates. + // for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), distinct = (a, b, c) + // optimized predicate: (a || b || c) + val lhs = splitConjunctivePredicates(left) + val rhs = splitConjunctivePredicates(right) + val common = lhs.filter(e => rhs.exists(e.semanticEquals)) + if (common.nonEmpty) { + val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals)) + val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals)) + if (ldiff.isEmpty || rdiff.isEmpty) { + // (a && b) || (a && b && c && ...) => a && b + common.reduce(And) } else { - // No common factors in conjunctive predicates, reduce common factor from disjunction - val all = splitDisjunctivePredicates(left) ++ splitDisjunctivePredicates(right) - val distinct = ExpressionSet(all) - if (all.size == distinct.size) { - // No common factors, return the original predicate - or - } else { - // (a || b) || a || (a || c) => a || b || c - buildBalancedPredicate(distinct.toSeq, Or) - } + // (a && b && c && ...) || (a && b && d && ...) => + // a && b && ((c && ...) || (d && ...)) + (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) } + } else { + // No common factors in conjunctive predicates, reduce common factor from disjunction + val all = splitDisjunctivePredicates(left) ++ splitDisjunctivePredicates(right) + val distinct = ExpressionSet(all) + if (all.size == distinct.size) { + // No common factors, return the original predicate + or + } else { + // (a || b) || a || (a || c) => a || b || c + buildBalancedPredicate(distinct.toSeq, Or) + } + } - case Not(TrueLiteral) => FalseLiteral - case Not(FalseLiteral) => TrueLiteral + case Not(TrueLiteral) => FalseLiteral + case Not(FalseLiteral) => TrueLiteral - case Not(a GreaterThan b) => LessThanOrEqual(a, b) - case Not(a GreaterThanOrEqual b) => LessThan(a, b) + case Not(a GreaterThan b) => LessThanOrEqual(a, b) + case Not(a GreaterThanOrEqual b) => LessThan(a, b) - case Not(a LessThan b) => GreaterThanOrEqual(a, b) - case Not(a LessThanOrEqual b) => GreaterThan(a, b) + case Not(a LessThan b) => GreaterThanOrEqual(a, b) + case Not(a LessThanOrEqual b) => GreaterThan(a, b) - case Not(a Or b) => And(Not(a), Not(b)) - case Not(a And b) => Or(Not(a), Not(b)) + // SPARK-54881: push down the NOT operators on children, before attaching the junction Node + // to the main tree. This ensures idempotency in an optimal way and avoids an extra rule + // iteration. + case Not(a Or b) => + And(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), ruleId) { + actualExprTransformer + } + case Not(a And b) => + Or(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), ruleId) { + actualExprTransformer + } - case Not(Not(e)) => e + case Not(Not(e)) => e - case Not(IsNull(e)) => IsNotNull(e) - case Not(IsNotNull(e)) => IsNull(e) - } + case Not(IsNull(e)) => IsNotNull(e) + case Not(IsNotNull(e)) => IsNull(e) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 4cc2ee99284a5..5a44119bf0495 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -291,6 +291,37 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper { checkCondition(Not(IsNull($"b")), IsNotNull($"b")) } + test("SPARK-54881: simplify Not(Expr) in single pass") { + def executeRuleOnce(exprToTest: Expression, optimizedExprExpected: Expression): Unit = { + val planAfterRuleApp = BooleanSimplification.apply(testRelation.where(exprToTest).analyze) + val expectedOptPlan = testRelation.where(optimizedExprExpected).analyze + comparePlans(expectedOptPlan, planAfterRuleApp) + } + // check simplify Not(A <= B OR A >= B) to (a > b AND a < b) in single pass + executeRuleOnce( + Not(($"a" <= $"b") || ($"a" >= $"b")), + $"a" > $"b" && $"a" < $"b" + ) + + // check simplify Not((expr1 OR expr2) OR (expr3 AND expr4)) in single pass + executeRuleOnce( + Not(($"a" <= $"b" || $"c" > $"a" + 4) || ($"a" >= $"b" && $"c" < $"a")), + And( + And($"a" > $"b", $"c" <= $"a" + 4), + Or($"a" < $"b", $"c" >= $"a") + ) + ) + + // check simplify Not((expr1 OR expr2) AND (expr3 OR expr4)) in single pass + executeRuleOnce( + Not(($"a" <= $"b" || $"c" > $"a" + 4) && ($"a" >= $"b" || $"c" < $"a")), + Or( + And($"a" > $"b", $"c" <= $"a" + 4), + And($"a" < $"b", $"c" >= $"a") + ) + ) + } + protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)