From db6afe2f6f055ac2b86004c83aa16e326e78b433 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sun, 25 May 2014 13:53:28 -0700 Subject: [PATCH 1/2] Introduce SchemaRDD#aggregate() for simple aggregations rdd.aggregate(Sum('val)) is just shorthand for rdd.groupBy()(Sum('val)), but seems be more natural than doing a groupBy with no grouping expressions when you really just want an aggregation over all rows. Did not add a JavaSchemaRDD or Python API, as these seem to be lacking in several other methods like groupBy() already -- leaving that cleanup for future patches. --- .../scala/org/apache/spark/sql/SchemaRDD.scala | 17 +++++++++++++++-- .../org/apache/spark/sql/DslQuerySuite.scala | 8 ++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 9883ebc0b3c62..83d39b97ce25d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -59,7 +59,7 @@ import java.util.{Map => JMap} * // Importing the SQL context gives access to all the SQL functions and implicit conversions. * import sqlContext._ * - * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_\$i"))) + * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) * // Any RDD containing case classes can be registered as a table. The schema of the table is * // automatically inferred using scala reflection. * rdd.registerAsTable("records") @@ -204,6 +204,19 @@ class SchemaRDD( new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan)) } + /** + * Performs an aggregation over all Rows in this RDD. + * + * {{{ + * schemaRDD.aggregate(Sum('sales) as 'totalSales) + * }}} + * + * @group Query + */ + def aggregate(aggregateExprs: Expression*): SchemaRDD = { + groupBy()(aggregateExprs: _*) + } + /** * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes * with the same name, for example, when performing self-joins. @@ -281,7 +294,7 @@ class SchemaRDD( * supports features such as filter pushdown. */ @Experimental - override def count(): Long = groupBy()(Count(Literal(1))).collect().head.getLong(0) + override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0) /** * :: Experimental :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 94ba13b14b33d..692569a73ffcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -39,6 +39,14 @@ class DslQuerySuite extends QueryTest { testData2.groupBy('a)('a, Sum('b)), Seq((1,3),(2,3),(3,3)) ) + checkAnswer( + testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)), + 9 + ) + checkAnswer( + testData2.aggregate(Sum('b)), + 9 + ) } test("select *") { From e9e68ee2f07397eed35e1caa33ca6778dc265562 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sun, 25 May 2014 16:56:49 -0700 Subject: [PATCH 2/2] Add comment --- sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 83d39b97ce25d..e855f36256bc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -206,6 +206,7 @@ class SchemaRDD( /** * Performs an aggregation over all Rows in this RDD. + * This is equivalent to a groupBy with no grouping expressions. * * {{{ * schemaRDD.aggregate(Sum('sales) as 'totalSales)