From 1ee454d057493d4d4118baa643216bf1fde85d3d Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Fri, 22 Jul 2022 13:47:07 +0800 Subject: [PATCH 1/2] Fix EliminateSorts remove global sort below the local sort --- .../sql/catalyst/optimizer/Optimizer.scala | 28 +++++++++++++++---- .../optimizer/EliminateSortsSuite.scala | 16 +++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 78fb8b5de8886..651565801e695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1528,21 +1528,33 @@ object EliminateSorts extends Rule[LogicalPlan] { } case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => applyLocally.lift(child).getOrElse(child) - case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) + case s @ Sort(_, global, child) => s.copy(child = recursiveRemoveSort(child, global)) case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) => - j.copy(left = recursiveRemoveSort(originLeft), right = recursiveRemoveSort(originRight)) + j.copy(left = recursiveRemoveSort(originLeft, true), + right = recursiveRemoveSort(originRight, true)) case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) => - g.copy(child = recursiveRemoveSort(originChild)) + g.copy(child = recursiveRemoveSort(originChild, true)) } - private def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = { + /** + * If the upper sort is global then we can remove the global or local sort recursively. + * If the upper sort is local then we can only remove the local sort recursively. + */ + private def recursiveRemoveSort( + plan: LogicalPlan, + canRemoveGlobalSort: Boolean): LogicalPlan = { if (!plan.containsPattern(SORT)) { return plan } plan match { - case Sort(_, _, child) => recursiveRemoveSort(child) + case Sort(_, _, child) if canRemoveGlobalSort => + recursiveRemoveSort(child, canRemoveGlobalSort) + case Sort(_, false, child) => + recursiveRemoveSort(child, canRemoveGlobalSort) case other if canEliminateSort(other) => - other.withNewChildren(other.children.map(recursiveRemoveSort)) + other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, canRemoveGlobalSort))) + case other if canEliminateGlobalSort(other) => + other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, true))) case _ => plan } } @@ -1550,6 +1562,10 @@ object EliminateSorts extends Rule[LogicalPlan] { private def canEliminateSort(plan: LogicalPlan): Boolean = plan match { case p: Project => p.projectList.forall(_.deterministic) case f: Filter => f.condition.deterministic + case _ => false + } + + private def canEliminateGlobalSort(plan: LogicalPlan): Boolean = plan match { case r: RepartitionByExpression => r.partitionExpressions.forall(_.deterministic) case r: RebalancePartitions => r.partitionExpressions.forall(_.deterministic) case _: Repartition => true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 865a06368a42e..1eef0bd407786 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -424,4 +424,20 @@ class EliminateSortsSuite extends AnalysisTest { comparePlans(optimized, correctAnswer) } } + + test("SPARK-39835: Fix EliminateSorts remove global sort below the local sort") { + // local - global + val plan = testRelation.orderBy($"a".asc).sortBy($"c".asc).analyze + comparePlans(Optimize.execute(plan), plan) + + // local - global - global + val plan2 = testRelation.orderBy($"a".asc).orderBy($"b".asc).sortBy($"c".asc).analyze + val expected2 = testRelation.orderBy($"b".asc).sortBy($"c".asc).analyze + comparePlans(Optimize.execute(plan2), expected2) + + // local - global - local + val plan3 = testRelation.sortBy($"a".asc).orderBy($"b".asc).sortBy($"c".asc).analyze + val expected3 = testRelation.orderBy($"b".asc).sortBy($"c".asc).analyze + comparePlans(Optimize.execute(plan3), expected3) + } } From 09efde846d88884502de2208973873325ed79021 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Sat, 23 Jul 2022 12:42:05 +0800 Subject: [PATCH 2/2] address comment --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 +--- .../spark/sql/catalyst/optimizer/EliminateSortsSuite.scala | 6 +++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 651565801e695..653d735da263a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1547,9 +1547,7 @@ object EliminateSorts extends Rule[LogicalPlan] { return plan } plan match { - case Sort(_, _, child) if canRemoveGlobalSort => - recursiveRemoveSort(child, canRemoveGlobalSort) - case Sort(_, false, child) => + case Sort(_, global, child) if canRemoveGlobalSort || !global => recursiveRemoveSort(child, canRemoveGlobalSort) case other if canEliminateSort(other) => other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, canRemoveGlobalSort))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 1eef0bd407786..1d879a7065e92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -426,16 +426,16 @@ class EliminateSortsSuite extends AnalysisTest { } test("SPARK-39835: Fix EliminateSorts remove global sort below the local sort") { - // local - global + // global -> local val plan = testRelation.orderBy($"a".asc).sortBy($"c".asc).analyze comparePlans(Optimize.execute(plan), plan) - // local - global - global + // global -> global -> local val plan2 = testRelation.orderBy($"a".asc).orderBy($"b".asc).sortBy($"c".asc).analyze val expected2 = testRelation.orderBy($"b".asc).sortBy($"c".asc).analyze comparePlans(Optimize.execute(plan2), expected2) - // local - global - local + // local -> global -> local val plan3 = testRelation.sortBy($"a".asc).orderBy($"b".asc).sortBy($"c".asc).analyze val expected3 = testRelation.orderBy($"b".asc).sortBy($"c".asc).analyze comparePlans(Optimize.execute(plan3), expected3)