From 599e9e0b9bd46be798da1274e9ba9839151b2aaf Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 29 Jul 2015 16:05:21 -0500 Subject: [PATCH 01/10] Add pivot to dataframe api --- .../sql/catalyst/analysis/Analyzer.scala | 27 ++++++++++++ .../plans/logical/basicOperators.scala | 10 +++++ .../org/apache/spark/sql/DataFrame.scala | 35 ++++++++++++++++ .../spark/sql/DataFramePivotSuite.scala | 42 +++++++++++++++++++ .../scala/org/apache/spark/sql/TestData.scala | 9 ++++ 5 files changed, 123 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 265f3d1e41765..c1dfb158dd4cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -72,6 +72,7 @@ class Analyzer( ResolveRelations :: ResolveReferences :: ResolveGroupingAnalytics :: + ResolvePivot :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -166,6 +167,10 @@ class Analyzer( if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) => g.withNewAggs(assignAliases(g.aggregations)) + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregate, child) + if child.resolved && groupByExprs.exists(_.isInstanceOf[UnresolvedAlias]) => + Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregate, child) + case Project(projectList, child) if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) => Project(assignAliases(projectList), child) @@ -249,6 +254,27 @@ class Analyzer( } } + object ResolvePivot extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Pivot if !p.childrenResolved => p + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregate, child) => aggregate match { + case u: UnaryExpression if u.isInstanceOf[AggregateExpression] => + val pivotAggregates = pivotValues.map { value => + val filteredAggregate = u.withNewChildren(Seq( + If(EqualTo(pivotColumn, Literal(value)), u.child, Literal(null)) + )) + Alias(filteredAggregate, value)() + } + val newGroupByExprs = groupByExprs.map { + case UnresolvedAlias(e) => e + case e => e + } + Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) + case unknown => throw new AnalysisException(s"$unknown is not an aggregate expression") + } + } + } + /** * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ @@ -924,6 +950,7 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Project => p case f: Filter => f + case p: Pivot => p // todo: It's hard to write a general rule to pull out nondeterministic expressions // from LogicalPlan, currently we only do it for UnaryNode which has same output diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index ad5af19578f33..e4672164cf1a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -373,6 +373,16 @@ case class Rollup( this.copy(aggregations = aggs) } +case class Pivot( + groupByExprs: Seq[NamedExpression], + pivotColumn: Expression, + pivotValues: Seq[String], + aggregate: Expression, + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = + groupByExprs.map(_.toAttribute) ++ pivotValues.map(AttributeReference(_, aggregate.dataType)()) +} + case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3ea0f9ed3bddd..5b197ab9518c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -918,6 +918,41 @@ class DataFrame private[sql]( GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType) } + /** + * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified + * aggregation. + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column. + * df.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings")) + * }}} + * @param groupBy Columns to group by. + * @param pivotColumn Column to pivot + * @param pivotValues Values of pivotColumn that will be translated to columns in the output data + * frame. + * @param aggregate Aggregate expression to preform for each combination of groupBy and + * pivotValues. + * @group dfops + * @since 1.5.0 + */ + def pivot( + groupBy: Seq[Column], + pivotColumn: Column, + pivotValues: Seq[String], + aggregate: Column): DataFrame = { + + val aliasedGroupBy = groupBy.map(_.expr).map { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + + new DataFrame(sqlContext, + Pivot(aliasedGroupBy, pivotColumn.expr, pivotValues, aggregate.expr, this.logicalPlan)) + } + /** * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala new file mode 100644 index 0000000000000..7450b3cfbfe64 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.functions._ + +class DataFramePivotSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("pivot courses") { + checkAnswer( + courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings")), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } + + test("pivot year") { + checkAnswer( + courseSales.pivot(Seq($"course"), $"year", Seq("2012", "2013"), sum($"earnings")), + Row("dotNet", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index e340f54850bcc..67c76ad20e8c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -194,4 +194,13 @@ object TestData { :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) :: Nil).toDF() complexData.registerTempTable("complexData") + + case class CourseSales(course: String, year: Int, earnings: Double) + val courseSales = TestSQLContext.sparkContext.parallelize( + CourseSales("dotNET", 2012, 10000) :: + CourseSales("Java", 2012, 20000) :: + CourseSales("dotNET", 2012, 5000) :: + CourseSales("dotNET", 2013, 48000) :: + CourseSales("Java", 2013, 30000) :: Nil).toDF() + courseSales.registerTempTable("courseSales") } From 32860d217f42da62cc821c7d39bd8cfaeace267d Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Mon, 3 Aug 2015 19:46:30 -0500 Subject: [PATCH 02/10] fix unit test answer --- .../test/scala/org/apache/spark/sql/DataFramePivotSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 7450b3cfbfe64..06b987b34fb93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -35,7 +35,7 @@ class DataFramePivotSuite extends QueryTest { test("pivot year") { checkAnswer( courseSales.pivot(Seq($"course"), $"year", Seq("2012", "2013"), sum($"earnings")), - Row("dotNet", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } From e369f153bf46e6dd7e2a2735522385a65d3258d6 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 22 Oct 2015 17:45:53 -0500 Subject: [PATCH 03/10] Support for pivot as operation on GroupedData like: courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) Also, fixed master merge. --- .../sql/catalyst/analysis/Analyzer.scala | 42 ++-- .../plans/logical/basicOperators.scala | 10 +- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../org/apache/spark/sql/GroupedData.scala | 71 +++++- .../spark/sql/DataFramePivotSuite.scala | 38 +++- .../scala/org/apache/spark/sql/TestData.scala | 206 ------------------ .../apache/spark/sql/test/SQLTestData.scala | 12 + 7 files changed, 141 insertions(+), 242 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TestData.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 385be0b945d7b..259cca522d64d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -166,9 +166,9 @@ class Analyzer( if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) => g.withNewAggs(assignAliases(g.aggregations)) - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregate, child) + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) if child.resolved && groupByExprs.exists(_.isInstanceOf[UnresolvedAlias]) => - Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregate, child) + Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) case Project(projectList, child) if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) => @@ -256,21 +256,31 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Pivot if !p.childrenResolved => p - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregate, child) => aggregate match { - case u: UnaryExpression if u.isInstanceOf[AggregateExpression] => - val pivotAggregates = pivotValues.map { value => - val filteredAggregate = u.withNewChildren(Seq( - If(EqualTo(pivotColumn, Literal(value)), u.child, Literal(null)) - )) - Alias(filteredAggregate, value)() - } - val newGroupByExprs = groupByExprs.map { - case UnresolvedAlias(e) => e - case e => e + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + val singleAgg = aggregates.size == 1 + val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap{ value => + aggregates.map{ aggregate => + val filteredAggregate = aggregate.transformDown{ + case u: UnaryExpression if u.isInstanceOf[AggregateExpression] => + u.withNewChildren(Seq( + If(EqualTo(pivotColumn, Literal(value)), u.child, Literal(null)) + )) + case other: AggregateExpression => + throw new AnalysisException( + s"Pivot does not support non unary aggregate expressions, found $other") + } + if(filteredAggregate.fastEquals(aggregate)) + throw new AnalysisException( + s"Unary aggregate expression required for pivot, found '$aggregate'") + val name = if(singleAgg) value else value + " " + aggregate.prettyString + Alias(filteredAggregate, name)() } - Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) - case unknown => throw new AnalysisException(s"$unknown is not an aggregate expression") - } + } + val newGroupByExprs = groupByExprs.map { + case UnresolvedAlias(e) => e + case e => e + } + Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 512c1f0d4b9cd..dbe5270875245 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -379,10 +379,14 @@ case class Pivot( groupByExprs: Seq[NamedExpression], pivotColumn: Expression, pivotValues: Seq[String], - aggregate: Expression, + aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = - groupByExprs.map(_.toAttribute) ++ pivotValues.map(AttributeReference(_, aggregate.dataType)()) + override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { + case aggregate :: Nil => pivotValues.map(AttributeReference(_, aggregate.dataType)()) + case _ => pivotValues.flatMap{ value => + aggregates.map(agg => AttributeReference(value + " " + agg.prettyString, agg.dataType)()) + } + } } case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index b2d05dd6d7344..7f542672fef42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -951,7 +951,7 @@ class DataFrame private[sql]( * @param aggregate Aggregate expression to preform for each combination of groupBy and * pivotValues. * @group dfops - * @since 1.5.0 + * @since 1.6.0 */ def pivot( groupBy: Seq[Column], @@ -969,7 +969,7 @@ class DataFrame private[sql]( } new DataFrame(sqlContext, - Pivot(aliasedGroupBy, pivotColumn.expr, pivotValues, aggregate.expr, this.logicalPlan)) + Pivot(aliasedGroupBy, pivotColumn.expr, pivotValues, Seq(aggregate.expr), this.logicalPlan)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 102b802ad0a0a..b44a37fbc73e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -23,7 +23,7 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} +import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType /** @@ -34,7 +34,7 @@ private[sql] object GroupedData { df: DataFrame, groupingExprs: Seq[Expression], groupType: GroupType): GroupedData = { - new GroupedData(df, groupingExprs, groupType: GroupType) + new GroupedData(df, groupingExprs, groupType) } /** @@ -56,6 +56,11 @@ private[sql] object GroupedData { * To indicate it's the ROLLUP */ private[sql] object RollupType extends GroupType + + /** + * To indicate it's the PIVOT + */ + private[sql] case class PivotType(pivotCol: Expression, values: Seq[String]) extends GroupType } /** @@ -77,14 +82,8 @@ class GroupedData protected[sql]( aggExprs } - val aliasedAgg = aggregates.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } + val aliasedAgg = aggregates.map(alias) + groupType match { case GroupedData.GroupByType => DataFrame( @@ -95,9 +94,22 @@ class GroupedData protected[sql]( case GroupedData.CubeType => DataFrame( df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) + case GroupedData.PivotType(pivotCol, values) => + val aliasedGrps = groupingExprs.map(alias) + DataFrame( + df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) } } + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + private[this] def alias(expr: Expression): NamedExpression = expr match { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) : DataFrame = { @@ -333,4 +345,43 @@ class GroupedData protected[sql]( def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Sum) } + + /** + * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified + * aggregation. + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column. + * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) + * }}} + * @param pivotColumn Column to pivot + * @param values Values of pivotColumn that will be translated to columns in the output data + * frame. + * @since 1.6.0 + */ + @scala.annotation.varargs + def pivot(pivotColumn: Column, values: String*): GroupedData = groupType match { + case _: GroupedData.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case GroupedData.GroupByType => + new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, values.toSeq)) + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + + /** + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column. + * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings") + * }}} + * @param pivotColumn Column to pivot + * @param values Values of pivotColumn that will be translated to columns in the output data + * frame. + * @since 1.6.0 + */ + @scala.annotation.varargs + def pivot(pivotColumn: String, values: String*): GroupedData = { + val resolvedPivotColumn = Column(df.resolve(pivotColumn)) + pivot(resolvedPivotColumn, values: _*) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 06b987b34fb93..6fdaa8420dfed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ +//import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFramePivotSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFramePivotSuite extends QueryTest with SharedSQLContext{ + import testImplicits._ test("pivot courses") { checkAnswer( @@ -39,4 +38,33 @@ class DataFramePivotSuite extends QueryTest { ) } + test("pivot courses groupBy") { + checkAnswer( + courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } + + test("pivot year groupBy") { + checkAnswer( + courseSales.groupBy($"course").pivot($"year", "2012", "2013").agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses groupBy multiple") { + checkAnswer( + courseSales.groupBy($"year").pivot($"course", "dotNET", "Java") + .agg(sum($"earnings"), avg($"earnings")), + Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: + Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil + ) + } + + test("pivot year groupBy with strings") { + checkAnswer( + courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala deleted file mode 100644 index 67c76ad20e8c0..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test._ - - -case class TestData(key: Int, value: String) - -object TestData { - val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") - - val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() - negativeData.registerTempTable("negativeData") - - case class LargeAndSmallInts(a: Int, b: Int) - val largeAndSmallInts = - TestSQLContext.sparkContext.parallelize( - LargeAndSmallInts(2147483644, 1) :: - LargeAndSmallInts(1, 2) :: - LargeAndSmallInts(2147483645, 1) :: - LargeAndSmallInts(2, 2) :: - LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toDF() - largeAndSmallInts.registerTempTable("largeAndSmallInts") - - case class TestData2(a: Int, b: Int) - val testData2 = - TestSQLContext.sparkContext.parallelize( - TestData2(1, 1) :: - TestData2(1, 2) :: - TestData2(2, 1) :: - TestData2(2, 2) :: - TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toDF() - testData2.registerTempTable("testData2") - - case class DecimalData(a: BigDecimal, b: BigDecimal) - - val decimalData = - TestSQLContext.sparkContext.parallelize( - DecimalData(1, 1) :: - DecimalData(1, 2) :: - DecimalData(2, 1) :: - DecimalData(2, 2) :: - DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toDF() - decimalData.registerTempTable("decimalData") - - case class BinaryData(a: Array[Byte], b: Int) - val binaryData = - TestSQLContext.sparkContext.parallelize( - BinaryData("12".getBytes(), 1) :: - BinaryData("22".getBytes(), 5) :: - BinaryData("122".getBytes(), 3) :: - BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toDF() - binaryData.registerTempTable("binaryData") - - case class TestData3(a: Int, b: Option[Int]) - val testData3 = - TestSQLContext.sparkContext.parallelize( - TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toDF() - testData3.registerTempTable("testData3") - - case class UpperCaseData(N: Int, L: String) - val upperCaseData = - TestSQLContext.sparkContext.parallelize( - UpperCaseData(1, "A") :: - UpperCaseData(2, "B") :: - UpperCaseData(3, "C") :: - UpperCaseData(4, "D") :: - UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toDF() - upperCaseData.registerTempTable("upperCaseData") - - case class LowerCaseData(n: Int, l: String) - val lowerCaseData = - TestSQLContext.sparkContext.parallelize( - LowerCaseData(1, "a") :: - LowerCaseData(2, "b") :: - LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toDF() - lowerCaseData.registerTempTable("lowerCaseData") - - case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) - val arrayData = - TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: - ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) - arrayData.toDF().registerTempTable("arrayData") - - case class MapData(data: scala.collection.Map[Int, String]) - val mapData = - TestSQLContext.sparkContext.parallelize( - MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: - MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: - MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: - MapData(Map(1 -> "a4", 2 -> "b4")) :: - MapData(Map(1 -> "a5")) :: Nil) - mapData.toDF().registerTempTable("mapData") - - case class StringData(s: String) - val repeatedData = - TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.toDF().registerTempTable("repeatedData") - - val nullableRepeatedData = - TestSQLContext.sparkContext.parallelize( - List.fill(2)(StringData(null)) ++ - List.fill(2)(StringData("test"))) - nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData") - - case class NullInts(a: Integer) - val nullInts = - TestSQLContext.sparkContext.parallelize( - NullInts(1) :: - NullInts(2) :: - NullInts(3) :: - NullInts(null) :: Nil - ).toDF() - nullInts.registerTempTable("nullInts") - - val allNulls = - TestSQLContext.sparkContext.parallelize( - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: Nil).toDF() - allNulls.registerTempTable("allNulls") - - case class NullStrings(n: Int, s: String) - val nullStrings = - TestSQLContext.sparkContext.parallelize( - NullStrings(1, "abc") :: - NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil).toDF() - nullStrings.registerTempTable("nullStrings") - - case class TableName(tableName: String) - TestSQLContext - .sparkContext - .parallelize(TableName("test") :: Nil) - .toDF() - .registerTempTable("tableName") - - val unparsedStrings = - TestSQLContext.sparkContext.parallelize( - "1, A1, true, null" :: - "2, B2, false, null" :: - "3, C3, true, null" :: - "4, D4, true, 2147483644" :: Nil) - - case class IntField(i: Int) - // An RDD with 4 elements and 8 partitions - val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - withEmptyParts.toDF().registerTempTable("withEmptyParts") - - case class Person(id: Int, name: String, age: Int) - case class Salary(personId: Int, salary: Double) - val person = TestSQLContext.sparkContext.parallelize( - Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil).toDF() - person.registerTempTable("person") - val salary = TestSQLContext.sparkContext.parallelize( - Salary(0, 2000.0) :: - Salary(1, 1000.0) :: Nil).toDF() - salary.registerTempTable("salary") - - case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) - val complexData = - TestSQLContext.sparkContext.parallelize( - ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1), true) - :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) - :: Nil).toDF() - complexData.registerTempTable("complexData") - - case class CourseSales(course: String, year: Int, earnings: Double) - val courseSales = TestSQLContext.sparkContext.parallelize( - CourseSales("dotNET", 2012, 10000) :: - CourseSales("Java", 2012, 20000) :: - CourseSales("dotNET", 2012, 5000) :: - CourseSales("dotNET", 2013, 48000) :: - CourseSales("Java", 2013, 30000) :: Nil).toDF() - courseSales.registerTempTable("courseSales") -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 520dea7f7dd92..abad0d7eaaedf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -242,6 +242,17 @@ private[sql] trait SQLTestData { self => df } + protected lazy val courseSales: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + CourseSales("dotNET", 2012, 10000) :: + CourseSales("Java", 2012, 20000) :: + CourseSales("dotNET", 2012, 5000) :: + CourseSales("dotNET", 2013, 48000) :: + CourseSales("Java", 2013, 30000) :: Nil).toDF() + df.registerTempTable("courseSales") + df + } + /** * Initialize all test data such that all temp tables are properly registered. */ @@ -295,4 +306,5 @@ private[sql] object SQLTestData { case class Person(id: Int, name: String, age: Int) case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) + case class CourseSales(course: String, year: Int, earnings: Double) } From f2827ea82c8f8f7f44552d820c12a693fe25aaa9 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 22 Oct 2015 22:58:36 -0500 Subject: [PATCH 04/10] fix some style issues and remove commented import --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 +++-- .../scala/org/apache/spark/sql/DataFramePivotSuite.scala | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e58b17a3f0967..ae07ed3d96167 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -268,10 +268,11 @@ class Analyzer( throw new AnalysisException( s"Pivot does not support non unary aggregate expressions, found $other") } - if(filteredAggregate.fastEquals(aggregate)) + if (filteredAggregate.fastEquals(aggregate)) { throw new AnalysisException( s"Unary aggregate expression required for pivot, found '$aggregate'") - val name = if(singleAgg) value else value + " " + aggregate.prettyString + } + val name = if (singleAgg) value else value + " " + aggregate.prettyString Alias(filteredAggregate, name)() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 6fdaa8420dfed..5fd0578cc01e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -//import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext From 2417548ee88369fc45547346506c8fa4feb2010e Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Mon, 9 Nov 2015 08:00:52 -0600 Subject: [PATCH 05/10] Fix scala style issue from merge --- .../src/main/scala/org/apache/spark/sql/GroupedData.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 1e5d3e10d3ec5..8fc2a9583da02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -286,7 +286,7 @@ class GroupedData protected[sql]( * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified * aggregation. * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column. + * // Compute the sum of earnings for each year by course with each course as a separate column * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) * }}} * @param pivotColumn Column to pivot @@ -307,7 +307,7 @@ class GroupedData protected[sql]( /** * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column. + * // Compute the sum of earnings for each year by course with each course as a separate column * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings") * }}} * @param pivotColumn Column to pivot From 1af796d45f58fccc46f5c49ec200f2bbc97e8d9d Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Mon, 9 Nov 2015 16:44:36 -0600 Subject: [PATCH 06/10] Update pivot to make values optional, call .distinct() on column if not provided. Add unit tests for this scenario. --- .../org/apache/spark/sql/GroupedData.scala | 27 ++++++++++++++----- .../spark/sql/DataFramePivotSuite.scala | 15 +++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 8fc2a9583da02..e3827631a1852 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} -import org.apache.spark.sql.types.NumericType +import org.apache.spark.sql.types.{StringType, NumericType} /** @@ -288,10 +288,13 @@ class GroupedData protected[sql]( * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) + * // Or without specifying column values + * df.groupBy($"year").pivot($"course").agg(sum($"earnings")) * }}} * @param pivotColumn Column to pivot - * @param values Values of pivotColumn that will be translated to columns in the output data - * frame. + * @param values Optional list of values of pivotColumn that will be translated to columns in the + * output data frame. If values are not provided the method with do an immediate + * call to .distinct() on the pivot column. * @since 1.6.0 */ @scala.annotation.varargs @@ -299,7 +302,16 @@ class GroupedData protected[sql]( case _: GroupedData.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case GroupedData.GroupByType => - new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, values.toSeq)) + val pivotValues = if (values.nonEmpty) { + values + } else { + // Get the distinct values of the column and sort them so its consistent + df.select(pivotColumn.cast(StringType)) + .distinct() + .map(_.getString(0)) + .collect().sorted.toSeq + } + new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues)) case _ => throw new UnsupportedOperationException("pivot is only supported after a groupBy") } @@ -309,10 +321,13 @@ class GroupedData protected[sql]( * {{{ * // Compute the sum of earnings for each year by course with each course as a separate column * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings") + * // Or without specifying column values + * df.groupBy("year").pivot("course").sum("earnings") * }}} * @param pivotColumn Column to pivot - * @param values Values of pivotColumn that will be translated to columns in the output data - * frame. + * @param values Optional list of values of pivotColumn that will be translated to columns in the + * output data frame. If values are not provided the method with do an immediate + * call to .distinct() on the pivot column. * @since 1.6.0 */ @scala.annotation.varargs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 5fd0578cc01e3..03e608caac334 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -66,4 +66,19 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } + + test("pivot courses groupBy with no values") { + // Note Java comes before dotNet in sorted order + checkAnswer( + courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil + ) + } + + test("pivot year groupBy with no values") { + checkAnswer( + courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } } From 6e3b1337f9da8a5c15b1bef458f25f11fe94dbcd Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Mon, 9 Nov 2015 21:10:17 -0600 Subject: [PATCH 07/10] Remove DataFrame.pivot monster method --- .../org/apache/spark/sql/DataFrame.scala | 35 ------------------- .../spark/sql/DataFramePivotSuite.scala | 14 -------- 2 files changed, 49 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index fafa468081ddb..f2d4db5550273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -988,41 +988,6 @@ class DataFrame private[sql]( GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType) } - /** - * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified - * aggregation. - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column. - * df.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings")) - * }}} - * @param groupBy Columns to group by. - * @param pivotColumn Column to pivot - * @param pivotValues Values of pivotColumn that will be translated to columns in the output data - * frame. - * @param aggregate Aggregate expression to preform for each combination of groupBy and - * pivotValues. - * @group dfops - * @since 1.6.0 - */ - def pivot( - groupBy: Seq[Column], - pivotColumn: Column, - pivotValues: Seq[String], - aggregate: Column): DataFrame = { - - val aliasedGroupBy = groupBy.map(_.expr).map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - - new DataFrame(sqlContext, - Pivot(aliasedGroupBy, pivotColumn.expr, pivotValues, Seq(aggregate.expr), this.logicalPlan)) - } - /** * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 03e608caac334..46105176102f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -23,20 +23,6 @@ import org.apache.spark.sql.test.SharedSQLContext class DataFramePivotSuite extends QueryTest with SharedSQLContext{ import testImplicits._ - test("pivot courses") { - checkAnswer( - courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings")), - Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil - ) - } - - test("pivot year") { - checkAnswer( - courseSales.pivot(Seq($"course"), $"year", Seq("2012", "2013"), sum($"earnings")), - Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil - ) - } - test("pivot courses groupBy") { checkAnswer( courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")), From 88dd513d69c7a94274cb37e76aa854ddceb977cb Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 11 Nov 2015 09:33:21 -0600 Subject: [PATCH 08/10] Address comments in Analyzer --- .../sql/catalyst/analysis/Analyzer.scala | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 47eaa9d4c1e1a..730cd3f910166 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -258,20 +258,25 @@ class Analyzer( case p: Pivot if !p.childrenResolved => p case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 - val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap{ value => - aggregates.map{ aggregate => - val filteredAggregate = aggregate.transformDown{ - case u: UnaryExpression if u.isInstanceOf[AggregateExpression] => - u.withNewChildren(Seq( - If(EqualTo(pivotColumn, Literal(value)), u.child, Literal(null)) - )) - case other: AggregateExpression => - throw new AnalysisException( - s"Pivot does not support non unary aggregate expressions, found $other") + val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => + def ifExpr(expr: Expression) = { + If(EqualTo(pivotColumn, Literal(value)), expr, Literal(null)) + } + aggregates.map { aggregate => + val filteredAggregate = aggregate.transformDown { + // Assumption is the aggregate function ignores nulls. This is true for all current + // AggregateFunction's with the exception of First and Last in their default mode + // (which we handle) and possibly some Hive UDAF's. + case First(expr, _) => + First(ifExpr(expr), Literal(true)) + case Last(expr, _) => + Last(ifExpr(expr), Literal(true)) + case a: AggregateFunction => + a.withNewChildren(a.children.map(ifExpr)) } if (filteredAggregate.fastEquals(aggregate)) { throw new AnalysisException( - s"Unary aggregate expression required for pivot, found '$aggregate'") + s"Aggregate expression required for pivot, found '$aggregate'") } val name = if (singleAgg) value else value + " " + aggregate.prettyString Alias(filteredAggregate, name)() @@ -1034,7 +1039,6 @@ class Analyzer( case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f - case p: Pivot => p // todo: It's hard to write a general rule to pull out nondeterministic expressions // from LogicalPlan, currently we only do it for UnaryNode which has same output From 12a8270b7592c0953464e2353fafbfbbeda92f3a Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 11 Nov 2015 12:23:15 -0600 Subject: [PATCH 09/10] Address remaining comments - Use Literal's for the pivot column values instead of strings. - Change seperator when using multiple aggregates to `_` instead of space. - Some additional unit testing --- .../sql/catalyst/analysis/Analyzer.scala | 4 +-- .../plans/logical/basicOperators.scala | 6 ++--- .../org/apache/spark/sql/GroupedData.scala | 23 ++++++++++------ .../spark/sql/DataFramePivotSuite.scala | 26 ++++++++++++------- 4 files changed, 37 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 730cd3f910166..b7fff412b0d74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -260,7 +260,7 @@ class Analyzer( val singleAgg = aggregates.size == 1 val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(expr: Expression) = { - If(EqualTo(pivotColumn, Literal(value)), expr, Literal(null)) + If(EqualTo(pivotColumn, value), expr, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { @@ -278,7 +278,7 @@ class Analyzer( throw new AnalysisException( s"Aggregate expression required for pivot, found '$aggregate'") } - val name = if (singleAgg) value else value + " " + aggregate.prettyString + val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString Alias(filteredAggregate, name)() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index b6e9cff771b16..23add9283ba0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -388,13 +388,13 @@ case class Rollup( case class Pivot( groupByExprs: Seq[NamedExpression], pivotColumn: Expression, - pivotValues: Seq[String], + pivotValues: Seq[Literal], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { - case aggregate :: Nil => pivotValues.map(AttributeReference(_, aggregate.dataType)()) + case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) case _ => pivotValues.flatMap{ value => - aggregates.map(agg => AttributeReference(value + " " + agg.prettyString, agg.dataType)()) + aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)()) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 69be836a6c0ce..50ea789defcd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -297,18 +297,25 @@ class GroupedData protected[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def pivot(pivotColumn: Column, values: String*): GroupedData = groupType match { + def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match { case _: GroupedData.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case GroupedData.GroupByType => val pivotValues = if (values.nonEmpty) { - values + values.map { + case Column(literal: Literal) => literal + case other => + throw new UnsupportedOperationException( + s"The values of a pivot must be literals, found $other") + } } else { // Get the distinct values of the column and sort them so its consistent - df.select(pivotColumn.cast(StringType)) + df.select(pivotColumn) .distinct() - .map(_.getString(0)) - .collect().sorted.toSeq + .sort(pivotColumn) + .map(_.get(0)) + .collect() + .map(Literal(_)).toSeq } new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues)) case _ => @@ -330,9 +337,9 @@ class GroupedData protected[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def pivot(pivotColumn: String, values: String*): GroupedData = { + def pivot(pivotColumn: String, values: Any*): GroupedData = { val resolvedPivotColumn = Column(df.resolve(pivotColumn)) - pivot(resolvedPivotColumn, values: _*) + pivot(resolvedPivotColumn, values.map(functions.lit): _*) } } @@ -372,5 +379,5 @@ private[sql] object GroupedData { /** * To indicate it's the PIVOT */ - private[sql] case class PivotType(pivotCol: Expression, values: Seq[String]) extends GroupType + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 46105176102f7..03beb202ce2c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -23,37 +23,45 @@ import org.apache.spark.sql.test.SharedSQLContext class DataFramePivotSuite extends QueryTest with SharedSQLContext{ import testImplicits._ - test("pivot courses groupBy") { + test("pivot courses with literals") { checkAnswer( - courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")), + courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + .agg(sum($"earnings")), Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil ) } - test("pivot year groupBy") { + test("pivot year with literals") { checkAnswer( - courseSales.groupBy($"course").pivot($"year", "2012", "2013").agg(sum($"earnings")), + courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } - test("pivot courses groupBy multiple") { + test("pivot courses with literals and multiple aggregations") { checkAnswer( - courseSales.groupBy($"year").pivot($"course", "dotNET", "Java") + courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) .agg(sum($"earnings"), avg($"earnings")), Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil ) } - test("pivot year groupBy with strings") { + test("pivot year with string values (cast)") { checkAnswer( courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } - test("pivot courses groupBy with no values") { + test("pivot year with int values") { + checkAnswer( + courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses with no values") { // Note Java comes before dotNet in sorted order checkAnswer( courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), @@ -61,7 +69,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ ) } - test("pivot year groupBy with no values") { + test("pivot year with no values") { checkAnswer( courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil From 676f1accde4ce7cbbb7f274315dc3ab6e679d3db Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 11 Nov 2015 15:31:50 -0600 Subject: [PATCH 10/10] Add configurable maximum number of pivot values when none are given to prevent unintended OOM errors. --- .../scala/org/apache/spark/sql/GroupedData.scala | 14 ++++++++++++-- .../main/scala/org/apache/spark/sql/SQLConf.scala | 7 +++++++ .../org/apache/spark/sql/DataFramePivotSuite.scala | 9 +++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 50ea789defcd2..63dd7fbcbe9e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -309,13 +309,23 @@ class GroupedData protected[sql]( s"The values of a pivot must be literals, found $other") } } else { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) // Get the distinct values of the column and sort them so its consistent - df.select(pivotColumn) + val values = df.select(pivotColumn) .distinct() .sort(pivotColumn) .map(_.get(0)) - .collect() + .take(maxValues + 1) .map(Literal(_)).toSeq + if (values.length > maxValues) { + throw new RuntimeException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " + + s"to at least the number of distinct values of the pivot column.") + } + values } new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues)) case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 89e196c066007..1f4d158e2afad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -453,6 +453,13 @@ private[spark] object SQLConf { defaultValue = Some(true), isPublic = false) + val DATAFRAME_PIVOT_MAX_VALUES = intConf( + "spark.sql.pivotMaxValues", + defaultValue = Some(10000), + doc = "When doing a pivot without specifying values for the pivot column this is the maximum " + + "number of (distinct) values that will be collected without error." + ) + val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles", defaultValue = Some(true), isPublic = false, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 03beb202ce2c8..0c23d142670c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -75,4 +75,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } + + test("pivot max values inforced") { + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) + intercept[RuntimeException]( + courseSales.groupBy($"year").pivot($"course") + ) + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) + } }