From 0f41f3cc3910fbb58d7df218abc49adb50bc4f8f Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Mon, 6 Oct 2025 13:21:54 -0700 Subject: [PATCH 01/12] Add canonicalization for dsv2 scan --- .../org/apache/spark/sql/connector/read/Scan.java | 7 +++++++ .../datasources/v2/DataSourceV2Relation.scala | 12 ++++++++++++ .../sql/execution/datasources/v2/BatchScanExec.scala | 1 + 3 files changed, 20 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java index 81b89e5750d83..3f83a8dc4dddc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java @@ -157,4 +157,11 @@ enum ColumnarSupportMode { default ColumnarSupportMode columnarSupportMode() { return ColumnarSupportMode.PARTITION_DEFINED; } + + /** + * Return the canonicalized scan + * + * @since 4.1.0 + */ + default Scan doCanonicalize() {return this;} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 26f4069994943..2dcaf29f980f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression, SortOrder} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataColumns, Histogram, HistogramBin, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils} @@ -163,6 +164,17 @@ case class DataSourceV2ScanRelation( Statistics(sizeInBytes = conf.defaultSizeInBytes) } } + + override def doCanonicalize(): LogicalPlan = { + val canonicalized = this.copy( + relation = this.relation.copy( + output = this.relation.output.map(QueryPlan.normalizeExpressions(_, this.relation.output)) + ), + output = this.output.map(QueryPlan.normalizeExpressions(_, this.output)), + scan = this.scan.doCanonicalize() + ) + canonicalized + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 55866cc858405..33dd98d0e37fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -149,6 +149,7 @@ case class BatchScanExec( override def doCanonicalize(): BatchScanExec = { this.copy( + scan = scan.doCanonicalize(), output = output.map(QueryPlan.normalizeExpressions(_, output)), runtimeFilters = QueryPlan.normalizePredicates( runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), From 3f04dd12d495a4d5d116d8985f37793947a2a650 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Fri, 10 Oct 2025 09:26:19 -0700 Subject: [PATCH 02/12] stash for now --- .../apache/spark/sql/connector/read/Scan.java | 7 -- .../datasources/v2/DataSourceV2Relation.scala | 8 +- .../datasources/v2/BatchScanExec.scala | 1 - .../sql/connector/DataSourceV2Suite.scala | 116 ++++++++++++++++++ 4 files changed, 119 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java index 3f83a8dc4dddc..81b89e5750d83 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java @@ -157,11 +157,4 @@ enum ColumnarSupportMode { default ColumnarSupportMode columnarSupportMode() { return ColumnarSupportMode.PARTITION_DEFINED; } - - /** - * Return the canonicalized scan - * - * @since 4.1.0 - */ - default Scan doCanonicalize() {return this;} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 2dcaf29f980f2..1d101d120de8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -165,15 +165,13 @@ case class DataSourceV2ScanRelation( } } - override def doCanonicalize(): LogicalPlan = { - val canonicalized = this.copy( + override def doCanonicalize(): DataSourceV2ScanRelation = { + this.copy( relation = this.relation.copy( output = this.relation.output.map(QueryPlan.normalizeExpressions(_, this.relation.output)) ), - output = this.output.map(QueryPlan.normalizeExpressions(_, this.output)), - scan = this.scan.doCanonicalize() + output = this.output.map(QueryPlan.normalizeExpressions(_, this.output)) ) - canonicalized } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 33dd98d0e37fa..55866cc858405 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -149,7 +149,6 @@ case class BatchScanExec( override def doCanonicalize(): BatchScanExec = { this.copy( - scan = scan.doCanonicalize(), output = output.map(QueryPlan.normalizeExpressions(_, output)), runtimeFilters = QueryPlan.normalizePredicates( runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index ca82a8c612099..e5e497460da7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -976,6 +976,37 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(result.length == 1) } } + + test("SPARK-53809: scan canonicalization") { + val df = spark.read.format(classOf[CanonicalizedScanDataSourceV2].getName).load() + val q1 = df.select($"i").where($"i" > 1 && $"i" > 2) + val q2 = df.select($"i").where($"i" > 2 && $"i" > 1) + val optimized1 = q1.queryExecution.optimizedPlan + val optimized2 = q2.queryExecution.optimizedPlan + val executed1 = q1.queryExecution.executedPlan + val executed2 = q2.queryExecution.executedPlan + val dsv2ScanRelation1 = optimized1.collect { + case d: DataSourceV2ScanRelation => d + }.head + val dsv2ScanRelation2 = optimized2.collect { + case d: DataSourceV2ScanRelation => d + }.head + val batchScanExec1 = executed1.collect { + case b: BatchScanExec => b + }.head + val batchScanExec2 = executed2.collect { + case b: BatchScanExec => b + }.head + + assert(optimized1.equals(optimized2)) + assert(optimized1.canonicalized == optimized2.canonicalized) + assert(executed1.equals(executed2)) + assert(executed1.canonicalized == executed2.canonicalized) + assert(dsv2ScanRelation1.equals(dsv2ScanRelation2)) + assert(dsv2ScanRelation1.canonicalized == dsv2ScanRelation2.canonicalized) + assert(batchScanExec1.equals(batchScanExec2)) + assert(batchScanExec1.canonicalized == batchScanExec2.canonicalized) + } } case class RangeInputPartition(start: Int, end: Int) extends InputPartition @@ -1072,6 +1103,91 @@ class ScanDefinedColumnarSupport extends TestingV2Source { } +class CanonicalizedScanDataSourceV2 extends TestingV2Source { + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + TestingV2Source.schema + } + + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new CanonicalizedScanBuilder() + } + } +} + +class CanonicalizedScanBuilder extends ScanBuilder + with SupportsPushDownFilters with SupportsPushDownRequiredColumns { + + var requiredSchema: StructType = TestingV2Source.schema + var filters = Array.empty[Filter] + + override def build(): Scan = new ScanWithCanonicalization(requiredSchema, filters) + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported + } + + override def pushedFilters(): Array[Filter] = filters + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } +} + +class ScanWithCanonicalization(readSchema: StructType, val filters: Array[Filter]) + extends Scan with Batch { + + override def readSchema(): StructType = readSchema + + override def toBatch: Batch = this + + override def equals(obj: Any): Boolean = { + obj match { + case that: ScanWithCanonicalization => + this.readSchema == that.readSchema && + this.filters.sortBy(_.hashCode()).sameElements(that.filters.sortBy(_.hashCode())) + case _ => false + } + } + + override def hashCode(): Int = { + var result = readSchema.hashCode() + result = 31 * result + java.util.Arrays.hashCode( + filters.asInstanceOf[Array[AnyRef]]) + result + } + + override def planInputPartitions(): Array[InputPartition] = { + val lowerBound = filters.collectFirst { + case GreaterThan("i", v: Int) => v + } + + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] + + if (lowerBound.isEmpty) { + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get < 4) { + res.append(RangeInputPartition(lowerBound.get + 1, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get < 9) { + res.append(RangeInputPartition(lowerBound.get + 1, 10)) + } + + res.toArray + } + + override def createReaderFactory(): PartitionReaderFactory = { + new AdvancedReaderFactory(readSchema) + } +} + // This class is used by pyspark tests. If this class is modified/moved, make sure pyspark // tests still pass. From 23cdc1967b0946065963232f3cefe046fdbe037b Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Mon, 20 Oct 2025 15:57:03 -0700 Subject: [PATCH 03/12] use mock rule in test --- .../sql/connector/DataSourceV2Suite.scala | 69 ++++++++++++------- 1 file changed, 45 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index e5e497460da7e..352d5cba21da3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -26,6 +26,9 @@ import test.org.apache.spark.sql.connector._ import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.And +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform} @@ -41,7 +44,7 @@ import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{Filter, GreaterThan} +import org.apache.spark.sql.sources.{Filter, GreaterThan, LessThan} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -979,33 +982,50 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS test("SPARK-53809: scan canonicalization") { val df = spark.read.format(classOf[CanonicalizedScanDataSourceV2].getName).load() - val q1 = df.select($"i").where($"i" > 1 && $"i" > 2) - val q2 = df.select($"i").where($"i" > 2 && $"i" > 1) + + val q1 = df.select($"i", $"j").where($"i" > 1 && $"i" < 8) + val q2 = df.select($"i", $"j").where($"i" < 8 && $"i" > 1) + val optimized1 = q1.queryExecution.optimizedPlan val optimized2 = q2.queryExecution.optimizedPlan - val executed1 = q1.queryExecution.executedPlan - val executed2 = q2.queryExecution.executedPlan - val dsv2ScanRelation1 = optimized1.collect { - case d: DataSourceV2ScanRelation => d - }.head - val dsv2ScanRelation2 = optimized2.collect { - case d: DataSourceV2ScanRelation => d - }.head - val batchScanExec1 = executed1.collect { - case b: BatchScanExec => b - }.head - val batchScanExec2 = executed2.collect { - case b: BatchScanExec => b - }.head - assert(optimized1.equals(optimized2)) + // Create a rule that reverses the order of DataSourceV2ScanRelation output + val reverseOutputRule = new Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case dsv2 @ DataSourceV2ScanRelation(relation, scan: Scan, output, _, _) => + val reversedOutput = output.reverse + dsv2.copy( + output = reversedOutput + ) + } + } + + // Apply the rule to both queries to ensure they go through different optimization paths + val optimized1Reversed = reverseOutputRule(optimized1) + +// val executed1 = q1.queryExecution.executedPlan +// val executed2 = q2.queryExecution.executedPlan +// val dsv2ScanRelation1 = optimized1.collect { +// case d: DataSourceV2ScanRelation => d +// }.head +// val dsv2ScanRelation2 = optimized2.collect { +// case d: DataSourceV2ScanRelation => d +// }.head +// val batchScanExec1 = executed1.collect { +// case b: BatchScanExec => b +// }.head +// val batchScanExec2 = executed2.collect { +// case b: BatchScanExec => b +// }.head + + assert(!optimized1Reversed.equals(optimized2)) assert(optimized1.canonicalized == optimized2.canonicalized) - assert(executed1.equals(executed2)) - assert(executed1.canonicalized == executed2.canonicalized) - assert(dsv2ScanRelation1.equals(dsv2ScanRelation2)) - assert(dsv2ScanRelation1.canonicalized == dsv2ScanRelation2.canonicalized) - assert(batchScanExec1.equals(batchScanExec2)) - assert(batchScanExec1.canonicalized == batchScanExec2.canonicalized) +// assert(executed1.equals(executed2)) +// assert(executed1.canonicalized == executed2.canonicalized) +// assert(dsv2ScanRelation1.equals(dsv2ScanRelation2)) +// assert(dsv2ScanRelation1.canonicalized == dsv2ScanRelation2.canonicalized) +// assert(batchScanExec1.equals(batchScanExec2)) +// assert(batchScanExec1.canonicalized == batchScanExec2.canonicalized) } } @@ -1127,6 +1147,7 @@ class CanonicalizedScanBuilder extends ScanBuilder override def pushFilters(filters: Array[Filter]): Array[Filter] = { val (supported, unsupported) = filters.partition { case GreaterThan("i", _: Int) => true + case LessThan("i", _: Int) => true case _ => false } this.filters = supported From d6222f8aa47696c229b3abe2c58cccbb1b456592 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 22 Oct 2025 14:32:36 -0700 Subject: [PATCH 04/12] stash for now --- .../sql/connector/DataSourceV2Suite.scala | 54 ++++++++++++------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 352d5cba21da3..b3712e72e234e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -26,7 +26,6 @@ import test.org.apache.spark.sql.connector._ import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.And import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider} @@ -36,7 +35,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.Scan.ColumnarSupportMode import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} -import org.apache.spark.sql.execution.SortExec +import org.apache.spark.sql.execution.{SortExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} @@ -992,9 +991,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS // Create a rule that reverses the order of DataSourceV2ScanRelation output val reverseOutputRule = new Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case dsv2 @ DataSourceV2ScanRelation(relation, scan: Scan, output, _, _) => + case dsv2 @ DataSourceV2ScanRelation(relation, _, output, _, _) => val reversedOutput = output.reverse + val reversedRelationOutput = relation.output.reverse dsv2.copy( + relation = relation.copy(output = reversedRelationOutput), output = reversedOutput ) } @@ -1003,25 +1004,38 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS // Apply the rule to both queries to ensure they go through different optimization paths val optimized1Reversed = reverseOutputRule(optimized1) -// val executed1 = q1.queryExecution.executedPlan -// val executed2 = q2.queryExecution.executedPlan -// val dsv2ScanRelation1 = optimized1.collect { -// case d: DataSourceV2ScanRelation => d -// }.head -// val dsv2ScanRelation2 = optimized2.collect { -// case d: DataSourceV2ScanRelation => d -// }.head -// val batchScanExec1 = executed1.collect { -// case b: BatchScanExec => b -// }.head -// val batchScanExec2 = executed2.collect { -// case b: BatchScanExec => b -// }.head - assert(!optimized1Reversed.equals(optimized2)) assert(optimized1.canonicalized == optimized2.canonicalized) -// assert(executed1.equals(executed2)) -// assert(executed1.canonicalized == executed2.canonicalized) + + + val executed1 = q1.queryExecution.executedPlan + val executed2 = q2.queryExecution.executedPlan + + val reverseOutputInSparkPlan = new Rule[SparkPlan] { + def apply(plan: SparkPlan): SparkPlan = plan transform { + case dsv2 @ BatchScanExec(output, _, _, _, _, _) => + val reversedOutput = output.reverse + dsv2.copy(output = reversedOutput) + } + } + + val executed1Reversed = reverseOutputInSparkPlan(executed1) + assert(!executed1Reversed.equals(executed2)) + assert(executed1Reversed.canonicalized == executed2.canonicalized) + + + // val dsv2ScanRelation1 = optimized1.collect { + // case d: DataSourceV2ScanRelation => d + // }.head + // val dsv2ScanRelation2 = optimized2.collect { + // case d: DataSourceV2ScanRelation => d + // }.head + // val batchScanExec1 = executed1.collect { + // case b: BatchScanExec => b + // }.head + // val batchScanExec2 = executed2.collect { + // case b: BatchScanExec => b + // }.head // assert(dsv2ScanRelation1.equals(dsv2ScanRelation2)) // assert(dsv2ScanRelation1.canonicalized == dsv2ScanRelation2.canonicalized) // assert(batchScanExec1.equals(batchScanExec2)) From 3f9d849e331ce359642f3cf488c295597181c6e4 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 22 Oct 2025 15:46:53 -0700 Subject: [PATCH 05/12] push test --- .../sql/connector/DataSourceV2Suite.scala | 47 +++++-------------- 1 file changed, 12 insertions(+), 35 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index b3712e72e234e..54f2f78dcb299 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.Scan.ColumnarSupportMode import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} -import org.apache.spark.sql.execution.{SortExec, SparkPlan} +import org.apache.spark.sql.execution.SortExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} @@ -983,7 +983,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val df = spark.read.format(classOf[CanonicalizedScanDataSourceV2].getName).load() val q1 = df.select($"i", $"j").where($"i" > 1 && $"i" < 8) - val q2 = df.select($"i", $"j").where($"i" < 8 && $"i" > 1) + val q2 = q1 val optimized1 = q1.queryExecution.optimizedPlan val optimized2 = q2.queryExecution.optimizedPlan @@ -1001,45 +1001,22 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } - // Apply the rule to both queries to ensure they go through different optimization paths + // Apply the rule to q1 to mimic that QO may generate different output ordering + // for subqueries on a shared scan val optimized1Reversed = reverseOutputRule(optimized1) assert(!optimized1Reversed.equals(optimized2)) assert(optimized1.canonicalized == optimized2.canonicalized) + val dsv2ScanRelation1 = optimized1Reversed.collect { + case d: DataSourceV2ScanRelation => d + }.head + val dsv2ScanRelation2 = optimized2.collect { + case d: DataSourceV2ScanRelation => d + }.head - val executed1 = q1.queryExecution.executedPlan - val executed2 = q2.queryExecution.executedPlan - - val reverseOutputInSparkPlan = new Rule[SparkPlan] { - def apply(plan: SparkPlan): SparkPlan = plan transform { - case dsv2 @ BatchScanExec(output, _, _, _, _, _) => - val reversedOutput = output.reverse - dsv2.copy(output = reversedOutput) - } - } - - val executed1Reversed = reverseOutputInSparkPlan(executed1) - assert(!executed1Reversed.equals(executed2)) - assert(executed1Reversed.canonicalized == executed2.canonicalized) - - - // val dsv2ScanRelation1 = optimized1.collect { - // case d: DataSourceV2ScanRelation => d - // }.head - // val dsv2ScanRelation2 = optimized2.collect { - // case d: DataSourceV2ScanRelation => d - // }.head - // val batchScanExec1 = executed1.collect { - // case b: BatchScanExec => b - // }.head - // val batchScanExec2 = executed2.collect { - // case b: BatchScanExec => b - // }.head -// assert(dsv2ScanRelation1.equals(dsv2ScanRelation2)) -// assert(dsv2ScanRelation1.canonicalized == dsv2ScanRelation2.canonicalized) -// assert(batchScanExec1.equals(batchScanExec2)) -// assert(batchScanExec1.canonicalized == batchScanExec2.canonicalized) + assert(!dsv2ScanRelation1.equals(dsv2ScanRelation2)) + assert(dsv2ScanRelation1.canonicalized == dsv2ScanRelation2.canonicalized) } } From d3928447b5c662eaa1fb63bed237bafe6e354639 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 22 Oct 2025 16:00:10 -0700 Subject: [PATCH 06/12] clean up test --- .../sql/connector/DataSourceV2Suite.scala | 97 +------------------ 1 file changed, 5 insertions(+), 92 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 54f2f78dcb299..5c414e2636642 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{Filter, GreaterThan, LessThan} +import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -980,9 +980,9 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } test("SPARK-53809: scan canonicalization") { - val df = spark.read.format(classOf[CanonicalizedScanDataSourceV2].getName).load() + val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() - val q1 = df.select($"i", $"j").where($"i" > 1 && $"i" < 8) + val q1 = df.select($"i", $"j") val q2 = q1 val optimized1 = q1.queryExecution.optimizedPlan @@ -1004,7 +1004,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS // Apply the rule to q1 to mimic that QO may generate different output ordering // for subqueries on a shared scan val optimized1Reversed = reverseOutputRule(optimized1) - + // The two plans are not identified as equal, but their canonicalized forms are equal assert(!optimized1Reversed.equals(optimized2)) assert(optimized1.canonicalized == optimized2.canonicalized) @@ -1014,7 +1014,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val dsv2ScanRelation2 = optimized2.collect { case d: DataSourceV2ScanRelation => d }.head - + // Check the effectiveness of canonicalization on DataSourceV2ScanRelation assert(!dsv2ScanRelation1.equals(dsv2ScanRelation2)) assert(dsv2ScanRelation1.canonicalized == dsv2ScanRelation2.canonicalized) } @@ -1114,93 +1114,6 @@ class ScanDefinedColumnarSupport extends TestingV2Source { } -class CanonicalizedScanDataSourceV2 extends TestingV2Source { - - override def inferSchema(options: CaseInsensitiveStringMap): StructType = { - TestingV2Source.schema - } - - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { - override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new CanonicalizedScanBuilder() - } - } -} - -class CanonicalizedScanBuilder extends ScanBuilder - with SupportsPushDownFilters with SupportsPushDownRequiredColumns { - - var requiredSchema: StructType = TestingV2Source.schema - var filters = Array.empty[Filter] - - override def build(): Scan = new ScanWithCanonicalization(requiredSchema, filters) - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - val (supported, unsupported) = filters.partition { - case GreaterThan("i", _: Int) => true - case LessThan("i", _: Int) => true - case _ => false - } - this.filters = supported - unsupported - } - - override def pushedFilters(): Array[Filter] = filters - - override def pruneColumns(requiredSchema: StructType): Unit = { - this.requiredSchema = requiredSchema - } -} - -class ScanWithCanonicalization(readSchema: StructType, val filters: Array[Filter]) - extends Scan with Batch { - - override def readSchema(): StructType = readSchema - - override def toBatch: Batch = this - - override def equals(obj: Any): Boolean = { - obj match { - case that: ScanWithCanonicalization => - this.readSchema == that.readSchema && - this.filters.sortBy(_.hashCode()).sameElements(that.filters.sortBy(_.hashCode())) - case _ => false - } - } - - override def hashCode(): Int = { - var result = readSchema.hashCode() - result = 31 * result + java.util.Arrays.hashCode( - filters.asInstanceOf[Array[AnyRef]]) - result - } - - override def planInputPartitions(): Array[InputPartition] = { - val lowerBound = filters.collectFirst { - case GreaterThan("i", v: Int) => v - } - - val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] - - if (lowerBound.isEmpty) { - res.append(RangeInputPartition(0, 5)) - res.append(RangeInputPartition(5, 10)) - } else if (lowerBound.get < 4) { - res.append(RangeInputPartition(lowerBound.get + 1, 5)) - res.append(RangeInputPartition(5, 10)) - } else if (lowerBound.get < 9) { - res.append(RangeInputPartition(lowerBound.get + 1, 10)) - } - - res.toArray - } - - override def createReaderFactory(): PartitionReaderFactory = { - new AdvancedReaderFactory(readSchema) - } -} - - // This class is used by pyspark tests. If this class is modified/moved, make sure pyspark // tests still pass. class SimpleDataSourceV2 extends TestingV2Source { From c45db13663e6debbd29c1646d55a1fb59390f587 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 22 Oct 2025 16:01:12 -0700 Subject: [PATCH 07/12] restore changes --- .../scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 5c414e2636642..628b5c5b7a659 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -1114,6 +1114,7 @@ class ScanDefinedColumnarSupport extends TestingV2Source { } + // This class is used by pyspark tests. If this class is modified/moved, make sure pyspark // tests still pass. class SimpleDataSourceV2 extends TestingV2Source { From 1af4c599e2a5d44ffe5b187fac2c9c9e743dc036 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 23 Oct 2025 10:34:06 -0700 Subject: [PATCH 08/12] fix test --- .../org/apache/spark/sql/connector/DataSourceV2Suite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 628b5c5b7a659..a0cc7fb9cc52d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -1006,7 +1006,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val optimized1Reversed = reverseOutputRule(optimized1) // The two plans are not identified as equal, but their canonicalized forms are equal assert(!optimized1Reversed.equals(optimized2)) - assert(optimized1.canonicalized == optimized2.canonicalized) + assert(optimized1Reversed.canonicalized == optimized2.canonicalized) val dsv2ScanRelation1 = optimized1Reversed.collect { case d: DataSourceV2ScanRelation => d From 6a9efee896a49170f7f4376f27ee728f5fee8717 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Sun, 26 Oct 2025 16:52:42 -0700 Subject: [PATCH 09/12] stash merge subquery test --- .../sql/connector/DataSourceV2Suite.scala | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index a0cc7fb9cc52d..fa7334996539b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -26,7 +26,7 @@ import test.org.apache.spark.sql.connector._ import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ @@ -1018,6 +1018,23 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(!dsv2ScanRelation1.equals(dsv2ScanRelation2)) assert(dsv2ScanRelation1.canonicalized == dsv2ScanRelation2.canonicalized) } + + test("SPARK-53809: check mergeScalarSubqueries") { + val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + df.createOrReplaceTempView("df") + + val query = sql("select (select max(i) from df) as max_i, (select min(i) from df) as min_i") + val optimizedPlan = query.queryExecution.optimizedPlan + + // check optimizedPlan merged scalar subqueries select max(i), min(i) from df + val aggregation = optimizedPlan.collect { + case a: Aggregate => a + } + + query.collect() + val plan = query.queryExecution.stringWithStats + assert(1==1) + } } case class RangeInputPartition(start: Int, end: Int) extends InputPartition @@ -1123,6 +1140,18 @@ class SimpleDataSourceV2 extends TestingV2Source { override def planInputPartitions(): Array[InputPartition] = { Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } + + override def equals(obj: Any): Boolean = { + obj match { + case s: Scan => + this.readSchema() == s.readSchema() + case _ => false + } + } + + override def hashCode(): Int = { + this.readSchema().hashCode() + } } override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { From ccba29e3023fcb399bd1f794bb609592c2432f46 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 30 Oct 2025 12:05:02 -0700 Subject: [PATCH 10/12] add test --- .../sql/connector/DataSourceV2Suite.scala | 90 +++++++++---------- 1 file changed, 44 insertions(+), 46 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index fa7334996539b..25557e82ef1af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -26,8 +26,8 @@ import test.org.apache.spark.sql.connector._ import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform} @@ -979,61 +979,59 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } - test("SPARK-53809: scan canonicalization") { + test("SPARK-53809: check mergeScalarSubqueries is effective for DataSourceV2ScanRelation") { val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + df.createOrReplaceTempView("df") - val q1 = df.select($"i", $"j") - val q2 = q1 - - val optimized1 = q1.queryExecution.optimizedPlan - val optimized2 = q2.queryExecution.optimizedPlan - - // Create a rule that reverses the order of DataSourceV2ScanRelation output - val reverseOutputRule = new Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case dsv2 @ DataSourceV2ScanRelation(relation, _, output, _, _) => - val reversedOutput = output.reverse - val reversedRelationOutput = relation.output.reverse - dsv2.copy( - relation = relation.copy(output = reversedRelationOutput), - output = reversedOutput - ) - } + val query = sql("select (select max(i) from df) as max_i, (select min(i) from df) as min_i") + val optimizedPlan = query.queryExecution.optimizedPlan + + // check optimizedPlan merged scalar subqueries `select max(i), min(i) from df` + val sub1 = optimizedPlan.asInstanceOf[Project].projectList.head.collect { + case s: ScalarSubquery => s + } + val sub2 = optimizedPlan.asInstanceOf[Project].projectList(1).collect { + case s: ScalarSubquery => s } - // Apply the rule to q1 to mimic that QO may generate different output ordering - // for subqueries on a shared scan - val optimized1Reversed = reverseOutputRule(optimized1) - // The two plans are not identified as equal, but their canonicalized forms are equal - assert(!optimized1Reversed.equals(optimized2)) - assert(optimized1Reversed.canonicalized == optimized2.canonicalized) + // Both subqueries should reference the same merged plan `select max(i), min(i) from df` + assert(sub1.nonEmpty && sub2.nonEmpty, "Both scalar subqueries should exist") + assert(sub1.head.plan == sub2.head.plan, + "Both subqueries should reference the same merged plan") - val dsv2ScanRelation1 = optimized1Reversed.collect { - case d: DataSourceV2ScanRelation => d - }.head - val dsv2ScanRelation2 = optimized2.collect { - case d: DataSourceV2ScanRelation => d + // Extract the aggregate from the merged plan + val agg = sub1.head.plan.collect { + case a: Aggregate => a }.head - // Check the effectiveness of canonicalization on DataSourceV2ScanRelation - assert(!dsv2ScanRelation1.equals(dsv2ScanRelation2)) - assert(dsv2ScanRelation1.canonicalized == dsv2ScanRelation2.canonicalized) - } - test("SPARK-53809: check mergeScalarSubqueries") { - val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() - df.createOrReplaceTempView("df") + // Check that the aggregate contains both max(i) and min(i) + val aggExprs = agg.aggregateExpressions - val query = sql("select (select max(i) from df) as max_i, (select min(i) from df) as min_i") - val optimizedPlan = query.queryExecution.optimizedPlan + val hasMax = aggExprs.exists { expr => + expr.collect { + case ae: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression => + ae.aggregateFunction match { + case _: org.apache.spark.sql.catalyst.expressions.aggregate.Max => true + case _ => false + } + }.nonEmpty + } - // check optimizedPlan merged scalar subqueries select max(i), min(i) from df - val aggregation = optimizedPlan.collect { - case a: Aggregate => a + val hasMin = aggExprs.exists { expr => + expr.collect { + case ae: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression => + ae.aggregateFunction match { + case _: org.apache.spark.sql.catalyst.expressions.aggregate.Min => true + case _ => false + } + }.nonEmpty } - query.collect() - val plan = query.queryExecution.stringWithStats - assert(1==1) + assert(hasMax, "Aggregate should contain max(i)") + assert(hasMin, "Aggregate should contain min(i)") + + // Verify the query produces correct results + checkAnswer(query, Row(9, 0)) } } From 74ac0ad308f6140b3a59cde57c2bf00ad154b144 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 30 Oct 2025 13:22:19 -0700 Subject: [PATCH 11/12] add scan canonicalization test --- .../sql/connector/DataSourceV2Suite.scala | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 25557e82ef1af..3c7a6fe3c0a14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, import org.apache.spark.sql.execution.SortExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.expressions.Window @@ -979,6 +980,34 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } + test("SPARK-53809: scan canonicalization") { + val table = new SimpleDataSourceV2().getTable(CaseInsensitiveStringMap.empty()) + + def createDsv2ScanRelation(): DataSourceV2ScanRelation = { + val relation = DataSourceV2Relation.create( + table, None, None, CaseInsensitiveStringMap.empty()) + val scan = relation.table.asReadable.newScanBuilder(relation.options).build() + DataSourceV2ScanRelation(relation, scan, relation.output) + } + + // Create two DataSourceV2ScanRelation instances, representing the scan of the same table + val scanRelation1 = createDsv2ScanRelation() + val scanRelation2 = createDsv2ScanRelation() + + // the two instances should not be the same, as they should have different attribute IDs + assert(scanRelation1 != scanRelation2, + "Two created DataSourceV2ScanRelation instances should not be the same") + assert(scanRelation1.output.map(_.exprId).toSet != scanRelation2.output.map(_.exprId).toSet, + "Output attributes should have different expression IDs before canonicalization") + assert(scanRelation1.relation.output.map(_.exprId).toSet != + scanRelation2.relation.output.map(_.exprId).toSet, + "Relation output attributes should have different expression IDs before canonicalization") + + // After canonicalization, the two instances should be equal + assert(scanRelation1.canonicalized == scanRelation2.canonicalized, + "Canonicalized DataSourceV2ScanRelation instances should be equal") + } + test("SPARK-53809: check mergeScalarSubqueries is effective for DataSourceV2ScanRelation") { val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() df.createOrReplaceTempView("df") From b8692146bd21174d8214c5d934a6c670c0ef9dbb Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Fri, 31 Oct 2025 09:49:45 -0700 Subject: [PATCH 12/12] cleaner test check --- .../sql/connector/DataSourceV2Suite.scala | 36 +++++++------------ 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 3c7a6fe3c0a14..01fa2b13b86f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -1028,36 +1028,26 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(sub1.head.plan == sub2.head.plan, "Both subqueries should reference the same merged plan") - // Extract the aggregate from the merged plan + // Extract the aggregate from the merged plan sub1 val agg = sub1.head.plan.collect { case a: Aggregate => a }.head // Check that the aggregate contains both max(i) and min(i) - val aggExprs = agg.aggregateExpressions - - val hasMax = aggExprs.exists { expr => - expr.collect { - case ae: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression => - ae.aggregateFunction match { - case _: org.apache.spark.sql.catalyst.expressions.aggregate.Max => true - case _ => false - } - }.nonEmpty - } - - val hasMin = aggExprs.exists { expr => + val aggFunctionSet = agg.aggregateExpressions.flatMap { expr => expr.collect { case ae: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression => - ae.aggregateFunction match { - case _: org.apache.spark.sql.catalyst.expressions.aggregate.Min => true - case _ => false - } - }.nonEmpty - } - - assert(hasMax, "Aggregate should contain max(i)") - assert(hasMin, "Aggregate should contain min(i)") + ae.aggregateFunction + } + }.toSet + + assert(aggFunctionSet.size == 2, "Aggregate should contain exactly two aggregate functions") + assert(aggFunctionSet + .exists(_.isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.Max]), + "Aggregate should contain max(i)") + assert(aggFunctionSet + .exists(_.isInstanceOf[org.apache.spark.sql.catalyst.expressions.aggregate.Min]), + "Aggregate should contain min(i)") // Verify the query produces correct results checkAnswer(query, Row(9, 0))