From d4b423c9b0835b399310f4830a609e2d76e4fd15 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Tue, 19 Apr 2022 20:43:55 +0800 Subject: [PATCH] refactor --- .../apache/spark/sql/internal/SQLConf.scala | 9 +++ .../spark/sql/execution/QueryExecution.scala | 2 + .../adaptive/AdaptiveSparkPlanExec.scala | 2 + .../sql/execution/joins/SwitchJoinSides.scala | 56 +++++++++++++++++++ .../org/apache/spark/sql/JoinSuite.scala | 19 +++++++ .../execution/benchmark/JoinBenchmark.scala | 18 ++++++ 6 files changed, 106 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SwitchJoinSides.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 36b666fd59c90..fda03f1fa6dcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2421,6 +2421,15 @@ object SQLConf { .doubleConf .createWithDefault(0.9) + val SWITCH_SORT_MERGE_JOIN_SIDES_ENABLED = + buildConf("spark.sql.switchSortMergeJoinSides.enabled") + .internal() + .doc("If true, switch the inner like join side for sort merge join according to the " + + "plan size and child unique keys.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + private def isValidTimezone(zone: String): Boolean = { Try { DateTimeUtils.getZoneId(zone) }.isSuccess } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9ea769b4cf153..bb2276f368751 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, Insert import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.execution.joins.SwitchJoinSides import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} import org.apache.spark.sql.internal.SQLConf @@ -405,6 +406,7 @@ object QueryExecution { // as the original plan is hidden behind `AdaptiveSparkPlanExec`. adaptiveExecutionRule.toSeq ++ Seq( + SwitchJoinSides, CoalesceBucketsInJoin, PlanDynamicPruningFilters(sparkSession), PlanSubqueries(sparkSession), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 4a2740656688f..36a3acd82cd0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan import org.apache.spark.sql.execution.exchange._ +import org.apache.spark.sql.execution.joins.SwitchJoinSides import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch @@ -116,6 +117,7 @@ case class AdaptiveSparkPlanExec( val ensureRequirements = EnsureRequirements(requiredDistribution.isDefined, requiredDistribution) Seq( + SwitchJoinSides, RemoveRedundantProjects, ensureRequirements, ReplaceHashWithSortAgg, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SwitchJoinSides.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SwitchJoinSides.scala new file mode 100644 index 0000000000000..7e8b94ea1a29d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SwitchJoinSides.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.expressions.ExpressionSet +import org.apache.spark.sql.catalyst.plans.InnerLike +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} +import org.apache.spark.sql.internal.SQLConf + +/** + * Switch Join sides if join satisfies: + * - it's a inner like join + * - it's physical plan is SortMergeJoinExec + * - it's streamed side size is less than buffered + * - it's streamed side is unique for join keys + */ +object SwitchJoinSides extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + if (!conf.getConf(SQLConf.SWITCH_SORT_MERGE_JOIN_SIDES_ENABLED)) { + return plan + } + + plan transformUp { + case j @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right, hint) + if j.logicalLink.isDefined => + j.logicalLink.get match { + case Join(logicalLeft, logicalRight, _: InnerLike, _, _) + if logicalLeft.distinctKeys.exists(_.subsetOf(ExpressionSet(leftKeys))) && + logicalLeft.stats.sizeInBytes * 3 < logicalRight.stats.sizeInBytes => + ProjectExec( + j.output, + SortMergeJoinExec(rightKeys, leftKeys, joinType, condition, right, left, hint) + ) + + case _ => j + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f41944d2ed53d..61f6f19b24c0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrd import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.BatchEvalPythonExec @@ -1440,4 +1441,22 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } } } + + test("SPARK-38887: Support switch inner join side for sort merge join") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(2).selectExpr("id as c1") + val df2 = spark.range(100).selectExpr("id as c2") + val plan1 = df1.groupBy($"c1").agg($"c1").join(df2, $"c1" === $"c2", "inner") + .queryExecution.executedPlan + val smj1 = find(plan1)(_.isInstanceOf[SortMergeJoinExec]).get.asInstanceOf[SortMergeJoinExec] + assert(!smj1.left.exists(_.isInstanceOf[HashAggregateExec])) + assert(smj1.right.exists(_.isInstanceOf[HashAggregateExec])) + + val plan2 = df2.groupBy($"c2").agg($"c2").join(df1, $"c1" === $"c2", "inner") + .queryExecution.executedPlan + val smj2 = find(plan2)(_.isInstanceOf[SortMergeJoinExec]).get.asInstanceOf[SortMergeJoinExec] + assert(smj2.left.exists(_.isInstanceOf[HashAggregateExec])) + assert(!smj2.right.exists(_.isInstanceOf[HashAggregateExec])) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala index 787fdc7b59d67..77b3dc16a08e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala @@ -149,6 +149,22 @@ object JoinBenchmark extends SqlBasedBenchmark { } } + def sortMergeJoinWithBufferedSideDuplicates(switch: Boolean): Unit = { + val N1 = 2 << 20 + val N2 = 2 << 24 + withSQLConf(SQLConf.SWITCH_SORT_MERGE_JOIN_SIDES_ENABLED.key -> switch.toString) { + codegenBenchmark(s"sort merge join with buffered side duplicates, switched: $switch,", N2) { + val df1 = spark.range(N1).distinct() + .selectExpr(s"id as k1") + val df2 = spark.range(N2) + .selectExpr(s"id % 1000 as k2") + val df = df1.join(df2, col("k1") === col("k2")) + assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[SortMergeJoinExec])) + df.noop() + } + } + } + def shuffleHashJoin(): Unit = { val N: Long = 4 << 20 withSQLConf( @@ -188,6 +204,8 @@ object JoinBenchmark extends SqlBasedBenchmark { broadcastHashJoinSemiJoinLongKey() sortMergeJoin() sortMergeJoinWithDuplicates() + sortMergeJoinWithBufferedSideDuplicates(true) + sortMergeJoinWithBufferedSideDuplicates(false) shuffleHashJoin() broadcastNestedLoopJoin() }