From ca09339faef343352cd021789b8ea20b20abc0c1 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Tue, 23 Feb 2021 19:30:21 -0800 Subject: [PATCH 1/3] Push down limit for LEFT SEMI and LEFT ANTI join --- .../sql/catalyst/optimizer/Optimizer.scala | 18 +++++++---- .../optimizer/LimitPushdownSuite.scala | 20 ++++++++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 30 +++++++++++++++++++ 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 46a90f600b2a3..2fc3e0d37578b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -539,12 +539,16 @@ object LimitPushDown extends Rule[LogicalPlan] { // pushdown Limit. case LocalLimit(exp, u: Union) => LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _)))) - // Add extra limits below JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to - // the left and right sides, respectively. For INNER and CROSS JOIN we push limits to - // both the left and right sides if join condition is empty. It's not safe to push limits - // below FULL OUTER JOIN in the general case without a more invasive rewrite. - // We also need to ensure that this limit pushdown rule will not eventually introduce limits - // on both sides if it is applied multiple times. Therefore: + // Add extra limits below JOIN: + // 1. For LEFT OUTER and RIGHT OUTER JOIN, we push limits to the left and right sides, + // respectively. + // 2. For INNER and CROSS JOIN, we push limits to both the left and right sides if join + // condition is empty. + // 3. For LEFT SEMI and LEFT ANTI JOIN, we push limits to the left side if join condition + // is empty. + // It's not safe to push limits below FULL OUTER JOIN in the general case without a more + // invasive rewrite. We also need to ensure that this limit pushdown rule will not eventually + // introduce limits on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. case LocalLimit(exp, join @ Join(left, right, joinType, conditionOpt, _)) => @@ -555,6 +559,8 @@ object LimitPushDown extends Rule[LogicalPlan] { join.copy( left = maybePushLocalLimit(exp, left), right = maybePushLocalLimit(exp, right)) + case LeftSemi | LeftAnti if conditionOpt.isEmpty => + join.copy(left = maybePushLocalLimit(exp, left)) case _ => join } LocalLimit(exp, newJoin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 5c760264ff219..7a33b5b4b53df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Add -import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -212,4 +212,22 @@ class LimitPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } } + + test("SPARK-34514: Push down limit through LEFT SEMI and LEFT ANTI join") { + // Push down when condition is empty + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery = x.join(y, joinType).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, joinType)).analyze + comparePlans(optimized, correctAnswer) + } + + // No push down when condition is not empty + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery = x.join(y, joinType, Some("x.a".attr === "y.b".attr)).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(y, joinType, Some("x.a".attr === "y.b".attr))).analyze + comparePlans(optimized, correctAnswer) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index fe8a080ac5aeb..82f79fb4dc083 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4034,6 +4034,36 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(df, Row(0, 0) :: Row(0, 1) :: Row(0, 2) :: Nil) } } + + test("SPARK-34514: Push down limit through LEFT SEMI and LEFT ANTI join") { + withTable("left_table", "nonempty_right_table", "empty_right_table") { + spark.range(5).toDF().repartition(1).write.saveAsTable("left_table") + spark.range(3).write.saveAsTable("nonempty_right_table") + spark.range(0).write.saveAsTable("empty_right_table") + Seq("LEFT SEMI").foreach { joinType => + val joinWithNonEmptyRightDf = spark.sql( + s"SELECT * FROM left_table $joinType JOIN nonempty_right_table LIMIT 3") + val joinWithEmptyRightDf = spark.sql( + s"SELECT * FROM left_table $joinType JOIN empty_right_table LIMIT 3") + + Seq(joinWithNonEmptyRightDf, joinWithEmptyRightDf).foreach { df => + val pushedLocalLimits = df.queryExecution.optimizedPlan.collect { + case l @ LocalLimit(_, _: LogicalRelation) => l + } + assert(pushedLocalLimits.length === 1) + } + + val expectedAnswer = Seq(Row(0), Row(1), Row(2)) + if (joinType == "LEFT SEMI") { + checkAnswer(joinWithNonEmptyRightDf, expectedAnswer) + checkAnswer(joinWithEmptyRightDf, Seq.empty) + } else { + checkAnswer(joinWithNonEmptyRightDf, Seq.empty) + checkAnswer(joinWithEmptyRightDf, expectedAnswer) + } + } + } + } } case class Foo(bar: Option[String]) From d188602647c91d7f1259cffac4cb26b86ffc1fc3 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Tue, 23 Feb 2021 19:37:45 -0800 Subject: [PATCH 2/3] Update previous comment for the rule --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2fc3e0d37578b..b08187d0bc3be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -502,7 +502,7 @@ object RemoveNoopOperators extends Rule[LogicalPlan] { } /** - * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. + * Pushes down [[LocalLimit]] beneath UNION ALL and joins. */ object LimitPushDown extends Rule[LogicalPlan] { From 22bfd5e13ffb5b51a6a29851161146a1af6a5666 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Tue, 23 Feb 2021 22:42:24 -0800 Subject: [PATCH 3/3] Fix unit test --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 82f79fb4dc083..82c49f9cbf29a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4040,7 +4040,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark spark.range(5).toDF().repartition(1).write.saveAsTable("left_table") spark.range(3).write.saveAsTable("nonempty_right_table") spark.range(0).write.saveAsTable("empty_right_table") - Seq("LEFT SEMI").foreach { joinType => + Seq("LEFT SEMI", "LEFT ANTI").foreach { joinType => val joinWithNonEmptyRightDf = spark.sql( s"SELECT * FROM left_table $joinType JOIN nonempty_right_table LIMIT 3") val joinWithEmptyRightDf = spark.sql(