diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4a0db18d328f1..3d615867a9278 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -179,13 +179,19 @@ private[spark] class DAGScheduler( private[spark] val jobIdToQueryExecutionId = new ConcurrentHashMap[Int, java.lang.Long]() - // The maps below back the test-only INJECT_SHUFFLE_FETCH_FAILURES machinery. They are always - // allocated rather than gated on `Utils.isTesting`: that helper reads the mutable - // `spark.testing` system property, so it can return a different value when this DAGScheduler is - // constructed than at the later use-sites. A construction-time `else null` would then be - // dereferenced by a use-site that re-checks `Utils.isTesting` and sees `true`, throwing an NPE - // that crashes the event loop. The maps are only ever populated inside the config-gated test - // paths, so in production they stay empty and carry no behavioral cost beyond an empty map. + // The maps below back the test-only INJECT_SHUFFLE_FETCH_FAILURES machinery, keyed by the + // globally-unique (never-reused) shuffleId. They are always allocated rather than gated on + // `Utils.isTesting`: that helper reads the mutable `spark.testing` system property, so it can + // return a different value when this DAGScheduler is constructed than at the later use-sites. + // A construction-time `else null` would then be dereferenced by a use-site that re-checks + // `Utils.isTesting` and sees `true`, throwing an NPE that crashes the event loop. The maps are + // only ever populated inside the config-gated test paths, so in production they stay empty and + // carry no behavioral cost beyond an empty map. Entries are evicted when the shuffle's map + // outputs are unregistered (via the CleanerListener attached lazily in + // ensureInjectShuffleFetchFailuresCleanerListenerForTest), not on stage removal: under AQE each + // Exchange is materialized as its own map-stage job whose stage is removed before the consuming + // stage runs, so evicting on stage removal would drop a pending corruption before its consumer + // is ever submitted. // For INJECT_SHUFFLE_FETCH_FAILURES: per-shuffleId, the stage attempt whose partition-0 task // we corrupted. Read to (a) avoid re-corrupting that partition on recompute, and (b) decide @@ -208,6 +214,32 @@ private[spark] class DAGScheduler( private val injectShuffleFetchFailuresDownstreamSuccessCount: ConcurrentHashMap[Int, Int] = new ConcurrentHashMap[Int, Int]() + // Whether the CleanerListener that evicts the injectShuffleFetchFailures* maps on shuffle + // cleanup has been attached. Attached lazily (not in the constructor) because sc.cleaner is + // created after the DAGScheduler. + @volatile private var injectShuffleFetchFailuresCleanerAttached = false + + // Lazily attach a CleanerListener that drops a shuffle's injectShuffleFetchFailures* entries + // when its map outputs are unregistered. Called from the test-gated injection path only, so it + // never runs in production. Runs on the single-threaded event loop, hence no extra locking. + private def ensureInjectShuffleFetchFailuresCleanerListenerForTest(): Unit = { + if (injectShuffleFetchFailuresCleanerAttached) return + sc.cleaner.foreach { cleaner => + cleaner.attachListener(new CleanerListener { + override def rddCleaned(rddId: Int): Unit = {} + override def shuffleCleaned(shuffleId: Int): Unit = { + injectShuffleFetchFailuresCorruptedAttempt.remove(shuffleId) + injectShuffleFetchFailuresPendingDelayedCorruption.remove(shuffleId) + injectShuffleFetchFailuresDownstreamSuccessCount.remove(shuffleId) + } + override def broadcastCleaned(broadcastId: Long): Unit = {} + override def accumCleaned(accId: Long): Unit = {} + override def checkpointCleaned(rddId: Long): Unit = {} + }) + injectShuffleFetchFailuresCleanerAttached = true + } + } + // Build the bogus BlockManagerId used by INJECT_SHUFFLE_FETCH_FAILURES to mark a corrupted // MapStatus: keeps the original host/port/topology so the consumer's locality preference // resolves to a real host; only the executorId is INVALID_EXECUTOR_ID, so any fetch from @@ -975,11 +1007,6 @@ private[spark] class DAGScheduler( } for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) { shuffleIdToMapStage.remove(k) - if (Utils.isTesting) { - injectShuffleFetchFailuresCorruptedAttempt.remove(k) - injectShuffleFetchFailuresPendingDelayedCorruption.remove(k) - injectShuffleFetchFailuresDownstreamSuccessCount.remove(k) - } } if (waitingStages.contains(stage)) { logDebug("Removing stage %d from waiting set.".format(stageId)) @@ -1676,6 +1703,7 @@ private[spark] class DAGScheduler( */ private def shouldCorruptShuffleOutputForTest(shuffleId: Int, task: Task[_]): Boolean = { if (task.partitionId != 0) return false + ensureInjectShuffleFetchFailuresCleanerListenerForTest() val recorded = injectShuffleFetchFailuresCorruptedAttempt.computeIfAbsent( shuffleId, _ => task.stageAttemptId) recorded == task.stageAttemptId diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 126b84b507caf..1dbc1001b5c08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -2703,12 +2703,13 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("metric values are stable across stage retries") { - // The join in the MERGE plan introduces a shuffle (with broadcast disabled), and the - // DAGScheduler corrupts the first attempt of every upstream shuffle map stage. Note: - // the current fetch-failure injection does not retry the MergeRowsExec/writer stage, - // so this test passes equally well with plain SQLMetric — it only exercises the - // SLAM-aware read path. Follow-up #55738 will add infra to actually retry the writer - // stage and exercise the SLAM behavior end-to-end for MERGE. + // INJECT_SHUFFLE_FETCH_FAILURES corrupts the partition-0 task of the first successful + // attempt of every shuffle map stage, so a downstream stage FetchFails and the producer + // re-runs. For the metadata variants of MERGE - where the writer's + // `RequiresDistributionAndOrdering` forces a re-shuffle between MergeRowsExec and the + // writer - MergeRowsExec sits in a non-leaf shuffle map stage and therefore re-runs with + // the same metric instances, double-counting the per-row increments. SQLLastAttemptMetric + // reports only the last attempt, so `MergeSummary` is still correct. withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -2720,9 +2721,9 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase val sourceDF = Seq(1, 2, 10).toDF("pk") sourceDF.createOrReplaceTempView("source") - withSparkContextConf( + val mergeExec = withSparkContextConf( config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") { - sql( + findMergeExec { s"""MERGE INTO $tableNameAsString t |USING source s |ON t.pk = s.pk @@ -2730,7 +2731,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase | UPDATE SET salary = salary + 100 |WHEN NOT MATCHED THEN | INSERT (pk, salary, dep) VALUES (s.pk, 999, 'unknown') - |""".stripMargin) + |""".stripMargin + } } val mergeSummary = getMergeSummary() @@ -2743,6 +2745,22 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(mergeSummary.numTargetRowsNotMatchedBySourceUpdated === 0L) assert(mergeSummary.numTargetRowsNotMatchedBySourceDeleted === 0L) + // For metadata variants, MergeRowsExec lives in a non-leaf shuffle map stage that the + // fetch-failure injection forces to re-run, so the raw per-MergeRowsExec accumulator + // (`metric.value`) overcounts. This doubles as a direct check that a retry actually + // fired. SLAM-aware `MergeSummary` (asserted above) is correct. + // For noMetadata variants, MergeRowsExec is in the result stage and is not re-run by an + // upstream injection, so there is no overcounting metric to assert. + if (!noMetadata) { + val rawUpdated = mergeExec.metrics("numTargetRowsUpdated").value + assert(rawUpdated > 2L, + s"Expected MergeRowsExec.numTargetRowsUpdated to overcount under fetch-failure " + + s"injection (got $rawUpdated)") + val rawMatchedUpdated = mergeExec.metrics("numTargetRowsMatchedUpdated").value + assert(rawMatchedUpdated > 2L, + s"Expected numTargetRowsMatchedUpdated to overcount (got $rawMatchedUpdated)") + } + checkAnswer( sql(s"SELECT pk, salary FROM $tableNameAsString ORDER BY pk"), Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 0c465969e347c..199b9ecbe0a07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -89,6 +89,12 @@ abstract class RowLevelOperationSuiteBase Collections.emptyMap[String, String] } + /** True for the *NoMetadata* test variants - the writer doesn't request any required + * distribution / ordering and so MergeRowsExec / writer can run in the same stage as the + * preceding join. */ + protected def noMetadata: Boolean = + extraTableProps.getOrDefault("no-metadata", "false") == "true" + protected def catalog: InMemoryRowLevelOperationTableCatalog = { val catalog = spark.sessionState.catalogManager.catalog("cat") catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index 6e9afe7abc97e..8eb314e00df81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -342,12 +342,14 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { } test("metric values are stable across stage retries") { - // Force a shuffle in the UPDATE plan via an IN-subquery (with broadcast disabled), then - // have the DAGScheduler corrupt the first attempt of every upstream shuffle map stage. - // Note: the current fetch-failure injection does not retry the writer stage, so this - // test passes equally well with plain SQLMetric — it only exercises the SLAM-aware - // read path. Follow-up #55738 will add infra to actually retry the writer stage and - // exercise the SLAM behavior end-to-end for UPDATE. + // INJECT_SHUFFLE_FETCH_FAILURES corrupts the partition-0 task of the first successful + // attempt of every shuffle map stage, so a downstream stage FetchFails and the producer + // re-runs. UPDATE writer-side metrics live on the result stage (`metric.add(N)` at + // end-of-task in WritingSparkTask), and ResultStage.findMissingPartitions only re-runs + // partitions that haven't successfully completed, so the writer accumulator single-counts; + // this test is regression coverage that retries don't break the SLAM-aware `UpdateSummary`. + // It does not independently assert that a retry fired (there is no overcounting metric to + // observe on the result stage). withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala index 6fc784f33815f..b704628b13eba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala @@ -420,6 +420,64 @@ class MetricsFailureInjectionSuite } } + test("Three stage metrics block failure injection with AQE") { + // Same as the previous test but with AQE enabled. Under AQE each Exchange is materialized + // as its own map-stage job, which exercises a different DAGScheduler path than the + // AQE-disabled variant: the injection's deferred corruption must survive across those + // per-shuffle jobs for the downstream FetchFailed (and thus the producer re-run) to fire. + val stage1Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 1 counter") + val stage2Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 2 counter") + val stage3Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 3 counter") + val stage1SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 1 SLAM") + val stage2SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 2 SLAM") + val stage3SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 3 SLAM") + + withTable("primary_table", "secondary_table") { + setUpTestTable("primary_table") + setUpTestTable("secondary_table") + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + withSparkContextConf( + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") { + val stage1MetricsExpr = incrementMetrics(Seq(stage1Metric, stage1SLAMetric)) + val stage1 = spark.read.table("primary_table") + .filter(Column(stage1MetricsExpr)) + val stage2MetricsExpr = incrementMetrics(Seq(stage2Metric, stage2SLAMetric)) + val stage2 = stage1.join( + spark.read.table("secondary_table"), + usingColumn = "id", + joinType = "fullOuter") + .filter(Column(stage2MetricsExpr)) + val stage3MetricsExpr = incrementMetrics(Seq(stage3Metric, stage3SLAMetric)) + val stage3 = stage2 + .groupBy("primary_table.low_cardinality_col") + .count() + .filter(Column(stage3MetricsExpr)) + val finalDf = stage3.as[(Int, Long)] + val result = finalDf.collect() + assert(result.toMap === (0 until 5).map(v => (v, 300 / 5)).toMap) + + // Both the leaf stage 1 and the non-leaf stage 2 get their first successful attempt + // corrupted and re-run, so their raw counters overcount. SLAM reports only the last + // successful attempt per RDD. + assert(stage1Metric.value > 300, s"stage1Metric=${stage1Metric.value}") + assert(stage2Metric.value > 300, s"stage2Metric=${stage2Metric.value}") + assert(stage3Metric.value === 5) + + assert(stage1SLAMetric.lastAttemptValueForHighestRDDId() === Some(300)) + assert(stage2SLAMetric.lastAttemptValueForHighestRDDId() === Some(300)) + assert(stage3SLAMetric.lastAttemptValueForHighestRDDId() === Some(5)) + + assert(stage1SLAMetric.lastAttemptValueForDataset(finalDf) === Some(300)) + assert(stage2SLAMetric.lastAttemptValueForDataset(finalDf) === Some(300)) + assert(stage3SLAMetric.lastAttemptValueForDataset(finalDf) === Some(5)) + } + } + } + } + test("Three stage metrics force-checksum-mismatch on recompute") { // INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE additionally flags the recompute of the // partition-0 task as a checksum mismatch. The DAGScheduler then runs