diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 7d086f34f6983..7dd6015b2f87e 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -142,7 +142,7 @@ private[spark] class PartitionIdPassthrough(override val numPartitions: Int) ext /** * A [[org.apache.spark.Partitioner]] that partitions all records using partition value map. * The `valueMap` is a map that contains tuples of (partition value, partition id). It is generated - * by [[org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning]], used to partition + * by [[org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning]], used to partition * the other side of a join to make sure records with same partition value are in the same * partition. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index b0fa4f889cda1..99ef23e54c74b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical +import java.util.Objects + import scala.annotation.tailrec import scala.collection.mutable @@ -346,69 +348,163 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa } /** - * Represents a partitioning where rows are split across partitions based on transforms defined - * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in - * ascending order, after evaluated by the transforms in `expressions`, for each input partition. - * In addition, its length must be the same as the number of Spark partitions (and thus is a 1-1 - * mapping), and each row in `partitionValues` must be unique. + * Represents a partitioning where rows are split across partitions based on transforms defined by + * `expressions`. + * + * == Usage Forms == + * `KeyedPartitioning` is used in two distinct forms: + * + * 1. '''As outputPartitioning''': When used as a node's output partitioning (e.g., in + * `BatchScanExec` or `GroupPartitionsExec`), the `partitionKeys` are always in sorted order. + * This is how leaf data source nodes produce partition keys originally, and this ordering is + * preserved through `GroupPartitionsExec`. The sorted order is critical for storage-partitioned + * join compatibility. + * + * 2. '''In KeyGroupedShuffleSpec''': When used within `KeyGroupedShuffleSpec`, the `partitionKeys` + * may not be in sorted order. This occurs because `KeyGroupedShuffleSpec` can project the + * partition keys by join key positions. The `EnsureRequirements` rule ensures that either the + * unordered keys from both sides of a join match exactly, or it builds a common ordered set of + * keys and pushes them down to `GroupPartitionsExec` on both sides to establish a compatible + * ordering. + * + * == Partition Keys == + * - `partitionKeys`: The partition keys, one per partition. May contain duplicates initially + * (ungrouped state), but becomes unique after `GroupPartitionsExec` applies grouping. + * + * == Grouping State == + * A KeyedPartitioning can be in two states: + * + * - '''Ungrouped''' (when `isGrouped == false`): `partitionKeys` contains duplicates, meaning + * multiple input partitions share the same key. This occurs when a data source has multiple + * splits for the same partition value. + * + * - '''Grouped''' (when `isGrouped == true`): `partitionKeys` contains only unique values, with + * each partition having a distinct key. This occurs when: (1) a data source natively produces + * unique partition keys, or (2) `GroupPartitionsExec` coalesces partitions with duplicate keys. + * + * == Distribution Satisfaction and Grouping == + * The `satisfies()` method returns true if this partitioning can satisfy a distribution, + * regardless of whether the partitioning is actually grouped. The method delegates to: + * - `nonGroupedSatisfies()`: Returns true for basic distributions (UnspecifiedDistribution, + * AllTuples when single partition) + * - `groupedSatisfies()`: Returns true for distributions requiring grouped partitioning + * (ClusteredDistribution, OrderedDistribution) * - * The `originalPartitionValues`, on the other hand, are partition values from the original input - * splits returned by data sources. It may contain duplicated values. + * If `satisfies()` returns true but `isGrouped == false`, the partitioning does NOT actually + * satisfy the distribution yet. The `EnsureRequirements` rule must insert `GroupPartitionsExec` to + * coalesce duplicate partition keys before the distribution requirement is truly satisfied. * - * For example, if a data source reports partition transform expressions `[years(ts_col)]` with 4 - * input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions` - * in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, which - * represents 3 input partitions with distinct partition values. All rows in each partition have - * the same value for column `ts_col` (which is of timestamp type), after being applied by the - * `years` transform. This is generated after combining the two splits with partition value `2` - * into a single Spark partition. + * For example, an ungrouped KeyedPartitioning with keys `[1, 2, 2, 3]` will return + * `satisfies(ClusteredDistribution(...)) == true` because it can satisfy the distribution after + * grouping. However, `EnsureRequirements` must add `GroupPartitionsExec` to produce grouped keys + * `[1, 2, 3]` before the distribution is actually satisfied. * - * On the other hand, in this example `[0, 1, 2, 2]` is the value of `originalPartitionValues` - * which is calculated from the original input splits. + * Similarly, for `OrderedDistribution`, even if `satisfies()` returns true, `GroupPartitionsExec` + * must be added to both group the partitions AND sort the partition keys according to the + * ordering requirement. * - * @param expressions partition expressions for the partitioning. - * @param numPartitions the number of partitions - * @param partitionValues the values for the final cluster keys (that is, after applying grouping - * on the input splits according to `expressions`) of the distribution, - * must be in ascending order, and must NOT contain duplicated values. - * @param originalPartitionValues the original input partition values before any grouping has been - * applied, must be in ascending order, and may contain duplicated - * values + * == Example == + * Consider a data source with partition transform `[years(ts_col)]` and 4 input splits: + * + * '''Before GroupPartitionsExec''' (ungrouped): + * {{{ + * expressions: [years(ts_col)] + * partitionKeys: [0, 1, 2, 2] // partitions 2 and 3 have the same key + * numPartitions: 4 + * isGrouped: false + * satisfies(ClusteredDistribution(...)) == true // CAN satisfy after grouping + * }}} + * + * '''After GroupPartitionsExec''' (grouped): + * {{{ + * expressions: [years(ts_col)] + * partitionKeys: [0, 1, 2] // duplicates removed, partitions coalesced + * numPartitions: 3 + * isGrouped: true + * satisfies(ClusteredDistribution(...)) == true // ACTUALLY satisfies now + * }}} + * + * @param expressions Partition transform expressions (e.g., `years(col)`, `bucket(10, col)`). + * @param partitionKeys Partition keys wrapped in InternalRowComparableWrapper for efficient + * comparison and grouping. One per partition. When used as outputPartitioning, + * always in sorted order. When used in KeyGroupedShuffleSpec, may be unsorted + * after projection. May contain duplicates when ungrouped. + * @param isGrouped Whether partition keys are unique (no duplicates). Computed on first + * creation, then preserved through copy operations to avoid recomputation. */ -case class KeyGroupedPartitioning( +case class KeyedPartitioning( expressions: Seq[Expression], - numPartitions: Int, - partitionValues: Seq[InternalRow] = Seq.empty, - originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike { + @transient partitionKeys: Seq[InternalRowComparableWrapper], + isGrouped: Boolean) extends Expression with Partitioning with Unevaluable { + override val numPartitions = partitionKeys.length + + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: DataType = IntegerType + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): KeyedPartitioning = + copy(expressions = newChildren) + + @transient lazy val expressionDataTypes: Seq[DataType] = expressions.map(_.dataType) + + @transient lazy val keyRowOrdering = + RowOrdering.createNaturalAscendingOrdering(expressionDataTypes) + + @transient lazy val keyOrdering = keyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) + + def toGrouped: KeyedPartitioning = { + val groupedPartitionKeys = partitionKeys.distinct.sorted(keyOrdering) + + new KeyedPartitioning(expressions, groupedPartitionKeys, isGrouped = true) + } + + /** + * Projects this partitioning's expressions by selecting only the specified positions. + * Returns the projected expressions and their data types together with the projected keys. + */ + def projectKeys(positions: Seq[Int]): (Seq[DataType], Seq[InternalRowComparableWrapper]) = + KeyedPartitioning.projectKeys(partitionKeys, expressionDataTypes, positions) + + /** + * Reduces this partitioning's partition keys by applying the given reducers. + * Returns the distinct reduced keys. + */ + def reduceKeys(reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRowComparableWrapper] = + KeyedPartitioning.reduceKeys(partitionKeys, expressionDataTypes, reducers).distinct override def satisfies0(required: Distribution): Boolean = { - super.satisfies0(required) || { - required match { - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => - if (requireAllClusterKeys) { - // Checks whether this partitioning is partitioned on exactly same clustering keys of - // `ClusteredDistribution`. - c.areAllClusterKeysMatched(expressions) + nonGroupedSatisfies(required) || groupedSatisfies(required) + } + + def nonGroupedSatisfies(required: Distribution): Boolean = super.satisfies0(required) + + def groupedSatisfies(required: Distribution): Boolean = { + required match { + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + if (requireAllClusterKeys) { + // Checks whether this partitioning is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) + } else { + // We'll need to find leaf attributes from the partition expressions first. + val attributes = expressions.flatMap(_.collectLeaves()) + + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // check that join keys (required clustering keys) + // overlap with partition keys (KeyedPartitioning attributes) + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) } else { - // We'll need to find leaf attributes from the partition expressions first. - val attributes = expressions.flatMap(_.collectLeaves()) - - if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - // check that join keys (required clustering keys) - // overlap with partition keys (KeyGroupedPartitioning attributes) - requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && - expressions.forall(_.collectLeaves().size == 1) - } else { - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) - } + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) } + } - case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => - o.areAllClusterKeysMatched(expressions) + case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => + o.areAllClusterKeysMatched(expressions) - case _ => - false - } + case _ => + false } } @@ -416,61 +512,45 @@ case class KeyGroupedPartitioning( val result = KeyGroupedShuffleSpec(this, distribution) if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { // If allowing join keys to be subset of clustering keys, we should create a new - // `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as + // `KeyedPartitioning` here that is grouped on the join keys instead, and use that as // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) - val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, - partitionValues, originalPartitionValues) + val projectedExpressions = joinKeyPositions.map(expressions) + val projectedKeys = projectKeys(joinKeyPositions)._2 + val distinctProjectedKeys = projectedKeys.distinct + val projectedPartitioning = + KeyedPartitioning(projectedExpressions, distinctProjectedKeys, isGrouped = true) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) } else { result } } - lazy val uniquePartitionValues: Seq[InternalRow] = { - val internalRowComparableFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - expressions.map(_.dataType)) - partitionValues - .map(internalRowComparableFactory) - .distinct - .map(_.row) + override def equals(that: Any): Boolean = that match { + case k: KeyedPartitioning if this.expressions == k.expressions => + this.partitionKeys == k.partitionKeys + + case _ => false } - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = - copy(expressions = newChildren) + override def hashCode(): Int = + Objects.hash(expressions, partitionKeys) } -object KeyGroupedPartitioning { +object KeyedPartitioning { + /** + * Creates a KeyedPartitioning with isGrouped computed from the partition keys. + * Use this when creating a new KeyedPartitioning from scratch (e.g., from a data source). + */ def apply( expressions: Seq[Expression], - projectionPositions: Seq[Int], - partitionValues: Seq[InternalRow], - originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { - val projectedExpressions = projectionPositions.map(expressions(_)) - val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) - val projectedOriginalPartitionValues = - originalPartitionValues.map(project(expressions, projectionPositions, _)) - val internalRowComparableFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - projectedExpressions.map(_.dataType)) - - val finalPartitionValues = projectedPartitionValues - .map(internalRowComparableFactory) - .distinct - .map(_.row) - - KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, - finalPartitionValues, projectedOriginalPartitionValues) - } - - def project( - expressions: Seq[Expression], - positions: Seq[Int], - input: InternalRow): InternalRow = { - val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType)) - .toArray - new GenericInternalRow(projectedValues) + partitionKeys: Seq[InternalRow]): KeyedPartitioning = { + val dataTypes = expressions.map(_.dataType) + val comparableKeyWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + val comparablePartitionKeys = partitionKeys.map(comparableKeyWrapperFactory) + val isGrouped = comparablePartitionKeys.distinct.size == comparablePartitionKeys.size + new KeyedPartitioning(expressions, comparablePartitionKeys, isGrouped) } def supportsExpressions(expressions: Seq[Expression]): Boolean = { @@ -491,6 +571,46 @@ object KeyGroupedPartitioning { case _ => false } } + + /** + * Projects a sequence of partition keys by selecting only the specified positions. + */ + def projectKeys( + keys: Seq[InternalRowComparableWrapper], + dataTypes: Seq[DataType], + positions: Seq[Int]): (Seq[DataType], Seq[InternalRowComparableWrapper]) = { + val projectedDataTypes = positions.map(dataTypes) + val comparableKeyWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) + val positionsWithTypes = positions.zip(projectedDataTypes) + val projectedKeys = keys.map { key => + val projectedKey = positionsWithTypes.map { + case (position, dataType) => key.row.get(position, dataType) + }.toArray[Any] + comparableKeyWrapperFactory(new GenericInternalRow(projectedKey)) + } + + (projectedDataTypes, projectedKeys) + } + + /** + * Reduces a sequence of partition keys by applying reducers to each position. + */ + def reduceKeys( + keys: Seq[InternalRowComparableWrapper], + dataTypes: Seq[DataType], + reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRowComparableWrapper] = { + val comparableKeyWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + keys.map { key => + val keyValues = key.row.toSeq(dataTypes) + val reducedKey = keyValues.zip(reducers).map { + case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) + case (v, _) => v + }.toArray + comparableKeyWrapperFactory(new GenericInternalRow(reducedKey)) + } + } } /** @@ -827,7 +947,7 @@ case class CoalescedHashShuffleSpec( } /** - * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]]. + * [[ShuffleSpec]] created by [[KeyedPartitioning]]. * * @param partitioning key grouped partitioning * @param distribution distribution @@ -835,7 +955,7 @@ case class CoalescedHashShuffleSpec( * This is set if joining on a subset of cluster keys is allowed. */ case class KeyGroupedShuffleSpec( - partitioning: KeyGroupedPartitioning, + partitioning: KeyedPartitioning, distribution: ClusteredDistribution, joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec { @@ -873,15 +993,9 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => - lazy val internalRowComparableFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - partitioning.expressions.map(_.dataType)) distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && - partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { - case (left, right) => - internalRowComparableFactory(left).equals(internalRowComparableFactory(right)) - } + partitioning.partitionKeys == otherPartitioning.partitionKeys case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) case _ => false @@ -959,25 +1073,7 @@ case class KeyGroupedShuffleSpec( te.copy(children = te.children.map(_ => clustering(positionSet.head))) case (_, positionSet) => clustering(positionSet.head) } - KeyGroupedPartitioning(newExpressions, - partitioning.numPartitions, - partitioning.partitionValues) - } -} - -object KeyGroupedShuffleSpec { - def reducePartitionValue( - row: InternalRow, - reducers: Seq[Option[Reducer[_, _]]], - dataTypes: Seq[DataType], - internalRowComparableWrapperFactory: InternalRow => InternalRowComparableWrapper - ): InternalRowComparableWrapper = { - val partitionVals = row.toSeq(dataTypes) - val reducedRow = partitionVals.zip(reducers).map{ - case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) - case (v, _) => v - }.toArray - internalRowComparableWrapperFactory(new GenericInternalRow(reducedRow)) + KeyedPartitioning(newExpressions, partitioning.partitionKeys, partitioning.isGrouped) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index b9935d40ed985..217d12710a6a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.util -import scala.collection.mutable - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BaseOrdering, Expression, Murmur3HashFunction, RowOrdering} import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition} @@ -101,31 +99,6 @@ object InternalRowComparableWrapper { new InternalRowComparableWrapper(partitionRow, partitionExpression.map(_.dataType)) } - def mergePartitions( - leftPartitioning: Seq[InternalRow], - rightPartitioning: Seq[InternalRow], - partitionExpression: Seq[Expression], - intersect: Boolean = false): Seq[InternalRowComparableWrapper] = { - val partitionDataTypes = partitionExpression.map(_.dataType) - val leftPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] - val internalRowComparableWrapperFactory = - getInternalRowComparableWrapperFactory(partitionDataTypes) - leftPartitioning - .map(internalRowComparableWrapperFactory) - .foreach(partition => leftPartitionSet.add(partition)) - val rightPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] - rightPartitioning - .map(internalRowComparableWrapperFactory) - .foreach(partition => rightPartitionSet.add(partition)) - - val result = if (intersect) { - leftPartitionSet.intersect(rightPartitionSet) - } else { - leftPartitionSet.union(rightPartitionSet) - } - result.toSeq - } - /** Creates a shared factory method for a given row schema to avoid excessive cache lookups. */ def getInternalRowComparableWrapperFactory( dataTypes: Seq[DataType]): InternalRow => InternalRowComparableWrapper = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala index 764dac35f6736..4f431e6171b28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.util +import scala.collection.mutable + import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} -import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning import org.apache.spark.sql.connector.catalog.PartitionInternalRow import org.apache.spark.sql.types.IntegerType @@ -41,30 +41,28 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase { val partitionNum = 200_000 val bucketNum = 4096 val day = 20240401 - val partitions = (0 until partitionNum).map { i => + val partitionKeys = (0 until partitionNum).map { i => val bucketId = i % bucketNum PartitionInternalRow.apply(Array(day, bucketId)); } val benchmark = new Benchmark("internal row comparable wrapper", partitionNum, output = output) + val comparableKeyWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + Seq(IntegerType, IntegerType)) + val comparablePartitionKeys = partitionKeys.map(comparableKeyWrapperFactory) + benchmark.addCase("toSet") { _ => - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - Seq(IntegerType, IntegerType)) - val distinct = partitions - .map(internalRowComparableWrapperFactory) - .toSet + val distinct = comparablePartitionKeys.toSet + assert(distinct.size == bucketNum) } benchmark.addCase("mergePartitions") { _ => - // just to mock the data types - val expressions = (Seq(Literal(day, IntegerType), Literal(0, IntegerType))) + val leftKeySet = mutable.HashSet.from(comparablePartitionKeys) + val rightKeySet = mutable.HashSet.from(comparablePartitionKeys) + val merged = leftKeySet.union(rightKeySet) - val leftPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions) - val rightPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions) - val merged = InternalRowComparableWrapper.mergePartitions( - leftPartitioning.partitionValues, rightPartitioning.partitionValues, expressions) assert(merged.size == bucketNum) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala deleted file mode 100644 index cac4a9bc852f6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.RowOrdering -import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec} -import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper -import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams - -/** Base trait for a data source scan capable of producing a key-grouped output. */ -trait KeyGroupedPartitionedScan[T] { - /** - * The output partitioning of this scan after applying any pushed-down SPJ parameters. - * - * @param basePartitioning The original key-grouped partitioning of the scan. - * @param spjParams SPJ parameters for the scan. - */ - def getOutputKeyGroupedPartitioning( - basePartitioning: KeyGroupedPartitioning, - spjParams: StoragePartitionJoinParams): KeyGroupedPartitioning = { - val projectedExpressions = spjParams.joinKeyPositions match { - case Some(projectionPositions) => - projectionPositions.map(i => basePartitioning.expressions(i)) - case _ => basePartitioning.expressions - } - - val newPartValues = spjParams.commonPartitionValues match { - case Some(commonPartValues) => - // We allow duplicated partition values if - // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true - commonPartValues.flatMap { - case (partValue, numSplits) => Seq.fill(numSplits)(partValue) - } - case None => - spjParams.joinKeyPositions match { - case Some(projectionPositions) => - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - projectedExpressions.map(_.dataType)) - basePartitioning.partitionValues.map { r => - val projectedRow = KeyGroupedPartitioning.project(basePartitioning.expressions, - projectionPositions, r) - internalRowComparableWrapperFactory(projectedRow) - }.distinct.map(_.row) - case _ => basePartitioning.partitionValues - } - } - basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length, - partitionValues = newPartValues) - } - - /** - * Re-groups the input partitions for this scan based on the provided SPJ params, returning a list - * of partitions to be scanned by each scan task. - * - * @param p The output KeyGroupedPartitioning of this scan. - * @param spjParams SPJ parameters for the scan. - * @param filteredPartitions The input partitions (after applying filtering) to be - * re-grouped for this scan, initially grouped by partition value. - * @param partitionValueAccessor Accessor for the partition values (as an [[InternalRow]]) - */ - def getInputPartitionGrouping( - p: KeyGroupedPartitioning, - spjParams: StoragePartitionJoinParams, - filteredPartitions: Seq[Seq[T]], - partitionValueAccessor: T => InternalRow): Seq[Seq[T]] = { - assert(spjParams.keyGroupedPartitioning.isDefined) - val expressions = spjParams.keyGroupedPartitioning.get - - // Re-group the input partitions if we are projecting on a subset of join keys - val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { - case Some(projectPositions) => - val projectedExpressions = projectPositions.map(i => expressions(i)) - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - projectedExpressions.map(_.dataType)) - val parts = filteredPartitions.flatten.groupBy(part => { - val row = partitionValueAccessor(part) - val projectedRow = KeyGroupedPartitioning.project( - expressions, projectPositions, row) - internalRowComparableWrapperFactory(projectedRow) - }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq - (parts, projectedExpressions) - case _ => - val groupedParts = filteredPartitions.map(splits => { - assert(splits.nonEmpty) - (partitionValueAccessor(splits.head), splits) - }) - (groupedParts, expressions) - } - - // Also re-group the partitions if we are reducing compatible partition expressions - val partitionDataTypes = partExpressions.map(_.dataType) - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(partitionDataTypes) - val finalGroupedPartitions = spjParams.reducers match { - case Some(reducers) => - val result = groupedPartitions.groupBy { case (row, _) => - KeyGroupedShuffleSpec.reducePartitionValue( - row, reducers, partitionDataTypes, internalRowComparableWrapperFactory) - }.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq - val rowOrdering = RowOrdering.createNaturalAscendingOrdering( - partExpressions.map(_.dataType)) - result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) - case _ => groupedPartitions - } - - // When partially clustered, the input partitions are not grouped by partition - // values. Here we'll need to check `commonPartitionValues` and decide how to group - // and replicate splits within a partition. - if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) { - // A mapping from the common partition values to how many splits the partition - // should contain. - val commonPartValuesMap = spjParams.commonPartitionValues - .get - .map(t => (internalRowComparableWrapperFactory(t._1), t._2)) - .toMap - val filteredGroupedPartitions = finalGroupedPartitions.filter { - case (partValues, _) => - commonPartValuesMap.keySet.contains(internalRowComparableWrapperFactory(partValues)) - } - val nestGroupedPartitions = filteredGroupedPartitions.map { case (partValue, splits) => - // `commonPartValuesMap` should contain the part value since it's the super set. - val numSplits = commonPartValuesMap.get(internalRowComparableWrapperFactory(partValue)) - assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + - "common partition values from Spark plan") - - val newSplits = if (spjParams.replicatePartitions) { - // We need to also replicate partitions according to the other side of join - Seq.fill(numSplits.get)(splits) - } else { - // Not grouping by partition values: this could be the side with partially - // clustered distribution. Because of dynamic filtering, we'll need to check if - // the final number of splits of a partition is smaller than the original - // number, and fill with empty splits if so. This is necessary so that both - // sides of a join will have the same number of partitions & splits. - splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) - } - (internalRowComparableWrapperFactory(partValue), newSplits) - } - - // Now fill missing partition keys with empty partitions - val partitionMapping = nestGroupedPartitions.toMap - spjParams.commonPartitionValues.get.flatMap { - case (partValue, numSplits) => - // Use empty partition for those partition values that are not present. - partitionMapping.getOrElse( - internalRowComparableWrapperFactory(partValue), - Seq.fill(numSplits)(Seq.empty)) - } - } else { - // either `commonPartitionValues` is not defined, or it is defined but - // `applyPartialClustering` is false. - val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) => - internalRowComparableWrapperFactory(partValue) -> splits - }.toMap - - // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there - // could exist duplicated partition values, as partition grouping is not done - // at the beginning and postponed to this method. It is important to use unique - // partition values here so that grouped partitions won't get duplicated. - p.uniquePartitionValues.map { partValue => - // Use empty partition for those partition values that are not present - partitionMapping.getOrElse(internalRowComparableWrapperFactory(partValue), Seq.empty) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 55866cc858405..bdecea2d4d085 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -24,12 +24,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read._ -import org.apache.spark.sql.execution.KeyGroupedPartitionedScan -import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams import org.apache.spark.util.ArrayImplicits._ /** @@ -41,8 +39,8 @@ case class BatchScanExec( runtimeFilters: Seq[Expression], ordering: Option[Seq[SortOrder]] = None, @transient table: Table, - spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams() - ) extends DataSourceV2ScanExecBase with KeyGroupedPartitionedScan[InputPartition] { + keyGroupedPartitioning: Option[Seq[Expression]] = None + ) extends DataSourceV2ScanExecBase { @transient lazy val batch: Batch = if (scan == null) null else scan.toBatch @@ -51,7 +49,7 @@ case class BatchScanExec( case other: BatchScanExec => this.batch != null && this.batch == other.batch && this.runtimeFilters == other.runtimeFilters && - this.spjParams == other.spjParams + this.keyGroupedPartitioning == other.keyGroupedPartitioning case _ => false } @@ -61,15 +59,15 @@ case class BatchScanExec( @transient override lazy val inputPartitions: Seq[InputPartition] = batch.planInputPartitions().toImmutableArraySeq - @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { + // Visible for testing + @transient private[sql] lazy val filteredPartitions: Seq[Option[InputPartition]] = { val dataSourceFilters = runtimeFilters.flatMap { case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) case _ => None } + val originalPartitioning = outputPartitioning if (dataSourceFilters.nonEmpty) { - val originalPartitioning = outputPartitioning - // the cast is safe as runtime filters are only assigned if the scan can be filtered val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] filterableScan.filter(dataSourceFilters.toArray) @@ -78,49 +76,53 @@ case class BatchScanExec( val newPartitions = scan.toBatch.planInputPartitions() originalPartitioning match { - case p: KeyGroupedPartitioning => + case k: KeyedPartitioning => if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) { throw new SparkException("Data source must have preserved the original partitioning " + "during runtime filtering: not all partitions implement HasPartitionKey after " + "filtering") } - val newPartitionValues = newPartitions.map(partition => - InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], p.expressions)) - .toSet - val oldPartitionValues = p.partitionValues - .map(partition => InternalRowComparableWrapper(partition, p.expressions)).toSet - // We require the new number of partition values to be equal or less than the old number - // of partition values here. In the case of less than, empty partitions will be added for - // those missing values that are not present in the new input partitions. - if (oldPartitionValues.size < newPartitionValues.size) { - throw new SparkException("During runtime filtering, data source must either report " + - "the same number of partition values, or a subset of partition values from the " + - s"original. Before: ${oldPartitionValues.size} partition values. " + - s"After: ${newPartitionValues.size} partition values") - } - if (!newPartitionValues.forall(oldPartitionValues.contains)) { + val inputMap = k.partitionKeys.groupBy(identity).view.mapValues(_.size) + val comparableKeyWrapperFactory = InternalRowComparableWrapper + .getInternalRowComparableWrapperFactory(k.expressionDataTypes) + val filteredMap = newPartitions.groupBy( + p => comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) + ) + + if (!filteredMap.keySet.subsetOf(inputMap.keySet)) { throw new SparkException("During runtime filtering, data source must not report new " + - "partition values that are not present in the original partitioning.") + "partition keys that are not present in the original partitioning.") } - groupPartitions(newPartitions.toImmutableArraySeq) - .map(_.groupedParts.map(_.parts)).getOrElse(Seq.empty) + inputMap.toSeq + .sortBy(_._1)(k.keyOrdering) + .flatMap { case (key, size) => + // We require the new number of partitions to be equal or less than the old number of + // partitions for a given key. In the case of less than, empty partitions are added. + val fps = filteredMap.getOrElse(key, Array.empty) + + if (fps.size > size) { + throw new SparkException("During runtime filtering, data source must not report " + + s"new partitions for a given key. Before: $size partitions. " + + s"After: ${fps.size} partitions") + } + + fps.map(Some).padTo(size, None) + } case _ => // no validation is needed as the data source did not report any specific partitioning - newPartitions.map(Seq(_)).toImmutableArraySeq + newPartitions.toSeq.map(Some) } } else { - partitions - } - } + (originalPartitioning match { + case k: KeyedPartitioning => + inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering) - override def outputPartitioning: Partitioning = { - super.outputPartitioning match { - case k: KeyGroupedPartitioning => getOutputKeyGroupedPartitioning(k, spjParams) - case p => p + case _ => inputPartitions + }).map(Some) } } @@ -131,28 +133,20 @@ case class BatchScanExec( // return an empty RDD with 1 partition if dynamic filtering removed the only split sparkContext.parallelize(Array.empty[InternalRow].toImmutableArraySeq, 1) } else { - val finalPartitions = outputPartitioning match { - case p: KeyGroupedPartitioning => getInputPartitionGrouping( - p, spjParams, filteredPartitions, p => p.asInstanceOf[HasPartitionKey].partitionKey()) - case _ => filteredPartitions - } - new DataSourceRDD( - sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics) + sparkContext, filteredPartitions, readerFactory, supportsColumnar, customMetrics) } postDriverMetrics() rdd } - override def keyGroupedPartitioning: Option[Seq[Expression]] = - spjParams.keyGroupedPartitioning - override def doCanonicalize(): BatchScanExec = { this.copy( output = output.map(QueryPlan.normalizeExpressions(_, output)), runtimeFilters = QueryPlan.normalizePredicates( runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), - output)) + output), + keyGroupedPartitioning = keyGroupedPartitioning.map(QueryPlan.normalizePredicates(_, output))) } override def simpleString(maxFields: Int): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala index 288233e691453..e9e5f0f3175cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala @@ -52,7 +52,7 @@ case class ContinuousScanExec( } override lazy val inputRDD: RDD[InternalRow] = { - assert(partitions.forall(_.length == 1), "should only contain a single partition") + assert(partitions.forall(_.isDefined), "should contain a partition") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) @@ -61,7 +61,7 @@ case class ContinuousScanExec( sparkContext, conf.continuousStreamingExecutorQueueSize, conf.continuousStreamingExecutorPollIntervalMs, - partitions.map(_.head), + partitions.map(_.get), schema, readerFactory, customMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 19a057c72506b..2fedb97e8461e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ArrayImplicits._ -class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition]) +class DataSourceRDDPartition(val index: Int, val inputPartition: Option[InputPartition]) extends Partition with Serializable /** @@ -50,9 +50,22 @@ private case class ReaderState(reader: PartitionReader[_], iterator: MetricsIter // TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for // columnar scan. +/** + * An RDD that reads data from a V2 data source. + * + * This RDD handles both row-based and columnar reads, tracks custom metrics from the data source, + * and ensures that task completion listeners are added only once per thread to avoid duplicate + * metric updates and resource cleanup. + * + * @param sc The Spark context + * @param inputPartitions The input partitions to read from + * @param partitionReaderFactory Factory for creating partition readers + * @param columnarReads Whether to use columnar reads + * @param customMetrics Custom metrics defined by the data source + */ class DataSourceRDD( sc: SparkContext, - @transient private val inputPartitions: Seq[Seq[InputPartition]], + @transient private val inputPartitions: Seq[Option[InputPartition]], partitionReaderFactory: PartitionReaderFactory, columnarReads: Boolean, customMetrics: Map[String, SQLMetric]) @@ -65,7 +78,7 @@ class DataSourceRDD( override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { - case (inputPartitions, index) => new DataSourceRDDPartition(index, inputPartitions) + case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition) }.toArray } @@ -98,62 +111,39 @@ class DataSourceRDD( } } - val iterator = new Iterator[Object] { - private val inputPartitions = castPartition(split).inputPartitions - private var currentIter: Option[Iterator[Object]] = None - private var currentIndex: Int = 0 - - override def hasNext: Boolean = currentIter.exists(_.hasNext) || advanceToNextIter() - - override def next(): Object = { - if (!hasNext) throw new NoSuchElementException("No more elements") - currentIter.get.next() + castPartition(split).inputPartition.iterator.flatMap { inputPartition => + val (iter, reader) = if (columnarReads) { + val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) + val iter = new MetricsBatchIterator( + new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) + (iter, batchReader) + } else { + val rowReader = partitionReaderFactory.createReader(inputPartition) + val iter = new MetricsRowIterator( + new PartitionIterator[InternalRow](rowReader, customMetrics)) + (iter, rowReader) } - private def advanceToNextIter(): Boolean = { - if (currentIndex >= inputPartitions.length) { - false - } else { - val inputPartition = inputPartitions(currentIndex) - currentIndex += 1 - - // TODO: SPARK-25083 remove the type erasure hack in data source scan - val (iter, reader) = if (columnarReads) { - val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) - val iter = new MetricsBatchIterator( - new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) - (iter, batchReader) - } else { - val rowReader = partitionReaderFactory.createReader(inputPartition) - val iter = new MetricsRowIterator( - new PartitionIterator[InternalRow](rowReader, customMetrics)) - (iter, rowReader) - } - - // Flush metrics and close the previous reader before advancing to the next one. - // Pass the accumulated metrics to the new reader so they carry forward correctly. - val prevState = taskReaderStates.get(taskAttemptId) - if (prevState != null) { - val metrics = prevState.reader.currentMetricsValues - CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics) - reader.initMetricsValues(metrics) - prevState.reader.close() - } + // Flush metrics and close the previous reader before advancing to the next one. + // Pass the accumulated metrics to the new reader so they carry forward correctly. + val prevState = taskReaderStates.get(taskAttemptId) + if (prevState != null) { + val metrics = prevState.reader.currentMetricsValues + CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics) + reader.initMetricsValues(metrics) + prevState.reader.close() + } - // Update the map so the completion listener always references the latest reader. - taskReaderStates.put(taskAttemptId, ReaderState(reader, iter)) + // Update the map so the completion listener always references the latest reader. + taskReaderStates.put(taskAttemptId, ReaderState(reader, iter)) - currentIter = Some(iter) - hasNext - } - } + // TODO: SPARK-25083 remove the type erasure hack in data source scan + new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) } - - new InterruptibleIterator(context, iterator).asInstanceOf[Iterator[InternalRow]] } override def getPreferredLocations(split: Partition): Seq[String] = { - castPartition(split).inputPartitions.flatMap(_.preferredLocations()) + castPartition(split).inputPartition.toSeq.flatMap(_.preferredLocations()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index 95d85dab5cedc..877e65341c1c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -21,12 +21,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering, SortOrder} import org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning -import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} +import org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan} import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsMetadata import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ArrayImplicits._ @@ -63,9 +62,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { redact(result) } - def partitions: Seq[Seq[InputPartition]] = { - groupedPartitions.map(_.groupedParts.map(_.parts)).getOrElse(inputPartitions.map(Seq(_))) - } + def partitions: Seq[Option[InputPartition]] = inputPartitions.map(Some) /** * Shorthand for calling redact() without specifying redacting rules @@ -94,76 +91,24 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { override def outputPartitioning: physical.Partitioning = { keyGroupedPartitioning match { - case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) => - groupedPartitions - .map { keyGroupedPartsInfo => - val keyGroupedParts = keyGroupedPartsInfo.groupedParts - KeyGroupedPartitioning(exprs, keyGroupedParts.size, keyGroupedParts.map(_.value), - keyGroupedPartsInfo.originalParts.map(_.partitionKey())) - } - .getOrElse(super.outputPartitioning) + case Some(exprs) if conf.v2BucketingEnabled && KeyedPartitioning.supportsExpressions(exprs) && + inputPartitions.nonEmpty && inputPartitions.forall(_.isInstanceOf[HasPartitionKey]) => + val dataTypes = exprs.map(_.dataType) + val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + val partitionKeys = + inputPartitions.map(_.asInstanceOf[HasPartitionKey].partitionKey()).sorted(rowOrdering) + KeyedPartitioning(exprs, partitionKeys) case _ => super.outputPartitioning } } - @transient lazy val groupedPartitions: Option[KeyGroupedPartitionInfo] = { - // Early check if we actually need to materialize the input partitions. - keyGroupedPartitioning match { - case Some(_) => groupPartitions(inputPartitions) - case _ => None - } - } - /** - * Group partition values for all the input partitions. This returns `Some` iff: - * - [[SQLConf.V2_BUCKETING_ENABLED]] is turned on - * - all input partitions implement [[HasPartitionKey]] - * - `keyGroupedPartitioning` is set - * - * The result, if defined, is a [[KeyGroupedPartitionInfo]] which contains a list of - * [[KeyGroupedPartition]], as well as a list of partition values from the original input splits, - * sorted according to the partition keys in ascending order. - * - * A non-empty result means each partition is clustered on a single key and therefore eligible - * for further optimizations to eliminate shuffling in some operations such as join and aggregate. + * Returns the output ordering from the data source if available, otherwise falls back + * to the default (no ordering). This allows data sources to report their natural ordering + * through `SupportsReportOrdering`. */ - def groupPartitions(inputPartitions: Seq[InputPartition]): Option[KeyGroupedPartitionInfo] = { - if (!SQLConf.get.v2BucketingEnabled) return None - - keyGroupedPartitioning.flatMap { expressions => - val results = inputPartitions.takeWhile { - case _: HasPartitionKey => true - case _ => false - }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p.asInstanceOf[HasPartitionKey])) - - if (results.length != inputPartitions.length || inputPartitions.isEmpty) { - // Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here. - None - } else { - // also sort the input partitions according to their partition key order. This ensures - // a canonical order from both sides of a bucketed join, for example. - val partitionDataTypes = expressions.map(_.dataType) - val rowOrdering = RowOrdering.createNaturalAscendingOrdering(partitionDataTypes) - val sortedKeyToPartitions = results.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) - val sortedGroupedPartitions = sortedKeyToPartitions - .map(t => (InternalRowComparableWrapper(t._1, expressions), t._2)) - .groupBy(_._1) - .toSeq - .map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) } - .sorted(rowOrdering.on((k: KeyGroupedPartition) => k.value)) - - Some(KeyGroupedPartitionInfo(sortedGroupedPartitions, sortedKeyToPartitions.map(_._2))) - } - } - } - - override def outputOrdering: Seq[SortOrder] = { - // when multiple partitions are grouped together, ordering inside partitions is not preserved - val partitioningPreservesOrdering = groupedPartitions - .forall(_.groupedParts.forall(_.parts.length <= 1)) - ordering.filter(_ => partitioningPreservesOrdering).getOrElse(super.outputOrdering) - } + override def outputOrdering: Seq[SortOrder] = ordering.getOrElse(super.outputOrdering) override def supportsColumnar: Boolean = { scan.columnarSupportMode() match { @@ -210,19 +155,3 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } } } - -/** - * A key-grouped Spark partition, which could consist of multiple input splits - * - * @param value the partition value shared by all the input splits - * @param parts the input splits that are grouped into a single Spark partition - */ -private[v2] case class KeyGroupedPartition(value: InternalRow, parts: Seq[InputPartition]) - -/** - * Information about key-grouped partitions, which contains a list of grouped partitions as well - * as the original input partitions before the grouping. - */ -private[v2] case class KeyGroupedPartitionInfo( - groupedParts: Seq[KeyGroupedPartition], - originalParts: Seq[HasPartitionKey]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 014b43c915437..556e37eba1c79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -45,7 +45,6 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan, SparkStrategy => Strategy} import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelationWithTable, PushableColumnAndNestedColumn} -import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH @@ -161,8 +160,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case _ => false } val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters, - relation.ordering, relation.relation.table, - StoragePartitionJoinParams(relation.keyGroupedPartitioning)) + relation.ordering, relation.relation.table, relation.keyGroupedPartitioning) DataSourceV2Strategy.withProjectAndFilter( project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala new file mode 100644 index 0000000000000..9910c4eb788cc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Partition, SparkException} +import org.apache.spark.rdd.{CoalescedRDD, PartitionCoalescer, PartitionGroup, RDD} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.connector.catalog.functions.Reducer +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Physical operator that groups input partitions by their partition keys. + * + * This operator is used to coalesce partitions from bucketed/partitioned data sources + * where multiple input partitions share the same partition key. It's commonly used in + * storage-partitioned joins to align partitions from different sides of the join. + * + * @param child The child plan providing bucketed/partitioned input + * @param joinKeyPositions Optional projection to select a subset of the partitioning key + * for join compatibility (e.g., when join keys are a subset of + * partition keys) + * @param expectedPartitionKeys Optional sequence of expected partition key values and their + * split counts + * @param reducers Optional reducers to apply to partition keys for grouping compatibility + * @param applyPartialClustering Whether to apply partial clustering for skewed data + * @param replicatePartitions Whether to replicate partitions across multiple keys + */ +case class GroupPartitionsExec( + child: SparkPlan, + @transient joinKeyPositions: Option[Seq[Int]] = None, + @transient expectedPartitionKeys: Option[Seq[(InternalRowComparableWrapper, Int)]] = None, + @transient reducers: Option[Seq[Option[Reducer[_, _]]]] = None, + @transient applyPartialClustering: Boolean = false, + @transient replicatePartitions: Boolean = false + ) extends UnaryExecNode { + + override def outputPartitioning: Partitioning = { + child.outputPartitioning match { + case p: Partitioning with Expression => + // There can be multiple `KeyedPartitioning` in an output partitioning of a join, but they + // can only differ in `expressions`. `partitionKeys` must match so we can calculate it only + // once via `groupedPartitions`. + + val keyedPartitionings = p.collect { case k: KeyedPartitioning => k } + if (keyedPartitionings.size > 1) { + val first = keyedPartitionings.head + keyedPartitionings.tail.foreach { k => + assert(k.partitionKeys == first.partitionKeys, + "All KeyedPartitioning nodes must have identical partition keys") + } + } + + p.transform { + case k: KeyedPartitioning => + val projectedExpressions = joinKeyPositions.fold(k.expressions)(_.map(k.expressions)) + KeyedPartitioning(projectedExpressions, groupedPartitions.map(_._1), + isGrouped = isGrouped) + }.asInstanceOf[Partitioning] + case o => o + } + } + + /** + * Aligns partitions based on `expectedPartitionKeys` and clustering mode. + */ + private def alignToExpectedKeys(keyMap: Map[InternalRowComparableWrapper, Seq[Int]]) = { + var isGrouped = true + val alignedPartitions = expectedPartitionKeys.get.flatMap { case (key, numSplits) => + if (numSplits > 1) isGrouped = false + val splits = keyMap.getOrElse(key, Seq.empty) + if (applyPartialClustering && !replicatePartitions) { + // Distribute splits across expected partitions, padding with empty sequences + val paddedSplits = splits.map(Seq(_)).padTo(numSplits, Seq.empty) + paddedSplits.map((key, _)) + } else { + // Replicate all splits to each expected partition + Seq.fill(numSplits)((key, splits)) + } + } + (alignedPartitions, isGrouped) + } + + /** + * Groups and sorts partitions by their keys in ascending order. + */ + private def groupAndSortByKeys( + keyMap: Map[InternalRowComparableWrapper, Seq[Int]], + dataTypes: Seq[DataType]) = { + val keyOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + keyMap.toSeq.sorted(keyOrdering.on((t: (InternalRowComparableWrapper, _)) => t._1.row)) + } + + /** + * Computes the grouped partitions by: + * 1. Projecting partition keys if joinKeyPositions is specified + * 2. Reducing keys if reducers are specified + * 3. Grouping input partition indices by their (possibly projected/reduced) keys + * 4. Sorting or distributing based on whether partial clustering is enabled + * + * Returns a tuple of (partitions, isGrouped) where: + * - partitions: sequence of (partitionKey, inputPartitionIndices) pairs representing + * how input partitions should be grouped together + * - isGrouped: whether the output partitioning is grouped (no duplicates in partition keys) + */ + @transient private lazy val groupedPartitionsTuple = { + // There must be a `KeyedPartitioning` in child's output partitioning as a + // `GroupPartitionsExec` node is added to a plan only in that case. + val keyedPartitioning = child.outputPartitioning + .asInstanceOf[Partitioning with Expression] + .collectFirst { case k: KeyedPartitioning => k } + .getOrElse( + throw new SparkException("GroupPartitionsExec requires a child with KeyedPartitioning")) + + // Project partition keys if join key positions are specified + val (projectedDataTypes, projectedKeys) = + joinKeyPositions.fold( + (keyedPartitioning.expressionDataTypes, keyedPartitioning.partitionKeys) + )(keyedPartitioning.projectKeys) + + // Reduce keys if reducers are specified + val reducedKeys = reducers.fold(projectedKeys)( + KeyedPartitioning.reduceKeys(projectedKeys, projectedDataTypes, _)) + + val keyToPartitionIndices = reducedKeys.zipWithIndex.groupMap(_._1)(_._2) + + if (expectedPartitionKeys.isDefined) { + alignToExpectedKeys(keyToPartitionIndices) + } else { + (groupAndSortByKeys(keyToPartitionIndices, projectedDataTypes), true) + } + } + + @transient lazy val groupedPartitions: Seq[(InternalRowComparableWrapper, Seq[Int])] = + groupedPartitionsTuple._1 + + @transient lazy val isGrouped: Boolean = groupedPartitionsTuple._2 + + override protected def doExecute(): RDD[InternalRow] = { + if (groupedPartitions.isEmpty) { + sparkContext.emptyRDD + } else { + val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) + new CoalescedRDD(child.execute(), groupedPartitions.size, Some(partitionCoalescer)) + } + } + + override def supportsColumnar: Boolean = child.supportsColumnar + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { + if (groupedPartitions.isEmpty) { + sparkContext.emptyRDD + } else { + val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) + new CoalescedRDD(child.executeColumnar(), groupedPartitions.size, Some(partitionCoalescer)) + } + } + + override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + override def outputOrdering: Seq[SortOrder] = { + // when multiple partitions are grouped together, ordering inside partitions is not preserved + if (groupedPartitions.forall(_._2.size <= 1)) { + child.outputOrdering + } else { + super.outputOrdering + } + } +} + +/** + * A PartitionCoalescer that groups partitions according to a pre-computed grouping plan. + * + * Unlike Spark's default coalescer which tries to minimize data movement, this coalescer + * groups partitions based on their partition keys to maintain the grouping semantics + * required for storage-partitioned operations. + * + * @param groupedPartitions Sequence where each element is a sequence of input partition + * indices that should be grouped together + */ +class GroupedPartitionCoalescer( + val groupedPartitions: Seq[Seq[Int]] + ) extends PartitionCoalescer with Serializable { + + override def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] = { + groupedPartitions.map { partitionIndices => + val partitions = new ArrayBuffer[Partition](partitionIndices.size) + val preferredLocations = new ArrayBuffer[String](partitionIndices.size) + partitionIndices.foreach { partitionIndex => + val partition = parent.partitions(partitionIndex) + partitions += partition + preferredLocations ++= parent.preferredLocations(partition) + } + // Select the most common location as the preferred location + val preferredLocation = preferredLocations + .groupBy(identity) + .view.mapValues(_.size) + .maxByOption(_._2) + .map(_._1) + val partitionGroup = new PartitionGroup(preferredLocation) + partitionGroup.partitions ++= partitions + partitionGroup + }.toArray + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e239174e40ad4..39da546256132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -21,7 +21,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.{LogKeys} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -29,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.GroupPartitionsExec import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf @@ -63,22 +62,62 @@ case class EnsureRequirements( assert(requiredChildOrderings.length == originalChildren.length) // Ensure that the operator's children satisfy their output distribution requirements. var children = originalChildren.zip(requiredChildDistributions).map { - case (child, distribution) if child.outputPartitioning.satisfies(distribution) => - ensureOrdering(child, distribution) - case (child, BroadcastDistribution(mode)) => - BroadcastExchangeExec(mode, child) case (child, distribution) => - val numPartitions = distribution.requiredNumPartitions - .getOrElse(conf.numShufflePartitions) - distribution match { - case _: StatefulOpClusteredDistribution => - ShuffleExchangeExec( - distribution.createPartitioning(numPartitions), child, - REQUIRED_BY_STATEFUL_OPERATOR) - - case _ => - ShuffleExchangeExec( - distribution.createPartitioning(numPartitions), child, shuffleOrigin) + // Split child's partitioning into categories + val (other, grouped, nonGrouped) = splitKeyedPartitionings(child.outputPartitioning) + + // If non-KeyedPartitioning already satisfies, no changes needed + if (other.exists(_.satisfies(distribution))) { + child + } else { + // Check KeyedPartitioning satisfaction conditions + val groupedSatisfies = grouped.exists(_.satisfies(distribution)) + val nonGroupedSatisfiesAsIs = nonGrouped.exists(_.nonGroupedSatisfies(distribution)) + val nonGroupedSatisfiesWhenGrouped = nonGrouped.exists(_.groupedSatisfies(distribution)) + + // Check if any KeyedPartitioning satisfies the distribution + if (groupedSatisfies || nonGroupedSatisfiesAsIs || nonGroupedSatisfiesWhenGrouped) { + distribution match { + case o: OrderedDistribution => + // OrderedDistribution requires grouped KeyedPartitioning with sorted keys. + // Find any KeyedPartitioning that satisfies via groupedSatisfies. + val satisfyingKeyedPartitioning = + (grouped ++ nonGrouped).find(_.groupedSatisfies(distribution)).get + val attrs = satisfyingKeyedPartitioning.expressions.flatMap(_.collectLeaves()) + .map(_.asInstanceOf[Attribute]) + val keyRowOrdering = RowOrdering.create(o.ordering, attrs) + val keyOrdering = keyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) + val sorted = satisfyingKeyedPartitioning.partitionKeys.sorted(keyOrdering) + GroupPartitionsExec(child, expectedPartitionKeys = Some(sorted.map((_, 1)))) + + case _ if groupedSatisfies => + // Grouped KeyedPartitioning already satisfies + child + + case _ if nonGroupedSatisfiesAsIs => + // Non-grouped KeyedPartitioning satisfies without grouping + child + + case _ => + // Non-grouped KeyedPartitioning satisfies only after grouping + GroupPartitionsExec(child) + } + } else { + // No partitioning satisfies - need broadcast or shuffle + val numPartitions = distribution.requiredNumPartitions + .getOrElse(conf.numShufflePartitions) + distribution match { + case BroadcastDistribution(mode) => + BroadcastExchangeExec(mode, child) + case _: StatefulOpClusteredDistribution => + ShuffleExchangeExec( + distribution.createPartitioning(numPartitions), child, + REQUIRED_BY_STATEFUL_OPERATOR) + case _ => + ShuffleExchangeExec( + distribution.createPartitioning(numPartitions), child, shuffleOrigin) + } + } } } @@ -138,17 +177,11 @@ case class EnsureRequirements( !p._2.canCreatePartitioning || children(p._1).isInstanceOf[ShuffleExchangeLike] ) // Choose all the specs that can be used to shuffle other children - val candidateSpecs = specs - .filter(_._2.canCreatePartitioning) - .filter { - // To choose a KeyGroupedShuffleSpec, we must be able to push down SPJ parameters into - // the scan (for join key positions). If these parameters can't be pushed down, this - // spec can't be used to shuffle other children. - case (idx, _: KeyGroupedShuffleSpec) => canPushDownSPJParamsToScan(children(idx)) - case _ => true - } - .filter(p => !shouldConsiderMinParallelism || - children(p._1).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions) + val candidateSpecs = specs.filter { case (index, spec) => + spec.canCreatePartitioning && + (!shouldConsiderMinParallelism || + children(index).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions) + } val bestSpecOpt = if (candidateSpecs.isEmpty) { None } else { @@ -200,13 +233,13 @@ case class EnsureRequirements( case ((child, dist), idx) => if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) { bestSpecOpt match { - // If keyGroupCompatible = false, we can still perform SPJ + // If `areChildrenCompatible` is false, we can still perform SPJ // by shuffling the other side based on join keys (see the else case below). // Hence we need to ensure that after this call, the outputPartitioning of the // partitioned side's BatchScanExec is grouped by join keys to match, // and we do that by pushing down the join keys case Some(KeyGroupedShuffleSpec(_, _, Some(joinKeyPositions))) => - populateJoinKeyPositions(child, Some(joinKeyPositions)) + withJoinKeyPositions(child, joinKeyPositions) case _ => child } } else { @@ -225,6 +258,7 @@ case class EnsureRequirements( child match { case ShuffleExchangeExec(_, c, so, ps) => ShuffleExchangeExec(newPartitioning, c, so, ps) + case GroupPartitionsExec(c, _, _, _, _, _) => ShuffleExchangeExec(newPartitioning, c) case _ => ShuffleExchangeExec(newPartitioning, child) } } @@ -305,23 +339,6 @@ case class EnsureRequirements( } } - private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = { - (plan.outputPartitioning, distribution) match { - case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _), - d @ OrderedDistribution(ordering)) if p.satisfies(d) => - val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute]) - val partitionOrdering: Ordering[InternalRow] = { - RowOrdering.create(ordering, attrs) - } - // Sort 'commonPartitionValues' and use this mechanism to ensure BatchScan's - // output partitions are ordered - val sorted = partitionValues.sorted(partitionOrdering) - populateCommonPartitionInfo(plan, sorted.map((_, 1)), - None, None, applyPartialClustering = false, replicatePartitions = false) - case _ => plan - } - } - /** * Recursively reorders the join keys based on partitioning. It starts reordering the * join keys to match HashPartitioning on either side, followed by PartitioningCollection. @@ -340,12 +357,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => + case (Some(KeyedPartitioning(clustering, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => + case (_, Some(KeyedPartitioning(clustering, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( @@ -389,47 +406,9 @@ case class EnsureRequirements( } } - /** - * Whether partial clustering can be applied to a given child query plan. This is true if the plan - * consists only of a sequence of unary nodes where each node does not use the scan's key-grouped - * partitioning to satisfy its required distribution. Otherwise, partially clustering could be - * applied to a key-grouped partitioning unrelated to this join. - */ - private def canApplyPartialClusteredDistribution(plan: SparkPlan): Boolean = { - !plan.exists { - // Unary nodes are safe as long as they don't have a required distribution (for example, a - // project or filter). If they have a required distribution, then we should assume that this - // plan can't be partially clustered (since the key-grouped partitioning may be needed to - // satisfy this distribution unrelated to this JOIN). - case u if u.children.length == 1 => - u.requiredChildDistribution.head != UnspecifiedDistribution - // Only allow a non-unary node if it's a leaf node - key-grouped partitionings other binary - // nodes (like another JOIN) aren't safe to partially cluster. - case other => other.children.nonEmpty - } - } - - /** - * Whether SPJ params can be pushed down to the leaf nodes of a physical plan. For a plan to be - * eligible for SPJ parameter pushdown, all leaf nodes must be a KeyGroupedPartitioning-aware - * scan. - * - * Notably, if the leaf of `plan` is an [[RDDScanExec]] created by checkpointing a DSv2 scan, the - * reported partitioning will be a [[KeyGroupedPartitioning]], but this plan will _not_ be - * eligible for SPJ parameter pushdown (as the partitioning is static and can't be easily - * re-grouped or padded with empty partitions according to the partition values on the other side - * of the join). - */ - private def canPushDownSPJParamsToScan(plan: SparkPlan): Boolean = { - plan.collectLeaves().forall { - case _: KeyGroupedPartitionedScan[_] => true - case _ => false - } - } - /** * Checks whether two children, `left` and `right`, of a join operator have compatible - * `KeyGroupedPartitioning`, and can benefit from storage-partitioned join. + * `KeyedPartitioning`, and can benefit from storage-partitioned join. * * Returns the updated new children if the check is successful, otherwise `None`. */ @@ -438,12 +417,6 @@ case class EnsureRequirements( left: SparkPlan, right: SparkPlan, requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = { - // If SPJ params can't be pushed down to either the left or right side, it's unsafe to do an - // SPJ. - if (!canPushDownSPJParamsToScan(left) || !canPushDownSPJParamsToScan(right)) { - return None - } - parent match { case smj: SortMergeJoinExec => checkKeyGroupCompatible(left, right, smj.joinType, requiredChildDistribution) @@ -474,12 +447,21 @@ case class EnsureRequirements( val leftSpec = specs.head val rightSpec = specs(1) - - var isCompatible = false - if (!conf.v2BucketingPushPartValuesEnabled && - !conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - isCompatible = leftSpec.isCompatibleWith(rightSpec) - } else { + val leftPartitioning = leftSpec.partitioning + val rightPartitioning = rightSpec.partitioning + + // We don't need to alter the existing or add new `GroupPartitionsExec` when the child + // partitionings are not modified (projected) in specs and left and right side partitionings are + // compatible with each other. + // Left and right `outputPartitioning` is a `PartitioningCollection` or a `KeyedPartitioning` + // otherwise `createKeyGroupedShuffleSpec()` would have returned `None`. + var isCompatible = + left.outputPartitioning.asInstanceOf[Expression].exists(_ == leftPartitioning) && + right.outputPartitioning.asInstanceOf[Expression].exists(_ == rightPartitioning) && + leftSpec.isCompatibleWith(rightSpec) + if ((!isCompatible || conf.v2BucketingPartiallyClusteredDistributionEnabled) && + (conf.v2BucketingPushPartValuesEnabled || + conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys)) { logInfo("Pushing common partition values for storage-partitioned join") isCompatible = leftSpec.areKeysCompatible(rightSpec) @@ -499,38 +481,33 @@ case class EnsureRequirements( // just push the common set of partition values: `[0, 1, 2, 3]` down to the two data // sources. if (isCompatible) { - val leftPartValues = leftSpec.partitioning.partitionValues - val rightPartValues = rightSpec.partitioning.partitionValues + val leftPartKeys = leftPartitioning.partitionKeys + val rightPartKeys = rightPartitioning.partitionKeys - val numLeftPartValues = MDC(LogKeys.NUM_LEFT_PARTITION_VALUES, leftPartValues.size) - val numRightPartValues = MDC(LogKeys.NUM_RIGHT_PARTITION_VALUES, rightPartValues.size) + val numLeftPartKeys = MDC(LogKeys.NUM_LEFT_PARTITION_VALUES, leftPartKeys.size) + val numRightPartKeys = MDC(LogKeys.NUM_RIGHT_PARTITION_VALUES, rightPartKeys.size) logInfo( log""" - |Left side # of partitions: $numLeftPartValues - |Right side # of partitions: $numRightPartValues + |Left side # of partitions: $numLeftPartKeys + |Right side # of partitions: $numRightPartKeys |""".stripMargin) - // As partition keys are compatible, we can pick either left or right as partition - // expressions - val partitionExprs = leftSpec.partitioning.expressions - // in case of compatible but not identical partition expressions, we apply 'reduce' // transforms to group one side's partitions as well as the common partition values val leftReducers = leftSpec.reducers(rightSpec) - val leftParts = reducePartValues(leftSpec.partitioning.partitionValues, - partitionExprs, - leftReducers) + val leftReducedKeys = + leftReducers.fold(leftPartitioning.partitionKeys)(leftPartitioning.reduceKeys) val rightReducers = rightSpec.reducers(leftSpec) - val rightParts = reducePartValues(rightSpec.partitioning.partitionValues, - partitionExprs, - rightReducers) + val rightReducedKeys = + rightReducers.fold(rightPartitioning.partitionKeys)(rightPartitioning.reduceKeys) // merge values on both sides - var mergedPartValues = mergePartitions(leftParts, rightParts, partitionExprs, joinType) - .map(v => (v, 1)) + var mergedPartitionKeys = + mergePartitions(leftReducedKeys, rightReducedKeys, joinType, leftPartitioning.keyOrdering) + .map((_, 1)) logInfo(log"After merging, there are " + - log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartValues.size)} partitions") + log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartitionKeys.size)} partitions") var replicateLeftSide = false var replicateRightSide = false @@ -549,23 +526,19 @@ case class EnsureRequirements( // whether partially clustered distribution can be applied. For instance, the // optimization cannot be applied to a left outer join, where the left hand // side is chosen as the side to replicate partitions according to stats. - // Similarly, the partially clustered distribution cannot be applied if the - // partially clustered side must use the scan's key-grouped partitioning to - // satisfy some unrelated required distribution in its plan (for example, for an aggregate - // or window function), as this will give incorrect results (for example, duplicate - // row_number() values). // Otherwise, query result could be incorrect. - val canReplicateLeft = canReplicateLeftSide(joinType) && - canApplyPartialClusteredDistribution(right) - val canReplicateRight = canReplicateRightSide(joinType) && - canApplyPartialClusteredDistribution(left) + val canReplicateLeft = canReplicateLeftSide(joinType) + val canReplicateRight = canReplicateRightSide(joinType) if (!canReplicateLeft && !canReplicateRight) { logInfo(log"Skipping partially clustered distribution as it cannot be applied for " + log"join type '${MDC(LogKeys.JOIN_TYPE, joinType)}'") } else { - val leftLink = left.logicalLink - val rightLink = right.logicalLink + val unwrappedLeft = unwrapGroupPartitions(left) + val unwrappedRight = unwrapGroupPartitions(right) + + val leftLink = unwrappedLeft.logicalLink + val rightLink = unwrappedRight.logicalLink replicateLeftSide = if ( leftLink.isDefined && rightLink.isDefined && @@ -588,7 +561,7 @@ case class EnsureRequirements( // to apply the grouping & replication of partitions logInfo("Using number of partitions to determine which side of join " + "to fully cluster partition values") - leftPartValues.size < rightPartValues.size + leftPartKeys.size < rightPartKeys.size } replicateRightSide = !replicateLeftSide @@ -608,33 +581,43 @@ case class EnsureRequirements( replicateRightSide = false } else { // In partially clustered distribution, we should use un-grouped partition values - val spec = if (replicateLeftSide) rightSpec else leftSpec - val partValues = spec.partitioning.originalPartitionValues - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - partitionExprs.map(_.dataType)) - - val numExpectedPartitions = partValues - .map(internalRowComparableWrapperFactory) - .groupBy(identity) - .transform((_, v) => v.size) - - mergedPartValues = mergedPartValues.map { case (partVal, numParts) => - (partVal, numExpectedPartitions.getOrElse( - internalRowComparableWrapperFactory(partVal), numParts)) + val (partiallyClusteredChild, partiallyClusteredSpec) = if (replicateLeftSide) { + (unwrappedRight, rightSpec) + } else { + (unwrappedLeft, leftSpec) + } + // Original `KeyedPartitioning` can be obtained from the child directly if the child + // satisfied the distribution requirement; or from the child's child if it didn't as + // the child must be a `GroupPartitionsExec` inserted by `EnsureRequirement` + // to satisfy the distribution requirement. + val originalPartitioning = + partiallyClusteredChild.outputPartitioning.asInstanceOf[Expression] + // `outputPartitioning` is either a `PartitioningCollection` or a `KeyedPartitioning` + // otherwise `createKeyGroupedShuffleSpec()` would have returned `None`. + val originalKeyedPartitioning = + originalPartitioning.collectFirst { case k: KeyedPartitioning => k }.get + val projectedOriginalPartitionKeys = partiallyClusteredSpec.joinKeyPositions + .fold(originalKeyedPartitioning.partitionKeys)( + originalKeyedPartitioning.projectKeys(_)._2) + + val numExpectedPartitions = + projectedOriginalPartitionKeys.groupBy(identity).view.mapValues(_.size) + + mergedPartitionKeys = mergedPartitionKeys.map { case (key, numParts) => + (key, numExpectedPartitions.getOrElse(key, numParts)) } logInfo(log"After applying partially clustered distribution, there are " + - log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartValues.map(_._2).sum)} partitions.") + log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartitionKeys.map(_._2).sum)} partitions.") applyPartialClustering = true } } } - // Now we need to push-down the common partition information to the scan in each child - newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions, + // Now we need to push-down the common partition information to the `GroupPartitionsExec`s. + newLeft = applyGroupPartitions(left, leftSpec.joinKeyPositions, mergedPartitionKeys, leftReducers, applyPartialClustering, replicateLeftSide) - newRight = populateCommonPartitionInfo(right, mergedPartValues, rightSpec.joinKeyPositions, + newRight = applyGroupPartitions(right, rightSpec.joinKeyPositions, mergedPartitionKeys, rightReducers, applyPartialClustering, replicateRightSide) } } @@ -670,75 +653,66 @@ case class EnsureRequirements( joinType == LeftAnti || joinType == LeftOuter } - // Populate the common partition information down to the scan nodes - private def populateCommonPartitionInfo( + /** + * Unwraps a GroupPartitionsExec to get the underlying child plan. + */ + private def unwrapGroupPartitions(plan: SparkPlan): SparkPlan = plan match { + case g: GroupPartitionsExec => g.child + case other => other + } + + /** + * Applies or updates `GroupPartitionsExec` with the given parameters. + * + * `GroupPartitionsExec` can be either the given plan node (child of the join inserted by + * `EnsureRequirement`) if the original child didn't satisfy the distribution requirement; or we + * can create a new one specifically for this join. + */ + private def applyGroupPartitions( plan: SparkPlan, - values: Seq[(InternalRow, Int)], joinKeyPositions: Option[Seq[Int]], + mergedPartitionKeys: Seq[(InternalRowComparableWrapper, Int)], reducers: Option[Seq[Option[Reducer[_, _]]]], applyPartialClustering: Boolean, - replicatePartitions: Boolean): SparkPlan = plan match { - case scan: BatchScanExec => - val newScan = scan.copy( - spjParams = scan.spjParams.copy( - commonPartitionValues = Some(values), + replicatePartitions: Boolean): SparkPlan = { + plan match { + case g: GroupPartitionsExec => + val newGroupPartitions = g.copy( joinKeyPositions = joinKeyPositions, + expectedPartitionKeys = Some(mergedPartitionKeys), reducers = reducers, applyPartialClustering = applyPartialClustering, - replicatePartitions = replicatePartitions - ) - ) - newScan.copyTagsFrom(scan) - newScan - case node => - node.mapChildren(child => populateCommonPartitionInfo( - child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) - } - - - private def populateJoinKeyPositions( - plan: SparkPlan, - joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match { - case scan: BatchScanExec => - val newScan = scan.copy( - spjParams = scan.spjParams.copy( - joinKeyPositions = joinKeyPositions - ) - ) - newScan.copyTagsFrom(scan) - newScan - case node => - node.mapChildren(child => populateJoinKeyPositions( - child, joinKeyPositions)) + replicatePartitions = replicatePartitions) + newGroupPartitions.copyTagsFrom(g) + newGroupPartitions + case _ => + GroupPartitionsExec(plan, joinKeyPositions, Some(mergedPartitionKeys), reducers, + applyPartialClustering, replicatePartitions) + } } - private def reducePartValues( - partValues: Seq[InternalRow], - expressions: Seq[Expression], - reducers: Option[Seq[Option[Reducer[_, _]]]]) = { - reducers match { - case Some(reducers) => - val partitionDataTypes = expressions.map(_.dataType) - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - partitionDataTypes) - partValues.map { row => - KeyGroupedShuffleSpec.reducePartitionValue( - row, reducers, partitionDataTypes, internalRowComparableWrapperFactory) - }.distinct.map(_.row) - case _ => partValues + /** + * Applies join key positions to a plan by wrapping or updating GroupPartitionsExec. + */ + private def withJoinKeyPositions(plan: SparkPlan, positions: Seq[Int]): SparkPlan = { + plan match { + case g: GroupPartitionsExec => + val newGroupPartitions = g.copy(joinKeyPositions = Some(positions)) + newGroupPartitions.copyTagsFrom(g) + newGroupPartitions + case _ => GroupPartitionsExec(plan, joinKeyPositions = Some(positions)) } } /** * Tries to create a [[KeyGroupedShuffleSpec]] from the input partitioning and distribution, if - * the partitioning is a [[KeyGroupedPartitioning]] (either directly or indirectly), and + * the partitioning is a [[KeyedPartitioning]] (either directly or indirectly), and * satisfies the given distribution. */ private def createKeyGroupedShuffleSpec( partitioning: Partitioning, distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = { - def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { + def tryCreate(partitioning: KeyedPartitioning): Option[KeyGroupedShuffleSpec] = { val attributes = partitioning.expressions.flatMap(_.collectLeaves()) val clustering = distribution.clustering @@ -758,53 +732,94 @@ case class EnsureRequirements( } partitioning match { - case p: KeyGroupedPartitioning => tryCreate(p) + case p: KeyedPartitioning => tryCreate(p) case PartitioningCollection(partitionings) => - val specs = partitionings.map(p => createKeyGroupedShuffleSpec(p, distribution)) - specs.filter(_.isDefined).map(_.get).headOption + partitionings.collectFirst(Function.unlift(createKeyGroupedShuffleSpec(_, distribution))) case _ => None } } /** - * Merge and sort partitions values for SPJ and optionally enable partition filtering. - * Both sides must have - * matching partition expressions. - * @param leftPartitioning left side partition values - * @param rightPartitioning right side partition values - * @param partitionExpression partition expressions + * Merge and sort partitions keys for SPJ and optionally enable partition filtering. + * Both sides must have matching partition expressions. + * @param leftPartitionKeys left side partition keys + * @param rightPartitionKeys right side partition keys * @param joinType join type for optional partition filtering + * @keyOrdering ordering to sort partition keys * @return merged and sorted partition values */ - private def mergePartitions( - leftPartitioning: Seq[InternalRow], - rightPartitioning: Seq[InternalRow], - partitionExpression: Seq[Expression], - joinType: JoinType): Seq[InternalRow] = { - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - partitionExpression.map(_.dataType)) - + def mergePartitions( + leftPartitionKeys: Seq[InternalRowComparableWrapper], + rightPartitionKeys: Seq[InternalRowComparableWrapper], + joinType: JoinType, + keyOrdering: Ordering[InternalRowComparableWrapper]): Seq[InternalRowComparableWrapper] = { val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) { joinType match { - case Inner => InternalRowComparableWrapper.mergePartitions( - leftPartitioning, rightPartitioning, partitionExpression, intersect = true) - case LeftOuter => leftPartitioning.map(internalRowComparableWrapperFactory) - case RightOuter => rightPartitioning.map(internalRowComparableWrapperFactory) - case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning, - rightPartitioning, partitionExpression) + case Inner => mergePartitionKeys(leftPartitionKeys, rightPartitionKeys, intersect = true) + case LeftOuter => leftPartitionKeys + case RightOuter => rightPartitionKeys + case _ => mergePartitionKeys(leftPartitionKeys, rightPartitionKeys) } } else { - InternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning, - partitionExpression) + mergePartitionKeys(leftPartitionKeys, rightPartitionKeys) } // SPARK-41471: We keep to order of partitions to make sure the order of // partitions is deterministic in different case. - val partitionOrdering: Ordering[InternalRow] = { - RowOrdering.createNaturalAscendingOrdering(partitionExpression.map(_.dataType)) + merged.sorted(keyOrdering) + } + + private def mergePartitionKeys( + leftPartitionKeys: Seq[InternalRowComparableWrapper], + rightPartitionKeys: Seq[InternalRowComparableWrapper], + intersect: Boolean = false) = { + val leftKeySet = mutable.HashSet.from(leftPartitionKeys) + val rightKeySet = mutable.HashSet.from(rightPartitionKeys) + val result = if (intersect) { + leftKeySet.intersect(rightKeySet) + } else { + leftKeySet.union(rightKeySet) } - merged.map(_.row).sorted(partitionOrdering) + result.toSeq + } + + /** + * Splits a partitioning into three categories: + * 1. Non-KeyedPartitioning (HashPartitioning, RangePartitioning, etc.) + * 2. Grouped KeyedPartitioning (isGrouped = true) + * 3. Non-grouped KeyedPartitioning (isGrouped = false) + * + * @param partitioning The partitioning to split + * @return A tuple of (other, grouped, nonGrouped) where: + * - other: Option containing non-KeyedPartitioning(s) + * - grouped: Seq of grouped KeyedPartitionings + * - nonGrouped: Seq of non-grouped KeyedPartitionings + */ + private def splitKeyedPartitionings(partitioning: Partitioning) = { + val otherPartitionings = ArrayBuffer.empty[Partitioning] + val groupedKeyedPartitionings = ArrayBuffer.empty[KeyedPartitioning] + val nonGroupedKeyedPartitionings = ArrayBuffer.empty[KeyedPartitioning] + + def split(p: Partitioning): Unit = p match { + case c: PartitioningCollection => c.partitionings.foreach(split) + case k: KeyedPartitioning => + if (k.isGrouped) { + groupedKeyedPartitionings += k + } else { + nonGroupedKeyedPartitionings += k + } + case o => otherPartitionings += o + } + + split(partitioning) + + val other = otherPartitionings.length match { + case 0 => None + case 1 => Some(otherPartitionings.head) + case _ => Some(PartitioningCollection(otherPartitionings.toSeq)) + } + + (other, groupedKeyedPartitionings.toSeq, nonGroupedKeyedPartitionings.toSeq) } def apply(plan: SparkPlan): SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 95120039a6f94..7dcbf3779b93d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -370,11 +370,12 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyGroupedPartitioning(expressions, n, _, _) => - val valueMap = k.uniquePartitionValues.zipWithIndex.map { - case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) + case k: KeyedPartitioning => + val keyGroupedPartitioning = k.toGrouped + val valueMap = keyGroupedPartitioning.partitionKeys.zipWithIndex.map { + case (key, index) => (key.row.toSeq(keyGroupedPartitioning.expressionDataTypes), index) }.toMap - new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), n) + new KeyGroupedPartitioner(mutable.Map.from(valueMap), keyGroupedPartitioning.numPartitions) case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -401,7 +402,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity - case KeyGroupedPartitioning(expressions, _, _, _) => + case KeyedPartitioning(expressions, _, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case s: ShufflePartitionIdPassThrough => // For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/StoragePartitionJoinParams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/StoragePartitionJoinParams.scala deleted file mode 100644 index a28eafc5cae5b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/StoragePartitionJoinParams.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import java.util.Objects - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.connector.catalog.functions.Reducer - -case class StoragePartitionJoinParams( - keyGroupedPartitioning: Option[Seq[Expression]] = None, - joinKeyPositions: Option[Seq[Int]] = None, - commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, - reducers: Option[Seq[Option[Reducer[_, _]]]] = None, - applyPartialClustering: Boolean = false, - replicatePartitions: Boolean = false) { - override def equals(other: Any): Boolean = other match { - case other: StoragePartitionJoinParams => - this.commonPartitionValues == other.commonPartitionValues && - this.replicatePartitions == other.replicatePartitions && - this.applyPartialClustering == other.applyPartialClustering && - this.joinKeyPositions == other.joinKeyPositions - case _ => - false - } - - override def hashCode(): Int = Objects.hash( - joinKeyPositions: Option[Seq[Int]], - commonPartitionValues: Option[Seq[(InternalRow, Int)]], - applyPartialClustering: java.lang.Boolean, - replicatePartitions: java.lang.Boolean) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index 1a0efa7c4aafb..d88a610f94b6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,8 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) => - KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues, - originalPartValues) + case KeyedPartitioning(expressions, partitionKeys, isGrouped) => + KeyedPartitioning(expressions.map(resolveAttrs(_, plan)), partitionKeys, isGrouped) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 56bd028464e54..61384bf9f1fca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -30,8 +30,7 @@ import org.apache.spark.sql.connector.distributions.Distributions import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.expressions.Expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation, GroupPartitionsExec} import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions.{col, max} @@ -76,7 +75,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Column.create("dept_id", IntegerType), Column.create("data", StringType)) - test("clustered distribution: output partitioning should be KeyGroupedPartitioning") { + test("clustered distribution: output partitioning should be KeyedPartitioning") { val partitions: Array[Transform] = Array(Expressions.years("ts")) // create a table with 3 partitions, partitioned by `years` transform @@ -89,18 +88,15 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { var df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY ts") val catalystDistribution = physical.ClusteredDistribution( Seq(TransformExpression(YearsFunction, Seq(attr("ts"))))) - val partitionValues = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) - val projectedPositions = catalystDistribution.clustering.indices + val partitionKeys = Seq(50L, 51L, 52L).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + physical.KeyedPartitioning(catalystDistribution.clustering, partitionKeys)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + physical.KeyedPartitioning(catalystDistribution.clustering, partitionKeys)) } test("non-clustered distribution: no partition") { @@ -124,9 +120,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) // Has exactly one partition. - val partitionValues = Seq(0).map(v => InternalRow.fromSeq(Seq(v))) + val partitionKeys = Seq(0).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, distribution, - physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues)) + physical.KeyedPartitioning(distribution.clustering, partitionKeys)) } test("non-clustered distribution: no V2 catalog") { @@ -275,7 +271,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { private def testWithCustomersAndOrders( customers_partitions: Array[Transform], orders_partitions: Array[Transform], - expectedNumOfShuffleExecs: Int): Unit = { + expectedNumOfShuffleExecs: Int, + expectedGroupPartitionsExecs: Int): Unit = { createTable(customers, customersColumns, customers_partitions) sql(s"INSERT INTO testcat.ns.$customers VALUES " + s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)") @@ -295,6 +292,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.length == expectedNumOfShuffleExecs) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.length == expectedGroupPartitionsExecs) + checkAnswer(df, Seq(Row("aaa", 10, 100.0), Row("aaa", 10, 200.0), Row("bbb", 20, 150.0), Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50))) @@ -306,6 +306,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + protected def collectAllGroupPartitions(plan: SparkPlan): Seq[GroupPartitionsExec] = { + collect(plan) { + case g: GroupPartitionsExec => g + } + } + protected def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = { // here we skip collecting shuffle operators that are not associated with SMJ collect(plan) { @@ -314,17 +320,56 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { collect(smj) { case s: ShuffleExchangeExec => s }) - } + }.toSet.toSeq + + protected def collectGroupPartitions(plan: SparkPlan): Seq[GroupPartitionsExec] = { + // here we skip collecting shuffle operators that are not associated with SMJ + collect(plan) { + case s: SortMergeJoinExec => s + }.flatMap(smj => + collect(smj) { + case g: GroupPartitionsExec => g + }) + }.toSet.toSeq private def collectScans(plan: SparkPlan): Seq[BatchScanExec] = { collect(plan) { case s: BatchScanExec => s } } + /** + * Helper method to verify that filteredPartitions contains the expected number of + * Some and None values. This is used to verify that dynamic partition filtering + * properly fills filtered-out partitions with None. + */ + private def assertFilteredPartitions( + scans: Seq[BatchScanExec], + expectedTotalPartitions: Seq[Int], + expectedFilteredOutPartitions: Seq[Int]): Unit = { + assert(scans.size === expectedTotalPartitions.size, + s"Expected ${expectedTotalPartitions.size} scans but got ${scans.size}") + + scans.zip(expectedTotalPartitions).zip(expectedFilteredOutPartitions).foreach { + case ((scan, expectedTotal), expectedFiltered) => + val filtered = scan.filteredPartitions + assert(filtered.size === expectedTotal, + s"Expected $expectedTotal total partitions but got ${filtered.size}") + + val noneCount = filtered.count(_.isEmpty) + assert(noneCount === expectedFiltered, + s"Expected $expectedFiltered None values but got $noneCount") + + val someCount = filtered.count(_.isDefined) + assert(someCount === (expectedTotal - expectedFiltered), + s"Expected ${expectedTotal - expectedFiltered} Some values but got $someCount") + } + } + + test("partitioned join: exact distribution (same number of buckets) from both sides") { val customers_partitions = Array(bucket(4, "customer_id")) val orders_partitions = Array(bucket(4, "customer_id")) - testWithCustomersAndOrders(customers_partitions, orders_partitions, 0) + testWithCustomersAndOrders(customers_partitions, orders_partitions, 0, 1) } test("partitioned join: number of buckets mismatch should trigger shuffle") { @@ -332,13 +377,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val orders_partitions = Array(bucket(2, "customer_id")) // should shuffle both sides when number of buckets are not the same - testWithCustomersAndOrders(customers_partitions, orders_partitions, 2) + testWithCustomersAndOrders(customers_partitions, orders_partitions, 2, 0) } test("partitioned join: only one side reports partitioning") { val customers_partitions = Array(bucket(4, "customer_id")) - testWithCustomersAndOrders(customers_partitions, Array.empty, 2) + testWithCustomersAndOrders(customers_partitions, Array.empty, 2, 0) } private val items: String = "items" @@ -354,6 +399,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Column.create("price", FloatType), Column.create("time", TimestampType)) + private val details: String = "details" + private val detailsColumns: Array[Column] = Array( + Column.create("item_id", LongType), + Column.create("description", StringType), + Column.create("updated", TimestampType)) + test("SPARK-48655: group by on partition keys should not introduce additional shuffle") { val items_partitions = Array(identity("id")) createTable(items, itemsColumns, items_partitions) @@ -366,7 +417,10 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val df = sql(s"SELECT MAX(price) AS res FROM testcat.ns.$items GROUP BY id") val shuffles = collectAllShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, - "should contain shuffle when not grouping by partition values") + "should not contain shuffle when grouping by partition values") + val groupPartitions = collectAllGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.size == 1, + "should contain group partitions when grouping by partition values") checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0))) } @@ -390,9 +444,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { if (sortingEnabled) { assert(collectAllShuffles(df.queryExecution.executedPlan).isEmpty, "should contain no shuffle when sorting by partition values") + assert(collectAllGroupPartitions(df.queryExecution.executedPlan).size == 1, + "should contain partition grouping when sorting by partition values") } else { assert(collectAllShuffles(df.queryExecution.executedPlan).size == 1, "should contain one shuffle when optimization is disabled") + assert(collectAllGroupPartitions(df.queryExecution.executedPlan).isEmpty, + "should contain no partition grouping when optimization is disabled") } checkAnswer(df, answer) }: Unit @@ -446,6 +504,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { |""".stripMargin) checkAnswer(df, Seq(Row(1, 1, "aa"), Row(2, 2, "bb"), Row(3, 3, "cc"))) assert(collectShuffles(df.queryExecution.executedPlan).isEmpty) + assert(collectGroupPartitions(df.queryExecution.executedPlan).isEmpty) } } @@ -473,6 +532,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.size === 2, + "should contain group partitions on both sides of the join") checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(1, "aa", 41.0, 44.0), Row(1, "aa", 41.0, 45.0), Row(2, "bb", 10.0, 11.0), Row(2, "bb", 10.5, 11.0), Row(3, "cc", 15.5, 19.5)) @@ -505,6 +567,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.size === 2, + "should contain group partitions on both sides of the join") checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(1, "aa", 41.0, 44.0), Row(1, "aa", 41.0, 45.0), Row(2, "bb", 10.0, 11.0), Row(2, "bb", 10.5, 11.0), Row(3, "cc", 15.5, 19.5)) @@ -532,11 +597,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString) { val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) val shuffles = collectShuffles(df.queryExecution.executedPlan) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not add shuffle when partition values mismatch") + assert(groupPartitions.size === 2, + "should add group partitions when partition values mismatch") } else { assert(shuffles.nonEmpty, "should add shuffle when partition values mismatch, and " + "pushing down partition values is not enabled") + assert(groupPartitions.isEmpty, "should not add group partition when partition values " + + "mismatch, and pushing down partition values is not enabled") } checkAnswer(df, @@ -566,11 +636,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not add shuffle when partition values mismatch") + assert(groupPartitions.size === 2, + "should add group partitions when partition values mismatch") } else { assert(shuffles.nonEmpty, "should add shuffle when partition values mismatch, and " + "pushing down partition values is not enabled") + assert(groupPartitions.isEmpty, "should not add group partition when partition values " + + "mismatch, and pushing down partition values is not enabled") } checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5))) @@ -598,11 +673,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not add shuffle when partition values mismatch") + assert(groupPartitions.size === 2, + "should add group partitions when partition values mismatch") } else { assert(shuffles.nonEmpty, "should add shuffle when partition values mismatch, and " + "pushing down partition values is not enabled") + assert(groupPartitions.isEmpty, "should not add group partition when partition values " + + "mismatch, and pushing down partition values is not enabled") } checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(2, "bb", 10.0, 19.5))) @@ -629,11 +709,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not add shuffle when partition values mismatch") + assert(groupPartitions.size === 2, + "should add group partitions when partition values mismatch") } else { assert(shuffles.nonEmpty, "should add shuffle when partition values mismatch, and " + "pushing down partition values is not enabled") + assert(groupPartitions.isEmpty, "should not add group partition when partition values " + + "mismatch, and pushing down partition values is not enabled") } checkAnswer(df, Seq.empty) @@ -641,7 +726,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } - test("SPARK-49205: KeyGroupedPartitioning should inherit HashPartitioningLike") { + test("SPARK-49205: KeyedPartitioning should be an Expression") { val items_partitions = Array(days("arrive_time")) createTable(items, itemsColumns, items_partitions) sql(s"INSERT INTO testcat.ns.$items VALUES " + @@ -717,8 +802,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not contain any shuffle") if (pushDownValues) { - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == expected)) } checkAnswer(df, Seq(Row(1, "aa", 40.0, 45.0), Row(1, "aa", 40.0, 50.0), Row(2, "bb", 10.0, 15.0), Row(2, "bb", 10.0, 20.0), Row(3, "cc", 15.5, 20.0))) @@ -758,8 +843,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not contain any shuffle") if (pushDownValues) { - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } checkAnswer(df, Seq( Row(1, "aa", 40.0, 45.0), Row(1, "aa", 40.0, 50.0), Row(1, "aa", 40.0, 55.0), @@ -806,8 +891,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -857,8 +942,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -903,8 +988,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -950,9 +1035,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected), - s"Expected $expected but got ${scans.head.inputRDD.partitions.length}") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -999,10 +1083,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.map(_.inputRDD.partitions.length).toSet.size == 1) - assert(scans.forall(_.inputRDD.partitions.length == expected), - s"Expected $expected but got ${scans.head.inputRDD.partitions.length}") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -1047,10 +1129,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.map(_.inputRDD.partitions.length).toSet.size == 1) - assert(scans.forall(_.inputRDD.partitions.length == expected), - s"Expected $expected but got ${scans.head.inputRDD.partitions.length}") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -1123,8 +1203,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not contain any shuffle") if (pushDownValues) { - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length === 3)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === 3)) } } } @@ -1161,15 +1241,40 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { // with empty partitions and the job should still succeed var df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p " + "WHERE i.id = p.item_id AND i.price > 40.0") + + var shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + var scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.outputPartitioning.numPartitions === 5)) + var groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === 3)) + checkAnswer(df, Seq(Row(131))) + // Verify that filteredPartitions contains None for filtered-out partitions. + // After DPF with filter i.price > 40.0, only id=1 survives on items side. + // The purchases side should be pruned to only item_id=1. + // purchases: 5 total partitions (3 for id=1, 1 for id=2, 1 for id=3) + // After DPF: 3 Some (id=1), 2 None (id=2, id=3) + assertFilteredPartitions(scans, Seq(5, 5), Seq(0, 2)) + // dynamic filtering doesn't change partitioning so storage-partitioned join should kick // in df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p " + "WHERE i.id = p.item_id AND i.price >= 10.0") - val shuffles = collectShuffles(df.queryExecution.executedPlan) + + shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.outputPartitioning.numPartitions === 5)) + groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === 3)) + checkAnswer(df, Seq(Row(303.5))) + + // With filter i.price >= 10.0, all ids (1, 2, 3) survive, + // so no partitions should be filtered out + assertFilteredPartitions(scans, Seq(5, 5), Seq(0, 0)) } } } @@ -1224,14 +1329,25 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkAnswer(df, Seq(Row(213.5))) val shuffles = collectShuffles(df.queryExecution.executedPlan) + val scans = collectScans(df.queryExecution.executedPlan) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(scans.map(_.outputPartitioning.numPartitions) === Seq(14, 6)) if (pushDownValues) { assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") + assert(groupPartitions.isEmpty) } + + // Verify filteredPartitions for DPF. + // After filter p.price < 45.0, purchases has item_ids {1, 2, 3, 5}. + // Items side should be pruned to these ids. Since items has {1, 2, 3, 4}, + // id=4 should be filtered out. + // purchases: 14 total, all kept (0 None) - no DPF on probe side + // items: 6 total, id=4 filtered (1 None) + assertFilteredPartitions(scans, Seq(14, 6), Seq(0, 1)) } } } @@ -1495,12 +1611,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "SPJ should be triggered") - val scans = collectScans(df.queryExecution.executedPlan) - .map(_.inputRDD.partitions.length) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) if (partiallyClustered) { - assert(scans == Seq(8, 8)) + assert(groupPartitions == Seq(8, 8)) } else { - assert(scans == Seq(4, 4)) + assert(groupPartitions == Seq(4, 4)) } checkAnswer(df, Seq( Row(3, "dd", "dd"), @@ -1564,23 +1680,23 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(shuffles.nonEmpty, "SPJ should not be triggered") } - val scannedPartitions = collectScans(df.queryExecution.executedPlan) - .map(_.inputRDD.partitions.length) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered, filter) match { // SPJ, partially-clustered, with filter - case (true, true, true) => assert(scannedPartitions == Seq(6, 6)) + case (true, true, true) => assert(groupPartitions == Seq(6, 6)) // SPJ, partially-clustered, no filter - case (true, true, false) => assert(scannedPartitions == Seq(8, 8)) + case (true, true, false) => assert(groupPartitions == Seq(8, 8)) // SPJ and not partially-clustered, with filter - case (true, false, true) => assert(scannedPartitions == Seq(2, 2)) + case (true, false, true) => assert(groupPartitions == Seq(2, 2)) // SPJ and not partially-clustered, no filter - case (true, false, false) => assert(scannedPartitions == Seq(4, 4)) + case (true, false, false) => assert(groupPartitions == Seq(4, 4)) // No SPJ - case _ => assert(scannedPartitions == Seq(5, 4)) + case _ => assert(groupPartitions == Seq.empty) } checkAnswer(df, Seq( @@ -1703,8 +1819,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "SPJ should be triggered") - val partions = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. - partitions.length) + val partions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) val expectedBuckets = Math.min(table1buckets1, table2buckets1) * Math.min(table1buckets2, table2buckets2) assert(partions == Seq(expectedBuckets, expectedBuckets)) @@ -1863,13 +1979,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "SPJ should be triggered") - val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. - partitions.length) - + val partitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt val expectedPartitions = gcd(table1buckets1, table2buckets1) * gcd(table1buckets2, table2buckets2) - assert(scans == Seq(expectedPartitions, expectedPartitions)) + assert(partitions == Seq(expectedPartitions, expectedPartitions)) checkAnswer(df, Seq( Row(0, 0, "aa", "aa"), @@ -2041,12 +2156,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "SPJ should be triggered") - val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. - partitions.length) + val partitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) val expectedBuckets = Math.min(table1buckets, table2buckets) - assert(scans == Seq(expectedBuckets, expectedBuckets)) + assert(partitions == Seq(expectedBuckets, expectedBuckets)) checkAnswer(df, Seq( Row(0, 6, 0, 0, "aa", "01"), @@ -2105,16 +2220,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { |""".stripMargin) val shuffles = collectShuffles(df.queryExecution.executedPlan) - val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. - partitions.length) + val partitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) (allowPushDown, partiallyClustered) match { case (true, false) => assert(shuffles.isEmpty, "SPJ should be triggered") - assert(scans == Seq(2, 2)) + assert(partitions == Seq(2, 2)) case (_, _) => assert(shuffles.nonEmpty, "SPJ should not be triggered") - assert(scans == Seq(3, 2)) + assert(partitions.isEmpty) } checkAnswer(df, Seq( @@ -2172,13 +2287,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(shuffles.nonEmpty, "SPJ should not be triggered") } - val scans = collectScans(df.queryExecution.executedPlan) - .map(_.inputRDD.partitions.length) + val partitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) (pushDownValues, allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { // SPJ and partially-clustered - case (true, true, true) => assert(scans == Seq(3, 3)) + case (_, true, _) => assert(partitions == Seq(3, 3)) // non-SPJ or SPJ/partially-clustered - case _ => assert(scans == Seq(3, 3)) + case _ => assert(partitions.isEmpty) } } } @@ -2226,15 +2341,15 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(shuffles.nonEmpty, "SPJ should not be triggered") } - val scans = collectScans(df.queryExecution.executedPlan) - .map(_.inputRDD.partitions.length) + val partitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { // SPJ and partially-clustered - case (true, true) => assert(scans == Seq(5, 5)) + case (true, true) => assert(partitions == Seq(5, 5)) // SPJ and not partially-clustered - case (true, false) => assert(scans == Seq(3, 3)) + case (true, false) => assert(partitions == Seq(3, 3)) // No SPJ - case _ => assert(scans == Seq(4, 4)) + case _ => assert(partitions.isEmpty) } checkAnswer(df, @@ -2466,8 +2581,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(5, "cc", 44.5, 44.0)) ) - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 2)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 2)) } } @@ -2491,8 +2606,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") assert(df.collect().isEmpty, "should return no results") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 0)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 0)) } } @@ -2523,8 +2638,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row(1, "aa", 40.0, 40.0)) ) - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 3)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 3)) } } @@ -2556,8 +2671,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row(1, "aa", 40.0, 40.0)) ) - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 4)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 4)) } } @@ -2588,8 +2703,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row(4, "aa", 40.0, 42.0)) ) - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 3)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 3)) } } @@ -2623,8 +2738,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(5, "cc", 44.5, 44.0)) ) - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 2)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 2)) } } @@ -2646,10 +2761,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectAllShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not contain shuffle when not grouping by partition values") + val groupPartitions = collectAllGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.size === 1) + assert(groupPartitions.head.outputPartitioning.numPartitions == 3) } } - test("SPARK-53322: checkpointed scans aren't used for SPJ") { + test("SPARK-53322: checkpointed scans are used for SPJ") { withTempDir { dir => spark.sparkContext.setCheckpointDir(dir.getPath) val itemsPartitions = Array(identity("id")) @@ -2688,14 +2806,21 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { df, Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) ) - // 1 shuffle for SORT and 2 shuffles for JOIN are expected. - assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3) + if (pushdownValues) { + // 1 shuffle for SORT and 2 group partitions for JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 1) + assert(collectAllGroupPartitions(df.queryExecution.executedPlan).length === 2) + } else { + // 1 shuffle for SORT and 2 shuffles for JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3) + assert(collectAllGroupPartitions(df.queryExecution.executedPlan).length === 0) + } } } } } - test("SPARK-53322: checkpointed scans can't shuffle other children on SPJ") { + test("SPARK-53322: checkpointed scans can shuffle other children on SPJ") { withTempDir { dir => spark.sparkContext.setCheckpointDir(dir.getPath) val itemsPartitions = Array(identity("id")) @@ -2727,52 +2852,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { df, Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) ) - // 1 shuffle for SORT and 2 shuffles for JOIN are expected. - assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3) + // 1 shuffle for SORT and 1 shuffle for JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 2) + // 0 group partitions are expected because both sides of the join are clustered from scans + assert(collectAllGroupPartitions(df.queryExecution.executedPlan).length === 0) } } } } - test("SPARK-53322: checkpointed scans can be shuffled by children on SPJ") { - withTempDir { dir => - spark.sparkContext.setCheckpointDir(dir.getPath) - val itemsPartitions = Array(identity("id")) - createTable(items, itemsColumns, itemsPartitions) - sql(s"INSERT INTO testcat.ns.$items VALUES " + - s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " + - s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " + - s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))") - - createTable(purchases, purchasesColumns, Array(identity("item_id"))) - sql(s"INSERT INTO testcat.ns.$purchases VALUES " + - s"(1, 40.0, cast('2020-01-01' as timestamp)), " + - s"(3, 25.5, cast('2020-01-03' as timestamp)), " + - s"(4, 20.0, cast('2020-01-04' as timestamp))") - - withSQLConf( - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") { - val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i") - val scanDF2 = spark.read.table(s"testcat.ns.$purchases").as("p") - - val df = scanDF1 - .join(scanDF2, col("id") === col("item_id")) - .selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price") - .orderBy("id", "purchase_price", "sale_price") - checkAnswer( - df, - Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) - ) - - // One shuffle for the sort and one shuffle for one side of the JOIN are expected. - assert(collectAllShuffles(df.queryExecution.executedPlan).length === 2) - } - } - } - - test("SPARK-54439: KeyGroupedPartitioning and join key size mismatch") { + test("SPARK-54439: KeyedPartitioning and join key size mismatch") { val items_partitions = Array(identity("id")) createTable(items, itemsColumns, items_partitions) @@ -2797,7 +2886,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } - test("SPARK-54439: KeyGroupedPartitioning with transform and join key size mismatch") { + test("SPARK-54439: KeyedPartitioning with transform and join key size mismatch") { val items_partitions = Array(years("arrive_time")) createTable(items, itemsColumns, items_partitions) @@ -2832,10 +2921,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))") val metrics = runAndFetchMetrics { - val df = sql(s"SELECT * FROM testcat.ns.$items") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans(0).inputRDD.partitions.length === 2, "items scan should have 2 partition groups") + val df = sql(s"SELECT id, count(*) FROM testcat.ns.$items GROUP BY id") df.collect() + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 3, "items scan should have 3 partitions") + val groupPartitions = collectAllGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions(0).outputPartitioning.numPartitions === 2, + "group partitions should have 2 partition groups") } assert(metrics("number of rows read") == "3") } @@ -2892,4 +2984,187 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row("ccc", 30, 400.50))) } } + + test("SPARK-55092: Scans should not group partitions") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))") + + val purchases_partitions = Array(years("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + val df = sql(s"SELECT * FROM testcat.ns.$items") + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 3, + "items scan should not group partitions") + + Seq((true, 1), (false, 2)).foreach { case (bucketingShuffle, expectedShuffleCount) => + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> bucketingShuffle.toString) { + val df = createJoinTestDF(Seq("id" -> "item_id")) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == expectedShuffleCount) + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 3, + "items scan should not group partitions") + assert(scans(1).inputRDD.partitions.length === 2, + "purchases scan should not group partitions") + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0))) + } + } + } + + test("SPARK-55535: Multi table join granular partition grouping") { + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val items_partitions = Array(identity("id"), years("arrive_time")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 10.0, cast('2021-01-01' as timestamp)), " + + "(1, 'aa', 20.0, cast('2022-01-01' as timestamp)), " + + "(2, 'aa', 30.0, cast('2021-01-01' as timestamp)), " + + "(2, 'aa', 40.0, cast('2022-01-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id"), years("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(2, 10.0, cast('2021-01-01' as timestamp)), " + + "(2, 20.0, cast('2022-01-01' as timestamp)), " + + "(3, 30.0, cast('2021-01-01' as timestamp)), " + + "(3, 40.0, cast('2022-01-01' as timestamp))") + + val details_partitions = Array(identity("item_id")) + createTable(details, detailsColumns, details_partitions) + + sql(s"INSERT INTO testcat.ns.$details VALUES " + + "(2, 'cc', cast('2021-01-01' as timestamp)), " + + "(3, 'cc', cast('2022-01-01' as timestamp))") + + val df = sql( + s""" + |SELECT i.id, i.arrive_time, p.item_id, d.item_id + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON p.item_id = i.id AND p.time = i.arrive_time + |JOIN testcat.ns.$details d ON d.item_id = i.id + |""".stripMargin) + + checkAnswer(df, Seq( + Row(2, Timestamp.valueOf("2021-01-01 00:00:00"), 2, 2), + Row(2, Timestamp.valueOf("2022-01-01 00:00:00"), 2, 2))) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not contain any shuffle") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + // Expect 6 partitions in the inner join node legs because partitioning uses 2 attributes. + // Expect 3 partitions in the outer join node legs because partitioning uses 1 attributes. + assert(groupPartitions.map(_.outputPartitioning.numPartitions) === Seq(3, 6, 6, 3)) + } + } + + test("SPARK-55535: Multi table join partial clustering") { + withSQLConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 10.0, cast('2021-01-01' as timestamp)), " + + "(1, 'aa', 20.0, cast('2022-01-01' as timestamp)), " + + "(2, 'aa', 30.0, cast('2021-01-01' as timestamp)), " + + "(2, 'aa', 40.0, cast('2022-01-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(2, 10.0, cast('2021-01-01' as timestamp)), " + + "(3, 20.0, cast('2022-01-01' as timestamp))") + + val details_partitions = Array(identity("item_id")) + createTable(details, detailsColumns, details_partitions) + + sql(s"INSERT INTO testcat.ns.$details VALUES " + + "(2, 'cc', cast('2021-01-01' as timestamp)), " + + "(4, 'cc', cast('2022-01-01' as timestamp))") + + val df = sql( + s""" + |SELECT i.id, i.price, p.price, d.description + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON p.item_id = i.id + |JOIN testcat.ns.$details d ON d.item_id = i.id + |""".stripMargin) + + checkAnswer(df, Seq( + Row(2, 30.0, 10.0, "cc"), + Row(2, 40.0, 10.0, "cc"))) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not contain any shuffle") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + // Expect 5 partitions in the inner join node legs because 4 from the partially clustered + // items table and 1 new from clustered purchases table. + // Expect 6 partitions in the outer join node legs because 5 from the partially clustered + // inner join result and 1 new from clustered details table. + assert(groupPartitions.map(_.outputPartitioning.numPartitions) === Seq(6, 5, 5, 6)) + } + } + + test("SPARK-55535: Empty partitioned table") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + + val df = createJoinTestDF(Seq("id" -> "item_id")) + checkAnswer(df, Seq.empty) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size === 2, + "both legs should be shuffled as empty tables should not report KeyedPartitioning") + + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.isEmpty, + "no legs should be grouped as empty tables should not report KeyedPartitioning") + } + } + + test("SPARK-55535: Empty group partitions due filtered partitions") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 39.0, cast('2020-01-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(2, 42.0, cast('2020-01-01' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id")) + checkAnswer(df, Seq.empty) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "no legs should be shuffled") + + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 0), + "group partitions should not have any (common) partitions") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 1cc0d795d74f8..7512cbe7f90b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, GroupPartitionsExec} import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec import org.apache.spark.sql.execution.window.WindowExec @@ -91,15 +91,15 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } - test("reorder should handle KeyGroupedPartitioning") { + test("reorder should handle KeyedPartitioning") { // partitioning on the left val plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(Seq( - years(exprA), bucket(4, exprB), days(exprC)), 4) + outputPartitioning = + KeyedPartitioning(Seq(years(exprA), bucket(4, exprB), days(exprC)), Seq.empty) ) val plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(Seq( - years(exprB), bucket(4, exprA), days(exprD)), 4) + outputPartitioning = + KeyedPartitioning(Seq(years(exprB), bucket(4, exprA), days(exprD)), Seq.empty) ) val smjExec = SortMergeJoinExec( exprB :: exprC :: exprA :: Nil, exprA :: exprD :: exprB :: Nil, @@ -107,8 +107,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { ) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, DummySparkPlan(_, _, _: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, _: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, _: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: KeyedPartitioning, _, _), _), _) => assert(leftKeys === Seq(exprA, exprB, exprC)) assert(rightKeys === Seq(exprB, exprA, exprD)) case other => fail(other.toString) @@ -116,8 +116,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { // partitioning on the right val plan3 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(Seq( - bucket(4, exprD), days(exprA), years(exprC)), 4) + outputPartitioning = + KeyedPartitioning(Seq(bucket(4, exprD), days(exprA), years(exprC)), Seq.empty) ) val smjExec2 = SortMergeJoinExec( exprB :: exprD :: exprC :: Nil, exprA :: exprC :: exprD :: Nil, @@ -777,18 +777,18 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } - test("Check with KeyGroupedPartitioning") { + test("Check with KeyedPartitioning") { // simplest case: identity transforms var plan1 = new DummySparkPlanWithBatchScanChild( - KeyGroupedPartitioning(exprA :: exprB :: Nil, 5)) + KeyedPartitioning(exprA :: exprB :: Nil, Seq.empty)) var plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(exprA :: exprC :: Nil, 5)) + outputPartitioning = KeyedPartitioning(exprA :: exprC :: Nil, Seq.empty)) var smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(exprA, exprB)) assert(right.expressions === Seq(exprA, exprC)) case other => fail(other.toString) @@ -796,19 +796,19 @@ class EnsureRequirementsSuite extends SharedSparkSession { // matching bucket transforms from both sides plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) + outputPartitioning = + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4) + outputPartitioning = + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(bucket(4, exprA), bucket(16, exprB))) assert(right.expressions === Seq(bucket(4, exprA), bucket(16, exprC))) case other => fail(other.toString) @@ -816,20 +816,20 @@ class EnsureRequirementsSuite extends SharedSparkSession { // partition collections plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) + outputPartitioning = + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = PartitioningCollection(Seq( - KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4), - KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4)) + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty), + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty)) ) ) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => assert(left.expressions === Seq(bucket(4, exprA), bucket(16, exprB))) case other => fail(other.toString) @@ -839,24 +839,24 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(right.expressions === Seq(bucket(4, exprA), bucket(16, exprB))) case other => fail(other.toString) } // bucket + years transforms from both sides plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(bucket(4, exprA), years(exprB))) assert(right.expressions === Seq(bucket(4, exprA), years(exprC))) case other => fail(other.toString) @@ -865,12 +865,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // by default spark.sql.requireAllClusterKeysForCoPartition is true, so when there isn't // exact match on all partition keys, Spark will fallback to shuffle. plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(4, exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: bucket(4, exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -884,14 +882,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } - test(s"KeyGroupedPartitioning with ${REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key} = false") { + test(s"KeyedPartitioning with ${REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key} = false") { var plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprB) :: years(exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprB) :: years(exprC) :: Nil, Seq.empty) ) var plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprC) :: years(exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprC) :: years(exprB) :: Nil, Seq.empty) ) // simple case @@ -899,8 +895,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(bucket(4, exprB), years(exprC))) assert(right.expressions === Seq(bucket(4, exprC), years(exprB))) case other => fail(other.toString) @@ -908,19 +904,17 @@ class EnsureRequirementsSuite extends SharedSparkSession { // should also work with distributions with duplicated keys plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: years(exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: years(exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(bucket(4, exprA), years(exprB))) assert(right.expressions === Seq(bucket(4, exprA), years(exprC))) case other => fail(other.toString) @@ -928,17 +922,17 @@ class EnsureRequirementsSuite extends SharedSparkSession { // both partitioning and distribution have duplicated keys plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - years(exprA) :: bucket(4, exprB) :: days(exprA) :: Nil, 5)) + outputPartitioning = + KeyedPartitioning(years(exprA) :: bucket(4, exprB) :: days(exprA) :: Nil, Seq.empty)) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - years(exprA) :: bucket(4, exprC) :: days(exprA) :: Nil, 5)) + outputPartitioning = + KeyedPartitioning(years(exprA) :: bucket(4, exprC) :: days(exprA) :: Nil, Seq.empty)) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(years(exprA), bucket(4, exprB), days(exprA))) assert(right.expressions === Seq(years(exprA), bucket(4, exprC), days(exprA))) case other => fail(other.toString) @@ -946,12 +940,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: partitioning key positions don't match plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprB) :: bucket(4, exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprB) :: bucket(4, exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( @@ -967,12 +959,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: different number of buckets (we don't support coalescing/repartitioning yet) plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(8, exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: bucket(8, exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -987,10 +977,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: partition key positions match but with different transforms plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(years(exprA) :: bucket(4, exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning(years(exprA) :: bucket(4, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(days(exprA) :: bucket(4, exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning(days(exprA) :: bucket(4, exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -1006,12 +996,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: multiple references in transform plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) + outputPartitioning = + KeyedPartitioning(years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) + outputPartitioning = + KeyedPartitioning(years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -1032,12 +1022,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { .map(new GenericInternalRow(_)) var plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, - leftPartValues.length, leftPartValues) + outputPartitioning = + KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues) ) var plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues.length, rightPartValues) + outputPartitioning = + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues) ) // simple case @@ -1045,8 +1035,13 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, left: KeyedPartitioning, _, _), + _, _, _, _, _), _), + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, right: KeyedPartitioning, _, _), + _, _, _, _, _), _), + _) => assert(left.expressions === Seq(bucket(4, exprB), bucket(8, exprC))) assert(right.expressions === Seq(bucket(4, exprC), bucket(8, exprB))) case other => fail(other.toString) @@ -1055,10 +1050,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { // With partition collections plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = PartitioningCollection( - Seq(KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, - leftPartValues.length, leftPartValues), - KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, - leftPartValues.length, leftPartValues)) + Seq(KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues), + KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues)) ) ) @@ -1066,11 +1059,16 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: PartitioningCollection, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, left: PartitioningCollection, _, _), + _, _, _, _, _), _), + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, right: KeyedPartitioning, _, _), + _, _, _, _, _), _), + _) => assert(left.partitionings.length == 2) - assert(left.partitionings.head.isInstanceOf[KeyGroupedPartitioning]) - assert(left.partitionings.head.asInstanceOf[KeyGroupedPartitioning].expressions == + assert(left.partitionings.head.isInstanceOf[KeyedPartitioning]) + assert(left.partitionings.head.asInstanceOf[KeyedPartitioning].expressions == Seq(bucket(4, exprB), bucket(8, exprC))) assert(right.expressions === Seq(bucket(4, exprC), bucket(8, exprB))) case other => fail(other.toString) @@ -1082,16 +1080,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { Seq( PartitioningCollection( Seq( - KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues.length, rightPartValues), - KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues.length, rightPartValues))), + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues), + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues))), PartitioningCollection( Seq( - KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues.length, rightPartValues), - KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues.length, rightPartValues))) + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues), + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues))) ) ) ) @@ -1100,11 +1094,16 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: PartitioningCollection, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: PartitioningCollection, _, _), _), _) => + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, left: PartitioningCollection, _, _), + _, _, _, _, _), _), + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, right: PartitioningCollection, _, _), + _, _, _, _, _), _), + _) => assert(left.partitionings.length == 2) - assert(left.partitionings.head.isInstanceOf[KeyGroupedPartitioning]) - assert(left.partitionings.head.asInstanceOf[KeyGroupedPartitioning].expressions == + assert(left.partitionings.head.isInstanceOf[KeyedPartitioning]) + assert(left.partitionings.head.asInstanceOf[KeyedPartitioning].expressions == Seq(bucket(4, exprB), bucket(8, exprC))) assert(right.partitionings.length == 2) assert(right.partitionings.head.isInstanceOf[PartitioningCollection]) @@ -1119,21 +1118,21 @@ class EnsureRequirementsSuite extends SharedSparkSession { val a1 = AttributeReference("a1", IntegerType)() - val partitionValue = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) - val plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = KeyGroupedPartitioning( - identity(a1) :: Nil, 4, partitionValue)) + val partitionKeys = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) + val plan1 = new DummySparkPlanWithBatchScanChild( + outputPartitioning = KeyedPartitioning(identity(a1) :: Nil, partitionKeys)) val plan2 = DummySparkPlan(outputPartitioning = SinglePartition) val smjExec = ShuffledHashJoinExec( a1 :: Nil, a1 :: Nil, Inner, BuildRight, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, - DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), - ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _), + DummySparkPlan(_, _, left: KeyedPartitioning, _, _), + ShuffleExchangeExec(KeyedPartitioning(attrs, pks, _), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil) - assert(partitionValue == pv) + assert(partitionKeys == pks.map(_.row)) case other => fail(other.toString) } }