diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala index 726c52592887f..6a8e2502cd78f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala @@ -29,25 +29,21 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] { private def projectDistinctKeys( keys: Set[ExpressionSet], projectList: Seq[NamedExpression]): Set[ExpressionSet] = { val outputSet = ExpressionSet(projectList.map(_.toAttribute)) - val aliases = projectList.filter(_.isInstanceOf[Alias]) + val aliases = projectList.collect { + // TODO: Expand distinctKeys for redundant aliases on the same expression + case alias: Alias if alias.child.deterministic => alias.child.canonicalized -> alias + }.toMap if (aliases.isEmpty) { keys.filter(_.subsetOf(outputSet)) } else { - val aliasedDistinctKeys = keys.map { expressionSet => - expressionSet.map { expression => - expression transform { - case expr: Expression => - // TODO: Expand distinctKeys for redundant aliases on the same expression - aliases - .collectFirst { case a: Alias if a.child.semanticEquals(expr) => a.toAttribute } - .getOrElse(expr) - } - } - } + val aliasedDistinctKeys = keys.map(_.map(_.transform { + case expr: Expression => + aliases.get(expr.canonicalized).map(_.toAttribute).getOrElse(expr) + })) aliasedDistinctKeys.collect { case es: ExpressionSet if es.subsetOf(outputSet) => ExpressionSet(es) } ++ keys.filter(_.subsetOf(outputSet)) - }.filter(_.nonEmpty) + } } /** @@ -69,7 +65,8 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] { override def default(p: LogicalPlan): Set[ExpressionSet] = Set.empty[ExpressionSet] override def visitAggregate(p: Aggregate): Set[ExpressionSet] = { - val groupingExps = ExpressionSet(p.groupingExpressions) // handle group by a, a + // handle group by a, a and global aggregate + val groupingExps = ExpressionSet(p.groupingExpressions) projectDistinctKeys(addDistinctKey(p.child.distinctKeys, groupingExps), p.aggregateExpressions) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala index 2ffa5a0e594e1..1843c2da478ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala @@ -29,12 +29,6 @@ import org.apache.spark.sql.internal.SQLConf.PROPAGATE_DISTINCT_KEYS_ENABLED */ trait LogicalPlanDistinctKeys { self: LogicalPlan => lazy val distinctKeys: Set[ExpressionSet] = { - if (conf.getConf(PROPAGATE_DISTINCT_KEYS_ENABLED)) { - val keys = DistinctKeyVisitor.visit(self) - require(keys.forall(_.nonEmpty)) - keys - } else { - Set.empty - } + if (conf.getConf(PROPAGATE_DISTINCT_KEYS_ENABLED)) DistinctKeyVisitor.visit(self) else Set.empty } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala index 80342f6dd7a78..88d2e8b78f51e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala @@ -61,7 +61,11 @@ class DistinctKeyVisitorSuite extends PlanTest { checkDistinctAttributes(t1.groupBy($"a")($"a", max($"b")), Set(ExpressionSet(Seq(a)))) checkDistinctAttributes(t1.groupBy($"a", $"b")($"a", $"b", d, e), Set(ExpressionSet(Seq(a, b)), ExpressionSet(Seq(d.toAttribute, e.toAttribute)))) - checkDistinctAttributes(t1.groupBy()(sum($"c")), Set.empty) + checkDistinctAttributes(t1.groupBy()(sum($"c")), Set(ExpressionSet())) + // ExpressionSet() is a subset of anything, so we do not need ExpressionSet(c2) + checkDistinctAttributes(t1.groupBy()(sum($"c") as "c2").groupBy($"c2")("c2"), + Set(ExpressionSet())) + checkDistinctAttributes(t1.groupBy()(), Set(ExpressionSet())) checkDistinctAttributes(t1.groupBy($"a")($"a", $"a" % 10, d, sum($"b")), Set(ExpressionSet(Seq(a)), ExpressionSet(Seq(d.toAttribute)))) checkDistinctAttributes(t1.groupBy(f.child, $"b")(f, $"b", sum($"c")),