diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index aeacdaec7a8de..d2bb12d2053aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -796,21 +796,44 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) checkKeyedPartitioningInvariant() + /** + * First [[KeyedPartitioning]] reachable from this collection through direct members or nested + * collections, if any. Since every collection validates the invariant on construction, this + * single representative stands for all [[KeyedPartitioning]]s in the subtree: they all share + * its `partitionKeys` reference and expression arity. The invariant check forces this lazy val + * during construction, so it is only recomputed after deserialization. + */ + @transient private[physical] lazy val firstKeyedPartitioning: Option[KeyedPartitioning] = + partitionings.view.map { + case k: KeyedPartitioning => Some(k) + case pc: PartitioningCollection => pc.firstKeyedPartitioning + case _ => None + }.collectFirst { case Some(k) => k } + + /** + * Nested collections already enforced the invariant on their own construction, so comparing one + * representative per direct member against [[firstKeyedPartitioning]] validates the whole + * subtree without walking it. Keeping this check O(partitionings.size) matters: join + * `outputPartitioning` builds these collections afresh on every call, and plans chaining many + * same-key joins nest them linearly deep. + */ private def checkKeyedPartitioningInvariant(): Unit = { - var first: KeyedPartitioning = null - foreach { - case k: KeyedPartitioning => - if (first == null) { - first = k - } else { - require(k.expressions.length == first.expressions.length, + firstKeyedPartitioning.foreach { first => + partitionings.foreach { p => + val representative = p match { + case k: KeyedPartitioning => k + case pc: PartitioningCollection => pc.firstKeyedPartitioning.orNull + case _ => null + } + if (representative != null && (representative ne first)) { + require(representative.expressions.length == first.expressions.length, "All KeyedPartitionings in a PartitioningCollection must have matching expression " + "arity") - require(k.partitionKeys eq first.partitionKeys, + require(representative.partitionKeys eq first.partitionKeys, "All KeyedPartitionings in a PartitioningCollection must share the same " + "partitionKeys reference") } - case _ => + } } } @@ -868,7 +891,24 @@ object PartitioningCollection { } else { k } - case pc: PartitioningCollection => new PartitioningCollection(pc.partitionings.map(intern)) + case pc: PartitioningCollection => + pc.firstKeyedPartitioning match { + // No KeyedPartitioning anywhere in this subtree: nothing to intern. Returning the + // collection as-is keeps repeated outputPartitioning computations over deeply nested + // collections (e.g. chains of same-key joins) O(1) per level. + case None => pc + case Some(representative) if canonicalKeys == null => + canonicalKeys = representative.partitionKeys + pc + // The collection's own invariant guarantees all its KeyedPartitionings share the + // representative's `partitionKeys` reference, so reference-equality of the + // representative's keys means the whole subtree is already interned. + case Some(representative) if representative.partitionKeys eq canonicalKeys => pc + case Some(representative) => + require(representative.partitionKeys == canonicalKeys, + "All KeyedPartitionings in a PartitioningCollection must have equal partitionKeys") + new PartitioningCollection(pc.partitionings.map(intern)) + } case other => other } new PartitioningCollection(partitionings.map(intern)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 88faf70103186..2e9e1270b3fc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -410,4 +410,60 @@ class DistributionSuite extends SparkFunSuite { assert(groupedKP.isGrouped) checkSatisfied(groupedKP, ClusteredDistribution(Seq(x)), true) } + + test("SPARK-56877: fromPartitionings reuses already-consistent nested collections") { + val x = AttributeReference("x", IntegerType)() + val y = AttributeReference("y", IntegerType)() + + // No KeyedPartitioning anywhere in the subtree: the nested collection is returned as-is. + // Rebuilding it would make outputPartitioning of deeply nested collections (e.g. chains of + // same-key shuffle joins) quadratic in the nesting depth. + val hashCollection = PartitioningCollection.fromPartitionings( + Seq(HashPartitioning(Seq(x), 10), HashPartitioning(Seq(y), 10))) + val wrapped = PartitioningCollection.fromPartitionings( + Seq(hashCollection, HashPartitioning(Seq(y), 10))) + assert(wrapped.partitionings.head eq hashCollection) + + // KeyedPartitionings already share the canonical partitionKeys reference: also as-is. + val kpX = KeyedPartitioning(Seq(x), Seq(InternalRow(1), InternalRow(2), InternalRow(3))) + val kpY = kpX.copy(expressions = Seq(y)) + val keyedCollection = PartitioningCollection.fromPartitionings(Seq(kpX, kpY)) + val keyedWrapped = PartitioningCollection.fromPartitionings(Seq(keyedCollection, kpX)) + assert(keyedWrapped.partitionings.head eq keyedCollection) + } + + test("SPARK-56877: fromPartitionings interns partitionKeys across nested collections") { + val x = AttributeReference("x", IntegerType)() + val y = AttributeReference("y", IntegerType)() + + val kpX = KeyedPartitioning(Seq(x), Seq(InternalRow(1), InternalRow(2))) + val nested = PartitioningCollection.fromPartitionings(Seq(kpX)) + // Structurally equal but reference-distinct partitionKeys. + val kpY = KeyedPartitioning(Seq(y), Seq(InternalRow(1), InternalRow(2))) + assert(kpX.partitionKeys ne kpY.partitionKeys) + + val combined = PartitioningCollection.fromPartitionings(Seq(nested, kpY)) + val interned = combined.partitionings.last.asInstanceOf[KeyedPartitioning] + assert(interned.partitionKeys eq kpX.partitionKeys) + } + + test("SPARK-56877: PartitioningCollection enforces the invariant through nesting") { + val x = AttributeReference("x", IntegerType)() + val y = AttributeReference("y", IntegerType)() + + val kpX = KeyedPartitioning(Seq(x), Seq(InternalRow(1), InternalRow(2))) + val nested = PartitioningCollection.fromPartitionings(Seq(kpX)) + + val kpY = KeyedPartitioning(Seq(y), Seq(InternalRow(1), InternalRow(2))) + val refMismatch = intercept[IllegalArgumentException] { + PartitioningCollection(Seq(nested, kpY)) + } + assert(refMismatch.getMessage.contains("share the same partitionKeys reference")) + + val kpXY = KeyedPartitioning(Seq(x, y), Seq(InternalRow(1, 1), InternalRow(2, 2))) + val arityMismatch = intercept[IllegalArgumentException] { + PartitioningCollection(Seq(nested, kpXY)) + } + assert(arityMismatch.getMessage.contains("matching expression arity")) + } }