From fa824dc905e5cb1c2f2cb7047dcbef62ec422e77 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Wed, 20 Apr 2022 18:00:27 +0800 Subject: [PATCH 1/2] Support propagate empty expression set for distinct key --- .../plans/logical/DistinctKeyVisitor.scala | 26 +++++++++---------- .../logical/LogicalPlanDistinctKeys.scala | 8 +----- .../logical/DistinctKeyVisitorSuite.scala | 6 ++++- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala index 726c52592887f..b2e0f0e310a3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala @@ -29,25 +29,21 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] { private def projectDistinctKeys( keys: Set[ExpressionSet], projectList: Seq[NamedExpression]): Set[ExpressionSet] = { val outputSet = ExpressionSet(projectList.map(_.toAttribute)) - val aliases = projectList.filter(_.isInstanceOf[Alias]) + val aliases = projectList.collect { + // TODO: Expand distinctKeys for redundant aliases on the same expression + case alias: Alias if alias.child.deterministic => alias.child.canonicalized -> alias + }.toMap if (aliases.isEmpty) { keys.filter(_.subsetOf(outputSet)) } else { - val aliasedDistinctKeys = keys.map { expressionSet => - expressionSet.map { expression => - expression transform { - case expr: Expression => - // TODO: Expand distinctKeys for redundant aliases on the same expression - aliases - .collectFirst { case a: Alias if a.child.semanticEquals(expr) => a.toAttribute } - .getOrElse(expr) - } - } - } + val aliasedDistinctKeys = keys.map(_.map(_.transform { + case expr: Expression => + aliases.get(expr.canonicalized).map(_.toAttribute).getOrElse(expr) + })) aliasedDistinctKeys.collect { case es: ExpressionSet if es.subsetOf(outputSet) => ExpressionSet(es) } ++ keys.filter(_.subsetOf(outputSet)) - }.filter(_.nonEmpty) + } } /** @@ -69,7 +65,9 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] { override def default(p: LogicalPlan): Set[ExpressionSet] = Set.empty[ExpressionSet] override def visitAggregate(p: Aggregate): Set[ExpressionSet] = { - val groupingExps = ExpressionSet(p.groupingExpressions) // handle group by a, a + // handle group by a, a + // handle global aggregate + val groupingExps = ExpressionSet(p.groupingExpressions) projectDistinctKeys(addDistinctKey(p.child.distinctKeys, groupingExps), p.aggregateExpressions) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala index 2ffa5a0e594e1..1843c2da478ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala @@ -29,12 +29,6 @@ import org.apache.spark.sql.internal.SQLConf.PROPAGATE_DISTINCT_KEYS_ENABLED */ trait LogicalPlanDistinctKeys { self: LogicalPlan => lazy val distinctKeys: Set[ExpressionSet] = { - if (conf.getConf(PROPAGATE_DISTINCT_KEYS_ENABLED)) { - val keys = DistinctKeyVisitor.visit(self) - require(keys.forall(_.nonEmpty)) - keys - } else { - Set.empty - } + if (conf.getConf(PROPAGATE_DISTINCT_KEYS_ENABLED)) DistinctKeyVisitor.visit(self) else Set.empty } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala index 80342f6dd7a78..88d2e8b78f51e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala @@ -61,7 +61,11 @@ class DistinctKeyVisitorSuite extends PlanTest { checkDistinctAttributes(t1.groupBy($"a")($"a", max($"b")), Set(ExpressionSet(Seq(a)))) checkDistinctAttributes(t1.groupBy($"a", $"b")($"a", $"b", d, e), Set(ExpressionSet(Seq(a, b)), ExpressionSet(Seq(d.toAttribute, e.toAttribute)))) - checkDistinctAttributes(t1.groupBy()(sum($"c")), Set.empty) + checkDistinctAttributes(t1.groupBy()(sum($"c")), Set(ExpressionSet())) + // ExpressionSet() is a subset of anything, so we do not need ExpressionSet(c2) + checkDistinctAttributes(t1.groupBy()(sum($"c") as "c2").groupBy($"c2")("c2"), + Set(ExpressionSet())) + checkDistinctAttributes(t1.groupBy()(), Set(ExpressionSet())) checkDistinctAttributes(t1.groupBy($"a")($"a", $"a" % 10, d, sum($"b")), Set(ExpressionSet(Seq(a)), ExpressionSet(Seq(d.toAttribute)))) checkDistinctAttributes(t1.groupBy(f.child, $"b")(f, $"b", sum($"c")), From 1f0185f013bb3cbafd293a416a376045c69cbcac Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Thu, 21 Apr 2022 09:36:39 +0800 Subject: [PATCH 2/2] address comment --- .../sql/catalyst/plans/logical/DistinctKeyVisitor.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala index b2e0f0e310a3e..6a8e2502cd78f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala @@ -37,9 +37,9 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] { keys.filter(_.subsetOf(outputSet)) } else { val aliasedDistinctKeys = keys.map(_.map(_.transform { - case expr: Expression => - aliases.get(expr.canonicalized).map(_.toAttribute).getOrElse(expr) - })) + case expr: Expression => + aliases.get(expr.canonicalized).map(_.toAttribute).getOrElse(expr) + })) aliasedDistinctKeys.collect { case es: ExpressionSet if es.subsetOf(outputSet) => ExpressionSet(es) } ++ keys.filter(_.subsetOf(outputSet)) @@ -65,8 +65,7 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] { override def default(p: LogicalPlan): Set[ExpressionSet] = Set.empty[ExpressionSet] override def visitAggregate(p: Aggregate): Set[ExpressionSet] = { - // handle group by a, a - // handle global aggregate + // handle group by a, a and global aggregate val groupingExps = ExpressionSet(p.groupingExpressions) projectDistinctKeys(addDistinctKey(p.child.distinctKeys, groupingExps), p.aggregateExpressions) }