diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8876d780799d2..0b4b3f2a3f465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -704,6 +704,15 @@ object SQLConf { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("256MB") + val SKEW_JOIN_MAX_SPLITS_PER_PARTITION = + buildConf("spark.sql.adaptive.skewJoin.maxSplitsPerPartition") + .doc(s"When '${ADAPTIVE_EXECUTION_ENABLED.key}' and '${SKEW_JOIN_ENABLED.key}' " + + s"are true, the max number (inclusive) of splits from a partition.") + .version("3.4.0") + .intConf + .checkValue(_ >= 10, "The max splits must be no less than 10.") + .createWithDefault(1000) + val NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN = buildConf("spark.sql.adaptive.nonEmptyPartitionRatioForBroadcastJoin") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index d4a173bb9cceb..3865a36ed6d6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -28,8 +28,11 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ValidateRequirements} -import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BatchEvalPythonExec, MapInPandasExec} +import org.apache.spark.sql.execution.window.{WindowExec, WindowExecBase} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -98,136 +101,276 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) /* * This method aim to optimize the skewed join with the following steps: - * 1. Check whether the shuffle partition is skewed based on the median size - * and the skewed partition threshold in origin shuffled join (smj and shj). - * 2. Assuming partition0 is skewed in left side, and it has 5 mappers (Map0, Map1...Map4). - * And we may split the 5 Mappers into 3 mapper ranges [(Map0, Map1), (Map2, Map3), (Map4)] - * based on the map size and the max split number. - * 3. Wrap the join left child with a special shuffle read that loads each mapper range with one - * task, so total 3 tasks. - * 4. Wrap the join right child with a special shuffle read that loads partition0 3 times by - * 3 tasks separately. + * 1. Collect all ShuffledJoins/ShuffleQueryStages, and check valid operators; + * 2. Collect all splittable ShuffleQueryStages by the semantics of internal nodes. + * A ShuffleQueryStages is splittable if it can be split into specs, each spec can be + * processed independently, and the original data result can be obtained by union all + * the outputs of specs. + * Splittable ShuffleQueryStages are collected in this way: + * 1, at Join node, select the splittable paths according to its JoinType; + * 2, at Agg/Window node, stop; + * 3, all the reached leave are splittable; + * For example, in the following example stage, ShuffleQueryStages s0/s2/s4 are splittable. + * + * inner + * / \ + * cross right + * / \ / \ + * inner s2 s3 inner + * / \ / \ + * s0 agg s4 win + * | | + * s1 s5 + * + * 3. Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec; + * 4. For each splittable ShuffleQueryStageExec, check whether skew partitions exists, if true, + * split them into specs. This step also detects and handles Combinatorial Explosion: for + * each skew partition, check whether the combination number is too large, if so, re-split the + * ShuffleQueryStageExecs. + * For example, for partition 0, stage s0/s2/s4 are split into 100/100/100 specs, + * respectively. Then there are 1M combinations, which is too large, and will cause + * performance regression. Given a threshold (1k by default), the numbers of specs will + * be optimized to 10/10/10. + * 5. Generate final specs. Suppose above splittable ShuffleQueryStages s0/s2/s4 are finally + * split into 2/2/3 specs, then there will be following 2X2X3=12 combinations: + * [s0_spec0, s1, s2_spec0, s3, s4_spec0, s5] + * [s0_spec0, s1, s2_spec0, s3, s4_spec1, s5] + * [s0_spec0, s1, s2_spec0, s3, s4_spec2, s5] + * [s0_spec0, s1, s2_spec1, s3, s4_spec0, s5] + * [s0_spec0, s1, s2_spec1, s3, s4_spec1, s5] + * [s0_spec0, s1, s2_spec1, s3, s4_spec2, s5] + * [s0_spec1, s1, s2_spec0, s3, s4_spec0, s5] + * [s0_spec1, s1, s2_spec0, s3, s4_spec1, s5] + * [s0_spec1, s1, s2_spec0, s3, s4_spec2, s5] + * [s0_spec1, s1, s2_spec1, s3, s4_spec0, s5] + * [s0_spec1, s1, s2_spec1, s3, s4_spec1, s5] + * [s0_spec1, s1, s2_spec1, s3, s4_spec2, s5] + * 6. Generate optimized plan by attaching new specs to ShuffleQueryStageExecs; */ - private def tryOptimizeJoinChildren( - left: ShuffleQueryStageExec, - right: ShuffleQueryStageExec, - joinType: JoinType): Option[(SparkPlan, SparkPlan)] = { - val canSplitLeft = canSplitLeftSide(joinType) - val canSplitRight = canSplitRightSide(joinType) - if (!canSplitLeft && !canSplitRight) return None - - val leftSizes = left.mapStats.get.bytesByPartitionId - val rightSizes = right.mapStats.get.bytesByPartitionId - assert(leftSizes.length == rightSizes.length) - val numPartitions = leftSizes.length - // We use the median size of the original shuffle partitions to detect skewed partitions. - val leftMedSize = Utils.median(leftSizes, false) - val rightMedSize = Utils.median(rightSizes, false) - logDebug( - s""" - |Optimizing skewed join. - |Left side partitions size info: - |${getSizeInfo(leftMedSize, leftSizes)} - |Right side partitions size info: - |${getSizeInfo(rightMedSize, rightSizes)} - """.stripMargin) - - val leftSkewThreshold = getSkewThreshold(leftMedSize) - val rightSkewThreshold = getSkewThreshold(rightMedSize) - val leftTargetSize = targetSize(leftSizes, leftSkewThreshold) - val rightTargetSize = targetSize(rightSizes, rightSkewThreshold) - - val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] - val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] - var numSkewedLeft = 0 - var numSkewedRight = 0 - for (partitionIndex <- 0 until numPartitions) { - val leftSize = leftSizes(partitionIndex) - val isLeftSkew = canSplitLeft && leftSize > leftSkewThreshold - val rightSize = rightSizes(partitionIndex) - val isRightSkew = canSplitRight && rightSize > rightSkewThreshold - val leftNoSkewPartitionSpec = - Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, leftSize)) - val rightNoSkewPartitionSpec = - Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) - - val leftParts = if (isLeftSkew) { - val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( - left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) - if (skewSpecs.isDefined) { - logDebug(s"Left side partition $partitionIndex " + - s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + - s"split it into ${skewSpecs.get.length} parts.") - numSkewedLeft += 1 + private def optimizeShuffledJoin(join: ShuffledJoin): SparkPlan = { + import OptimizeSkewedJoin._ + + // Step 1: Collect all ShuffledJoins/ShuffleQueryStages, and validate operators. + val joins = mutable.ArrayBuffer.empty[ShuffledJoin] + val stages = mutable.ArrayBuffer.empty[ShuffleQueryStageExec] + + join foreach { + // All leave must be QueryStage for now + // TODO: support Bucket Join with other types of leaves. + case s: ShuffleQueryStageExec if s.isMaterialized => stages.append(s) + case _: BroadcastQueryStageExec => + + case j: SortMergeJoinExec if !j.isSkewJoin => joins.append(j) + case j: ShuffledHashJoinExec if !j.isSkewJoin => joins.append(j) + case _: BroadcastHashJoinExec => + case _: BroadcastNestedLoopJoinExec => + + case a: ObjectHashAggregateExec if !a.isSkew => + case a: HashAggregateExec if !a.isSkew => + case a: SortAggregateExec if !a.isSkew => + + case w: WindowExec if !w.isSkew => + + case _: SortExec => + case _: FilterExec => + case _: ProjectExec => + case _: GenerateExec => + case _: CollectMetricsExec => + case _: WholeStageCodegenExec => + + case _: ColumnarToRowExec => + case _: RowToColumnarExec => + + case _: DeserializeToObjectExec => + case _: SerializeFromObjectExec => + + case _: MapElementsExec => + case _: MapPartitionsExec => + case _: MapPartitionsInRWithArrowExec => + case _: MapInPandasExec => + case _: ArrowEvalPythonExec => + case _: BatchEvalPythonExec => + + // There are more and more physical operators, this check is for data correctness + // TODO: support more operators like AggregateInPandasExec/FlatMapCoGroupsInPandasExec/etc + case invalid => + logDebug(s"Do NOT support operator $invalid") + return join + } + if (joins.isEmpty || joins.size != stages.size - 1) { + return join + } + + val stageStats = stages.flatMap { + case ShuffleStage(stage: ShuffleQueryStageExec) => + stage.mapStats.filter(_.bytesByPartitionId.nonEmpty).map(stats => stage.id -> stats) + case _ => None + }.toMap + if (stageStats.size != stages.size || + stageStats.values.map(_.bytesByPartitionId.length).toSet.size != 1) { + return join + } + val stageIds = stageStats.keysIterator.toArray + logDebug(s"ShuffleQueryStages: ${stageIds.mkString("[", ", ", "]")}") + val numPartitions = stageStats.head._2.bytesByPartitionId.length + + // Step 2: Collect all splittable ShuffleQueryStageExecs + // How to determine splittable ShuffleQueryStageExecs: + // 1, at Join node, select the splittable paths according to its JoinType; + // 2, at Agg/Window node, stop; + // 3, all the reached leave are splittable; + def collectSplittableStageIds(plan: SparkPlan): Seq[Int] = plan match { + case stage: ShuffleQueryStageExec => Seq(stage.id) + + case join: ShuffledJoin => + var splittableChildren = Seq.empty[SparkPlan] + if (canSplitLeftSide(join.joinType)) splittableChildren :+= join.left + if (canSplitRightSide(join.joinType)) splittableChildren :+= join.right + splittableChildren.flatMap(collectSplittableStageIds) + + case _: BaseAggregateExec => Seq.empty + + case _: WindowExecBase => Seq.empty + + case _ => plan.children.flatMap(collectSplittableStageIds) + } + val splittableStageIds = collectSplittableStageIds(join) + logDebug(s"Splittable ShuffleQueryStages: ${splittableStageIds.mkString("[", ", ", "]")}") + if (splittableStageIds.isEmpty || + !splittableStageIds.forall(stageStats.contains)) { + return join + } + + // Step 3: Precompute skewThreshold and targetSize for each splittable ShuffleQueryStageExec + val splittableStageInfos = splittableStageIds.map { stageId => + val sizes = stageStats(stageId).bytesByPartitionId + val medSize = Utils.median(sizes, false) + val threshold = getSkewThreshold(medSize) + val target = targetSize(sizes, threshold) + logDebug(s"Analyzing ShuffleQueryStage #$stageId in " + + s"skew join, size info: ${getSizeInfo(medSize, sizes)}") + stageId -> (threshold, target) + }.toMap + + // Step 4: Split skew partitions + // within each partition, find and split the splittable skew ShuffleQueryStageExecs + // (partitionIndex, stageId) -> skew splits + val skewSpecs = mutable.OpenHashMap.empty[(Int, Int), Seq[PartialReducerPartitionSpec]] + val partSpecs = mutable.OpenHashMap.empty[Int, Seq[PartialReducerPartitionSpec]] + val maxCombinations = conf.getConf(SQLConf.SKEW_JOIN_MAX_SPLITS_PER_PARTITION) + + Range(0, numPartitions).foreach { partitionIndex => + partSpecs.clear() + splittableStageInfos.foreach { case (stageId, (threshold, target)) => + val stats = stageStats(stageId) + val size = stats.bytesByPartitionId(partitionIndex) + if (size > threshold) { + ShufflePartitionsUtil + .createSkewPartitionSpecs(stats.shuffleId, partitionIndex, target) + .foreach { splits => + logDebug(s"Splitting ShuffleQueryStage #$stageId: partition " + + s"$partitionIndex(${FileUtils.byteCountToDisplaySize(size)}) -> " + + s"${splits.size} splits") + partSpecs(stageId) = splits + } } - skewSpecs.getOrElse(leftNoSkewPartitionSpec) - } else { - leftNoSkewPartitionSpec } - val rightParts = if (isRightSkew) { - val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( - right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) - if (skewSpecs.isDefined) { - logDebug(s"Right side partition $partitionIndex " + - s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + - s"split it into ${skewSpecs.get.length} parts.") - numSkewedRight += 1 - } - skewSpecs.getOrElse(rightNoSkewPartitionSpec) - } else { - rightNoSkewPartitionSpec + // Handle Combinatorial Explosion. + val numCombinations = safeProduct(partSpecs.valuesIterator.map(_.size)) + if (numCombinations > maxCombinations) { + val (splitStageIds, numSplits) = partSpecs.mapValues(_.size).toArray.unzip + val combinedNumSplits = combine(maxCombinations, numSplits) + logDebug(s"partition $partitionIndex: Combinatorial Explosion! " + + s"Try to combine $numCombinations(${numSplits.mkString("[", ", ", "]")}) " + + s"to ${safeProduct(combinedNumSplits)}(${combinedNumSplits.mkString("[", ", ", "]")})") + + partSpecs.clear() + splitStageIds.zip(combinedNumSplits) + .filter(_._2 > 1) + .foreach { case (stageId, newNumSplits) => + val stats = stageStats(stageId) + val size = stats.bytesByPartitionId(partitionIndex) + // TODO: ShufflePartitionsUtil supports target number of specs + // simply adjust the target size to control the number of splits for now + val newTarget = (1.1 * size.toDouble / newNumSplits).toLong + 1L + ShufflePartitionsUtil + .createSkewPartitionSpecs(stats.shuffleId, partitionIndex, newTarget) + .foreach { splits => + logDebug(s"Re-splitting ShuffleQueryStage #$stageId: partition " + + s"$partitionIndex(${FileUtils.byteCountToDisplaySize(size)}) -> " + + s"${splits.size} splits") + partSpecs(stageId) = splits + } + } } - for { - leftSidePartition <- leftParts - rightSidePartition <- rightParts - } { - leftSidePartitions += leftSidePartition - rightSidePartitions += rightSidePartition + partSpecs.foreach { case (stageId, splits) => skewSpecs((partitionIndex, stageId)) = splits } + } + partSpecs.clear() + logDebug(s"Totally ${skewSpecs.size} skew partitions found") + if (skewSpecs.isEmpty) return join + + // Step 5: Generate final specs + // within a partition, split the skew ShuffleQueryStageExecs, and duplicate others + def createNonSkewSpec(partitionIndex: Int, stageId: Int) = { + val size = stageStats(stageId).bytesByPartitionId(partitionIndex) + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, size)) + } + + def traverseCombinations(seqs: Seq[Seq[ShufflePartitionSpec]]) = { + require(seqs.nonEmpty) + val iter = seqs.head.iterator.map(item => Seq(item)) + seqs.tail.foldLeft(iter)((iter, seq) => iter.flatMap(comb => seq.map(item => comb :+ item))) + } + + val buffers = stageIds.map(_ => mutable.ArrayBuffer.empty[ShufflePartitionSpec]) + Range(0, numPartitions).foreach { partitionIndex => + val specs = stageIds.map { stageId => + skewSpecs.getOrElse((partitionIndex, stageId), createNonSkewSpec(partitionIndex, stageId)) + } + traverseCombinations(specs).foreach { combination => + combination.indices.foreach(i => buffers(i) += combination(i)) } } - logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight") - if (numSkewedLeft > 0 || numSkewedRight > 0) { - Some(( - SkewJoinChildWrapper(AQEShuffleReadExec(left, leftSidePartitions.toSeq)), - SkewJoinChildWrapper(AQEShuffleReadExec(right, rightSidePartitions.toSeq)) - )) - } else { - None + val newSpecs = stageIds.zip(buffers.map(_.toSeq)).toMap + + // Step 6: Generate final plan + // 1, mark all Join/Agg/Window nodes skew; + // 2, attach new specs to ShuffleQueryStageExecs; + join transform { + case smj: SortMergeJoinExec => smj.copy(isSkewJoin = true) + case shj: ShuffledHashJoinExec => shj.copy(isSkewJoin = true) + + case obj: ObjectHashAggregateExec => obj.copy(isSkew = true) + case hash: HashAggregateExec => hash.copy(isSkew = true) + case sort: SortAggregateExec => sort.copy(isSkew = true) + + case win: WindowExec => win.copy(isSkew = true) + + case stage: ShuffleQueryStageExec => + SkewJoinChildWrapper(AQEShuffleReadExec(stage, newSpecs(stage.id))) } } - def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp { - case smj @ SortMergeJoinExec(_, _, joinType, _, - s1 @ SortExec(_, _, ShuffleStage(left: ShuffleQueryStageExec), _), - s2 @ SortExec(_, _, ShuffleStage(right: ShuffleQueryStageExec), _), false) => - tryOptimizeJoinChildren(left, right, joinType).map { - case (newLeft, newRight) => - smj.copy( - left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true) - }.getOrElse(smj) - - case shj @ ShuffledHashJoinExec(_, _, joinType, _, _, - ShuffleStage(left: ShuffleQueryStageExec), - ShuffleStage(right: ShuffleQueryStageExec), false) => - tryOptimizeJoinChildren(left, right, joinType).map { - case (newLeft, newRight) => - shj.copy(left = newLeft, right = newRight, isSkewJoin = true) - }.getOrElse(shj) + private def optimize(plan: SparkPlan): SparkPlan = { + plan transformDown { + case join: ShuffledJoin if !join.isSkewJoin => optimizeShuffledJoin(join) + } } override def apply(plan: SparkPlan): SparkPlan = { if (!conf.getConf(SQLConf.SKEW_JOIN_ENABLED)) { return plan } + if (plan.collectFirst { case s: ShuffledJoin if !s.isSkewJoin => s }.isEmpty) { + return plan + } - // We try to optimize every skewed sort-merge/shuffle-hash joins in the query plan. If this - // introduces extra shuffles, we give up the optimization and return the original query plan, or - // accept the extra shuffles if the force-apply config is true. - // TODO: It's possible that only one skewed join in the query plan leads to extra shuffles and - // we only need to skip optimizing that join. We should make the strategy smarter here. - val optimized = optimizeSkewJoin(plan) + val optimized = optimize(plan) + if (optimized.collectFirst { case s: ShuffledJoin if s.isSkewJoin => s }.isEmpty) { + return plan + } val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) { ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get) } else { @@ -256,6 +399,82 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) } } + +private[adaptive] object OptimizeSkewedJoin { + + /** + * same as values.product, but make sure NO overflow: + * Iterator(10, 20, 30, 4, 10, 2, 1, 999, 88).product -> -751,912,960 + */ + def safeProduct(values: TraversableOnce[Int]): BigInt = values.foldLeft(BigInt(1))(_ * _) + + /** + * Combine splits to make sure the total number of combinations no greater than given threshold. + * This algorithm iteratively estimates an appropriate shrinkage factor for remaining combinable + * stages (with more than 1 splits), and then perform split merge. Note that it tries to keep the + * proportional relationship among input numbers of splits. + */ + def combine(maxCombinations: Int, numSplits: Array[Int]): Array[Int] = { + require(maxCombinations > 0) + require(numSplits.nonEmpty && numSplits.forall(_ > 0)) + + var numCombinations = safeProduct(numSplits) + if (numCombinations <= maxCombinations) return numSplits + + val intNumSplits = numSplits.clone() + val floatNumSplits = intNumSplits.map(_.toDouble) + var numCombinables = intNumSplits.count(_ > 1) + + val maxShrinkage = 0.999 + val minShrinkage = 0.1 + val maxIterations = 1000 // 20 iterations should be enough in most cases, set 1000 for safety + var iteration = 0 + while (numCombinations > maxCombinations && numCombinables > 0 && iteration < maxIterations) { + var shrinkage = math.pow( + (BigDecimal(maxCombinations) / BigDecimal(numCombinations)).doubleValue, + 1.0 / numCombinables + ) + if (shrinkage.isNaN) { + shrinkage = maxShrinkage + } else { + // clip shrinkage for numeric stability + shrinkage = math.min(shrinkage, maxShrinkage) + shrinkage = math.max(shrinkage, minShrinkage) + } + + floatNumSplits.indices.foreach { i => + floatNumSplits(i) = math.max(1.0, floatNumSplits(i) * shrinkage) + } + + Iterator.tabulate(floatNumSplits.length) { i => + val prevIntNumSplits = intNumSplits(i) + val currIntNumSplits = floatNumSplits(i).round.toInt + (i, prevIntNumSplits, currIntNumSplits) + }.filter { case (_, prevIntNumSplits, currIntNumSplits) => + currIntNumSplits < prevIntNumSplits + }.toArray.sortBy { case (i, prevIntNumSplits, currIntNumSplits) => + // first try small updates to numCombinations + (1.0 - currIntNumSplits.toDouble / prevIntNumSplits, i) + }.foreach { case (i, prevIntNumSplits, currIntNumSplits) => + if (numCombinations > maxCombinations) { + intNumSplits(i) = currIntNumSplits + numCombinations /= prevIntNumSplits + numCombinations *= currIntNumSplits + } + } + + numCombinables = intNumSplits.count(_ > 1) + iteration += 1 + } + + if (numCombinations <= maxCombinations) { + intNumSplits + } else { + Array.fill(numSplits.length)(1) + } + } +} + // After optimizing skew joins, we need to run EnsureRequirements again to add necessary shuffles // caused by skew join optimization. However, this shouldn't apply to the sub-plan under skew join, // as it's guaranteed to satisfy distribution requirement. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index 756b5eb09d0b9..cc4974fe76717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -35,6 +35,8 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning def aggregateAttributes: Seq[Attribute] def initialInputBufferOffset: Int def resultExpressions: Seq[NamedExpression] + def isSkew: Boolean = false + override def nodeName: String = if (isSkew) super.nodeName + "(skew=true)" else super.nodeName override def verboseStringWithOperatorId(): String = { s""" @@ -94,6 +96,7 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { + case _ if isSkew => UnspecifiedDistribution :: Nil case Some(exprs) if exprs.isEmpty => AllTuples :: Nil case Some(exprs) => if (isStreaming) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 935844d96d9a0..8fc913f3fb049 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -53,11 +53,14 @@ case class HashAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan) + child: SparkPlan, + override val isSkew: Boolean = false) extends AggregateCodegenSupport { require(Aggregate.supportsHashAggregate(aggregateBufferAttributes)) + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index e6530e94701f9..416ec5abf0e58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -66,9 +66,12 @@ case class ObjectHashAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan) + child: SparkPlan, + override val isSkew: Boolean = false) extends BaseAggregateExec { + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 3cf63a5318dcf..6f6a3446dc930 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -39,10 +39,13 @@ case class SortAggregateExec( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - child: SparkPlan) + child: SparkPlan, + override val isSkew: Boolean = false) extends AggregateCodegenSupport with AliasAwareOutputOrdering { + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 33c37e871e385..c5a02005019a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.window import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} /** @@ -87,8 +88,19 @@ case class WindowExec( windowExpression: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], - child: SparkPlan) + child: SparkPlan, + isSkew: Boolean = false) extends WindowExecBase { + override def nodeName: String = if (isSkew) super.nodeName + "(skew=true)" else super.nodeName + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + + override def requiredChildDistribution: Seq[Distribution] = { + if (isSkew) { + UnspecifiedDistribution :: Nil + } else { + super[WindowExecBase].requiredChildDistribution + } + } protected override def doExecute(): RDD[InternalRow] = { // Unwrap the window expressions and window frame factories from the map. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 622cce7e8b3bc..b9c2a1918fe5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -744,7 +744,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert( executedPlan.exists { case WholeStageCodegenExec( - HashAggregateExec(_, _, _, _, _, _, _, _, _: LocalTableScanExec)) => true + HashAggregateExec(_, _, _, _, _, _, _, _, _: LocalTableScanExec, _)) => true case _ => false }, "LocalTableScanExec should be within a WholeStageCodegen domain.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index eff19302c5b20..def5340cb8e72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -115,6 +115,12 @@ class AdaptiveQueryExecSuite } } + private def findTopLevelShuffledJoin(plan: SparkPlan): Seq[ShuffledJoin] = { + collect(plan) { + case j: ShuffledJoin => j + } + } + private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = { collect(plan) { case j: BaseJoinExec => j @@ -843,6 +849,280 @@ class AdaptiveQueryExecSuite } } + test("SPARK-36638: Generalize OptimizeSkewedJoin - 3-table join with Window and Aggregate") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SKEW_JOIN_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "80", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "80") { + withTempView("skewData1", "skewData2", "skewData3") { + spark + .range(0, 10000, 1, 9) + .select( + when('id < 2500, 2499) + .when('id >= 7500, 10000) + .otherwise('id).as("key1"), + 'id as "value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 100000, 1, 7) + .select( + when('id < 5000, 4999) + .when('id >= 10000, 10001) + .otherwise('id).as("key2"), + 'id as "value2") + .createOrReplaceTempView("skewData2") + spark + .range(0, 10000, 1, 5) + .select('id as "key3", 'id as "value3") + .repartition(11) + .createOrReplaceTempView("skewData3") + + val join = + s""" + | SELECT value1, value3, row FROM + | (SELECT key1, value1, row_number() OVER (PARTITION BY key1 ORDER BY value1 DESC) + | AS row FROM skewData1) + | JOIN skewData2 ON key1 = key2 + | LEFT JOIN (SELECT key3, max(value3) AS value3 FROM skewData3 GROUP BY key3) + | ON key1 = key3 + |""".stripMargin + val query = s"SELECT value1, min(value3), max(row) FROM ($join) GROUP BY value1" + + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val joins = findTopLevelShuffledJoin(adaptivePlan) + assert(joins.size === 2) + assert(joins.forall(_.isSkewJoin)) + } + } + } + + test("SPARK-36638: Generalize OptimizeSkewedJoin - 3-table join UNION 2-table join") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SKEW_JOIN_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "80", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "80") { + withTempView("skewData1", "skewData2", "skewData3") { + spark + .range(0, 10000, 1, 9) + .select( + when('id < 2500, 2499) + .when('id >= 7500, 10000) + .otherwise('id).as("key1"), + 'id as "value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 10000, 1, 7) + .select( + when('id < 2500, 2499) + .otherwise('id).as("key2"), + 'id as "value2") + .createOrReplaceTempView("skewData2") + spark + .range(0, 10000, 1, 5) + .select('id as "key3", 'id as "value3") + .createOrReplaceTempView("skewData3") + + Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint => + val union = + s""" + | SELECT value1, value3 FROM + | skewData1 JOIN skewData2 ON key1 = key2 + | LEFT JOIN + | (SELECT key3, max(value3) AS value3 FROM skewData3 GROUP BY key3) ON key1 = key3 + | UNION ALL + | SELECT /*+ $joinHint(skewData1) */ value1, value2 FROM skewData1 + | LEFT JOIN skewData2 ON key1 = key2 + |""".stripMargin + val query = s"SELECT value1, max(value3) FROM ($union) GROUP BY value1" + + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val joins = findTopLevelShuffledJoin(adaptivePlan) + assert(joins.size === 3) + assert(joins.forall(_.isSkewJoin)) + } + } + } + } + + test("SPARK-36638: Generalize OptimizeSkewedJoin - 5-table join") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SKEW_JOIN_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "80", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "80") { + withTempView("skewData1", "skewData2", "skewData3", "skewData4", "skewData5") { + spark + .range(0, 10000, 1, 9) + .select( + when('id < 2500, 2499) + .when('id >= 7500, 10000) + .otherwise('id).as("key1"), + 'id as "value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 10000, 1, 7) + .select( + when('id < 2500, 2499) + .otherwise('id).as("key2"), + 'id as "value2") + .createOrReplaceTempView("skewData2") + spark + .range(0, 10000, 1, 5) + .select('id as "key3", 'id as "value3") + .createOrReplaceTempView("skewData3") + spark + .range(0, 10000, 1, 3) + .select( + when('id < 2000, 1999) + .otherwise('id).as("key4"), + 'id as "value4") + .createOrReplaceTempView("skewData4") + spark + .range(0, 10000, 1, 11) + .select( + when('id > 6000, 6000) + .otherwise('id).as("key5"), + 'id as "value5") + .createOrReplaceTempView("skewData5") + + + for (joinType12 <- Seq("LEFT", "RIGHT"); + joinType123 <- Seq("INNER", "CROSS"); + joinType45 <- Seq("RIGHT", "INNER")) { + + val join = + s""" + | SELECT * FROM + | skewData1 $joinType12 JOIN skewData2 ON key1 = key2 + | $joinType123 JOIN skewData3 ON key1 = key3 + | JOIN + | ((SELECT /*+ SHUFFLE_HASH(skewData4) */ key4, max(value4) AS value4 + | FROM skewData4 GROUP BY key4) + | $joinType45 JOIN skewData5 ON key4 = key5) + | ON key1 = key5 + |""".stripMargin + val query = s"SELECT value1, max(value3) FROM ($join) GROUP BY value1" + + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val joins = findTopLevelShuffledJoin(adaptivePlan) + assert(joins.size === 4) + assert(joins.forall(_.isSkewJoin)) + } + } + } + } + + test("SPARK-36638: Generalize OptimizeSkewedJoin - handle Combinatorial Explosion") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SKEW_JOIN_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 2000, 1, 9) + .select( + when('id < 1000, 999) + .when('id >= 1500, 2000) + .otherwise('id).as("key1"), 'id as "value1") + .createOrReplaceTempView("skewData1") + + spark + .range(0, 2000, 1, 11) + .select( + when('id < 1000, 999) + .when('id >= 1500, 2000) + .otherwise('id).as("key2"), 'id as "value2") + .createOrReplaceTempView("skewData2") + + val query = "SELECT * FROM skewData1 JOIN skewData2 ON key1 = key2" + + withSQLConf(SQLConf.SKEW_JOIN_MAX_SPLITS_PER_PARTITION.key -> "1000") { + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + assert(adaptivePlan.execute().getNumPartitions > 150) + } + + withSQLConf(SQLConf.SKEW_JOIN_MAX_SPLITS_PER_PARTITION.key -> "10") { + /** + * related log: + * Splittable ShuffleQueryStages: [0, 1] + * Splitting ShuffleQueryStage #0: partition 57(6 KB) -> 6 splits + * Splitting ShuffleQueryStage #1: partition 57(6 KB) -> 7 splits + * Re-splitting ShuffleQueryStage #0: partition 57(6 KB) -> 3 splits + * Re-splitting ShuffleQueryStage #1: partition 57(6 KB) -> 3 splits + * Splitting ShuffleQueryStage #0: partition 86(2 KB) -> 4 splits + * Splitting ShuffleQueryStage #1: partition 86(2 KB) -> 4 splits + * Re-splitting ShuffleQueryStage #0: partition 86(2 KB) -> 3 splits + * Re-splitting ShuffleQueryStage #1: partition 86(2 KB) -> 3 splits + */ + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + assert(adaptivePlan.execute().getNumPartitions < 120) + } + } + } + } + + test("SPARK-36638: Generalize OptimizeSkewedJoin - handle Combinatorial Explosion II") { + import OptimizeSkewedJoin.combine + + val array0 = Array(511, 622) // 317,842 splits + assert(combine(1, array0) === Array(1, 1)) // 1 split + assert(combine(10, array0) === Array(3, 3)) // 9 splits + assert(combine(100, array0) === Array(9, 11)) // 99 splits + assert(combine(1000, array0) === Array(29, 34)) // 986 splits + assert(combine(10000, array0) === Array(90, 110)) // 9,900 splits + assert(combine(100000, array0) === Array(286, 349)) // 99,814 splits + + val array1 = Array(1111, 4096, 128, 8) // 364,904,448 splits + assert(combine(1, array1) === Array(1, 1, 1, 1)) // 1 split + assert(combine(10, array1) === Array(2, 5, 1, 1)) // 10 splits + assert(combine(100, array1) === Array(5, 19, 1, 1)) // 95 splits + assert(combine(1000, array1) === Array(12, 45, 1, 1)) // 540 splits + assert(combine(10000, array1) === Array(29, 105, 3, 1)) // 9,135 splits + assert(combine(100000, array1) === Array(61, 226, 7, 1)) // 96,502 splits + + val array2 = Array(77, 99, 77) // 586,971 splits + assert(combine(1, array2) === Array(1, 1, 1)) // 1 split + assert(combine(10, array2) === Array(2, 2, 2)) // 8 splits + assert(combine(100, array2) === Array(4, 5, 4)) // 80 splits + assert(combine(1000, array2) === Array(9, 12, 9)) // 972 splits + assert(combine(10000, array2) === Array(20, 25, 20)) // 10,000 splits + assert(combine(100000, array2) === Array(42, 55, 43)) // 99,330 splits + + val array3 = Array(9999) // 9,999 splits + assert(combine(1, array3) === Array(1)) // 1 split + assert(combine(10, array3) === Array(10)) // 10 splits + assert(combine(100, array3) === Array(100)) // 100 splits + assert(combine(1000, array3) === Array(1000)) // 1000 splits + assert(combine(10000, array3) === Array(9999)) // 9,999 splits + + val array4 = Array(10, 20, 30, 4, 10, 2, 1, 999, 88) // 42,197,760,000 splits + assert(combine(1, array4) === Array(1, 1, 1, 1, 1, 1, 1, 1, 1)) // 1 split + assert(combine(10, array4) === Array(1, 1, 1, 1, 1, 1, 1, 10, 1)) // 10 splits + assert(combine(100, array4) === Array(1, 1, 1, 1, 1, 1, 1, 33, 3)) // 99 splits + assert(combine(1000, array4) === Array(1, 1, 2, 1, 1, 1, 1, 69, 6)) // 828 splits + assert(combine(10000, array4) === Array(1, 2, 4, 1, 1, 1, 1, 119, 10)) // 9,520 splits + assert(combine(100000, array4) === Array(2, 3, 4, 1, 2, 1, 1, 148, 13)) // 92,352 splits + + val array5 = Array.fill(10)(2) // 1,024 splits + assert(combine(1, array5) === Array(1, 1, 1, 1, 1, 1, 1, 1, 1, 1)) // 1 split + assert(combine(10, array5) === Array(1, 1, 1, 1, 1, 1, 1, 2, 2, 2)) // 8 splits + assert(combine(100, array5) === Array(1, 1, 1, 1, 2, 2, 2, 2, 2, 2)) // 64 splits + assert(combine(1000, array5) === Array(1, 2, 2, 2, 2, 2, 2, 2, 2, 2)) // 512 splits + assert(combine(10000, array5) === Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2)) // 1,024 splits + } + test("SPARK-30291: AQE should catch the exceptions when doing materialize") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { @@ -2520,11 +2800,10 @@ class AdaptiveQueryExecSuite "UNION ALL SELECT key2 FROM skewData2 GROUP BY key2", 1, 1) // skewJoin1 union (skewJoin2 join aggregate) - // skewJoin2 will lead to extra shuffles, but skew1 cannot be optimized checkSkewJoin( "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 UNION ALL " + "SELECT key1 from (SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2) tmp1 " + - "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 = key2", 3, 0) + "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 = key2", 3, 3) } } }