diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index fbf5c06fe051b..19a057c72506b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.concurrent.ConcurrentHashMap + import scala.language.existentials import org.apache.spark._ @@ -24,7 +26,6 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} @@ -34,6 +35,19 @@ import org.apache.spark.util.ArrayImplicits._ class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition]) extends Partition with Serializable +/** + * Holds the state for a reader in a task, used by the completion listener to access the most + * recently created reader and iterator for final metrics updates and cleanup. + * + * When `compute()` is called multiple times for the same task (e.g., when DataSourceRDD is + * coalesced), this state is updated on each call to track the most recent reader. The task + * completion listener then uses this most recent reader for final cleanup and metrics reporting. + * + * @param reader The partition reader + * @param iterator The metrics iterator wrapping the reader + */ +private case class ReaderState(reader: PartitionReader[_], iterator: MetricsIterator[_]) + // TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for // columnar scan. class DataSourceRDD( @@ -44,6 +58,11 @@ class DataSourceRDD( customMetrics: Map[String, SQLMetric]) extends RDD[InternalRow](sc, Nil) { + // Map from task attempt ID to the most recently created ReaderState for that task. + // When compute() is called multiple times for the same task (due to coalescing), the map entry + // is updated each time so the completion listener always closes the last reader. + @transient private lazy val taskReaderStates = new ConcurrentHashMap[Long, ReaderState]() + override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { case (inputPartitions, index) => new DataSourceRDDPartition(index, inputPartitions) @@ -56,20 +75,34 @@ class DataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val taskAttemptId = context.taskAttemptId() + + // Add completion listener only once per task attempt. When compute() is called a second time + // for the same task (e.g., due to coalescing), the first call will have already put a + // ReaderState into taskReaderStates, so containsKey returns true and we skip this block. + if (!taskReaderStates.containsKey(taskAttemptId)) { + context.addTaskCompletionListener[Unit] { ctx => + // In case of early stopping before consuming the entire iterator, + // we need to do one more metric update at the end of the task. + try { + val readerState = taskReaderStates.get(ctx.taskAttemptId()) + if (readerState != null) { + CustomMetrics.updateMetrics( + readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics) + readerState.iterator.forceUpdateMetrics() + readerState.reader.close() + } + } finally { + taskReaderStates.remove(ctx.taskAttemptId()) + } + } + } val iterator = new Iterator[Object] { private val inputPartitions = castPartition(split).inputPartitions private var currentIter: Option[Iterator[Object]] = None private var currentIndex: Int = 0 - private val partitionMetricCallback = new PartitionMetricCallback(customMetrics) - - // In case of early stopping before consuming the entire iterator, - // we need to do one more metric update at the end of the task. - context.addTaskCompletionListener[Unit] { _ => - partitionMetricCallback.execute() - } - override def hasNext: Boolean = currentIter.exists(_.hasNext) || advanceToNextIter() override def next(): Object = { @@ -97,9 +130,18 @@ class DataSourceRDD( (iter, rowReader) } - // Once we advance to the next partition, update the metric callback for early finish - val previousMetrics = partitionMetricCallback.advancePartition(iter, reader) - previousMetrics.foreach(reader.initMetricsValues) + // Flush metrics and close the previous reader before advancing to the next one. + // Pass the accumulated metrics to the new reader so they carry forward correctly. + val prevState = taskReaderStates.get(taskAttemptId) + if (prevState != null) { + val metrics = prevState.reader.currentMetricsValues + CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics) + reader.initMetricsValues(metrics) + prevState.reader.close() + } + + // Update the map so the completion listener always references the latest reader. + taskReaderStates.put(taskAttemptId, ReaderState(reader, iter)) currentIter = Some(iter) hasNext @@ -115,35 +157,6 @@ class DataSourceRDD( } } -private class PartitionMetricCallback - (customMetrics: Map[String, SQLMetric]) { - private var iter: MetricsIterator[_] = null - private var reader: PartitionReader[_] = null - - def advancePartition( - iter: MetricsIterator[_], - reader: PartitionReader[_]): Option[Array[CustomTaskMetric]] = { - val metrics = execute() - - this.iter = iter - this.reader = reader - - metrics - } - - def execute(): Option[Array[CustomTaskMetric]] = { - if (iter != null && reader != null) { - val metrics = reader.currentMetricsValues - CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics) - iter.forceUpdateMetrics() - reader.close() - Some(metrics) - } else { - None - } - } -} - private class PartitionIterator[T]( reader: PartitionReader[T], customMetrics: Map[String, SQLMetric]) extends Iterator[T] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 8cd55304d71c0..56bd028464e54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -2840,6 +2840,21 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(metrics("number of rows read") == "3") } + test("SPARK-55619: Custom metrics of coalesced partitions") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(2, 'bb', 10.0, cast('2021-01-01' as timestamp))") + + val metrics = runAndFetchMetrics { + val df = sql(s"SELECT * FROM testcat.ns.$items").coalesce(1) + df.collect() + } + assert(metrics("number of rows read") == "2") + } + test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " + "are less than cluster keys") { withSQLConf(