Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 40 additions & 12 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,19 @@ private[spark] class DAGScheduler(

private[spark] val jobIdToQueryExecutionId = new ConcurrentHashMap[Int, java.lang.Long]()

// The maps below back the test-only INJECT_SHUFFLE_FETCH_FAILURES machinery. They are always
// allocated rather than gated on `Utils.isTesting`: that helper reads the mutable
// `spark.testing` system property, so it can return a different value when this DAGScheduler is
// constructed than at the later use-sites. A construction-time `else null` would then be
// dereferenced by a use-site that re-checks `Utils.isTesting` and sees `true`, throwing an NPE
// that crashes the event loop. The maps are only ever populated inside the config-gated test
// paths, so in production they stay empty and carry no behavioral cost beyond an empty map.
// The maps below back the test-only INJECT_SHUFFLE_FETCH_FAILURES machinery, keyed by the
// globally-unique (never-reused) shuffleId. They are always allocated rather than gated on
// `Utils.isTesting`: that helper reads the mutable `spark.testing` system property, so it can
// return a different value when this DAGScheduler is constructed than at the later use-sites.
// A construction-time `else null` would then be dereferenced by a use-site that re-checks
// `Utils.isTesting` and sees `true`, throwing an NPE that crashes the event loop. The maps are
// only ever populated inside the config-gated test paths, so in production they stay empty and
// carry no behavioral cost beyond an empty map. Entries are evicted when the shuffle's map
// outputs are unregistered (via the CleanerListener attached lazily in
// ensureInjectShuffleFetchFailuresCleanerListenerForTest), not on stage removal: under AQE each
// Exchange is materialized as its own map-stage job whose stage is removed before the consuming
// stage runs, so evicting on stage removal would drop a pending corruption before its consumer
// is ever submitted.

// For INJECT_SHUFFLE_FETCH_FAILURES: per-shuffleId, the stage attempt whose partition-0 task
// we corrupted. Read to (a) avoid re-corrupting that partition on recompute, and (b) decide
Expand All @@ -208,6 +214,32 @@ private[spark] class DAGScheduler(
private val injectShuffleFetchFailuresDownstreamSuccessCount: ConcurrentHashMap[Int, Int] =
new ConcurrentHashMap[Int, Int]()

// Whether the CleanerListener that evicts the injectShuffleFetchFailures* maps on shuffle
// cleanup has been attached. Attached lazily (not in the constructor) because sc.cleaner is
// created after the DAGScheduler.
@volatile private var injectShuffleFetchFailuresCleanerAttached = false

// Lazily attach a CleanerListener that drops a shuffle's injectShuffleFetchFailures* entries
// when its map outputs are unregistered. Called from the test-gated injection path only, so it
// never runs in production. Runs on the single-threaded event loop, hence no extra locking.
private def ensureInjectShuffleFetchFailuresCleanerListenerForTest(): Unit = {
if (injectShuffleFetchFailuresCleanerAttached) return
sc.cleaner.foreach { cleaner =>
cleaner.attachListener(new CleanerListener {
override def rddCleaned(rddId: Int): Unit = {}
override def shuffleCleaned(shuffleId: Int): Unit = {
injectShuffleFetchFailuresCorruptedAttempt.remove(shuffleId)
injectShuffleFetchFailuresPendingDelayedCorruption.remove(shuffleId)
injectShuffleFetchFailuresDownstreamSuccessCount.remove(shuffleId)
}
override def broadcastCleaned(broadcastId: Long): Unit = {}
override def accumCleaned(accId: Long): Unit = {}
override def checkpointCleaned(rddId: Long): Unit = {}
})
injectShuffleFetchFailuresCleanerAttached = true
}
}

// Build the bogus BlockManagerId used by INJECT_SHUFFLE_FETCH_FAILURES to mark a corrupted
// MapStatus: keeps the original host/port/topology so the consumer's locality preference
// resolves to a real host; only the executorId is INVALID_EXECUTOR_ID, so any fetch from
Expand Down Expand Up @@ -975,11 +1007,6 @@ private[spark] class DAGScheduler(
}
for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) {
shuffleIdToMapStage.remove(k)
if (Utils.isTesting) {
injectShuffleFetchFailuresCorruptedAttempt.remove(k)
injectShuffleFetchFailuresPendingDelayedCorruption.remove(k)
injectShuffleFetchFailuresDownstreamSuccessCount.remove(k)
}
}
if (waitingStages.contains(stage)) {
logDebug("Removing stage %d from waiting set.".format(stageId))
Expand Down Expand Up @@ -1676,6 +1703,7 @@ private[spark] class DAGScheduler(
*/
private def shouldCorruptShuffleOutputForTest(shuffleId: Int, task: Task[_]): Boolean = {
if (task.partitionId != 0) return false
ensureInjectShuffleFetchFailuresCleanerListenerForTest()
val recorded = injectShuffleFetchFailuresCorruptedAttempt.computeIfAbsent(
shuffleId, _ => task.stageAttemptId)
recorded == task.stageAttemptId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2703,12 +2703,13 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
}

test("metric values are stable across stage retries") {
// The join in the MERGE plan introduces a shuffle (with broadcast disabled), and the
// DAGScheduler corrupts the first attempt of every upstream shuffle map stage. Note:
// the current fetch-failure injection does not retry the MergeRowsExec/writer stage,
// so this test passes equally well with plain SQLMetric — it only exercises the
// SLAM-aware read path. Follow-up #55738 will add infra to actually retry the writer
// stage and exercise the SLAM behavior end-to-end for MERGE.
// INJECT_SHUFFLE_FETCH_FAILURES corrupts the partition-0 task of the first successful
// attempt of every shuffle map stage, so a downstream stage FetchFails and the producer
// re-runs. For the metadata variants of MERGE - where the writer's
// `RequiresDistributionAndOrdering` forces a re-shuffle between MergeRowsExec and the
// writer - MergeRowsExec sits in a non-leaf shuffle map stage and therefore re-runs with
// the same metric instances, double-counting the per-row increments. SQLLastAttemptMetric
// reports only the last attempt, so `MergeSummary` is still correct.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
Expand All @@ -2720,17 +2721,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
val sourceDF = Seq(1, 2, 10).toDF("pk")
sourceDF.createOrReplaceTempView("source")

withSparkContextConf(
val mergeExec = withSparkContextConf(
config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") {
sql(
findMergeExec {
s"""MERGE INTO $tableNameAsString t
|USING source s
|ON t.pk = s.pk
|WHEN MATCHED THEN
| UPDATE SET salary = salary + 100
|WHEN NOT MATCHED THEN
| INSERT (pk, salary, dep) VALUES (s.pk, 999, 'unknown')
|""".stripMargin)
|""".stripMargin
}
}

val mergeSummary = getMergeSummary()
Expand All @@ -2743,6 +2745,22 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
assert(mergeSummary.numTargetRowsNotMatchedBySourceUpdated === 0L)
assert(mergeSummary.numTargetRowsNotMatchedBySourceDeleted === 0L)

// For metadata variants, MergeRowsExec lives in a non-leaf shuffle map stage that the
// fetch-failure injection forces to re-run, so the raw per-MergeRowsExec accumulator
// (`metric.value`) overcounts. This doubles as a direct check that a retry actually
// fired. SLAM-aware `MergeSummary` (asserted above) is correct.
// For noMetadata variants, MergeRowsExec is in the result stage and is not re-run by an
// upstream injection, so there is no overcounting metric to assert.
if (!noMetadata) {
val rawUpdated = mergeExec.metrics("numTargetRowsUpdated").value
assert(rawUpdated > 2L,
s"Expected MergeRowsExec.numTargetRowsUpdated to overcount under fetch-failure " +
s"injection (got $rawUpdated)")
val rawMatchedUpdated = mergeExec.metrics("numTargetRowsMatchedUpdated").value
assert(rawMatchedUpdated > 2L,
s"Expected numTargetRowsMatchedUpdated to overcount (got $rawMatchedUpdated)")
}

checkAnswer(
sql(s"SELECT pk, salary FROM $tableNameAsString ORDER BY pk"),
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ abstract class RowLevelOperationSuiteBase
Collections.emptyMap[String, String]
}

/** True for the *NoMetadata* test variants - the writer doesn't request any required
* distribution / ordering and so MergeRowsExec / writer can run in the same stage as the
* preceding join. */
protected def noMetadata: Boolean =
extraTableProps.getOrDefault("no-metadata", "false") == "true"

protected def catalog: InMemoryRowLevelOperationTableCatalog = {
val catalog = spark.sessionState.catalogManager.catalog("cat")
catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,14 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase {
}

test("metric values are stable across stage retries") {
// Force a shuffle in the UPDATE plan via an IN-subquery (with broadcast disabled), then
// have the DAGScheduler corrupt the first attempt of every upstream shuffle map stage.
// Note: the current fetch-failure injection does not retry the writer stage, so this
// test passes equally well with plain SQLMetric — it only exercises the SLAM-aware
// read path. Follow-up #55738 will add infra to actually retry the writer stage and
// exercise the SLAM behavior end-to-end for UPDATE.
// INJECT_SHUFFLE_FETCH_FAILURES corrupts the partition-0 task of the first successful

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing here asserts a retry actually fired. With the writer single-counting by design, the test passes even if the injection silently stops retrying — the vacuous-pass gap this PR closes for metadata MERGE. Consider asserting an upstream raw-metric overcount, or note that the new infra-level AQE test is what guards retry-fires. (Same for the MERGE if (!noMetadata) path.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, unfortunately it's still impossible to assert for overcounts in result stage metrics, because Spark does not support any retries in this stage. But the infra PR at least added more "interesting" scenarios of restarts, and it added coverage for Merge with metadata.
The new infra-level AQE tests checks that retry fires - it asserts overcount in stage1 and stage2, but not stage3 (again, because there can't be restarts in result stage).

// attempt of every shuffle map stage, so a downstream stage FetchFails and the producer
// re-runs. UPDATE writer-side metrics live on the result stage (`metric.add(N)` at
// end-of-task in WritingSparkTask), and ResultStage.findMissingPartitions only re-runs
// partitions that haven't successfully completed, so the writer accumulator single-counts;
// this test is regression coverage that retries don't break the SLAM-aware `UpdateSummary`.
// It does not independently assert that a retry fired (there is no overcounting metric to
// observe on the result stage).
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,64 @@ class MetricsFailureInjectionSuite
}
}

test("Three stage metrics block failure injection with AQE") {
// Same as the previous test but with AQE enabled. Under AQE each Exchange is materialized

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test drops four assertions its non-AQE sibling has — stage3Metric.value === 5 and the three lastAttemptValueForDataset(finalDf) checks — yet the comment calls it the same test. If they don't hold under AQE (e.g. dataset lookup through AdaptiveSparkPlanExec), say so here; otherwise restore them so the AQE path gets equal coverage.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Told my agent to stop being sloppy and readd the assertion :-).

// as its own map-stage job, which exercises a different DAGScheduler path than the
// AQE-disabled variant: the injection's deferred corruption must survive across those
// per-shuffle jobs for the downstream FetchFailed (and thus the producer re-run) to fire.
val stage1Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 1 counter")
val stage2Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 2 counter")
val stage3Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 3 counter")
val stage1SLAMetric =
SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 1 SLAM")
val stage2SLAMetric =
SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 2 SLAM")
val stage3SLAMetric =
SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 3 SLAM")

withTable("primary_table", "secondary_table") {
setUpTestTable("primary_table")
setUpTestTable("secondary_table")
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
withSparkContextConf(
config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") {
val stage1MetricsExpr = incrementMetrics(Seq(stage1Metric, stage1SLAMetric))
val stage1 = spark.read.table("primary_table")
.filter(Column(stage1MetricsExpr))
val stage2MetricsExpr = incrementMetrics(Seq(stage2Metric, stage2SLAMetric))
val stage2 = stage1.join(
spark.read.table("secondary_table"),
usingColumn = "id",
joinType = "fullOuter")
.filter(Column(stage2MetricsExpr))
val stage3MetricsExpr = incrementMetrics(Seq(stage3Metric, stage3SLAMetric))
val stage3 = stage2
.groupBy("primary_table.low_cardinality_col")
.count()
.filter(Column(stage3MetricsExpr))
val finalDf = stage3.as[(Int, Long)]
val result = finalDf.collect()
assert(result.toMap === (0 until 5).map(v => (v, 300 / 5)).toMap)

// Both the leaf stage 1 and the non-leaf stage 2 get their first successful attempt
// corrupted and re-run, so their raw counters overcount. SLAM reports only the last
// successful attempt per RDD.
assert(stage1Metric.value > 300, s"stage1Metric=${stage1Metric.value}")
assert(stage2Metric.value > 300, s"stage2Metric=${stage2Metric.value}")
assert(stage3Metric.value === 5)

assert(stage1SLAMetric.lastAttemptValueForHighestRDDId() === Some(300))
assert(stage2SLAMetric.lastAttemptValueForHighestRDDId() === Some(300))
assert(stage3SLAMetric.lastAttemptValueForHighestRDDId() === Some(5))

assert(stage1SLAMetric.lastAttemptValueForDataset(finalDf) === Some(300))
assert(stage2SLAMetric.lastAttemptValueForDataset(finalDf) === Some(300))
assert(stage3SLAMetric.lastAttemptValueForDataset(finalDf) === Some(5))
}
}
}
}

test("Three stage metrics force-checksum-mismatch on recompute") {
// INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE additionally flags the recompute of the
// partition-0 task as a checksum mismatch. The DAGScheduler then runs
Expand Down