From f7f69458989967f0f18f58eaf6ffe051f8eaefdc Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 22 Aug 2015 21:31:13 -0700 Subject: [PATCH 1/2] Use transformDown. Conflicts: sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala --- .../sql/catalyst/planning/patterns.scala | 5 ++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 22 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9c8c643f7d17a..4a7f455d07689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -151,7 +151,10 @@ object PartialAggregation { // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + // transformDown is needed at here because we want to match aggregate function first. + // Otherwise, if a grouping expression is used as an argument of an aggregate function, + // we will match grouping expression first and have a wrong plan. + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => partialEvaluations(new TreeNodeRef(e)).finalEvaluation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 87e7cf8c8af9f..b52b6069d3cb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1099,4 +1099,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } + + test("SPARK-10169: grouping expressions used as arguments of aggregate functions.") { + sqlCtx.sparkContext + .parallelize((1 to 1000), 50) + .map(i => Tuple1(i)) + .toDF("i") + .registerTempTable("t") + + val query = sqlCtx.sql( + """ + |select i % 10, sum(if(i % 10 = 5, 1, 0)), count(i) + |from t + |where i % 10 = 5 + |group by i % 10 + """.stripMargin) + + checkAnswer( + query, + Row(5, 100, 100)) + + dropTempTable("t") + } } From 059c9566a8efd0c65f2dcbbd114a7397c656d22f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 23 Aug 2015 21:49:42 +0800 Subject: [PATCH 2/2] don't trim the top level Alias Conflicts: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala --- .../apache/spark/sql/catalyst/planning/patterns.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 4a7f455d07689..d0ebe24864fbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -162,8 +162,15 @@ object PartialAggregation { // Should trim aliases around `GetField`s. These aliases are introduced while // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) + def trimAliases(e: Expression): Expression = + e.transform { case Alias(g: GetField, _) => g } + val trimmed = e match { + // Don't trim the top level Alias. + case Alias(child, name) => Alias(trimAliases(child), name)() + case _ => trimAliases(e) + } namedGroupingExpressions - .get(e.transform { case Alias(g: GetField, _) => g }) + .get(trimmed) .map(_.toAttribute) .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]]