Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.sql.execution.datasources.v2

import java.util.concurrent.ConcurrentHashMap

import scala.language.existentials

import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
Expand All @@ -34,6 +35,19 @@ import org.apache.spark.util.ArrayImplicits._
class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition])
extends Partition with Serializable

/**
* Holds the state for a reader in a task, used by the completion listener to access the most
* recently created reader and iterator for final metrics updates and cleanup.
*
* When `compute()` is called multiple times for the same task (e.g., when DataSourceRDD is
* coalesced), this state is updated on each call to track the most recent reader. The task
* completion listener then uses this most recent reader for final cleanup and metrics reporting.
*
* @param reader The partition reader
* @param iterator The metrics iterator wrapping the reader
*/
private case class ReaderState(reader: PartitionReader[_], iterator: MetricsIterator[_])

// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for
// columnar scan.
class DataSourceRDD(
Expand All @@ -44,6 +58,11 @@ class DataSourceRDD(
customMetrics: Map[String, SQLMetric])
extends RDD[InternalRow](sc, Nil) {

// Map from task attempt ID to the most recently created ReaderState for that task.
// When compute() is called multiple times for the same task (due to coalescing), the map entry
// is updated each time so the completion listener always closes the last reader.
@transient private lazy val taskReaderStates = new ConcurrentHashMap[Long, ReaderState]()

override protected def getPartitions: Array[Partition] = {
inputPartitions.zipWithIndex.map {
case (inputPartitions, index) => new DataSourceRDDPartition(index, inputPartitions)
Expand All @@ -56,20 +75,34 @@ class DataSourceRDD(
}

override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val taskAttemptId = context.taskAttemptId()

// Add completion listener only once per task attempt. When compute() is called a second time
// for the same task (e.g., due to coalescing), the first call will have already put a
// ReaderState into taskReaderStates, so containsKey returns true and we skip this block.
if (!taskReaderStates.containsKey(taskAttemptId)) {
context.addTaskCompletionListener[Unit] { ctx =>
// In case of early stopping before consuming the entire iterator,
// we need to do one more metric update at the end of the task.
try {
val readerState = taskReaderStates.get(ctx.taskAttemptId())
if (readerState != null) {
CustomMetrics.updateMetrics(
readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
readerState.iterator.forceUpdateMetrics()
readerState.reader.close()
}
} finally {
taskReaderStates.remove(ctx.taskAttemptId())
}
}
}

val iterator = new Iterator[Object] {
private val inputPartitions = castPartition(split).inputPartitions
private var currentIter: Option[Iterator[Object]] = None
private var currentIndex: Int = 0

private val partitionMetricCallback = new PartitionMetricCallback(customMetrics)

// In case of early stopping before consuming the entire iterator,
// we need to do one more metric update at the end of the task.
context.addTaskCompletionListener[Unit] { _ =>
partitionMetricCallback.execute()
}

override def hasNext: Boolean = currentIter.exists(_.hasNext) || advanceToNextIter()

override def next(): Object = {
Expand Down Expand Up @@ -97,9 +130,18 @@ class DataSourceRDD(
(iter, rowReader)
}

// Once we advance to the next partition, update the metric callback for early finish
val previousMetrics = partitionMetricCallback.advancePartition(iter, reader)
previousMetrics.foreach(reader.initMetricsValues)
// Flush metrics and close the previous reader before advancing to the next one.
// Pass the accumulated metrics to the new reader so they carry forward correctly.
val prevState = taskReaderStates.get(taskAttemptId)
if (prevState != null) {
val metrics = prevState.reader.currentMetricsValues
CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
reader.initMetricsValues(metrics)
prevState.reader.close()
}

// Update the map so the completion listener always references the latest reader.
taskReaderStates.put(taskAttemptId, ReaderState(reader, iter))

currentIter = Some(iter)
hasNext
Expand All @@ -115,35 +157,6 @@ class DataSourceRDD(
}
}

private class PartitionMetricCallback
(customMetrics: Map[String, SQLMetric]) {
private var iter: MetricsIterator[_] = null
private var reader: PartitionReader[_] = null

def advancePartition(
iter: MetricsIterator[_],
reader: PartitionReader[_]): Option[Array[CustomTaskMetric]] = {
val metrics = execute()

this.iter = iter
this.reader = reader

metrics
}

def execute(): Option[Array[CustomTaskMetric]] = {
if (iter != null && reader != null) {
val metrics = reader.currentMetricsValues
CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
iter.forceUpdateMetrics()
reader.close()
Some(metrics)
} else {
None
}
}
}

private class PartitionIterator[T](
reader: PartitionReader[T],
customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2840,6 +2840,21 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
assert(metrics("number of rows read") == "3")
}

test("SPARK-55619: Custom metrics of coalesced partitions") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(2, 'bb', 10.0, cast('2021-01-01' as timestamp))")

val metrics = runAndFetchMetrics {
val df = sql(s"SELECT * FROM testcat.ns.$items").coalesce(1)
df.collect()
}
assert(metrics("number of rows read") == "2")
}

test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " +
"are less than cluster keys") {
withSQLConf(
Expand Down