diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 99ef23e54c74b..28a9225b6ce23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.plans.physical -import java.util.Objects - import scala.annotation.tailrec import scala.collection.mutable @@ -360,12 +358,11 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa * preserved through `GroupPartitionsExec`. The sorted order is critical for storage-partitioned * join compatibility. * - * 2. '''In KeyGroupedShuffleSpec''': When used within `KeyGroupedShuffleSpec`, the `partitionKeys` - * may not be in sorted order. This occurs because `KeyGroupedShuffleSpec` can project the - * partition keys by join key positions. The `EnsureRequirements` rule ensures that either the - * unordered keys from both sides of a join match exactly, or it builds a common ordered set of - * keys and pushes them down to `GroupPartitionsExec` on both sides to establish a compatible - * ordering. + * 2. '''In KeyedShuffleSpec''': When used within `KeyedShuffleSpec`, the `partitionKeys` may not be + * in sorted order. This occurs because `KeyedShuffleSpec` can project the partition keys by join + * key positions. The `EnsureRequirements` rule ensures that either the unordered keys from both + * sides of a join match exactly, or it builds a common ordered set of keys and pushes them down + * to `GroupPartitionsExec` on both sides to establish a compatible ordering. * * == Partition Keys == * - `partitionKeys`: The partition keys, one per partition. May contain duplicates initially @@ -427,7 +424,7 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa * @param expressions Partition transform expressions (e.g., `years(col)`, `bucket(10, col)`). * @param partitionKeys Partition keys wrapped in InternalRowComparableWrapper for efficient * comparison and grouping. One per partition. When used as outputPartitioning, - * always in sorted order. When used in KeyGroupedShuffleSpec, may be unsorted + * always in sorted order. When used in `KeyedShuffleSpec`, may be unsorted * after projection. May contain duplicates when ungrouped. * @param isGrouped Whether partition keys are unique (no duplicates). Computed on first * creation, then preserved through copy operations to avoid recomputation. @@ -509,7 +506,7 @@ case class KeyedPartitioning( } override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { - val result = KeyGroupedShuffleSpec(this, distribution) + val result = KeyedShuffleSpec(this, distribution) if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { // If allowing join keys to be subset of clustering keys, we should create a new // `KeyedPartitioning` here that is grouped on the join keys instead, and use that as @@ -525,16 +522,6 @@ case class KeyedPartitioning( result } } - - override def equals(that: Any): Boolean = that match { - case k: KeyedPartitioning if this.expressions == k.expressions => - this.partitionKeys == k.partitionKeys - - case _ => false - } - - override def hashCode(): Int = - Objects.hash(expressions, partitionKeys) } object KeyedPartitioning { @@ -954,7 +941,7 @@ case class CoalescedHashShuffleSpec( * @param joinKeyPositions position of join keys among cluster keys. * This is set if joining on a subset of cluster keys is allowed. */ -case class KeyGroupedShuffleSpec( +case class KeyedShuffleSpec( partitioning: KeyedPartitioning, distribution: ClusteredDistribution, joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec { @@ -992,7 +979,7 @@ case class KeyGroupedShuffleSpec( // 3.3 each pair of partition expressions at the same index must share compatible // transform functions. // 4. the partition values from both sides are following the same order. - case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => + case otherSpec @ KeyedShuffleSpec(otherPartitioning, otherDistribution, _) => distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionKeys == otherPartitioning.partitionKeys @@ -1003,7 +990,7 @@ case class KeyGroupedShuffleSpec( // Whether the partition keys (i.e., partition expressions) are compatible between this and the // `other` spec. - def areKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = { + def areKeysCompatible(other: KeyedShuffleSpec): Boolean = { val expressions = partitioning.expressions val otherExpressions = other.partitioning.expressions @@ -1047,7 +1034,7 @@ case class KeyGroupedShuffleSpec( * * @param other other key-grouped shuffle spec */ - def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { + def reducers(other: KeyedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { val results = partitioning.expressions.zip(other.partitioning.expressions).map { case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) case (_, _) => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 84f2f6b90aa7a..2887733df5876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2075,11 +2075,11 @@ object SQLConf { val V2_BUCKETING_PUSH_PART_VALUES_ENABLED = buildConf("spark.sql.sources.v2.bucketing.pushPartValues.enabled") .doc(s"Whether to pushdown common partition values when ${V2_BUCKETING_ENABLED.key} is " + - "enabled. When turned on, if both sides of a join are of KeyGroupedPartitioning and if " + + "enabled. When turned on, if both sides of a join are of KeyedPartitioning and if " + "they share compatible partition keys, even if they don't have the exact same partition " + "values, Spark will calculate a superset of partition values and pushdown that info to " + - "scan nodes, which will use empty partitions for the missing partition values on either " + - "side. This could help to eliminate unnecessary shuffles") + "group partition nodes, which will use empty partitions for the missing partition values " + + "on either side. This could help to eliminate unnecessary shuffles") .version("3.4.0") .booleanConf .createWithDefault(true) @@ -2087,7 +2087,7 @@ object SQLConf { val V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED = buildConf("spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled") .doc("During a storage-partitioned join, whether to allow input partitions to be " + - "partially clustered, when both sides of the join are of KeyGroupedPartitioning. At " + + "partially clustered, when both sides of the join are of KeyedPartitioning. At " + "planning time, Spark will pick the side with less data size based on table " + "statistics, group and replicate them to match the other side. This is an optimization " + "on skew join and can help to reduce data skewness when certain partitions are assigned " + @@ -2100,7 +2100,7 @@ object SQLConf { val V2_BUCKETING_SHUFFLE_ENABLED = buildConf("spark.sql.sources.v2.bucketing.shuffle.enabled") .doc("During a storage-partitioned join, whether to allow to shuffle only one side. " + - "When only one side is KeyGroupedPartitioning, if the conditions are met, spark will " + + "When only one side is KeyedPartitioning, if the conditions are met, spark will " + "only shuffle the other side. This optimization will reduce the amount of data that " + s"needs to be shuffle. This config requires ${V2_BUCKETING_ENABLED.key} to be enabled") .version("4.0.0") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 9910c4eb788cc..221741a56fe09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -45,16 +45,16 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * @param expectedPartitionKeys Optional sequence of expected partition key values and their * split counts * @param reducers Optional reducers to apply to partition keys for grouping compatibility - * @param applyPartialClustering Whether to apply partial clustering for skewed data - * @param replicatePartitions Whether to replicate partitions across multiple keys + * @param distributePartitions When true, splits for a key are distributed across the expected + * partitions (padding with empty partitions). When false, all splits + * are replicated to every expected partition for that key. */ case class GroupPartitionsExec( child: SparkPlan, @transient joinKeyPositions: Option[Seq[Int]] = None, @transient expectedPartitionKeys: Option[Seq[(InternalRowComparableWrapper, Int)]] = None, @transient reducers: Option[Seq[Option[Reducer[_, _]]]] = None, - @transient applyPartialClustering: Boolean = false, - @transient replicatePartitions: Boolean = false + @transient distributePartitions: Boolean = false ) extends UnaryExecNode { override def outputPartitioning: Partitioning = { @@ -91,7 +91,7 @@ case class GroupPartitionsExec( val alignedPartitions = expectedPartitionKeys.get.flatMap { case (key, numSplits) => if (numSplits > 1) isGrouped = false val splits = keyMap.getOrElse(key, Seq.empty) - if (applyPartialClustering && !replicatePartitions) { + if (distributePartitions) { // Distribute splits across expected partitions, padding with empty sequences val paddedSplits = splits.map(Seq(_)).padTo(numSplits, Seq.empty) paddedSplits.map((key, _)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 39da546256132..cca37558584f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -71,26 +71,40 @@ case class EnsureRequirements( child } else { // Check KeyedPartitioning satisfaction conditions - val groupedSatisfies = grouped.exists(_.satisfies(distribution)) + val groupedSatisfies = grouped.find(_.satisfies(distribution)) val nonGroupedSatisfiesAsIs = nonGrouped.exists(_.nonGroupedSatisfies(distribution)) - val nonGroupedSatisfiesWhenGrouped = nonGrouped.exists(_.groupedSatisfies(distribution)) + val nonGroupedSatisfiesWhenGrouped = nonGrouped.find(_.groupedSatisfies(distribution)) // Check if any KeyedPartitioning satisfies the distribution - if (groupedSatisfies || nonGroupedSatisfiesAsIs || nonGroupedSatisfiesWhenGrouped) { + if (groupedSatisfies.isDefined || nonGroupedSatisfiesAsIs + || nonGroupedSatisfiesWhenGrouped.isDefined) { distribution match { case o: OrderedDistribution => - // OrderedDistribution requires grouped KeyedPartitioning with sorted keys. + // OrderedDistribution requires grouped KeyedPartitioning with sorted keys + // according to the distribution's ordering. // Find any KeyedPartitioning that satisfies via groupedSatisfies. val satisfyingKeyedPartitioning = - (grouped ++ nonGrouped).find(_.groupedSatisfies(distribution)).get + groupedSatisfies.orElse(nonGroupedSatisfiesWhenGrouped).get val attrs = satisfyingKeyedPartitioning.expressions.flatMap(_.collectLeaves()) .map(_.asInstanceOf[Attribute]) val keyRowOrdering = RowOrdering.create(o.ordering, attrs) val keyOrdering = keyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) - val sorted = satisfyingKeyedPartitioning.partitionKeys.sorted(keyOrdering) - GroupPartitionsExec(child, expectedPartitionKeys = Some(sorted.map((_, 1)))) - - case _ if groupedSatisfies => + if (satisfyingKeyedPartitioning.partitionKeys.sliding(2).forall { + case Seq(k1, k2) => keyOrdering.lteq(k1, k2) + }) { + child + } else { + // Use distributePartitions to spread splits across expected partitions + val sortedGroupedKeys = satisfyingKeyedPartitioning.partitionKeys + .groupBy(identity).view.mapValues(_.size) + .toSeq.sortBy(_._1)(keyOrdering) + GroupPartitionsExec(child, + expectedPartitionKeys = Some(sortedGroupedKeys), + distributePartitions = true + ) + } + + case _ if groupedSatisfies.isDefined => // Grouped KeyedPartitioning already satisfies child @@ -238,7 +252,7 @@ case class EnsureRequirements( // Hence we need to ensure that after this call, the outputPartitioning of the // partitioned side's BatchScanExec is grouped by join keys to match, // and we do that by pushing down the join keys - case Some(KeyGroupedShuffleSpec(_, _, Some(joinKeyPositions))) => + case Some(KeyedShuffleSpec(_, _, Some(joinKeyPositions))) => withJoinKeyPositions(child, joinKeyPositions) case _ => child } @@ -258,7 +272,7 @@ case class EnsureRequirements( child match { case ShuffleExchangeExec(_, c, so, ps) => ShuffleExchangeExec(newPartitioning, c, so, ps) - case GroupPartitionsExec(c, _, _, _, _, _) => ShuffleExchangeExec(newPartitioning, c) + case GroupPartitionsExec(c, _, _, _, _) => ShuffleExchangeExec(newPartitioning, c) case _ => ShuffleExchangeExec(newPartitioning, child) } } @@ -440,7 +454,7 @@ case class EnsureRequirements( val specs = Seq(left, right).zip(requiredChildDistribution).map { case (p, d) => if (!d.isInstanceOf[ClusteredDistribution]) return None val cd = d.asInstanceOf[ClusteredDistribution] - val specOpt = createKeyGroupedShuffleSpec(p.outputPartitioning, cd) + val specOpt = createKeyedShuffleSpec(p.outputPartitioning, cd) if (specOpt.isEmpty) return None specOpt.get } @@ -454,7 +468,7 @@ case class EnsureRequirements( // partitionings are not modified (projected) in specs and left and right side partitionings are // compatible with each other. // Left and right `outputPartitioning` is a `PartitioningCollection` or a `KeyedPartitioning` - // otherwise `createKeyGroupedShuffleSpec()` would have returned `None`. + // otherwise `createKeyedShuffleSpec()` would have returned `None`. var isCompatible = left.outputPartitioning.asInstanceOf[Expression].exists(_ == leftPartitioning) && right.outputPartitioning.asInstanceOf[Expression].exists(_ == rightPartitioning) && @@ -593,7 +607,7 @@ case class EnsureRequirements( val originalPartitioning = partiallyClusteredChild.outputPartitioning.asInstanceOf[Expression] // `outputPartitioning` is either a `PartitioningCollection` or a `KeyedPartitioning` - // otherwise `createKeyGroupedShuffleSpec()` would have returned `None`. + // otherwise `createKeyedShuffleSpec()` would have returned `None`. val originalKeyedPartitioning = originalPartitioning.collectFirst { case k: KeyedPartitioning => k }.get val projectedOriginalPartitionKeys = partiallyClusteredSpec.joinKeyPositions @@ -616,9 +630,9 @@ case class EnsureRequirements( // Now we need to push-down the common partition information to the `GroupPartitionsExec`s. newLeft = applyGroupPartitions(left, leftSpec.joinKeyPositions, mergedPartitionKeys, - leftReducers, applyPartialClustering, replicateLeftSide) + leftReducers, distributePartitions = applyPartialClustering && !replicateLeftSide) newRight = applyGroupPartitions(right, rightSpec.joinKeyPositions, mergedPartitionKeys, - rightReducers, applyPartialClustering, replicateRightSide) + rightReducers, distributePartitions = applyPartialClustering && !replicateRightSide) } } @@ -673,21 +687,19 @@ case class EnsureRequirements( joinKeyPositions: Option[Seq[Int]], mergedPartitionKeys: Seq[(InternalRowComparableWrapper, Int)], reducers: Option[Seq[Option[Reducer[_, _]]]], - applyPartialClustering: Boolean, - replicatePartitions: Boolean): SparkPlan = { + distributePartitions: Boolean): SparkPlan = { plan match { case g: GroupPartitionsExec => val newGroupPartitions = g.copy( joinKeyPositions = joinKeyPositions, expectedPartitionKeys = Some(mergedPartitionKeys), reducers = reducers, - applyPartialClustering = applyPartialClustering, - replicatePartitions = replicatePartitions) + distributePartitions = distributePartitions) newGroupPartitions.copyTagsFrom(g) newGroupPartitions case _ => GroupPartitionsExec(plan, joinKeyPositions, Some(mergedPartitionKeys), reducers, - applyPartialClustering, replicatePartitions) + distributePartitions) } } @@ -705,14 +717,14 @@ case class EnsureRequirements( } /** - * Tries to create a [[KeyGroupedShuffleSpec]] from the input partitioning and distribution, if - * the partitioning is a [[KeyedPartitioning]] (either directly or indirectly), and - * satisfies the given distribution. + * Tries to create a [[KeyedShuffleSpec]] from the input partitioning and distribution, if the + * partitioning is a [[KeyedPartitioning]] (either directly or indirectly), and satisfies the + * given distribution. */ - private def createKeyGroupedShuffleSpec( + private def createKeyedShuffleSpec( partitioning: Partitioning, - distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = { - def tryCreate(partitioning: KeyedPartitioning): Option[KeyGroupedShuffleSpec] = { + distribution: ClusteredDistribution): Option[KeyedShuffleSpec] = { + def tryCreate(partitioning: KeyedPartitioning): Option[KeyedShuffleSpec] = { val attributes = partitioning.expressions.flatMap(_.collectLeaves()) val clustering = distribution.clustering @@ -725,7 +737,7 @@ case class EnsureRequirements( } if (satisfies) { - Some(partitioning.createShuffleSpec(distribution).asInstanceOf[KeyGroupedShuffleSpec]) + Some(partitioning.createShuffleSpec(distribution).asInstanceOf[KeyedShuffleSpec]) } else { None } @@ -734,7 +746,7 @@ case class EnsureRequirements( partitioning match { case p: KeyedPartitioning => tryCreate(p) case PartitioningCollection(partitionings) => - partitionings.collectFirst(Function.unlift(createKeyGroupedShuffleSpec(_, distribution))) + partitionings.collectFirst(Function.unlift(createKeyedShuffleSpec(_, distribution))) case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 61384bf9f1fca..ace0040049efe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -439,12 +439,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Seq(true, false).foreach { sortingEnabled => withSQLConf(SQLConf.V2_BUCKETING_SORTING_ENABLED.key -> sortingEnabled.toString) { - def verifyShuffle(cmd: String, answer: Seq[Row]): Unit = { + def verifyShuffle(cmd: String, answer: Seq[Row], expectedGroupPartitions: Int): Unit = { val df = sql(cmd) if (sortingEnabled) { assert(collectAllShuffles(df.queryExecution.executedPlan).isEmpty, "should contain no shuffle when sorting by partition values") - assert(collectAllGroupPartitions(df.queryExecution.executedPlan).size == 1, + assert(collectAllGroupPartitions(df.queryExecution.executedPlan).size == + expectedGroupPartitions, "should contain partition grouping when sorting by partition values") } else { assert(collectAllShuffles(df.queryExecution.executedPlan).size == 1, @@ -457,30 +458,32 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { verifyShuffle( s"SELECT price, id FROM testcat.ns.$items ORDER BY price ASC, id ASC", + // Default ordering of partitions matches requested ordering so we don't expect any + // shuffles or group partitions Seq(Row(null, 3), Row(10.0, 2), Row(15.5, null), - Row(15.5, 3), Row(40.0, 1), Row(41.0, 1))) + Row(15.5, 3), Row(40.0, 1), Row(41.0, 1)), 0) verifyShuffle( s"SELECT price, id FROM testcat.ns.$items " + s"ORDER BY price ASC NULLS LAST, id ASC NULLS LAST", Seq(Row(10.0, 2), Row(15.5, 3), Row(15.5, null), - Row(40.0, 1), Row(41.0, 1), Row(null, 3))) + Row(40.0, 1), Row(41.0, 1), Row(null, 3)), 1) verifyShuffle( s"SELECT price, id FROM testcat.ns.$items ORDER BY price DESC, id ASC", Seq(Row(41.0, 1), Row(40.0, 1), Row(15.5, null), - Row(15.5, 3), Row(10.0, 2), Row(null, 3))) + Row(15.5, 3), Row(10.0, 2), Row(null, 3)), 1) verifyShuffle( s"SELECT price, id FROM testcat.ns.$items ORDER BY price DESC, id DESC", Seq(Row(41.0, 1), Row(40.0, 1), Row(15.5, 3), - Row(15.5, null), Row(10.0, 2), Row(null, 3))) + Row(15.5, null), Row(10.0, 2), Row(null, 3)), 1) verifyShuffle( s"SELECT price, id FROM testcat.ns.$items " + s"ORDER BY price DESC NULLS FIRST, id DESC NULLS FIRST", Seq(Row(null, 3), Row(41.0, 1), Row(40.0, 1), - Row(15.5, null), Row(15.5, 3), Row(10.0, 2))); + Row(15.5, null), Row(15.5, 3), Row(10.0, 2)), 1); } } } @@ -3142,7 +3145,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } - test("SPARK-55535: Empty group partitions due filtered partitions") { + test("SPARK-55535: Empty group partitions due to filtered partitions") { val items_partitions = Array(identity("id")) createTable(items, itemsColumns, items_partitions) @@ -3167,4 +3170,50 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { "group partitions should not have any (common) partitions") } } + + test("SPARK-55535: Order by on partitions keys") { + withSQLConf(SQLConf.V2_BUCKETING_SORTING_ENABLED.key -> "true") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(2, 'aa', 10.0, cast('2021-01-01' as timestamp)), " + + "(3, 'aa', 20.0, cast('2022-01-01' as timestamp)), " + + "(1, 'aa', 40.0, cast('2022-01-01' as timestamp))") + + val df = sql(s"SELECT id FROM testcat.ns.$items i ORDER BY id") + + val expected = (1 to 3).map(Row(_)) + checkAnswer(df, expected) + + val reverseDf = sql(s"SELECT id FROM testcat.ns.$items i ORDER BY id DESC") + + checkAnswer(reverseDf, expected.reverse) + + sql(s"INSERT INTO testcat.ns.$items VALUES (2, 'aa', 30.0, cast('2021-01-01' as timestamp))") + + val dfWithDuplicate = sql(s"SELECT id FROM testcat.ns.$items i ORDER BY id") + + val expectedWithDuplicate = Seq(1, 2, 2, 3).map(Row(_)) + checkAnswer(dfWithDuplicate, expectedWithDuplicate) + + val reverseDfWithDuplicate = sql(s"SELECT id FROM testcat.ns.$items i ORDER BY id DESC") + + checkAnswer(reverseDfWithDuplicate, expectedWithDuplicate.reverse) + + Seq( + df -> Seq.empty, + reverseDf -> Seq(3), + dfWithDuplicate -> Seq.empty, + reverseDfWithDuplicate -> Seq(4) + ).foreach { + case (df, expectedPartitions) => + val shuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not contain any shuffle") + + val groupPartitions = collectAllGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.map(_.outputPartitioning.numPartitions) == expectedPartitions) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 7512cbe7f90b1..9c67a334c801c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1037,10 +1037,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, GroupPartitionsExec(DummySparkPlan(_, _, left: KeyedPartitioning, _, _), - _, _, _, _, _), _), + _, _, _, _), _), SortExec(_, _, GroupPartitionsExec(DummySparkPlan(_, _, right: KeyedPartitioning, _, _), - _, _, _, _, _), _), + _, _, _, _), _), _) => assert(left.expressions === Seq(bucket(4, exprB), bucket(8, exprC))) assert(right.expressions === Seq(bucket(4, exprC), bucket(8, exprB))) @@ -1061,10 +1061,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, GroupPartitionsExec(DummySparkPlan(_, _, left: PartitioningCollection, _, _), - _, _, _, _, _), _), + _, _, _, _), _), SortExec(_, _, GroupPartitionsExec(DummySparkPlan(_, _, right: KeyedPartitioning, _, _), - _, _, _, _, _), _), + _, _, _, _), _), _) => assert(left.partitionings.length == 2) assert(left.partitionings.head.isInstanceOf[KeyedPartitioning]) @@ -1096,10 +1096,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, GroupPartitionsExec(DummySparkPlan(_, _, left: PartitioningCollection, _, _), - _, _, _, _, _), _), + _, _, _, _), _), SortExec(_, _, GroupPartitionsExec(DummySparkPlan(_, _, right: PartitioningCollection, _, _), - _, _, _, _, _), _), + _, _, _, _), _), _) => assert(left.partitionings.length == 2) assert(left.partitionings.head.isInstanceOf[KeyedPartitioning])