diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 26f4069994943..1d101d120de8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression, SortOrder} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataColumns, Histogram, HistogramBin, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils} @@ -163,6 +164,15 @@ case class DataSourceV2ScanRelation( Statistics(sizeInBytes = conf.defaultSizeInBytes) } } + + override def doCanonicalize(): DataSourceV2ScanRelation = { + this.copy( + relation = this.relation.copy( + output = this.relation.output.map(QueryPlan.normalizeExpressions(_, this.relation.output)) + ), + output = this.output.map(QueryPlan.normalizeExpressions(_, this.output)) + ) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index ca82a8c612099..01fa2b13b86f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -26,6 +26,8 @@ import test.org.apache.spark.sql.connector._ import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform} @@ -36,6 +38,7 @@ import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, import org.apache.spark.sql.execution.SortExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.expressions.Window @@ -976,6 +979,79 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(result.length == 1) } } + + test("SPARK-53809: scan canonicalization") { + val table = new SimpleDataSourceV2().getTable(CaseInsensitiveStringMap.empty()) + + def createDsv2ScanRelation(): DataSourceV2ScanRelation = { + val relation = DataSourceV2Relation.create( + table, None, None, CaseInsensitiveStringMap.empty()) + val scan = relation.table.asReadable.newScanBuilder(relation.options).build() + DataSourceV2ScanRelation(relation, scan, relation.output) + } + + // Create two DataSourceV2ScanRelation instances, representing the scan of the same table + val scanRelation1 = createDsv2ScanRelation() + val scanRelation2 = createDsv2ScanRelation() + + // the two instances should not be the same, as they should have different attribute IDs + assert(scanRelation1 != scanRelation2, + "Two created DataSourceV2ScanRelation instances should not be the same") + assert(scanRelation1.output.map(_.exprId).toSet != scanRelation2.output.map(_.exprId).toSet, + "Output attributes should have different expression IDs before canonicalization") + assert(scanRelation1.relation.output.map(_.exprId).toSet != + scanRelation2.relation.output.map(_.exprId).toSet, + "Relation output attributes should have different expression IDs before canonicalization") + + // After canonicalization, the two instances should be equal + assert(scanRelation1.canonicalized == scanRelation2.canonicalized, + "Canonicalized DataSourceV2ScanRelation instances should be equal") + } + + test("SPARK-53809: check mergeScalarSubqueries is effective for DataSourceV2ScanRelation") { + val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + df.createOrReplaceTempView("df") + + val query = sql("select (select max(i) from df) as max_i, (select min(i) from df) as min_i") + val optimizedPlan = query.queryExecution.optimizedPlan + + // check optimizedPlan merged scalar subqueries `select max(i), min(i) from df` + val sub1 = optimizedPlan.asInstanceOf[Project].projectList.head.collect { + case s: ScalarSubquery => s + } + val sub2 = optimizedPlan.asInstanceOf[Project].projectList(1).collect { + case s: ScalarSubquery => s + } + + // Both subqueries should reference the same merged plan `select max(i), min(i) from df` + assert(sub1.nonEmpty && sub2.nonEmpty, "Both scalar subqueries should exist") + assert(sub1.head.plan == sub2.head.plan, + "Both subqueries should reference the same merged plan") + + // Extract the aggregate from the merged plan sub1 + val agg = sub1.head.plan.collect { + case a: Aggregate => a + }.head + + // Check that the aggregate contains both max(i) and min(i) + val aggFunctionSet = agg.aggregateExpressions.flatMap { expr => + expr.collect { + case ae: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression => + ae.aggregateFunction + } + }.toSet + + assert(aggFunctionSet.size == 2, "Aggregate should contain exactly two aggregate functions") + assert(aggFunctionSet + .exists(_.isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.Max]), + "Aggregate should contain max(i)") + assert(aggFunctionSet + .exists(_.isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.Min]), + "Aggregate should contain min(i)") + + // Verify the query produces correct results + checkAnswer(query, Row(9, 0)) + } } case class RangeInputPartition(start: Int, end: Int) extends InputPartition @@ -1081,6 +1157,18 @@ class SimpleDataSourceV2 extends TestingV2Source { override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } + + override def equals(obj: Any): Boolean = { + obj match { + case s: Scan => + this.readSchema() == s.readSchema() + case _ => false + } + } + + override def hashCode(): Int = { + this.readSchema().hashCode() + } } override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {