From b53f42619f9c0a41fbdac2d07c81e5a1909722f0 Mon Sep 17 00:00:00 2001 From: harris233 <1657417742@qq.com> Date: Fri, 1 Aug 2025 13:52:56 +0800 Subject: [PATCH 1/4] Fix cube-related data quality problem --- .../org/apache/spark/sql/SQLQuerySuite.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 107514edbc876..0daac5ef90cc3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -5057,6 +5057,28 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-53094: Fix cube-related data quality problem") { + withTable("table1") { + withSQLConf() { + sql( + """CREATE TABLE table1(product string, amount bigint, + |region string) using csv""".stripMargin) + + sql("INSERT INTO table1 " + "VALUES('a', 100, 'east')") + sql("INSERT INTO table1 " + "VALUES('b', 200, 'east')") + sql("INSERT INTO table1 " + "VALUES('a', 150, 'west')") + sql("INSERT INTO table1 " + "VALUES('b', 250, 'west')") + sql("INSERT INTO table1 " + "VALUES('a', 120, 'east')") + + checkAnswer( + sql("select product, region, sum(amount) as s " + + "from table1 group by product, region with cube having count(product) > 2 " + + "order by s desc"), + Seq(Row(null, null, 820), Row(null, "east", 420), Row("a", null, 370))) + } + } + } } case class Foo(bar: Option[String]) From 4b4fc2b09b521395634807c3cdc8ce8113cf5ad2 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 4 Aug 2025 20:39:07 +0200 Subject: [PATCH 2/4] defer `UnresolvedHaving` resolution if it has unresolved aggregate functions --- .../sql/catalyst/analysis/Analyzer.scala | 46 +++++++++++-------- .../analyzer-results/grouping_set.sql.out | 2 +- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 12ba41145c208..787bbd9ef8cab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -792,27 +792,33 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } else { colResolved.havingCondition } - // Try resolving the condition of the filter as though it is in the aggregate clause - val (extraAggExprs, Seq(resolvedHavingCond)) = - ResolveAggregateFunctions.resolveExprsWithAggregate(Seq(cond), aggForResolving) - - // Push the aggregate expressions into the aggregate (if any). - val newChild = constructAggregate(h, selectedGroupByExprs, groupByExprs, - aggregate.aggregateExpressions ++ extraAggExprs, aggregate.child) - - // Since the output exprId will be changed in the constructed aggregate, here we build an - // attrMap to resolve the condition again. - val attrMap = AttributeMap((aggForResolving.output ++ extraAggExprs.map(_.toAttribute)) - .zip(newChild.output)) - val newCond = resolvedHavingCond.transform { - case a: AttributeReference => attrMap.getOrElse(a, a) - } - - if (extraAggExprs.isEmpty) { - Filter(newCond, newChild) + // `cond` might contain unresolved aggregate functions so defer its resolution to + // `ResolveAggregateFunctions` rule if needed. + if (!cond.resolved) { + colResolved } else { - Project(newChild.output.dropRight(extraAggExprs.length), - Filter(newCond, newChild)) + // Try resolving the condition of the filter as though it is in the aggregate clause + val (extraAggExprs, Seq(resolvedHavingCond)) = + ResolveAggregateFunctions.resolveExprsWithAggregate(Seq(cond), aggForResolving) + + // Push the aggregate expressions into the aggregate (if any). + val newChild = constructAggregate(h, selectedGroupByExprs, groupByExprs, + aggregate.aggregateExpressions ++ extraAggExprs, aggregate.child) + + // Since the output exprId will be changed in the constructed aggregate, here we build an + // attrMap to resolve the condition again. + val attrMap = AttributeMap((aggForResolving.output ++ extraAggExprs.map(_.toAttribute)) + .zip(newChild.output)) + val newCond = resolvedHavingCond.transform { + case a: AttributeReference => attrMap.getOrElse(a, a) + } + + if (extraAggExprs.isEmpty) { + Filter(newCond, newChild) + } else { + Project(newChild.output.dropRight(extraAggExprs.length), + Filter(newCond, newChild)) + } } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out index 254f9d0785408..2c63fb1525a46 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out @@ -116,7 +116,7 @@ FROM (VALUES ('x', 'a', 10), ('y', 'b', 20) ) AS t (c1, c2, c3) GROUP BY GROUPING SETS ( ( c1 ), ( c2 ) ) HAVING GROUPING__ID > 1 -- !query analysis -Filter (grouping__id#xL > cast(1 as bigint)) +Filter (GROUPING__ID#xL > cast(1 as bigint)) +- Aggregate [c1#x, c2#x, spark_grouping_id#xL], [c1#x, c2#x, sum(c3#x) AS sum(c3)#xL, spark_grouping_id#xL AS grouping__id#xL] +- Expand [[c1#x, c2#x, c3#x, c1#x, null, 1], [c1#x, c2#x, c3#x, null, c2#x, 2]], [c1#x, c2#x, c3#x, c1#x, c2#x, spark_grouping_id#xL] +- Project [c1#x, c2#x, c3#x, c1#x AS c1#x, c2#x AS c2#x] From 4b956be46f0fe3c245aa6f9cc243c6149cfdbc18 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 4 Aug 2025 21:26:49 +0200 Subject: [PATCH 3/4] simplify test --- .../org/apache/spark/sql/SQLQuerySuite.scala | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0daac5ef90cc3..905f34cd7d340 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -5059,25 +5059,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("SPARK-53094: Fix cube-related data quality problem") { - withTable("table1") { - withSQLConf() { - sql( - """CREATE TABLE table1(product string, amount bigint, - |region string) using csv""".stripMargin) - - sql("INSERT INTO table1 " + "VALUES('a', 100, 'east')") - sql("INSERT INTO table1 " + "VALUES('b', 200, 'east')") - sql("INSERT INTO table1 " + "VALUES('a', 150, 'west')") - sql("INSERT INTO table1 " + "VALUES('b', 250, 'west')") - sql("INSERT INTO table1 " + "VALUES('a', 120, 'east')") - - checkAnswer( - sql("select product, region, sum(amount) as s " + - "from table1 group by product, region with cube having count(product) > 2 " + - "order by s desc"), - Seq(Row(null, null, 820), Row(null, "east", 420), Row("a", null, 370))) - } - } + val df = sql( + """SELECT product, region, sum(amount) AS s + |FROM VALUES + | ('a', 'east', 100), + | ('b', 'east', 200), + | ('a', 'west', 150), + | ('b', 'west', 250), + | ('a', 'east', 120) AS t(product, region, amount) + |GROUP BY product, region WITH CUBE + |HAVING count(product) > 2 + |ORDER BY s DESC""".stripMargin) + + checkAnswer(df, Seq(Row(null, null, 820), Row(null, "east", 420), Row("a", null, 370))) } } From 54ccd0feb50cb2cd3a778d07c963ef72dd192656 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 5 Aug 2025 16:26:34 +0200 Subject: [PATCH 4/4] review fix --- .../sql/catalyst/analysis/Analyzer.scala | 46 +++++++++---------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 787bbd9ef8cab..e49e6aa7f0448 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -794,31 +794,29 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } // `cond` might contain unresolved aggregate functions so defer its resolution to // `ResolveAggregateFunctions` rule if needed. - if (!cond.resolved) { - colResolved - } else { - // Try resolving the condition of the filter as though it is in the aggregate clause - val (extraAggExprs, Seq(resolvedHavingCond)) = - ResolveAggregateFunctions.resolveExprsWithAggregate(Seq(cond), aggForResolving) - - // Push the aggregate expressions into the aggregate (if any). - val newChild = constructAggregate(h, selectedGroupByExprs, groupByExprs, - aggregate.aggregateExpressions ++ extraAggExprs, aggregate.child) - - // Since the output exprId will be changed in the constructed aggregate, here we build an - // attrMap to resolve the condition again. - val attrMap = AttributeMap((aggForResolving.output ++ extraAggExprs.map(_.toAttribute)) - .zip(newChild.output)) - val newCond = resolvedHavingCond.transform { - case a: AttributeReference => attrMap.getOrElse(a, a) - } + if (!cond.resolved) return colResolved + + // Try resolving the condition of the filter as though it is in the aggregate clause + val (extraAggExprs, Seq(resolvedHavingCond)) = + ResolveAggregateFunctions.resolveExprsWithAggregate(Seq(cond), aggForResolving) + + // Push the aggregate expressions into the aggregate (if any). + val newChild = constructAggregate(h, selectedGroupByExprs, groupByExprs, + aggregate.aggregateExpressions ++ extraAggExprs, aggregate.child) + + // Since the output exprId will be changed in the constructed aggregate, here we build an + // attrMap to resolve the condition again. + val attrMap = AttributeMap((aggForResolving.output ++ extraAggExprs.map(_.toAttribute)) + .zip(newChild.output)) + val newCond = resolvedHavingCond.transform { + case a: AttributeReference => attrMap.getOrElse(a, a) + } - if (extraAggExprs.isEmpty) { - Filter(newCond, newChild) - } else { - Project(newChild.output.dropRight(extraAggExprs.length), - Filter(newCond, newChild)) - } + if (extraAggExprs.isEmpty) { + Filter(newCond, newChild) + } else { + Project(newChild.output.dropRight(extraAggExprs.length), + Filter(newCond, newChild)) } }