diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index b4bd02773edfb..7542599793e19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} -import org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeMap, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.CollapseProject.{buildCleanedProjectList, canCollapseExpressions} import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule @@ -29,11 +29,12 @@ import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType} import org.apache.spark.sql.util.SchemaUtils._ -object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { @@ -88,27 +89,43 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder) } - def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform { + private def collapseProject(plan: LogicalPlan): LogicalPlan = { + val alwaysInline = conf.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) + plan transformUp { + case agg @ Aggregate(_, aggregateExpressions, p: Project) + if canCollapseExpressions(aggregateExpressions, p.projectList, alwaysInline) => + agg.copy(aggregateExpressions = buildCleanedProjectList( + aggregateExpressions, p.projectList)) + } + } + + def pushDownAggregates(plan: LogicalPlan): LogicalPlan = collapseProject(plan).transform { // update the scan builder with agg pushdown and return a new plan with agg pushed case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) - if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) => + if filters.isEmpty && project.forall(_.deterministic) => sHolder.builder match { case r: SupportsPushDownAggregates => + val aliasMap = getAliasMap(project) + val newResultExpressions = resultExpressions.map(replaceAliasWithAttr(_, aliasMap)) + val newGroupingExpressions = groupingExpressions.map { + case e: NamedExpression => replaceAliasWithAttr(e, aliasMap) + case other => other + } val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal) + val aggregates = collectAggregates(newResultExpressions, aggExprToOutputOrdinal) val normalizedAggregates = DataSourceStrategy.normalizeExprs( aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( - groupingExpressions, sHolder.relation.output) + newGroupingExpressions, sHolder.relation.output) val translatedAggregates = DataSourceStrategy.translateAggregation( normalizedAggregates, normalizedGroupingExpressions) - val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { + val (selectedResultExpressions, selectedAggregates, selectedTranslatedAggregates) = { if (translatedAggregates.isEmpty || r.supportCompletePushDown(translatedAggregates.get) || translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { - (resultExpressions, aggregates, translatedAggregates) + (newResultExpressions, aggregates, translatedAggregates) } else { // scalastyle:off // The data source doesn't support the complete push-down of this aggregation. @@ -156,15 +173,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } - if (finalTranslatedAggregates.isEmpty) { - aggNode // return original plan node - } else if (!r.supportCompletePushDown(finalTranslatedAggregates.get) && - !supportPartialAggPushDown(finalTranslatedAggregates.get)) { - aggNode // return original plan node + if (selectedTranslatedAggregates.isEmpty) { + return plan // return original plan node + } else if (!r.supportCompletePushDown(selectedTranslatedAggregates.get) && + !supportPartialAggPushDown(selectedTranslatedAggregates.get)) { + return plan // return original plan node } else { - val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) + val pushedAggregates = selectedTranslatedAggregates.filter(r.pushAggregation) if (pushedAggregates.isEmpty) { - aggNode // return original plan node + return plan // return original plan node } else { // No need to do column pruning because only the aggregate columns are used as // DataSourceV2ScanRelation output columns. All the other columns are not @@ -182,7 +199,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] // scalastyle:on val newOutput = scan.readSchema().toAttributes - assert(newOutput.length == groupingExpressions.length + finalAggregates.length) + assert(newOutput.length == groupingExpressions.length + selectedAggregates.length) val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) case (_, b) => b @@ -203,8 +220,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) + val aliasAttrMap = getAttrToAliasMap(aliasMap) + val finalResultExpressions = selectedResultExpressions.map { + case attr: AttributeReference => + aliasAttrMap.getOrElse(attr, attr) + case other => other + } if (r.supportCompletePushDown(pushedAggregates.get)) { - val projectExpressions = resultExpressions.map { expr => + val projectExpressions = finalResultExpressions.map { expr => // TODO At present, only push down group by attribute is supported. // In future, more attribute conversion is extended here. e.g. GetStructField expr.transform { @@ -259,12 +282,31 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } } - case _ => aggNode + case _ => return plan } - case _ => aggNode + case _ => return plan } } + /** + * Replace all alias, with the aliased attribute. + */ + private def replaceAliasWithAttr( + expr: NamedExpression, + aliasMap: AttributeMap[Alias]): NamedExpression = { + replaceAliasButKeepName(expr, aliasMap).transform { + case Alias(attr: Attribute, _) => attr + }.asInstanceOf[NamedExpression] + } + + protected def getAttrToAliasMap(aliasMap: AttributeMap[Alias]): AttributeMap[Alias] = { + val attrToAliasMap = aliasMap.values.toSeq.collect { + case alias @ Alias(originAttr: Attribute, _) => + (originAttr, alias) + } + AttributeMap(attrToAliasMap) + } + private def collectAggregates(resultExpressions: Seq[NamedExpression], aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { var ordinal = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index 47740c5274616..26dfe1a50971f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -184,7 +184,7 @@ trait FileSourceAggregatePushDownSuite } } - test("aggregate over alias not push down") { + test("aggregate over alias push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) withDataSourceTable(data, "t") { @@ -194,7 +194,7 @@ trait FileSourceAggregatePushDownSuite query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: []" // aggregate alias not pushed down + "PushedAggregation: [MIN(_1)]" checkKeywordsExistsInExplain(query, expected_plan_fragment) } checkAnswer(query, Seq(Row(-2))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 17bd7f7a6d5bc..33fe25ef9e38f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -263,9 +263,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedFilters: [IsNotNull(ID), GreaterThan(ID,1)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Row("mary", 2)) @@ -410,11 +410,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0))) } @@ -432,11 +432,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(ID), AVG(ID)], " + "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " + "PushedGroupByColumns: []" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(2, 1.5))) } @@ -463,9 +463,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(12001))) } @@ -475,9 +475,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(*)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(5))) } @@ -487,9 +487,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(DEPT)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(5))) } @@ -499,9 +499,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(DISTINCT DEPT)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(3))) } @@ -523,9 +523,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(53000))) } @@ -535,9 +535,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(DISTINCT SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(31000))) } @@ -547,11 +547,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY)], " + "PushedFilters: [], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) } @@ -561,11 +561,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(DISTINCT SALARY)], " + "PushedFilters: [], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) } @@ -577,11 +577,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT, NAME]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300), Row(10000, 1000), Row(12000, 1200))) @@ -597,11 +597,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df1) df1.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT, NAME]" - checkKeywordsExistsInExplain(df1, expected_plan_fragment) + checkKeywordsExistsInExplain(df1, expectedPlanFragment) } checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), Row("2#alex", 12000), Row("2#david", 10000), Row("6#jen", 12000))) @@ -615,11 +615,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df2) df2.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT, NAME]" - checkKeywordsExistsInExplain(df2, expected_plan_fragment) + checkKeywordsExistsInExplain(df2, expectedPlanFragment) } checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), Row("2#david", 11300), Row("6#jen", 13200))) @@ -633,9 +633,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df3, false) df3.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " - checkKeywordsExistsInExplain(df3, expected_plan_fragment) + checkKeywordsExistsInExplain(df3, expectedPlanFragment) } checkAnswer(df3, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), Row("2#david", 11300), Row("6#jen", 13200))) @@ -651,11 +651,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200))) } @@ -667,11 +667,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [MIN(SALARY)], " + "PushedFilters: [], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000))) } @@ -691,11 +691,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(query, expected_plan_fragment) + checkKeywordsExistsInExplain(query, expectedPlanFragment) } checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000))) } @@ -707,9 +707,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(query) query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY), SUM(BONUS)]" - checkKeywordsExistsInExplain(query, expected_plan_fragment) + checkKeywordsExistsInExplain(query, expectedPlanFragment) } checkAnswer(query, Seq(Row(47100.0))) } @@ -734,11 +734,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } @@ -750,11 +750,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) } @@ -766,11 +766,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } @@ -782,24 +782,28 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [CORR(BONUS, BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) } - test("scan with aggregate push-down: aggregate over alias NOT push down") { + test("scan with aggregate push-down: aggregate over alias push down") { val cols = Seq("a", "b", "c", "d") val df1 = sql("select * from h2.test.employee").toDF(cols: _*) val df2 = df1.groupBy().sum("c") - checkAggregateRemoved(df2, false) + checkAggregateRemoved(df2) df2.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.aggregation.isEmpty) + case relation: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: []" + checkKeywordsExistsInExplain(df2, expectedPlanFragment) + relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.aggregation.nonEmpty) } } checkAnswer(df2, Seq(Row(53000.00))) @@ -847,12 +851,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(CASE WHEN ((SALARY) > (8000.00)) AND ((SALARY) < (10000.00))" + " THEN SALARY ELSE 0.00 END), C..., " + "PushedFilters: [], " + "PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 2, 0d), Row(2, 2, 2, 2, 2, 0d, 10000d, 0d, 10000d, 10000d, 0d, 0d, 2, 0d), @@ -864,7 +868,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee") checkAggregateRemoved(df, ansiMode) - val expected_plan_fragment = if (ansiMode) { + val expectedPlanFragment = if (ansiMode) { "PushedAggregates: [SUM((2147483647) + (DEPT))], " + "PushedFilters: [], PushedGroupByColumns: []" } else { @@ -872,7 +876,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } if (ansiMode) { val e = intercept[SparkException] { @@ -894,9 +898,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(query, false) query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedFilters: []" - checkKeywordsExistsInExplain(query, expected_plan_fragment) + checkKeywordsExistsInExplain(query, expectedPlanFragment) } checkAnswer(query, Seq(Row(47100.0))) } @@ -935,9 +939,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(`dept id`)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(2))) } @@ -949,9 +953,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [COUNT(`名`)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(2))) // scalastyle:on @@ -968,9 +972,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) @@ -985,9 +989,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df2, Seq( Row("alex", 12000.00, 12000.000000, 1), @@ -1008,9 +1012,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df, false) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) @@ -1025,9 +1029,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAggregateRemoved(df, false) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = + val expectedPlanFragment = "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } checkAnswer(df2, Seq( Row("alex", 12000.00, 12000.000000, 1), @@ -1044,4 +1048,76 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel |ON h2.test.view1.`|col1` = h2.test.view2.`|col1`""".stripMargin) checkAnswer(df, Seq.empty[Row]) } + + test("scan with aggregate push-down: complete push-down aggregate with alias") { + val df = spark.table("h2.test.employee") + .select($"DEPT", $"SALARY".as("mySalary")) + .groupBy($"DEPT") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expectedPlanFragment) + } + checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) + + val df2 = spark.table("h2.test.employee") + .select($"DEPT".as("myDept"), $"SALARY".as("mySalary")) + .groupBy($"myDept") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df2) + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df2, expectedPlanFragment) + } + checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) + } + + test("scan with aggregate push-down: partial push-down aggregate with alias") { + val df = spark.read + .option("partitionColumn", "DEPT") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .select($"NAME", $"SALARY".as("mySalary")) + .groupBy($"NAME") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" + checkKeywordsExistsInExplain(df, expectedPlanFragment) + } + checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00), + Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) + + val df2 = spark.read + .option("partitionColumn", "DEPT") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .select($"NAME".as("myName"), $"SALARY".as("mySalary")) + .groupBy($"myName") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df2, false) + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" + checkKeywordsExistsInExplain(df2, expectedPlanFragment) + } + checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), + Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) + } }