Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -796,21 +796,44 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])

checkKeyedPartitioningInvariant()

/**
* First [[KeyedPartitioning]] reachable from this collection through direct members or nested
* collections, if any. Since every collection validates the invariant on construction, this
* single representative stands for all [[KeyedPartitioning]]s in the subtree: they all share
* its `partitionKeys` reference and expression arity. The invariant check forces this lazy val
* during construction, so it is only recomputed after deserialization.
*/
@transient private[physical] lazy val firstKeyedPartitioning: Option[KeyedPartitioning] =
partitionings.view.map {
case k: KeyedPartitioning => Some(k)
case pc: PartitioningCollection => pc.firstKeyedPartitioning
case _ => None
}.collectFirst { case Some(k) => k }

/**
* Nested collections already enforced the invariant on their own construction, so comparing one
* representative per direct member against [[firstKeyedPartitioning]] validates the whole
* subtree without walking it. Keeping this check O(partitionings.size) matters: join
* `outputPartitioning` builds these collections afresh on every call, and plans chaining many
* same-key joins nest them linearly deep.
*/
private def checkKeyedPartitioningInvariant(): Unit = {
var first: KeyedPartitioning = null
foreach {
case k: KeyedPartitioning =>
if (first == null) {
first = k
} else {
require(k.expressions.length == first.expressions.length,
firstKeyedPartitioning.foreach { first =>
partitionings.foreach { p =>
val representative = p match {
case k: KeyedPartitioning => k
case pc: PartitioningCollection => pc.firstKeyedPartitioning.orNull
case _ => null
}
if (representative != null && (representative ne first)) {
require(representative.expressions.length == first.expressions.length,
"All KeyedPartitionings in a PartitioningCollection must have matching expression " +
"arity")
require(k.partitionKeys eq first.partitionKeys,
require(representative.partitionKeys eq first.partitionKeys,
"All KeyedPartitionings in a PartitioningCollection must share the same " +
"partitionKeys reference")
}
case _ =>
}
}
}

Expand Down Expand Up @@ -868,7 +891,24 @@ object PartitioningCollection {
} else {
k
}
case pc: PartitioningCollection => new PartitioningCollection(pc.partitionings.map(intern))
case pc: PartitioningCollection =>
pc.firstKeyedPartitioning match {
// No KeyedPartitioning anywhere in this subtree: nothing to intern. Returning the
// collection as-is keeps repeated outputPartitioning computations over deeply nested
// collections (e.g. chains of same-key joins) O(1) per level.
case None => pc
case Some(representative) if canonicalKeys == null =>
canonicalKeys = representative.partitionKeys
pc
// The collection's own invariant guarantees all its KeyedPartitionings share the
// representative's `partitionKeys` reference, so reference-equality of the
// representative's keys means the whole subtree is already interned.
case Some(representative) if representative.partitionKeys eq canonicalKeys => pc
case Some(representative) =>
require(representative.partitionKeys == canonicalKeys,
"All KeyedPartitionings in a PartitioningCollection must have equal partitionKeys")
new PartitioningCollection(pc.partitionings.map(intern))
}
case other => other
}
new PartitioningCollection(partitionings.map(intern))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,4 +410,60 @@ class DistributionSuite extends SparkFunSuite {
assert(groupedKP.isGrouped)
checkSatisfied(groupedKP, ClusteredDistribution(Seq(x)), true)
}

test("SPARK-56877: fromPartitionings reuses already-consistent nested collections") {
val x = AttributeReference("x", IntegerType)()
val y = AttributeReference("y", IntegerType)()

// No KeyedPartitioning anywhere in the subtree: the nested collection is returned as-is.
// Rebuilding it would make outputPartitioning of deeply nested collections (e.g. chains of
// same-key shuffle joins) quadratic in the nesting depth.
val hashCollection = PartitioningCollection.fromPartitionings(
Seq(HashPartitioning(Seq(x), 10), HashPartitioning(Seq(y), 10)))
val wrapped = PartitioningCollection.fromPartitionings(
Seq(hashCollection, HashPartitioning(Seq(y), 10)))
assert(wrapped.partitionings.head eq hashCollection)

// KeyedPartitionings already share the canonical partitionKeys reference: also as-is.
val kpX = KeyedPartitioning(Seq(x), Seq(InternalRow(1), InternalRow(2), InternalRow(3)))
val kpY = kpX.copy(expressions = Seq(y))
val keyedCollection = PartitioningCollection.fromPartitionings(Seq(kpX, kpY))
val keyedWrapped = PartitioningCollection.fromPartitionings(Seq(keyedCollection, kpX))
assert(keyedWrapped.partitionings.head eq keyedCollection)
}

test("SPARK-56877: fromPartitionings interns partitionKeys across nested collections") {
val x = AttributeReference("x", IntegerType)()
val y = AttributeReference("y", IntegerType)()

val kpX = KeyedPartitioning(Seq(x), Seq(InternalRow(1), InternalRow(2)))
val nested = PartitioningCollection.fromPartitionings(Seq(kpX))
// Structurally equal but reference-distinct partitionKeys.
val kpY = KeyedPartitioning(Seq(y), Seq(InternalRow(1), InternalRow(2)))
assert(kpX.partitionKeys ne kpY.partitionKeys)

val combined = PartitioningCollection.fromPartitionings(Seq(nested, kpY))
val interned = combined.partitionings.last.asInstanceOf[KeyedPartitioning]
assert(interned.partitionKeys eq kpX.partitionKeys)
}

test("SPARK-56877: PartitioningCollection enforces the invariant through nesting") {
val x = AttributeReference("x", IntegerType)()
val y = AttributeReference("y", IntegerType)()

val kpX = KeyedPartitioning(Seq(x), Seq(InternalRow(1), InternalRow(2)))
val nested = PartitioningCollection.fromPartitionings(Seq(kpX))

val kpY = KeyedPartitioning(Seq(y), Seq(InternalRow(1), InternalRow(2)))
val refMismatch = intercept[IllegalArgumentException] {
PartitioningCollection(Seq(nested, kpY))
}
assert(refMismatch.getMessage.contains("share the same partitionKeys reference"))

val kpXY = KeyedPartitioning(Seq(x, y), Seq(InternalRow(1, 1), InternalRow(2, 2)))
val arityMismatch = intercept[IllegalArgumentException] {
PartitioningCollection(Seq(nested, kpXY))
}
assert(arityMismatch.getMessage.contains("matching expression arity"))
}
}