From 9871d0440ef77a3ece8c32280af7cf3fbd638efd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 20 Feb 2026 10:03:23 -0800 Subject: [PATCH] [SPARK-55619][SQL] Fix custom metrics in case of coalesced partitions Replace PartitionMetricCallback with a ConcurrentHashMap keyed by task attempt ID to correctly track reader state across multiple compute() calls when DataSourceRDD is coalesced. The completion listener is registered only once per task attempt, and metrics are flushed and carried forward between readers as partitions are advanced. Co-Authored-By: Peter Toth Co-Authored-By: Claude Sonnet 4.6 --- .../datasources/v2/DataSourceRDD.scala | 95 +++++++++++-------- .../KeyGroupedPartitioningSuite.scala | 15 +++ 2 files changed, 69 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index fbf5c06fe051b..19a057c72506b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.concurrent.ConcurrentHashMap + import scala.language.existentials import org.apache.spark._ @@ -24,7 +26,6 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} @@ -34,6 +35,19 @@ import org.apache.spark.util.ArrayImplicits._ class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition]) extends Partition with Serializable +/** + * Holds the state for a reader in a task, used by the completion listener to access the most + * recently created reader and iterator for final metrics updates and cleanup. + * + * When `compute()` is called multiple times for the same task (e.g., when DataSourceRDD is + * coalesced), this state is updated on each call to track the most recent reader. The task + * completion listener then uses this most recent reader for final cleanup and metrics reporting. + * + * @param reader The partition reader + * @param iterator The metrics iterator wrapping the reader + */ +private case class ReaderState(reader: PartitionReader[_], iterator: MetricsIterator[_]) + // TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for // columnar scan. class DataSourceRDD( @@ -44,6 +58,11 @@ class DataSourceRDD( customMetrics: Map[String, SQLMetric]) extends RDD[InternalRow](sc, Nil) { + // Map from task attempt ID to the most recently created ReaderState for that task. + // When compute() is called multiple times for the same task (due to coalescing), the map entry + // is updated each time so the completion listener always closes the last reader. + @transient private lazy val taskReaderStates = new ConcurrentHashMap[Long, ReaderState]() + override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { case (inputPartitions, index) => new DataSourceRDDPartition(index, inputPartitions) @@ -56,20 +75,34 @@ class DataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val taskAttemptId = context.taskAttemptId() + + // Add completion listener only once per task attempt. When compute() is called a second time + // for the same task (e.g., due to coalescing), the first call will have already put a + // ReaderState into taskReaderStates, so containsKey returns true and we skip this block. + if (!taskReaderStates.containsKey(taskAttemptId)) { + context.addTaskCompletionListener[Unit] { ctx => + // In case of early stopping before consuming the entire iterator, + // we need to do one more metric update at the end of the task. + try { + val readerState = taskReaderStates.get(ctx.taskAttemptId()) + if (readerState != null) { + CustomMetrics.updateMetrics( + readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics) + readerState.iterator.forceUpdateMetrics() + readerState.reader.close() + } + } finally { + taskReaderStates.remove(ctx.taskAttemptId()) + } + } + } val iterator = new Iterator[Object] { private val inputPartitions = castPartition(split).inputPartitions private var currentIter: Option[Iterator[Object]] = None private var currentIndex: Int = 0 - private val partitionMetricCallback = new PartitionMetricCallback(customMetrics) - - // In case of early stopping before consuming the entire iterator, - // we need to do one more metric update at the end of the task. - context.addTaskCompletionListener[Unit] { _ => - partitionMetricCallback.execute() - } - override def hasNext: Boolean = currentIter.exists(_.hasNext) || advanceToNextIter() override def next(): Object = { @@ -97,9 +130,18 @@ class DataSourceRDD( (iter, rowReader) } - // Once we advance to the next partition, update the metric callback for early finish - val previousMetrics = partitionMetricCallback.advancePartition(iter, reader) - previousMetrics.foreach(reader.initMetricsValues) + // Flush metrics and close the previous reader before advancing to the next one. + // Pass the accumulated metrics to the new reader so they carry forward correctly. + val prevState = taskReaderStates.get(taskAttemptId) + if (prevState != null) { + val metrics = prevState.reader.currentMetricsValues + CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics) + reader.initMetricsValues(metrics) + prevState.reader.close() + } + + // Update the map so the completion listener always references the latest reader. + taskReaderStates.put(taskAttemptId, ReaderState(reader, iter)) currentIter = Some(iter) hasNext @@ -115,35 +157,6 @@ class DataSourceRDD( } } -private class PartitionMetricCallback - (customMetrics: Map[String, SQLMetric]) { - private var iter: MetricsIterator[_] = null - private var reader: PartitionReader[_] = null - - def advancePartition( - iter: MetricsIterator[_], - reader: PartitionReader[_]): Option[Array[CustomTaskMetric]] = { - val metrics = execute() - - this.iter = iter - this.reader = reader - - metrics - } - - def execute(): Option[Array[CustomTaskMetric]] = { - if (iter != null && reader != null) { - val metrics = reader.currentMetricsValues - CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics) - iter.forceUpdateMetrics() - reader.close() - Some(metrics) - } else { - None - } - } -} - private class PartitionIterator[T]( reader: PartitionReader[T], customMetrics: Map[String, SQLMetric]) extends Iterator[T] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 8cd55304d71c0..56bd028464e54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -2840,6 +2840,21 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(metrics("number of rows read") == "3") } + test("SPARK-55619: Custom metrics of coalesced partitions") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(2, 'bb', 10.0, cast('2021-01-01' as timestamp))") + + val metrics = runAndFetchMetrics { + val df = sql(s"SELECT * FROM testcat.ns.$items").coalesce(1) + df.collect() + } + assert(metrics("number of rows read") == "2") + } + test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " + "are less than cluster keys") { withSQLConf(