From 02c2969c626caab1a439dfd0c1b5f1950e3de817 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 16 Aug 2021 15:41:07 +0800 Subject: [PATCH 01/12] init drop last str arg format sql in ut support both ShuffleQueryStageExec and BroadcastQueryStageExec as leaves nit reslove conflicts narrow valid operator whitelist move agg stringArgs to subclasses add doc && resolve conflicts del sample node resolve conflicts nit --- .../apache/spark/sql/internal/SQLConf.scala | 19 + .../adaptive/OptimizeSkewedJoin.scala | 485 +++++++++++++----- .../aggregate/BaseAggregateExec.scala | 3 + .../aggregate/HashAggregateExec.scala | 5 +- .../aggregate/ObjectHashAggregateExec.scala | 5 +- .../aggregate/SortAggregateExec.scala | 5 +- .../sql/execution/window/WindowExec.scala | 14 +- .../execution/WholeStageCodegenSuite.scala | 2 +- .../adaptive/AdaptiveQueryExecSuite.scala | 227 ++++++++ 9 files changed, 645 insertions(+), 120 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8876d780799d2..33a2d0d564fbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -704,6 +704,25 @@ object SQLConf { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("256MB") + val SKEW_JOIN_MAX_JOINS = + buildConf("spark.sql.adaptive.skewJoin.maxJoins") + .doc(s"When '${ADAPTIVE_EXECUTION_ENABLED.key}' and '${SKEW_JOIN_ENABLED.key}' " + + s"are true, the max number (inclusive) of shuffled joins in a stage that general " + + s"skew algorithm can handle.") + .version("3.3.0") + .intConf + .checkValue(_ > 0, "The max joins must be positive.") + .createWithDefault(5) + + val SKEW_JOIN_MAX_SPLITS_PER_PARTITION = + buildConf("spark.sql.adaptive.skewJoin.maxSplitsPerPartition") + .doc(s"When '${ADAPTIVE_EXECUTION_ENABLED.key}' and '${SKEW_JOIN_ENABLED.key}' " + + s"are true, the max number (inclusive) of splits from a partition.") + .version("3.3.0") + .intConf + .checkValue(_ >= 10, "The max splits must be no less than 10.") + .createWithDefault(1000) + val NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN = buildConf("spark.sql.adaptive.nonEmptyPartitionRatioForBroadcastJoin") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index d4a173bb9cceb..33f3d194418c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -28,8 +28,11 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ValidateRequirements} -import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BatchEvalPythonExec, MapInPandasExec} +import org.apache.spark.sql.execution.window.{WindowExec, WindowExecBase} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -98,123 +101,280 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) /* * This method aim to optimize the skewed join with the following steps: - * 1. Check whether the shuffle partition is skewed based on the median size - * and the skewed partition threshold in origin shuffled join (smj and shj). - * 2. Assuming partition0 is skewed in left side, and it has 5 mappers (Map0, Map1...Map4). - * And we may split the 5 Mappers into 3 mapper ranges [(Map0, Map1), (Map2, Map3), (Map4)] - * based on the map size and the max split number. - * 3. Wrap the join left child with a special shuffle read that loads each mapper range with one - * task, so total 3 tasks. - * 4. Wrap the join right child with a special shuffle read that loads partition0 3 times by - * 3 tasks separately. + * 0. Collect all ShuffledJoin in this plan. Find the top level ShuffledJoin as the root + * for following steps; + * 1. Check whether this plan satisfy the required pattern of optimization algorithm: + * all the nodes under the top level ShuffledJoin MUST have types in a whitelist including: + * JoinExec/AggExec/WindowExec/SortExec/etc; + * 2. Collect all ShuffleQueryStages under the top level ShuffledJoin; + * 3. Collect all splittable ShuffleQueryStages by the semantics of internal nodes. + * A ShuffleQueryStages is splittable if it can be split into specs, each spec can be + * processed independently, and the original data result can be obtained by union all + * the outputs of specs. + * Splittable ShuffleQueryStages are collected in this way: + * 0, start at the top level ShuffledJoin; + * 1, at a Join node, select the splittable paths according to its JoinType; + * 2, at a Agg/Window node, skip all its descendants; + * 3, all the reached leave are splittable; + * For example, in the following stage, ShuffleQueryStages s6/s7/s8 are splittable. + * cross + * / \ + * agg \ + * / \ + * left cross + * / \ / \ + * inner s3 agg inner + * / \ / / \ + * s0 right inner inner left + * / \ / \ / \ / \ + * s1 s2 s4 s5 s6 s7 s8 s9 + * + * 4. Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec; + * 5. For each splittable ShuffleQueryStageExec, check whether skew partitions exists, if true, + * split them into specs. This step also detects and handles Combinatorial Explosion: for + * each skew partition, check whether the combination number is too large, if so, re-split the + * ShuffleQueryStageExecs. + * For example, for partition 0, stage s6/s7/s8 are split into 100/100/100 specs, + * respectively. Then there are 1M combinations, which is too large, and will cause + * performance regression. Given a threshold (1k by default), the numbers of specs will + * be optimized to 10/10/10. + * 6. Generate final specs. Suppose above splittable ShuffleQueryStages s6/s7/s8 are finally + * split into 2/2/3 specs, then there will be following 2X2X3=12 combinations: + * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec0, s8_spec0, s9] + * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec0, s8_spec1, s9] + * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec0, s8_spec2, s9] + * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec1, s8_spec0, s9] + * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec1, s8_spec1, s9] + * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec1, s8_spec2, s9] + * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec0, s8_spec0, s9] + * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec0, s8_spec1, s9] + * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec0, s8_spec2, s9] + * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec1, s8_spec0, s9] + * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec1, s8_spec1, s9] + * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec1, s8_spec2, s9] + * 7. Generate optimized plan by attaching new specs to ShuffleQueryStageExecs; */ - private def tryOptimizeJoinChildren( - left: ShuffleQueryStageExec, - right: ShuffleQueryStageExec, - joinType: JoinType): Option[(SparkPlan, SparkPlan)] = { - val canSplitLeft = canSplitLeftSide(joinType) - val canSplitRight = canSplitRightSide(joinType) - if (!canSplitLeft && !canSplitRight) return None - - val leftSizes = left.mapStats.get.bytesByPartitionId - val rightSizes = right.mapStats.get.bytesByPartitionId - assert(leftSizes.length == rightSizes.length) - val numPartitions = leftSizes.length - // We use the median size of the original shuffle partitions to detect skewed partitions. - val leftMedSize = Utils.median(leftSizes, false) - val rightMedSize = Utils.median(rightSizes, false) - logDebug( - s""" - |Optimizing skewed join. - |Left side partitions size info: - |${getSizeInfo(leftMedSize, leftSizes)} - |Right side partitions size info: - |${getSizeInfo(rightMedSize, rightSizes)} - """.stripMargin) - - val leftSkewThreshold = getSkewThreshold(leftMedSize) - val rightSkewThreshold = getSkewThreshold(rightMedSize) - val leftTargetSize = targetSize(leftSizes, leftSkewThreshold) - val rightTargetSize = targetSize(rightSizes, rightSkewThreshold) - - val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] - val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] - var numSkewedLeft = 0 - var numSkewedRight = 0 - for (partitionIndex <- 0 until numPartitions) { - val leftSize = leftSizes(partitionIndex) - val isLeftSkew = canSplitLeft && leftSize > leftSkewThreshold - val rightSize = rightSizes(partitionIndex) - val isRightSkew = canSplitRight && rightSize > rightSkewThreshold - val leftNoSkewPartitionSpec = - Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, leftSize)) - val rightNoSkewPartitionSpec = - Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) - - val leftParts = if (isLeftSkew) { - val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( - left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) - if (skewSpecs.isDefined) { - logDebug(s"Left side partition $partitionIndex " + - s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + - s"split it into ${skewSpecs.get.length} parts.") - numSkewedLeft += 1 - } - skewSpecs.getOrElse(leftNoSkewPartitionSpec) - } else { - leftNoSkewPartitionSpec - } + private def optimize(plan: SparkPlan): SparkPlan = { + import OptimizeSkewedJoin._ + val logPrefix = s"Optimizing ${plan.nodeName} #${plan.id}" + + // Step 0: Collect all ShuffledJoins (SMJ/SHJ) + def collectShuffledJoins(plan: SparkPlan): Seq[ShuffledJoin] = plan match { + case join: ShuffledJoin => Seq(join) ++ join.children.flatMap(collectShuffledJoins) + case _ => plan.children.flatMap(collectShuffledJoins) + } + val joins = collectShuffledJoins(plan) + logDebug(s"$logPrefix: ShuffledJoins: ${joins.map(_.nodeName).mkString("[", ", ", "]")}") + if (joins.isEmpty || joins.exists(_.isSkewJoin)) return plan + val topJoin = joins.head + + // Step 1: Validate physical operators + // There are more and more physical operators, this whitelist is for data correctness + // TODO: support more operators like AggregateInPandasExec/FlatMapCoGroupsInPandasExec/etc + val invalidOperators = topJoin.collect { + case _: ShuffleQueryStageExec => None + case _: BroadcastQueryStageExec => None + + case _: SortMergeJoinExec => None + case _: ShuffledHashJoinExec => None + case _: BroadcastHashJoinExec => None + case _: BroadcastNestedLoopJoinExec => None + case _: CartesianProductExec => None + + case _: ObjectHashAggregateExec => None + case _: HashAggregateExec => None + case _: SortAggregateExec => None + + case _: WindowExec => None + + case _: SortExec => None + case _: FilterExec => None + case _: ProjectExec => None + case _: GenerateExec => None + case _: CollectMetricsExec => None + case _: WholeStageCodegenExec => None + + case _: ColumnarToRowExec => None + case _: RowToColumnarExec => None + + case _: DeserializeToObjectExec => None + case _: SerializeFromObjectExec => None + + case _: MapElementsExec => None + case _: MapPartitionsExec => None + case _: MapPartitionsInRWithArrowExec => None + case _: MapInPandasExec => None + case _: ArrowEvalPythonExec => None + case _: BatchEvalPythonExec => None + + case invalid => Some(invalid) + }.flatten + if (invalidOperators.nonEmpty) { + logDebug(s"$logPrefix: Do NOT support operators " + + s"${invalidOperators.map(_.nodeName).mkString("[", ", ", "]")}") + return plan + } + + // Step 2: Collect all ShuffleQueryStages + // TODO: support Bucket Join with other types of leaves. + val leaves = topJoin.collectLeaves() + if (leaves.exists(!_.isInstanceOf[QueryStageExec])) return plan + val stages = leaves.filter(_.isInstanceOf[ShuffleQueryStageExec]) + // for a N-Join stage, there should be N+1 ShuffleQueryStages. + if (stages.size != joins.size + 1) return plan + // stageId -> MapOutputStatistics + val stageStats = stages.flatMap { + case ShuffleStage(stage: ShuffleQueryStageExec) => + stage.mapStats.filter(_.bytesByPartitionId.nonEmpty).map(stats => stage.id -> stats) + case _ => None + }.toMap + if (stageStats.size != joins.size + 1) return plan + val stageIds = stageStats.keysIterator.toArray + logDebug(s"$logPrefix: ShuffleQueryStages: ${stageIds.mkString("[", ", ", "]")}") + val numPartitions = stageStats.head._2.bytesByPartitionId.length + if (stageStats.exists(_._2.bytesByPartitionId.length != numPartitions)) return plan + + // Step 3: Collect all splittable ShuffleQueryStageExecs + // How to determine splittable ShuffleQueryStageExecs: + // 0, start at the top Join node; + // 1, at Join node, select the splittable paths according to its JoinType; + // 2, at Agg/Window node, skip all its descendants; + // 3, all the reached leave are splittable; + def collectSplittableStageIds(plan: SparkPlan): Seq[Int] = plan match { + case stage: ShuffleQueryStageExec => Seq(stage.id) - val rightParts = if (isRightSkew) { - val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( - right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) - if (skewSpecs.isDefined) { - logDebug(s"Right side partition $partitionIndex " + - s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + - s"split it into ${skewSpecs.get.length} parts.") - numSkewedRight += 1 + case join: ShuffledJoin => + var splittableChildren = Seq.empty[SparkPlan] + if (canSplitLeftSide(join.joinType)) splittableChildren :+= join.left + if (canSplitRightSide(join.joinType)) splittableChildren :+= join.right + splittableChildren.flatMap(collectSplittableStageIds) + + case _: BaseAggregateExec => Seq.empty + + case _: WindowExecBase => Seq.empty + + case _ => plan.children.flatMap(collectSplittableStageIds) + } + val splittableStageIds = collectSplittableStageIds(topJoin) + logDebug(s"$logPrefix: Splittable ShuffleQueryStages: " + + s"${splittableStageIds.mkString("[", ", ", "]")}") + if (splittableStageIds.isEmpty || !splittableStageIds.forall(stageStats.contains)) return plan + + // Step 4: Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec + val splittableStageInfos = splittableStageIds.map { stageId => + val sizes = stageStats(stageId).bytesByPartitionId + val medSize = Utils.median(sizes) + val threshold = getSkewThreshold(medSize) + val target = targetSize(sizes, threshold) + logDebug(s"$logPrefix: Optimizing ShuffleQueryStage #$stageId in " + + s"skew join, size info: ${getSizeInfo(medSize, sizes)}") + stageId -> (threshold, target) + }.toMap + + // Step 5: Split skew partitions + // within each partition, find and split the splittable skew ShuffleQueryStageExecs + // (partitionIndex, stageId) -> skew splits + val skewSpecs = mutable.OpenHashMap.empty[(Int, Int), Seq[PartialReducerPartitionSpec]] + val partSpecs = mutable.OpenHashMap.empty[Int, Seq[PartialReducerPartitionSpec]] + val maxCombinations = conf.getConf(SQLConf.SKEW_JOIN_MAX_SPLITS_PER_PARTITION) + + Range(0, numPartitions).foreach { partitionIndex => + partSpecs.clear() + splittableStageInfos.foreach { case (stageId, (threshold, target)) => + val stats = stageStats(stageId) + val size = stats.bytesByPartitionId(partitionIndex) + if (size > threshold) { + ShufflePartitionsUtil + .createSkewPartitionSpecs(stats.shuffleId, partitionIndex, target) + .foreach { splits => + logDebug(s"$logPrefix: Splitting ShuffleQueryStage #$stageId: " + + s"partition $partitionIndex(${FileUtils.byteCountToDisplaySize(size)}) -> " + + s"${splits.size} splits") + partSpecs(stageId) = splits + } } - skewSpecs.getOrElse(rightNoSkewPartitionSpec) - } else { - rightNoSkewPartitionSpec } - for { - leftSidePartition <- leftParts - rightSidePartition <- rightParts - } { - leftSidePartitions += leftSidePartition - rightSidePartitions += rightSidePartition + // Handle Combinatorial Explosion. + val numCombinations = safeProduct(partSpecs.valuesIterator.map(_.size)) + if (numCombinations > maxCombinations) { + val (splitStageIds, numSplits) = partSpecs.mapValues(_.size).toArray.unzip + val combinedNumSplits = combine(maxCombinations, numSplits) + logDebug(s"$logPrefix: partition $partitionIndex: Combinatorial Explosion! " + + s"Try to combine $numCombinations(${numSplits.mkString("[", ", ", "]")}) " + + s"to ${safeProduct(combinedNumSplits)}(${combinedNumSplits.mkString("[", ", ", "]")})") + + partSpecs.clear() + splitStageIds.zip(combinedNumSplits) + .filter(_._2 > 1) + .foreach { case (stageId, newNumSplits) => + val stats = stageStats(stageId) + val size = stats.bytesByPartitionId(partitionIndex) + // TODO: ShufflePartitionsUtil supports target number of specs + // simply adjust the target size to control the number of splits for now + val newTarget = (1.1 * size.toDouble / newNumSplits).toLong + 1L + ShufflePartitionsUtil + .createSkewPartitionSpecs(stats.shuffleId, partitionIndex, newTarget) + .foreach { splits => + logDebug(s"$logPrefix: Re-splitting ShuffleQueryStage #$stageId: " + + s"partition $partitionIndex(${FileUtils.byteCountToDisplaySize(size)}) -> " + + s"${splits.size} splits") + partSpecs(stageId) = splits + } + } } + + partSpecs.foreach { case (stageId, splits) => skewSpecs((partitionIndex, stageId)) = splits } } - logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight") - if (numSkewedLeft > 0 || numSkewedRight > 0) { - Some(( - SkewJoinChildWrapper(AQEShuffleReadExec(left, leftSidePartitions.toSeq)), - SkewJoinChildWrapper(AQEShuffleReadExec(right, rightSidePartitions.toSeq)) - )) - } else { - None + partSpecs.clear() + logDebug(s"$logPrefix: Totally ${skewSpecs.size} skew partitions found") + if (skewSpecs.isEmpty) return plan + + // Step 6: Generate final specs + // within a partition, split the skew ShuffleQueryStageExecs, and duplicate others + def createNonSkewSpec(partitionIndex: Int, stageId: Int) = { + val size = stageStats(stageId).bytesByPartitionId(partitionIndex) + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, size)) } - } - def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp { - case smj @ SortMergeJoinExec(_, _, joinType, _, - s1 @ SortExec(_, _, ShuffleStage(left: ShuffleQueryStageExec), _), - s2 @ SortExec(_, _, ShuffleStage(right: ShuffleQueryStageExec), _), false) => - tryOptimizeJoinChildren(left, right, joinType).map { - case (newLeft, newRight) => - smj.copy( - left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true) - }.getOrElse(smj) - - case shj @ ShuffledHashJoinExec(_, _, joinType, _, _, - ShuffleStage(left: ShuffleQueryStageExec), - ShuffleStage(right: ShuffleQueryStageExec), false) => - tryOptimizeJoinChildren(left, right, joinType).map { - case (newLeft, newRight) => - shj.copy(left = newLeft, right = newRight, isSkewJoin = true) - }.getOrElse(shj) + def traverseCombinations(seqs: Seq[Seq[ShufflePartitionSpec]]) = { + require(seqs.nonEmpty) + val iter = seqs.head.iterator.map(item => Seq(item)) + seqs.tail.foldLeft(iter)((iter, seq) => iter.flatMap(comb => seq.map(item => comb :+ item))) + } + + val buffers = stageIds.map(_ => mutable.ArrayBuffer.empty[ShufflePartitionSpec]) + Range(0, numPartitions).foreach { partitionIndex => + val specs = stageIds.map { stageId => + skewSpecs.getOrElse((partitionIndex, stageId), createNonSkewSpec(partitionIndex, stageId)) + } + traverseCombinations(specs).foreach { combination => + combination.indices.foreach(i => buffers(i) += combination(i)) + } + } + val newSpecs = stageIds.zip(buffers.map(_.toSeq)).toMap + + // Step 7: Generate final plan + // 0, start at the top Join node; + // 1, mark all Join/Agg/Window nodes skew; + // 2, attach new specs to ShuffleQueryStageExecs; + val topJoinId = topJoin.id + plan transform { + case join: ShuffledJoin if join.id == topJoinId => + join transform { + case smj: SortMergeJoinExec => smj.copy(isSkewJoin = true) + case shj: ShuffledHashJoinExec => shj.copy(isSkewJoin = true) + + case obj: ObjectHashAggregateExec => obj.copy(isSkew = true) + case hash: HashAggregateExec => hash.copy(isSkew = true) + case sort: SortAggregateExec => sort.copy(isSkew = true) + + case win: WindowExec => win.copy(isSkew = true) + + case stage: ShuffleQueryStageExec => + SkewJoinChildWrapper(AQEShuffleReadExec(stage, newSpecs(stage.id))) + } + } } override def apply(plan: SparkPlan): SparkPlan = { @@ -222,17 +382,36 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) return plan } - // We try to optimize every skewed sort-merge/shuffle-hash joins in the query plan. If this - // introduces extra shuffles, we give up the optimization and return the original query plan, or - // accept the extra shuffles if the force-apply config is true. - // TODO: It's possible that only one skewed join in the query plan leads to extra shuffles and - // we only need to skip optimizing that join. We should make the strategy smarter here. - val optimized = optimizeSkewJoin(plan) + val shuffledJoins = plan.collect { case s: ShuffledJoin => s } + if (shuffledJoins.isEmpty || shuffledJoins.exists(_.isSkewJoin)) return plan + if (shuffledJoins.size > conf.getConf(SQLConf.SKEW_JOIN_MAX_JOINS)) { + logDebug(s"${shuffledJoins.size} ShuffledJoins in ${plan.nodeName} " + + s"exceeds threshold ${conf.getConf(SQLConf.SKEW_JOIN_MAX_JOINS)}") + return plan + } + + val unions = plan.collect { case u: UnionExec => u } + // there should be at most one UnionExec in one stage, skip here for safety + if (unions.size > 1) return plan + + val optimized = if (unions.size == 1) { + plan transform { + // TODO: if extra shuffle is NOT allowed, only accept children without shuffle. + case u @ UnionExec(children) => u.withNewChildren(children.map(optimize)) + } + } else { + optimize(plan) + } + if (optimized.collect { case s: ShuffledJoin if s.isSkewJoin => s }.isEmpty) return plan + val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) { ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get) } else { ValidateRequirements.validate(optimized) } + // Two cases we will apply the skewed join optimization: + // 1. optimize the skew join without extra shuffle + // 2. optimize the skew join with extra shuffle but the force-apply config is true. if (requirementSatisfied) { optimized.transform { case SkewJoinChildWrapper(child) => child @@ -256,6 +435,82 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) } } + +private[adaptive] object OptimizeSkewedJoin { + + /** + * same as values.product, but make sure NO overflow: + * Iterator(10, 20, 30, 4, 10, 2, 1, 999, 88).product -> -751,912,960 + */ + def safeProduct(values: TraversableOnce[Int]): BigInt = values.foldLeft(BigInt(1))(_ * _) + + /** + * Combine splits to make sure the total number of combinations no greater than given threshold. + * This algorithm iteratively estimates an appropriate shrinkage factor for remaining combinable + * stages (with more than 1 splits), and then perform split merge. Note that it tries to keep the + * proportional relationship among input numbers of splits. + */ + def combine(maxCombinations: Int, numSplits: Array[Int]): Array[Int] = { + require(maxCombinations > 0) + require(numSplits.nonEmpty && numSplits.forall(_ > 0)) + + var numCombinations = safeProduct(numSplits) + if (numCombinations <= maxCombinations) return numSplits + + val intNumSplits = numSplits.clone() + val floatNumSplits = intNumSplits.map(_.toDouble) + var numCombinables = intNumSplits.count(_ > 1) + + val maxShrinkage = 0.999 + val minShrinkage = 0.1 + val maxIterations = 1000 // 20 iterations should be enough in most cases, set 1000 for safety + var iteration = 0 + while (numCombinations > maxCombinations && numCombinables > 0 && iteration < maxIterations) { + var shrinkage = math.pow( + (BigDecimal(maxCombinations) / BigDecimal(numCombinations)).doubleValue, + 1.0 / numCombinables + ) + if (shrinkage.isNaN) { + shrinkage = maxShrinkage + } else { + // clip shrinkage for numeric stability + shrinkage = math.min(shrinkage, maxShrinkage) + shrinkage = math.max(shrinkage, minShrinkage) + } + + floatNumSplits.indices.foreach { i => + floatNumSplits(i) = math.max(1.0, floatNumSplits(i) * shrinkage) + } + + Iterator.tabulate(floatNumSplits.length) { i => + val prevIntNumSplits = intNumSplits(i) + val currIntNumSplits = floatNumSplits(i).round.toInt + (i, prevIntNumSplits, currIntNumSplits) + }.filter { case (_, prevIntNumSplits, currIntNumSplits) => + currIntNumSplits < prevIntNumSplits + }.toArray.sortBy { case (i, prevIntNumSplits, currIntNumSplits) => + // first try small updates to numCombinations + (1.0 - currIntNumSplits.toDouble / prevIntNumSplits, i) + }.foreach { case (i, prevIntNumSplits, currIntNumSplits) => + if (numCombinations > maxCombinations) { + intNumSplits(i) = currIntNumSplits + numCombinations /= prevIntNumSplits + numCombinations *= currIntNumSplits + } + } + + numCombinables = intNumSplits.count(_ > 1) + iteration += 1 + } + + if (numCombinations <= maxCombinations) { + intNumSplits + } else { + Array.fill(numSplits.length)(1) + } + } +} + // After optimizing skew joins, we need to run EnsureRequirements again to add necessary shuffles // caused by skew join optimization. However, this shouldn't apply to the sub-plan under skew join, // as it's guaranteed to satisfy distribution requirement. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index 756b5eb09d0b9..cc4974fe76717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -35,6 +35,8 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning def aggregateAttributes: Seq[Attribute] def initialInputBufferOffset: Int def resultExpressions: Seq[NamedExpression] + def isSkew: Boolean = false + override def nodeName: String = if (isSkew) super.nodeName + "(skew=true)" else super.nodeName override def verboseStringWithOperatorId(): String = { s""" @@ -94,6 +96,7 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { + case _ if isSkew => UnspecifiedDistribution :: Nil case Some(exprs) if exprs.isEmpty => AllTuples :: Nil case Some(exprs) => if (isStreaming) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 935844d96d9a0..8fc913f3fb049 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -53,11 +53,14 @@ case class HashAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan) + child: SparkPlan, + override val isSkew: Boolean = false) extends AggregateCodegenSupport { require(Aggregate.supportsHashAggregate(aggregateBufferAttributes)) + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index e6530e94701f9..416ec5abf0e58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -66,9 +66,12 @@ case class ObjectHashAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan) + child: SparkPlan, + override val isSkew: Boolean = false) extends BaseAggregateExec { + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 3cf63a5318dcf..6f6a3446dc930 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -39,10 +39,13 @@ case class SortAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan) + child: SparkPlan, + override val isSkew: Boolean = false) extends AggregateCodegenSupport with AliasAwareOutputOrdering { + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 33c37e871e385..c5a02005019a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.window import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} /** @@ -87,8 +88,19 @@ case class WindowExec( windowExpression: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], - child: SparkPlan) + child: SparkPlan, + isSkew: Boolean = false) extends WindowExecBase { + override def nodeName: String = if (isSkew) super.nodeName + "(skew=true)" else super.nodeName + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + + override def requiredChildDistribution: Seq[Distribution] = { + if (isSkew) { + UnspecifiedDistribution :: Nil + } else { + super[WindowExecBase].requiredChildDistribution + } + } protected override def doExecute(): RDD[InternalRow] = { // Unwrap the window expressions and window frame factories from the map. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 622cce7e8b3bc..b9c2a1918fe5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -744,7 +744,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert( executedPlan.exists { case WholeStageCodegenExec( - HashAggregateExec(_, _, _, _, _, _, _, _, _: LocalTableScanExec)) => true + HashAggregateExec(_, _, _, _, _, _, _, _, _: LocalTableScanExec, _)) => true case _ => false }, "LocalTableScanExec should be within a WholeStageCodegen domain.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index eff19302c5b20..949b9cfe90b82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -115,6 +115,12 @@ class AdaptiveQueryExecSuite } } + private def findTopLevelShuffledJoin(plan: SparkPlan): Seq[ShuffledJoin] = { + collect(plan) { + case j: ShuffledJoin => j + } + } + private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = { collect(plan) { case j: BaseJoinExec => j @@ -843,6 +849,227 @@ class AdaptiveQueryExecSuite } } + test("SPARK-36638: General Skew Join: 3-table join") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SKEW_JOIN_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "80", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "80") { + withTempView("skewData1", "skewData2", "skewData3") { + spark + .range(0, 10000, 1, 9) + .select( + when('id < 2500, 2499) + .when('id >= 7500, 10000) + .otherwise('id).as("key1"), + 'id as "value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 100000, 1, 7) + .select( + when('id < 5000, 4999) + .when('id >= 10000, 10001) + .otherwise('id).as("key2"), + 'id as "value2") + .createOrReplaceTempView("skewData2") + spark + .range(0, 10000, 1, 5) + .select('id as "key3", 'id as "value3") + .repartition(11) + .createOrReplaceTempView("skewData3") + + val join = + s""" + | SELECT value1, value3, row FROM + | (SELECT key1, value1, row_number() OVER (PARTITION BY key1 ORDER BY value1 DESC) + | AS row FROM skewData1) + | JOIN skewData2 ON key1 = key2 + | LEFT JOIN (SELECT key3, max(value3) AS value3 FROM skewData3 GROUP BY key3) + | ON key1 = key3 + |""".stripMargin + val query = s"SELECT value1, max(row) FROM ($join) GROUP BY value1" + + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val joins = findTopLevelShuffledJoin(adaptivePlan) + assert(joins.size === 2) + assert(joins.forall(_.isSkewJoin)) + } + } + } + + test("SPARK-36638: General Skew Join: 3-table join UNION 2-table join") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SKEW_JOIN_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "80", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "80") { + withTempView("skewData1", "skewData2", "skewData3") { + spark + .range(0, 10000, 1, 9) + .select( + when('id < 2500, 2499) + .when('id >= 7500, 10000) + .otherwise('id).as("key1"), + 'id as "value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 10000, 1, 7) + .select( + when('id < 2500, 2499) + .otherwise('id).as("key2"), + 'id as "value2") + .createOrReplaceTempView("skewData2") + spark + .range(0, 10000, 1, 5) + .select('id as "key3", 'id as "value3") + .createOrReplaceTempView("skewData3") + + Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint => + val union = + s""" + | SELECT value1, value3 FROM + | skewData1 JOIN skewData2 ON key1 = key2 + | LEFT JOIN + | (SELECT key3, max(value3) AS value3 FROM skewData3 GROUP BY key3) ON key1 = key3 + | UNION ALL + | SELECT /*+ $joinHint(skewData1) */ value1, value2 FROM skewData1 + | LEFT JOIN skewData2 ON key1 = key2 + |""".stripMargin + val query = s"SELECT value1, max(value3) FROM ($union) GROUP BY value1" + + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val joins = findTopLevelShuffledJoin(adaptivePlan) + assert(joins.size === 3) + assert(joins.forall(_.isSkewJoin)) + } + } + } + } + + test("SPARK-36638: General Skew Join: 5-table join") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SKEW_JOIN_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "80", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "80") { + withTempView("skewData1", "skewData2", "skewData3", "skewData4", "skewData5") { + spark + .range(0, 10000, 1, 9) + .select( + when('id < 2500, 2499) + .when('id >= 7500, 10000) + .otherwise('id).as("key1"), + 'id as "value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 10000, 1, 7) + .select( + when('id < 2500, 2499) + .otherwise('id).as("key2"), + 'id as "value2") + .createOrReplaceTempView("skewData2") + spark + .range(0, 10000, 1, 5) + .select('id as "key3", 'id as "value3") + .createOrReplaceTempView("skewData3") + spark + .range(0, 10000, 1, 3) + .select( + when('id < 2000, 1999) + .otherwise('id).as("key4"), + 'id as "value4") + .createOrReplaceTempView("skewData4") + spark + .range(0, 10000, 1, 11) + .select( + when('id > 6000, 6000) + .otherwise('id).as("key5"), + 'id as "value5") + .createOrReplaceTempView("skewData5") + + + for (joinType12 <- Seq("LEFT", "RIGHT"); + joinType123 <- Seq("INNER", "CROSS"); + joinType45 <- Seq("RIGHT", "INNER")) { + + val join = + s""" + | SELECT * FROM + | skewData1 $joinType12 JOIN skewData2 ON key1 = key2 + | $joinType123 JOIN skewData3 ON key1 = key3 + | JOIN + | ((SELECT /*+ SHUFFLE_HASH(skewData4) */ key4, max(value4) AS value4 + | FROM skewData4 GROUP BY key4) + | $joinType45 JOIN skewData5 ON key4 = key5) + | ON key1 = key5 + |""".stripMargin + val query = s"SELECT value1, max(value3) FROM ($join) GROUP BY value1" + + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val joins = findTopLevelShuffledJoin(adaptivePlan) + assert(joins.size === 4) + assert(joins.forall(_.isSkewJoin)) + } + } + } + } + + test("SPARK-36638: General Skew Join: combine splits to handle Combinatorial Explosion") { + import OptimizeSkewedJoin.combine + + val array0 = Array(511, 622) // 317,842 splits + assert(combine(1, array0) === Array(1, 1)) // 1 split + assert(combine(10, array0) === Array(3, 3)) // 9 splits + assert(combine(100, array0) === Array(9, 11)) // 99 splits + assert(combine(1000, array0) === Array(29, 34)) // 986 splits + assert(combine(10000, array0) === Array(90, 110)) // 9,900 splits + assert(combine(100000, array0) === Array(286, 349)) // 99,814 splits + + val array1 = Array(1111, 4096, 128, 8) // 364,904,448 splits + assert(combine(1, array1) === Array(1, 1, 1, 1)) // 1 split + assert(combine(10, array1) === Array(2, 5, 1, 1)) // 10 splits + assert(combine(100, array1) === Array(5, 19, 1, 1)) // 95 splits + assert(combine(1000, array1) === Array(12, 45, 1, 1)) // 540 splits + assert(combine(10000, array1) === Array(29, 105, 3, 1)) // 9,135 splits + assert(combine(100000, array1) === Array(61, 226, 7, 1)) // 96,502 splits + + val array2 = Array(77, 99, 77) // 586,971 splits + assert(combine(1, array2) === Array(1, 1, 1)) // 1 split + assert(combine(10, array2) === Array(2, 2, 2)) // 8 splits + assert(combine(100, array2) === Array(4, 5, 4)) // 80 splits + assert(combine(1000, array2) === Array(9, 12, 9)) // 972 splits + assert(combine(10000, array2) === Array(20, 25, 20)) // 10,000 splits + assert(combine(100000, array2) === Array(42, 55, 43)) // 99,330 splits + + val array3 = Array(9999) // 9,999 splits + assert(combine(1, array3) === Array(1)) // 1 split + assert(combine(10, array3) === Array(10)) // 10 splits + assert(combine(100, array3) === Array(100)) // 100 splits + assert(combine(1000, array3) === Array(1000)) // 1000 splits + assert(combine(10000, array3) === Array(9999)) // 9,999 splits + + val array4 = Array(10, 20, 30, 4, 10, 2, 1, 999, 88) // 42,197,760,000 splits + assert(combine(1, array4) === Array(1, 1, 1, 1, 1, 1, 1, 1, 1)) // 1 split + assert(combine(10, array4) === Array(1, 1, 1, 1, 1, 1, 1, 10, 1)) // 10 splits + assert(combine(100, array4) === Array(1, 1, 1, 1, 1, 1, 1, 33, 3)) // 99 splits + assert(combine(1000, array4) === Array(1, 1, 2, 1, 1, 1, 1, 69, 6)) // 828 splits + assert(combine(10000, array4) === Array(1, 2, 4, 1, 1, 1, 1, 119, 10)) // 9,520 splits + assert(combine(100000, array4) === Array(2, 3, 4, 1, 2, 1, 1, 148, 13)) // 92,352 splits + + val array5 = Array.fill(10)(2) // 1,024 splits + assert(combine(1, array5) === Array(1, 1, 1, 1, 1, 1, 1, 1, 1, 1)) // 1 split + assert(combine(10, array5) === Array(1, 1, 1, 1, 1, 1, 1, 2, 2, 2)) // 8 splits + assert(combine(100, array5) === Array(1, 1, 1, 1, 2, 2, 2, 2, 2, 2)) // 64 splits + assert(combine(1000, array5) === Array(1, 2, 2, 2, 2, 2, 2, 2, 2, 2)) // 512 splits + assert(combine(10000, array5) === Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2)) // 1,024 splits + } + test("SPARK-30291: AQE should catch the exceptions when doing materialize") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { From b8f014162e826286f768df61421daa9749b2347d Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 15 Dec 2021 15:40:09 +0800 Subject: [PATCH 02/12] support intermediate skew join --- .../adaptive/OptimizeSkewedJoin.scala | 74 +++++++++---------- 1 file changed, 36 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 33f3d194418c4..151e862da62a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -99,10 +99,15 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) sizes.sum / sizes.length } + private def optimize(plan: SparkPlan): SparkPlan = { + plan transform { + case join: ShuffledJoin if !join.isSkewJoin => optimizeShuffledJoin(join) + } + } + /* * This method aim to optimize the skewed join with the following steps: - * 0. Collect all ShuffledJoin in this plan. Find the top level ShuffledJoin as the root - * for following steps; + * 0. Collect all ShuffledJoin in this plan; * 1. Check whether this plan satisfy the required pattern of optimization algorithm: * all the nodes under the top level ShuffledJoin MUST have types in a whitelist including: * JoinExec/AggExec/WindowExec/SortExec/etc; @@ -154,24 +159,23 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec1, s8_spec2, s9] * 7. Generate optimized plan by attaching new specs to ShuffleQueryStageExecs; */ - private def optimize(plan: SparkPlan): SparkPlan = { + private def optimizeShuffledJoin(join: ShuffledJoin): SparkPlan = { import OptimizeSkewedJoin._ - val logPrefix = s"Optimizing ${plan.nodeName} #${plan.id}" + val logPrefix = s"Optimizing ${join.nodeName} #${join.id}" // Step 0: Collect all ShuffledJoins (SMJ/SHJ) def collectShuffledJoins(plan: SparkPlan): Seq[ShuffledJoin] = plan match { case join: ShuffledJoin => Seq(join) ++ join.children.flatMap(collectShuffledJoins) case _ => plan.children.flatMap(collectShuffledJoins) } - val joins = collectShuffledJoins(plan) + val joins = collectShuffledJoins(join) logDebug(s"$logPrefix: ShuffledJoins: ${joins.map(_.nodeName).mkString("[", ", ", "]")}") - if (joins.isEmpty || joins.exists(_.isSkewJoin)) return plan - val topJoin = joins.head + if (joins.isEmpty || joins.exists(_.isSkewJoin)) return join // Step 1: Validate physical operators // There are more and more physical operators, this whitelist is for data correctness // TODO: support more operators like AggregateInPandasExec/FlatMapCoGroupsInPandasExec/etc - val invalidOperators = topJoin.collect { + val invalidOperators = join.collect { case _: ShuffleQueryStageExec => None case _: BroadcastQueryStageExec => None @@ -212,34 +216,33 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) if (invalidOperators.nonEmpty) { logDebug(s"$logPrefix: Do NOT support operators " + s"${invalidOperators.map(_.nodeName).mkString("[", ", ", "]")}") - return plan + return join } // Step 2: Collect all ShuffleQueryStages // TODO: support Bucket Join with other types of leaves. - val leaves = topJoin.collectLeaves() - if (leaves.exists(!_.isInstanceOf[QueryStageExec])) return plan + val leaves = join.collectLeaves() + if (leaves.exists(!_.isInstanceOf[QueryStageExec])) return join val stages = leaves.filter(_.isInstanceOf[ShuffleQueryStageExec]) // for a N-Join stage, there should be N+1 ShuffleQueryStages. - if (stages.size != joins.size + 1) return plan + if (stages.size != joins.size + 1) return join // stageId -> MapOutputStatistics val stageStats = stages.flatMap { case ShuffleStage(stage: ShuffleQueryStageExec) => stage.mapStats.filter(_.bytesByPartitionId.nonEmpty).map(stats => stage.id -> stats) case _ => None }.toMap - if (stageStats.size != joins.size + 1) return plan + if (stageStats.size != joins.size + 1) return join val stageIds = stageStats.keysIterator.toArray logDebug(s"$logPrefix: ShuffleQueryStages: ${stageIds.mkString("[", ", ", "]")}") val numPartitions = stageStats.head._2.bytesByPartitionId.length - if (stageStats.exists(_._2.bytesByPartitionId.length != numPartitions)) return plan + if (stageStats.exists(_._2.bytesByPartitionId.length != numPartitions)) return join // Step 3: Collect all splittable ShuffleQueryStageExecs // How to determine splittable ShuffleQueryStageExecs: - // 0, start at the top Join node; - // 1, at Join node, select the splittable paths according to its JoinType; - // 2, at Agg/Window node, skip all its descendants; - // 3, all the reached leave are splittable; + // 0, at Join node, select the splittable paths according to its JoinType; + // 1, at Agg/Window node, stop; + // 2, all the reached leave are splittable; def collectSplittableStageIds(plan: SparkPlan): Seq[Int] = plan match { case stage: ShuffleQueryStageExec => Seq(stage.id) @@ -255,10 +258,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) case _ => plan.children.flatMap(collectSplittableStageIds) } - val splittableStageIds = collectSplittableStageIds(topJoin) + val splittableStageIds = collectSplittableStageIds(join) logDebug(s"$logPrefix: Splittable ShuffleQueryStages: " + s"${splittableStageIds.mkString("[", ", ", "]")}") - if (splittableStageIds.isEmpty || !splittableStageIds.forall(stageStats.contains)) return plan + if (splittableStageIds.isEmpty || !splittableStageIds.forall(stageStats.contains)) return join // Step 4: Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec val splittableStageInfos = splittableStageIds.map { stageId => @@ -328,7 +331,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) } partSpecs.clear() logDebug(s"$logPrefix: Totally ${skewSpecs.size} skew partitions found") - if (skewSpecs.isEmpty) return plan + if (skewSpecs.isEmpty) return join // Step 6: Generate final specs // within a partition, split the skew ShuffleQueryStageExecs, and duplicate others @@ -355,25 +358,20 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) val newSpecs = stageIds.zip(buffers.map(_.toSeq)).toMap // Step 7: Generate final plan - // 0, start at the top Join node; - // 1, mark all Join/Agg/Window nodes skew; - // 2, attach new specs to ShuffleQueryStageExecs; - val topJoinId = topJoin.id - plan transform { - case join: ShuffledJoin if join.id == topJoinId => - join transform { - case smj: SortMergeJoinExec => smj.copy(isSkewJoin = true) - case shj: ShuffledHashJoinExec => shj.copy(isSkewJoin = true) + // 0, mark all Join/Agg/Window nodes skew; + // 1, attach new specs to ShuffleQueryStageExecs; + join transform { + case smj: SortMergeJoinExec => smj.copy(isSkewJoin = true) + case shj: ShuffledHashJoinExec => shj.copy(isSkewJoin = true) - case obj: ObjectHashAggregateExec => obj.copy(isSkew = true) - case hash: HashAggregateExec => hash.copy(isSkew = true) - case sort: SortAggregateExec => sort.copy(isSkew = true) + case obj: ObjectHashAggregateExec => obj.copy(isSkew = true) + case hash: HashAggregateExec => hash.copy(isSkew = true) + case sort: SortAggregateExec => sort.copy(isSkew = true) - case win: WindowExec => win.copy(isSkew = true) + case win: WindowExec => win.copy(isSkew = true) - case stage: ShuffleQueryStageExec => - SkewJoinChildWrapper(AQEShuffleReadExec(stage, newSpecs(stage.id))) - } + case stage: ShuffleQueryStageExec => + SkewJoinChildWrapper(AQEShuffleReadExec(stage, newSpecs(stage.id))) } } @@ -402,7 +400,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) } else { optimize(plan) } - if (optimized.collect { case s: ShuffledJoin if s.isSkewJoin => s }.isEmpty) return plan + if (optimized.collectFirst { case s: ShuffledJoin if s.isSkewJoin => s }.isEmpty) return plan val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) { ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get) From 3a602754abc9fda706d2a47db657c3cba0fcaab3 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 22 Dec 2021 11:25:43 +0800 Subject: [PATCH 03/12] simplify collecting nodes --- .../adaptive/OptimizeSkewedJoin.scala | 166 ++++++++---------- 1 file changed, 71 insertions(+), 95 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 151e862da62a0..9ee8b6f577b51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -100,26 +100,21 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) } private def optimize(plan: SparkPlan): SparkPlan = { - plan transform { + plan transformDown { case join: ShuffledJoin if !join.isSkewJoin => optimizeShuffledJoin(join) } } /* * This method aim to optimize the skewed join with the following steps: - * 0. Collect all ShuffledJoin in this plan; - * 1. Check whether this plan satisfy the required pattern of optimization algorithm: - * all the nodes under the top level ShuffledJoin MUST have types in a whitelist including: - * JoinExec/AggExec/WindowExec/SortExec/etc; - * 2. Collect all ShuffleQueryStages under the top level ShuffledJoin; - * 3. Collect all splittable ShuffleQueryStages by the semantics of internal nodes. + * 1. Collect all ShuffledJoins/ShuffleQueryStages, and check valid operators; + * 2. Collect all splittable ShuffleQueryStages by the semantics of internal nodes. * A ShuffleQueryStages is splittable if it can be split into specs, each spec can be * processed independently, and the original data result can be obtained by union all * the outputs of specs. * Splittable ShuffleQueryStages are collected in this way: - * 0, start at the top level ShuffledJoin; - * 1, at a Join node, select the splittable paths according to its JoinType; - * 2, at a Agg/Window node, skip all its descendants; + * 1, at Join node, select the splittable paths according to its JoinType; + * 2, at Agg/Window node, stop; * 3, all the reached leave are splittable; * For example, in the following stage, ShuffleQueryStages s6/s7/s8 are splittable. * cross @@ -134,8 +129,8 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) * / \ / \ / \ / \ * s1 s2 s4 s5 s6 s7 s8 s9 * - * 4. Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec; - * 5. For each splittable ShuffleQueryStageExec, check whether skew partitions exists, if true, + * 3. Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec; + * 4. For each splittable ShuffleQueryStageExec, check whether skew partitions exists, if true, * split them into specs. This step also detects and handles Combinatorial Explosion: for * each skew partition, check whether the combination number is too large, if so, re-split the * ShuffleQueryStageExecs. @@ -143,7 +138,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) * respectively. Then there are 1M combinations, which is too large, and will cause * performance regression. Given a threshold (1k by default), the numbers of specs will * be optimized to 10/10/10. - * 6. Generate final specs. Suppose above splittable ShuffleQueryStages s6/s7/s8 are finally + * 5. Generate final specs. Suppose above splittable ShuffleQueryStages s6/s7/s8 are finally * split into 2/2/3 specs, then there will be following 2X2X3=12 combinations: * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec0, s8_spec0, s9] * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec0, s8_spec1, s9] @@ -157,76 +152,68 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec1, s8_spec0, s9] * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec1, s8_spec1, s9] * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec1, s8_spec2, s9] - * 7. Generate optimized plan by attaching new specs to ShuffleQueryStageExecs; + * 6. Generate optimized plan by attaching new specs to ShuffleQueryStageExecs; */ private def optimizeShuffledJoin(join: ShuffledJoin): SparkPlan = { import OptimizeSkewedJoin._ val logPrefix = s"Optimizing ${join.nodeName} #${join.id}" - // Step 0: Collect all ShuffledJoins (SMJ/SHJ) - def collectShuffledJoins(plan: SparkPlan): Seq[ShuffledJoin] = plan match { - case join: ShuffledJoin => Seq(join) ++ join.children.flatMap(collectShuffledJoins) - case _ => plan.children.flatMap(collectShuffledJoins) + // Step 1: Collect all ShuffledJoins/ShuffleQueryStages, and validate operators. + val joins = mutable.ArrayBuffer.empty[ShuffledJoin] + val stages = mutable.ArrayBuffer.empty[ShuffleQueryStageExec] + val invalids = mutable.ArrayBuffer.empty[SparkPlan] + + join.foreach { + // All leave must be QueryStage for now + // TODO: support Bucket Join with other types of leaves. + case s: ShuffleQueryStageExec if s.isMaterialized => stages.append(s) + case b: BroadcastQueryStageExec if b.isMaterialized => + + case j: SortMergeJoinExec if !j.isSkewJoin => joins.append(j) + case j: ShuffledHashJoinExec if !j.isSkewJoin => joins.append(j) + case _: BroadcastHashJoinExec => + case _: BroadcastNestedLoopJoinExec => + case _: CartesianProductExec => + + case a: ObjectHashAggregateExec if !a.isSkew => + case a: HashAggregateExec if !a.isSkew => + case a: SortAggregateExec if !a.isSkew => + + case w: WindowExec if !w.isSkew => + + case _: SortExec => + case _: FilterExec => + case _: ProjectExec => + case _: GenerateExec => + case _: CollectMetricsExec => + case _: WholeStageCodegenExec => + + case _: ColumnarToRowExec => + case _: RowToColumnarExec => + + case _: DeserializeToObjectExec => + case _: SerializeFromObjectExec => + + case _: MapElementsExec => + case _: MapPartitionsExec => + case _: MapPartitionsInRWithArrowExec => + case _: MapInPandasExec => + case _: ArrowEvalPythonExec => + case _: BatchEvalPythonExec => + + // There are more and more physical operators, this check is for data correctness + // TODO: support more operators like AggregateInPandasExec/FlatMapCoGroupsInPandasExec/etc + case invalid => invalids.append(invalid) } - val joins = collectShuffledJoins(join) - logDebug(s"$logPrefix: ShuffledJoins: ${joins.map(_.nodeName).mkString("[", ", ", "]")}") - if (joins.isEmpty || joins.exists(_.isSkewJoin)) return join - - // Step 1: Validate physical operators - // There are more and more physical operators, this whitelist is for data correctness - // TODO: support more operators like AggregateInPandasExec/FlatMapCoGroupsInPandasExec/etc - val invalidOperators = join.collect { - case _: ShuffleQueryStageExec => None - case _: BroadcastQueryStageExec => None - - case _: SortMergeJoinExec => None - case _: ShuffledHashJoinExec => None - case _: BroadcastHashJoinExec => None - case _: BroadcastNestedLoopJoinExec => None - case _: CartesianProductExec => None - - case _: ObjectHashAggregateExec => None - case _: HashAggregateExec => None - case _: SortAggregateExec => None - - case _: WindowExec => None - - case _: SortExec => None - case _: FilterExec => None - case _: ProjectExec => None - case _: GenerateExec => None - case _: CollectMetricsExec => None - case _: WholeStageCodegenExec => None - - case _: ColumnarToRowExec => None - case _: RowToColumnarExec => None - - case _: DeserializeToObjectExec => None - case _: SerializeFromObjectExec => None - - case _: MapElementsExec => None - case _: MapPartitionsExec => None - case _: MapPartitionsInRWithArrowExec => None - case _: MapInPandasExec => None - case _: ArrowEvalPythonExec => None - case _: BatchEvalPythonExec => None - - case invalid => Some(invalid) - }.flatten - if (invalidOperators.nonEmpty) { + + if (invalids.nonEmpty) { logDebug(s"$logPrefix: Do NOT support operators " + - s"${invalidOperators.map(_.nodeName).mkString("[", ", ", "]")}") + s"${invalids.map(_.nodeName).mkString("[", ", ", "]")}") return join } - - // Step 2: Collect all ShuffleQueryStages - // TODO: support Bucket Join with other types of leaves. - val leaves = join.collectLeaves() - if (leaves.exists(!_.isInstanceOf[QueryStageExec])) return join - val stages = leaves.filter(_.isInstanceOf[ShuffleQueryStageExec]) // for a N-Join stage, there should be N+1 ShuffleQueryStages. - if (stages.size != joins.size + 1) return join - // stageId -> MapOutputStatistics + if (joins.isEmpty || stages.size != joins.size + 1) return join + val stageStats = stages.flatMap { case ShuffleStage(stage: ShuffleQueryStageExec) => stage.mapStats.filter(_.bytesByPartitionId.nonEmpty).map(stats => stage.id -> stats) @@ -238,11 +225,11 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) val numPartitions = stageStats.head._2.bytesByPartitionId.length if (stageStats.exists(_._2.bytesByPartitionId.length != numPartitions)) return join - // Step 3: Collect all splittable ShuffleQueryStageExecs + // Step 2: Collect all splittable ShuffleQueryStageExecs // How to determine splittable ShuffleQueryStageExecs: - // 0, at Join node, select the splittable paths according to its JoinType; - // 1, at Agg/Window node, stop; - // 2, all the reached leave are splittable; + // 1, at Join node, select the splittable paths according to its JoinType; + // 2, at Agg/Window node, stop; + // 3, all the reached leave are splittable; def collectSplittableStageIds(plan: SparkPlan): Seq[Int] = plan match { case stage: ShuffleQueryStageExec => Seq(stage.id) @@ -263,7 +250,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) s"${splittableStageIds.mkString("[", ", ", "]")}") if (splittableStageIds.isEmpty || !splittableStageIds.forall(stageStats.contains)) return join - // Step 4: Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec + // Step 3: Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec val splittableStageInfos = splittableStageIds.map { stageId => val sizes = stageStats(stageId).bytesByPartitionId val medSize = Utils.median(sizes) @@ -274,7 +261,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) stageId -> (threshold, target) }.toMap - // Step 5: Split skew partitions + // Step 4: Split skew partitions // within each partition, find and split the splittable skew ShuffleQueryStageExecs // (partitionIndex, stageId) -> skew splits val skewSpecs = mutable.OpenHashMap.empty[(Int, Int), Seq[PartialReducerPartitionSpec]] @@ -333,7 +320,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) logDebug(s"$logPrefix: Totally ${skewSpecs.size} skew partitions found") if (skewSpecs.isEmpty) return join - // Step 6: Generate final specs + // Step 5: Generate final specs // within a partition, split the skew ShuffleQueryStageExecs, and duplicate others def createNonSkewSpec(partitionIndex: Int, stageId: Int) = { val size = stageStats(stageId).bytesByPartitionId(partitionIndex) @@ -357,9 +344,9 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) } val newSpecs = stageIds.zip(buffers.map(_.toSeq)).toMap - // Step 7: Generate final plan - // 0, mark all Join/Agg/Window nodes skew; - // 1, attach new specs to ShuffleQueryStageExecs; + // Step 6: Generate final plan + // 1, mark all Join/Agg/Window nodes skew; + // 2, attach new specs to ShuffleQueryStageExecs; join transform { case smj: SortMergeJoinExec => smj.copy(isSkewJoin = true) case shj: ShuffledHashJoinExec => shj.copy(isSkewJoin = true) @@ -388,18 +375,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) return plan } - val unions = plan.collect { case u: UnionExec => u } - // there should be at most one UnionExec in one stage, skip here for safety - if (unions.size > 1) return plan - - val optimized = if (unions.size == 1) { - plan transform { - // TODO: if extra shuffle is NOT allowed, only accept children without shuffle. - case u @ UnionExec(children) => u.withNewChildren(children.map(optimize)) - } - } else { - optimize(plan) - } + val optimized = optimize(plan) if (optimized.collectFirst { case s: ShuffledJoin if s.isSkewJoin => s }.isEmpty) return plan val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) { From 2273484e967cc5bb01b0df336beb50da29cc4fc4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 22 Dec 2021 18:19:15 +0800 Subject: [PATCH 04/12] rebase --- .../spark/sql/execution/adaptive/OptimizeSkewedJoin.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 9ee8b6f577b51..f61008fdaedba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -377,15 +377,11 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) val optimized = optimize(plan) if (optimized.collectFirst { case s: ShuffledJoin if s.isSkewJoin => s }.isEmpty) return plan - val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) { ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get) } else { ValidateRequirements.validate(optimized) } - // Two cases we will apply the skewed join optimization: - // 1. optimize the skew join without extra shuffle - // 2. optimize the skew join with extra shuffle but the force-apply config is true. if (requirementSatisfied) { optimized.transform { case SkewJoinChildWrapper(child) => child From fbdb8e1fedacfa77383b0d0bc1339d33f2b78ec9 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 22 Dec 2021 19:03:03 +0800 Subject: [PATCH 05/12] del maxJoins --- .../org/apache/spark/sql/internal/SQLConf.scala | 10 ---------- .../execution/adaptive/OptimizeSkewedJoin.scala | 16 +++++++--------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 33a2d0d564fbf..861beb104a2d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -704,16 +704,6 @@ object SQLConf { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("256MB") - val SKEW_JOIN_MAX_JOINS = - buildConf("spark.sql.adaptive.skewJoin.maxJoins") - .doc(s"When '${ADAPTIVE_EXECUTION_ENABLED.key}' and '${SKEW_JOIN_ENABLED.key}' " + - s"are true, the max number (inclusive) of shuffled joins in a stage that general " + - s"skew algorithm can handle.") - .version("3.3.0") - .intConf - .checkValue(_ > 0, "The max joins must be positive.") - .createWithDefault(5) - val SKEW_JOIN_MAX_SPLITS_PER_PARTITION = buildConf("spark.sql.adaptive.skewJoin.maxSplitsPerPartition") .doc(s"When '${ADAPTIVE_EXECUTION_ENABLED.key}' and '${SKEW_JOIN_ENABLED.key}' " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index f61008fdaedba..1b1deb6940dd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -211,8 +211,9 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) s"${invalids.map(_.nodeName).mkString("[", ", ", "]")}") return join } - // for a N-Join stage, there should be N+1 ShuffleQueryStages. - if (joins.isEmpty || stages.size != joins.size + 1) return join + if (joins.isEmpty || joins.size != stages.size - 1) { + return join + } val stageStats = stages.flatMap { case ShuffleStage(stage: ShuffleQueryStageExec) => @@ -366,17 +367,14 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) if (!conf.getConf(SQLConf.SKEW_JOIN_ENABLED)) { return plan } - - val shuffledJoins = plan.collect { case s: ShuffledJoin => s } - if (shuffledJoins.isEmpty || shuffledJoins.exists(_.isSkewJoin)) return plan - if (shuffledJoins.size > conf.getConf(SQLConf.SKEW_JOIN_MAX_JOINS)) { - logDebug(s"${shuffledJoins.size} ShuffledJoins in ${plan.nodeName} " + - s"exceeds threshold ${conf.getConf(SQLConf.SKEW_JOIN_MAX_JOINS)}") + if (plan.collectFirst { case s: ShuffledJoin if !s.isSkewJoin => s }.isEmpty) { return plan } val optimized = optimize(plan) - if (optimized.collectFirst { case s: ShuffledJoin if s.isSkewJoin => s }.isEmpty) return plan + if (optimized.collectFirst { case s: ShuffledJoin if s.isSkewJoin => s }.isEmpty) { + return plan + } val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) { ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get) } else { From fcb2e1e9c49d736701fd6c416f3170e3f28ba2b4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 22 Dec 2021 19:13:12 +0800 Subject: [PATCH 06/12] nit nit --- .../sql/execution/adaptive/OptimizeSkewedJoin.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 1b1deb6940dd6..abbef54c1d0ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -220,11 +220,13 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) stage.mapStats.filter(_.bytesByPartitionId.nonEmpty).map(stats => stage.id -> stats) case _ => None }.toMap - if (stageStats.size != joins.size + 1) return join + if (stageStats.size != stages.size || + stageStats.values.map(_.bytesByPartitionId.length).toSet.size != 1) { + return join + } val stageIds = stageStats.keysIterator.toArray logDebug(s"$logPrefix: ShuffleQueryStages: ${stageIds.mkString("[", ", ", "]")}") val numPartitions = stageStats.head._2.bytesByPartitionId.length - if (stageStats.exists(_._2.bytesByPartitionId.length != numPartitions)) return join // Step 2: Collect all splittable ShuffleQueryStageExecs // How to determine splittable ShuffleQueryStageExecs: @@ -249,7 +251,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) val splittableStageIds = collectSplittableStageIds(join) logDebug(s"$logPrefix: Splittable ShuffleQueryStages: " + s"${splittableStageIds.mkString("[", ", ", "]")}") - if (splittableStageIds.isEmpty || !splittableStageIds.forall(stageStats.contains)) return join + if (splittableStageIds.isEmpty || + !splittableStageIds.forall(stageStats.contains)) { + return join + } // Step 3: Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec val splittableStageInfos = splittableStageIds.map { stageId => From 733606269e4dba04f480d1f2270e8e01137022bc Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 19 Jan 2022 22:43:08 +0800 Subject: [PATCH 07/12] fix conflict --- .../adaptive/OptimizeSkewedJoin.scala | 51 +++++++++---------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index abbef54c1d0ad..277fe087d63d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -116,42 +116,41 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) * 1, at Join node, select the splittable paths according to its JoinType; * 2, at Agg/Window node, stop; * 3, all the reached leave are splittable; - * For example, in the following stage, ShuffleQueryStages s6/s7/s8 are splittable. - * cross - * / \ - * agg \ - * / \ - * left cross - * / \ / \ - * inner s3 agg inner - * / \ / / \ - * s0 right inner inner left - * / \ / \ / \ / \ - * s1 s2 s4 s5 s6 s7 s8 s9 + * For example, in the following example stage, ShuffleQueryStages s0/s2/s4 are splittable. + * + * inner + * / \ + * cross right + * / \ / \ + * inner s2 s3 inner + * / \ / \ + * s0 agg s4 win + * | | + * s1 s5 * * 3. Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec; * 4. For each splittable ShuffleQueryStageExec, check whether skew partitions exists, if true, * split them into specs. This step also detects and handles Combinatorial Explosion: for * each skew partition, check whether the combination number is too large, if so, re-split the * ShuffleQueryStageExecs. - * For example, for partition 0, stage s6/s7/s8 are split into 100/100/100 specs, + * For example, for partition 0, stage s0/s2/s4 are split into 100/100/100 specs, * respectively. Then there are 1M combinations, which is too large, and will cause * performance regression. Given a threshold (1k by default), the numbers of specs will * be optimized to 10/10/10. - * 5. Generate final specs. Suppose above splittable ShuffleQueryStages s6/s7/s8 are finally + * 5. Generate final specs. Suppose above splittable ShuffleQueryStages s0/s2/s4 are finally * split into 2/2/3 specs, then there will be following 2X2X3=12 combinations: - * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec0, s8_spec0, s9] - * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec0, s8_spec1, s9] - * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec0, s8_spec2, s9] - * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec1, s8_spec0, s9] - * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec1, s8_spec1, s9] - * [s0, s1, s2, s3, s4, s5, s6_spec0, s7_spec1, s8_spec2, s9] - * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec0, s8_spec0, s9] - * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec0, s8_spec1, s9] - * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec0, s8_spec2, s9] - * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec1, s8_spec0, s9] - * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec1, s8_spec1, s9] - * [s0, s1, s2, s3, s4, s5, s6_spec1, s7_spec1, s8_spec2, s9] + * [s0_spec0, s1, s2_spec0, s3, s4_spec0, s5] + * [s0_spec0, s1, s2_spec0, s3, s4_spec1, s5] + * [s0_spec0, s1, s2_spec0, s3, s4_spec2, s5] + * [s0_spec0, s1, s2_spec1, s3, s4_spec0, s5] + * [s0_spec0, s1, s2_spec1, s3, s4_spec1, s5] + * [s0_spec0, s1, s2_spec1, s3, s4_spec2, s5] + * [s0_spec1, s1, s2_spec0, s3, s4_spec0, s5] + * [s0_spec1, s1, s2_spec0, s3, s4_spec1, s5] + * [s0_spec1, s1, s2_spec0, s3, s4_spec2, s5] + * [s0_spec1, s1, s2_spec1, s3, s4_spec0, s5] + * [s0_spec1, s1, s2_spec1, s3, s4_spec1, s5] + * [s0_spec1, s1, s2_spec1, s3, s4_spec2, s5] * 6. Generate optimized plan by attaching new specs to ShuffleQueryStageExecs; */ private def optimizeShuffledJoin(join: ShuffledJoin): SparkPlan = { From 1c580a7532ee10e36f11db263f8f8fdc8764d31f Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 25 Jan 2022 14:27:54 +0800 Subject: [PATCH 08/12] nit --- .../execution/adaptive/OptimizeSkewedJoin.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 277fe087d63d1..da1580faca6ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -99,12 +99,6 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) sizes.sum / sizes.length } - private def optimize(plan: SparkPlan): SparkPlan = { - plan transformDown { - case join: ShuffledJoin if !join.isSkewJoin => optimizeShuffledJoin(join) - } - } - /* * This method aim to optimize the skewed join with the following steps: * 1. Collect all ShuffledJoins/ShuffleQueryStages, and check valid operators; @@ -261,7 +255,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) val medSize = Utils.median(sizes) val threshold = getSkewThreshold(medSize) val target = targetSize(sizes, threshold) - logDebug(s"$logPrefix: Optimizing ShuffleQueryStage #$stageId in " + + logDebug(s"$logPrefix: Analyzing ShuffleQueryStage #$stageId in " + s"skew join, size info: ${getSizeInfo(medSize, sizes)}") stageId -> (threshold, target) }.toMap @@ -367,6 +361,12 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) } } + private def optimize(plan: SparkPlan): SparkPlan = { + plan transformDown { + case join: ShuffledJoin if !join.isSkewJoin => optimizeShuffledJoin(join) + } + } + override def apply(plan: SparkPlan): SparkPlan = { if (!conf.getConf(SQLConf.SKEW_JOIN_ENABLED)) { return plan From b328f2b0c94883fe5011977d52b9b73cbb20891c Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 14 Feb 2022 16:19:28 +0800 Subject: [PATCH 09/12] fix SPARK-37652 case --- .../spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 949b9cfe90b82..70fc8f65f2599 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -2747,11 +2747,10 @@ class AdaptiveQueryExecSuite "UNION ALL SELECT key2 FROM skewData2 GROUP BY key2", 1, 1) // skewJoin1 union (skewJoin2 join aggregate) - // skewJoin2 will lead to extra shuffles, but skew1 cannot be optimized checkSkewJoin( "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 UNION ALL " + "SELECT key1 from (SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2) tmp1 " + - "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 = key2", 3, 0) + "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 = key2", 3, 3) } } } From bb808537ffa51c41ea44f207efb5aff5cb995436 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 14 Mar 2022 18:56:32 +0800 Subject: [PATCH 10/12] resolve conflict --- .../adaptive/OptimizeSkewedJoin.scala | 39 ++++++++----------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index da1580faca6ab..0f9979ad958c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -149,24 +149,21 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) */ private def optimizeShuffledJoin(join: ShuffledJoin): SparkPlan = { import OptimizeSkewedJoin._ - val logPrefix = s"Optimizing ${join.nodeName} #${join.id}" // Step 1: Collect all ShuffledJoins/ShuffleQueryStages, and validate operators. val joins = mutable.ArrayBuffer.empty[ShuffledJoin] val stages = mutable.ArrayBuffer.empty[ShuffleQueryStageExec] - val invalids = mutable.ArrayBuffer.empty[SparkPlan] - join.foreach { + join foreach { // All leave must be QueryStage for now // TODO: support Bucket Join with other types of leaves. case s: ShuffleQueryStageExec if s.isMaterialized => stages.append(s) - case b: BroadcastQueryStageExec if b.isMaterialized => + case _: BroadcastQueryStageExec => case j: SortMergeJoinExec if !j.isSkewJoin => joins.append(j) case j: ShuffledHashJoinExec if !j.isSkewJoin => joins.append(j) case _: BroadcastHashJoinExec => case _: BroadcastNestedLoopJoinExec => - case _: CartesianProductExec => case a: ObjectHashAggregateExec if !a.isSkew => case a: HashAggregateExec if !a.isSkew => @@ -196,13 +193,9 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) // There are more and more physical operators, this check is for data correctness // TODO: support more operators like AggregateInPandasExec/FlatMapCoGroupsInPandasExec/etc - case invalid => invalids.append(invalid) - } - - if (invalids.nonEmpty) { - logDebug(s"$logPrefix: Do NOT support operators " + - s"${invalids.map(_.nodeName).mkString("[", ", ", "]")}") - return join + case invalid => + logDebug(s"Do NOT support operator $invalid") + return join } if (joins.isEmpty || joins.size != stages.size - 1) { return join @@ -218,7 +211,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) return join } val stageIds = stageStats.keysIterator.toArray - logDebug(s"$logPrefix: ShuffleQueryStages: ${stageIds.mkString("[", ", ", "]")}") + logDebug(s"ShuffleQueryStages: ${stageIds.mkString("[", ", ", "]")}") val numPartitions = stageStats.head._2.bytesByPartitionId.length // Step 2: Collect all splittable ShuffleQueryStageExecs @@ -242,8 +235,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) case _ => plan.children.flatMap(collectSplittableStageIds) } val splittableStageIds = collectSplittableStageIds(join) - logDebug(s"$logPrefix: Splittable ShuffleQueryStages: " + - s"${splittableStageIds.mkString("[", ", ", "]")}") + logDebug(s"Splittable ShuffleQueryStages: ${splittableStageIds.mkString("[", ", ", "]")}") if (splittableStageIds.isEmpty || !splittableStageIds.forall(stageStats.contains)) { return join @@ -252,10 +244,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) // Step 3: Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec val splittableStageInfos = splittableStageIds.map { stageId => val sizes = stageStats(stageId).bytesByPartitionId - val medSize = Utils.median(sizes) + val medSize = Utils.median(sizes, false) val threshold = getSkewThreshold(medSize) val target = targetSize(sizes, threshold) - logDebug(s"$logPrefix: Analyzing ShuffleQueryStage #$stageId in " + + logDebug(s"Analyzing ShuffleQueryStage #$stageId in " + s"skew join, size info: ${getSizeInfo(medSize, sizes)}") stageId -> (threshold, target) }.toMap @@ -276,8 +268,8 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) ShufflePartitionsUtil .createSkewPartitionSpecs(stats.shuffleId, partitionIndex, target) .foreach { splits => - logDebug(s"$logPrefix: Splitting ShuffleQueryStage #$stageId: " + - s"partition $partitionIndex(${FileUtils.byteCountToDisplaySize(size)}) -> " + + logDebug(s"Splitting ShuffleQueryStage #$stageId: partition " + + s"$partitionIndex(${FileUtils.byteCountToDisplaySize(size)}) -> " + s"${splits.size} splits") partSpecs(stageId) = splits } @@ -289,7 +281,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) if (numCombinations > maxCombinations) { val (splitStageIds, numSplits) = partSpecs.mapValues(_.size).toArray.unzip val combinedNumSplits = combine(maxCombinations, numSplits) - logDebug(s"$logPrefix: partition $partitionIndex: Combinatorial Explosion! " + + logDebug(s"partition $partitionIndex: Combinatorial Explosion! " + s"Try to combine $numCombinations(${numSplits.mkString("[", ", ", "]")}) " + s"to ${safeProduct(combinedNumSplits)}(${combinedNumSplits.mkString("[", ", ", "]")})") @@ -305,8 +297,8 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) ShufflePartitionsUtil .createSkewPartitionSpecs(stats.shuffleId, partitionIndex, newTarget) .foreach { splits => - logDebug(s"$logPrefix: Re-splitting ShuffleQueryStage #$stageId: " + - s"partition $partitionIndex(${FileUtils.byteCountToDisplaySize(size)}) -> " + + logDebug(s"Re-splitting ShuffleQueryStage #$stageId: partition " + + s"$partitionIndex(${FileUtils.byteCountToDisplaySize(size)}) -> " + s"${splits.size} splits") partSpecs(stageId) = splits } @@ -316,7 +308,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) partSpecs.foreach { case (stageId, splits) => skewSpecs((partitionIndex, stageId)) = splits } } partSpecs.clear() - logDebug(s"$logPrefix: Totally ${skewSpecs.size} skew partitions found") + logDebug(s"Totally ${skewSpecs.size} skew partitions found") if (skewSpecs.isEmpty) return join // Step 5: Generate final specs @@ -361,6 +353,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) } } + // try to optimize at top join in each stage private def optimize(plan: SparkPlan): SparkPlan = { plan transformDown { case join: ShuffledJoin if !join.isSkewJoin => optimizeShuffledJoin(join) From cb493c47e245bbc20d4b29716e74f0f3d95a9073 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 22 Mar 2022 23:53:11 +0800 Subject: [PATCH 11/12] add ut --- .../adaptive/OptimizeSkewedJoin.scala | 1 - .../adaptive/AdaptiveQueryExecSuite.scala | 53 +++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 0f9979ad958c3..3865a36ed6d6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -353,7 +353,6 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) } } - // try to optimize at top join in each stage private def optimize(plan: SparkPlan): SparkPlan = { plan transformDown { case join: ShuffledJoin if !join.isSkewJoin => optimizeShuffledJoin(join) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 70fc8f65f2599..a2634fe22620a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1021,6 +1021,59 @@ class AdaptiveQueryExecSuite } test("SPARK-36638: General Skew Join: combine splits to handle Combinatorial Explosion") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SKEW_JOIN_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 2000, 1, 9) + .select( + when('id < 1000, 999) + .when('id >= 1500, 2000) + .otherwise('id).as("key1"), 'id as "value1") + .createOrReplaceTempView("skewData1") + + spark + .range(0, 2000, 1, 11) + .select( + when('id < 1000, 999) + .when('id >= 1500, 2000) + .otherwise('id).as("key2"), 'id as "value2") + .createOrReplaceTempView("skewData2") + + val query = "SELECT * FROM skewData1 JOIN skewData2 ON key1 = key2" + + withSQLConf(SQLConf.SKEW_JOIN_MAX_SPLITS_PER_PARTITION.key -> "1000") { + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + assert(adaptivePlan.execute().getNumPartitions > 150) + } + + withSQLConf(SQLConf.SKEW_JOIN_MAX_SPLITS_PER_PARTITION.key -> "10") { + /** + * related log: + * Splittable ShuffleQueryStages: [0, 1] + * Splitting ShuffleQueryStage #0: partition 57(6 KB) -> 6 splits + * Splitting ShuffleQueryStage #1: partition 57(6 KB) -> 7 splits + * Re-splitting ShuffleQueryStage #0: partition 57(6 KB) -> 3 splits + * Re-splitting ShuffleQueryStage #1: partition 57(6 KB) -> 3 splits + * Splitting ShuffleQueryStage #0: partition 86(2 KB) -> 4 splits + * Splitting ShuffleQueryStage #1: partition 86(2 KB) -> 4 splits + * Re-splitting ShuffleQueryStage #0: partition 86(2 KB) -> 3 splits + * Re-splitting ShuffleQueryStage #1: partition 86(2 KB) -> 3 splits + */ + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + assert(adaptivePlan.execute().getNumPartitions < 120) + } + } + } + } + + test("SPARK-36638: General Skew Join: combine splits to handle Combinatorial Explosion II") { import OptimizeSkewedJoin.combine val array0 = Array(511, 622) // 317,842 splits From 0d250b004e78bc7d5d07393e8d1b64f41cee974c Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 4 May 2022 23:07:19 +0800 Subject: [PATCH 12/12] rebase, update version, fix ut --- .../org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../execution/adaptive/AdaptiveQueryExecSuite.scala | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 861beb104a2d4..0b4b3f2a3f465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -708,7 +708,7 @@ object SQLConf { buildConf("spark.sql.adaptive.skewJoin.maxSplitsPerPartition") .doc(s"When '${ADAPTIVE_EXECUTION_ENABLED.key}' and '${SKEW_JOIN_ENABLED.key}' " + s"are true, the max number (inclusive) of splits from a partition.") - .version("3.3.0") + .version("3.4.0") .intConf .checkValue(_ >= 10, "The max splits must be no less than 10.") .createWithDefault(1000) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index a2634fe22620a..def5340cb8e72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -849,7 +849,7 @@ class AdaptiveQueryExecSuite } } - test("SPARK-36638: General Skew Join: 3-table join") { + test("SPARK-36638: Generalize OptimizeSkewedJoin - 3-table join with Window and Aggregate") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.SKEW_JOIN_ENABLED.key -> "true", @@ -889,7 +889,7 @@ class AdaptiveQueryExecSuite | LEFT JOIN (SELECT key3, max(value3) AS value3 FROM skewData3 GROUP BY key3) | ON key1 = key3 |""".stripMargin - val query = s"SELECT value1, max(row) FROM ($join) GROUP BY value1" + val query = s"SELECT value1, min(value3), max(row) FROM ($join) GROUP BY value1" val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) val joins = findTopLevelShuffledJoin(adaptivePlan) @@ -899,7 +899,7 @@ class AdaptiveQueryExecSuite } } - test("SPARK-36638: General Skew Join: 3-table join UNION 2-table join") { + test("SPARK-36638: Generalize OptimizeSkewedJoin - 3-table join UNION 2-table join") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.SKEW_JOIN_ENABLED.key -> "true", @@ -950,7 +950,7 @@ class AdaptiveQueryExecSuite } } - test("SPARK-36638: General Skew Join: 5-table join") { + test("SPARK-36638: Generalize OptimizeSkewedJoin - 5-table join") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.SKEW_JOIN_ENABLED.key -> "true", @@ -1020,7 +1020,7 @@ class AdaptiveQueryExecSuite } } - test("SPARK-36638: General Skew Join: combine splits to handle Combinatorial Explosion") { + test("SPARK-36638: Generalize OptimizeSkewedJoin - handle Combinatorial Explosion") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.SKEW_JOIN_ENABLED.key -> "true", @@ -1073,7 +1073,7 @@ class AdaptiveQueryExecSuite } } - test("SPARK-36638: General Skew Join: combine splits to handle Combinatorial Explosion II") { + test("SPARK-36638: Generalize OptimizeSkewedJoin - handle Combinatorial Explosion II") { import OptimizeSkewedJoin.combine val array0 = Array(511, 622) // 317,842 splits