From 5750807eaad5e1ff1732e35a2c0c1e0aa7911034 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 9 Jun 2026 18:35:36 +0000 Subject: [PATCH 1/2] [SPARK-56877][SQL][FOLLOWUP] Make PartitioningCollection invariant check O(1) per nesting level The constructor invariant added in SPARK-56877 walked the entire partitioning tree via TreeNode.foreach on every PartitioningCollection construction, and fromPartitionings rebuilt every nested collection even when nothing changed. Join outputPartitioning builds these collections afresh on every call and plans chaining many same-key shuffle joins nest them linearly deep, so planning cost became cubic in the join-chain length (EnsureRequirements planning time grew ~9x on a 125-join benchmark query). Validate the invariant using one cached representative KeyedPartitioning per nested collection instead: since every nested collection already enforced the invariant on its own construction, comparing representatives of direct members covers the whole subtree. fromPartitionings uses the same representative to skip already-consistent subtrees in O(1), only rebuilding when interning is actually needed. Co-authored-by: Isaac --- .../plans/physical/partitioning.scala | 55 +++++++++++++++--- .../sql/catalyst/DistributionSuite.scala | 56 +++++++++++++++++++ 2 files changed, 104 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index aeacdaec7a8de..596688dfe8c31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -796,21 +796,45 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) checkKeyedPartitioningInvariant() + /** + * First [[KeyedPartitioning]] reachable from this collection through direct members or nested + * collections, if any. Since every collection validates the invariant on construction, this + * single representative stands for all [[KeyedPartitioning]]s in the subtree: they all share + * its `partitionKeys` reference and expression arity. + */ + @transient private[physical] lazy val firstKeyedPartitioning: Option[KeyedPartitioning] = + partitionings.view.map { + case k: KeyedPartitioning => Some(k) + case pc: PartitioningCollection => pc.firstKeyedPartitioning + case _ => None + }.collectFirst { case Some(k) => k } + + /** + * Nested collections already enforced the invariant on their own construction, so comparing one + * representative per direct member validates the whole subtree without walking it. Keeping this + * check O(partitionings.size) matters: join `outputPartitioning` builds these collections afresh + * on every call, and plans chaining many same-key joins nest them linearly deep. + */ private def checkKeyedPartitioningInvariant(): Unit = { var first: KeyedPartitioning = null - foreach { - case k: KeyedPartitioning => + partitionings.foreach { p => + val representative = p match { + case k: KeyedPartitioning => k + case pc: PartitioningCollection => pc.firstKeyedPartitioning.orNull + case _ => null + } + if (representative != null) { if (first == null) { - first = k + first = representative } else { - require(k.expressions.length == first.expressions.length, + require(representative.expressions.length == first.expressions.length, "All KeyedPartitionings in a PartitioningCollection must have matching expression " + "arity") - require(k.partitionKeys eq first.partitionKeys, + require(representative.partitionKeys eq first.partitionKeys, "All KeyedPartitionings in a PartitioningCollection must share the same " + "partitionKeys reference") } - case _ => + } } } @@ -868,7 +892,24 @@ object PartitioningCollection { } else { k } - case pc: PartitioningCollection => new PartitioningCollection(pc.partitionings.map(intern)) + case pc: PartitioningCollection => + pc.firstKeyedPartitioning match { + // No KeyedPartitioning anywhere in this subtree: nothing to intern. Returning the + // collection as-is keeps repeated outputPartitioning computations over deeply nested + // collections (e.g. chains of same-key joins) O(1) per level. + case None => pc + case Some(representative) if canonicalKeys == null => + canonicalKeys = representative.partitionKeys + pc + // The collection's own invariant guarantees all its KeyedPartitionings share the + // representative's `partitionKeys` reference, so reference-equality of the + // representative's keys means the whole subtree is already interned. + case Some(representative) if representative.partitionKeys eq canonicalKeys => pc + case Some(representative) => + require(representative.partitionKeys == canonicalKeys, + "All KeyedPartitionings in a PartitioningCollection must have equal partitionKeys") + new PartitioningCollection(pc.partitionings.map(intern)) + } case other => other } new PartitioningCollection(partitionings.map(intern)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 88faf70103186..2e9e1270b3fc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -410,4 +410,60 @@ class DistributionSuite extends SparkFunSuite { assert(groupedKP.isGrouped) checkSatisfied(groupedKP, ClusteredDistribution(Seq(x)), true) } + + test("SPARK-56877: fromPartitionings reuses already-consistent nested collections") { + val x = AttributeReference("x", IntegerType)() + val y = AttributeReference("y", IntegerType)() + + // No KeyedPartitioning anywhere in the subtree: the nested collection is returned as-is. + // Rebuilding it would make outputPartitioning of deeply nested collections (e.g. chains of + // same-key shuffle joins) quadratic in the nesting depth. + val hashCollection = PartitioningCollection.fromPartitionings( + Seq(HashPartitioning(Seq(x), 10), HashPartitioning(Seq(y), 10))) + val wrapped = PartitioningCollection.fromPartitionings( + Seq(hashCollection, HashPartitioning(Seq(y), 10))) + assert(wrapped.partitionings.head eq hashCollection) + + // KeyedPartitionings already share the canonical partitionKeys reference: also as-is. + val kpX = KeyedPartitioning(Seq(x), Seq(InternalRow(1), InternalRow(2), InternalRow(3))) + val kpY = kpX.copy(expressions = Seq(y)) + val keyedCollection = PartitioningCollection.fromPartitionings(Seq(kpX, kpY)) + val keyedWrapped = PartitioningCollection.fromPartitionings(Seq(keyedCollection, kpX)) + assert(keyedWrapped.partitionings.head eq keyedCollection) + } + + test("SPARK-56877: fromPartitionings interns partitionKeys across nested collections") { + val x = AttributeReference("x", IntegerType)() + val y = AttributeReference("y", IntegerType)() + + val kpX = KeyedPartitioning(Seq(x), Seq(InternalRow(1), InternalRow(2))) + val nested = PartitioningCollection.fromPartitionings(Seq(kpX)) + // Structurally equal but reference-distinct partitionKeys. + val kpY = KeyedPartitioning(Seq(y), Seq(InternalRow(1), InternalRow(2))) + assert(kpX.partitionKeys ne kpY.partitionKeys) + + val combined = PartitioningCollection.fromPartitionings(Seq(nested, kpY)) + val interned = combined.partitionings.last.asInstanceOf[KeyedPartitioning] + assert(interned.partitionKeys eq kpX.partitionKeys) + } + + test("SPARK-56877: PartitioningCollection enforces the invariant through nesting") { + val x = AttributeReference("x", IntegerType)() + val y = AttributeReference("y", IntegerType)() + + val kpX = KeyedPartitioning(Seq(x), Seq(InternalRow(1), InternalRow(2))) + val nested = PartitioningCollection.fromPartitionings(Seq(kpX)) + + val kpY = KeyedPartitioning(Seq(y), Seq(InternalRow(1), InternalRow(2))) + val refMismatch = intercept[IllegalArgumentException] { + PartitioningCollection(Seq(nested, kpY)) + } + assert(refMismatch.getMessage.contains("share the same partitionKeys reference")) + + val kpXY = KeyedPartitioning(Seq(x, y), Seq(InternalRow(1, 1), InternalRow(2, 2))) + val arityMismatch = intercept[IllegalArgumentException] { + PartitioningCollection(Seq(nested, kpXY)) + } + assert(arityMismatch.getMessage.contains("matching expression arity")) + } } From b5c7f55d0ace6b01ab2d488e133e6b0adc046b9f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jun 2026 22:53:23 +0000 Subject: [PATCH 2/2] Compute firstKeyedPartitioning once during construction Address review feedback: the invariant check was duplicating the scan that firstKeyedPartitioning performs. Now the check forces the lazy val and compares each direct member's representative against it, so the representative is computed once and stored at construction time. Co-authored-by: Isaac --- .../plans/physical/partitioning.scala | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 596688dfe8c31..d2bb12d2053aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -800,7 +800,8 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) * First [[KeyedPartitioning]] reachable from this collection through direct members or nested * collections, if any. Since every collection validates the invariant on construction, this * single representative stands for all [[KeyedPartitioning]]s in the subtree: they all share - * its `partitionKeys` reference and expression arity. + * its `partitionKeys` reference and expression arity. The invariant check forces this lazy val + * during construction, so it is only recomputed after deserialization. */ @transient private[physical] lazy val firstKeyedPartitioning: Option[KeyedPartitioning] = partitionings.view.map { @@ -811,22 +812,20 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) /** * Nested collections already enforced the invariant on their own construction, so comparing one - * representative per direct member validates the whole subtree without walking it. Keeping this - * check O(partitionings.size) matters: join `outputPartitioning` builds these collections afresh - * on every call, and plans chaining many same-key joins nest them linearly deep. + * representative per direct member against [[firstKeyedPartitioning]] validates the whole + * subtree without walking it. Keeping this check O(partitionings.size) matters: join + * `outputPartitioning` builds these collections afresh on every call, and plans chaining many + * same-key joins nest them linearly deep. */ private def checkKeyedPartitioningInvariant(): Unit = { - var first: KeyedPartitioning = null - partitionings.foreach { p => - val representative = p match { - case k: KeyedPartitioning => k - case pc: PartitioningCollection => pc.firstKeyedPartitioning.orNull - case _ => null - } - if (representative != null) { - if (first == null) { - first = representative - } else { + firstKeyedPartitioning.foreach { first => + partitionings.foreach { p => + val representative = p match { + case k: KeyedPartitioning => k + case pc: PartitioningCollection => pc.firstKeyedPartitioning.orNull + case _ => null + } + if (representative != null && (representative ne first)) { require(representative.expressions.length == first.expressions.length, "All KeyedPartitionings in a PartitioningCollection must have matching expression " + "arity")