From 4eaef85d3bd6e2ad37b0a6ccdd75e7083dbd16ed Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 20 Jun 2019 18:24:09 -0700 Subject: [PATCH 1/2] Change FlatMapGroupsInPandasExec and AggregateInPandasExec to skip empty partitions --- .../python/AggregateInPandasExec.scala | 84 ++++++++++--------- .../python/FlatMapGroupsInPandasExec.scala | 64 +++++++------- 2 files changed, 81 insertions(+), 67 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 0c78cca086ed3..a73b7f0ba15f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -106,50 +106,56 @@ case class AggregateInPandasExec( }) inputRDD.mapPartitionsInternal { iter => - val prunedProj = UnsafeProjection.create(allInputs, child.output) - val grouped = if (groupingExpressions.isEmpty) { - // Use an empty unsafe row as a place holder for the grouping key - Iterator((new UnsafeRow(), iter)) - } else { - GroupedIterator(iter, groupingExpressions, child.output) - }.map { case (key, rows) => - (key, rows.map(prunedProj)) - } + // Only execute on non-empty partitions + if (iter.nonEmpty) { + val prunedProj = UnsafeProjection.create(allInputs, child.output) - val context = TaskContext.get() + val grouped = if (groupingExpressions.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, groupingExpressions, child.output) + }.map { case (key, rows) => + (key, rows.map(prunedProj)) + } - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = HybridRowQueue(context.taskMemoryManager(), - new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) - context.addTaskCompletionListener[Unit] { _ => - queue.close() - } + val context = TaskContext.get() - // Add rows to queue to join later with the result. - val projectedRowIter = grouped.map { case (groupingKey, rows) => - queue.add(groupingKey.asInstanceOf[UnsafeRow]) - rows - } + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) + context.addTaskCompletionListener[Unit] { _ => + queue.close() + } + + // Add rows to queue to join later with the result. + val projectedRowIter = grouped.map { case (groupingKey, rows) => + queue.add(groupingKey.asInstanceOf[UnsafeRow]) + rows + } - val columnarBatchIter = new ArrowPythonRunner( - pyFuncs, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, - argOffsets, - aggInputSchema, - sessionLocalTimeZone, - pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) - - val joinedAttributes = - groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) - val joined = new JoinedRow - val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes) - - columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow => - val leftRow = queue.remove() - val joinedRow = joined(leftRow, aggOutputRow) - resultProj(joinedRow) + val columnarBatchIter = new ArrowPythonRunner( + pyFuncs, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + argOffsets, + aggInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) + + val joinedAttributes = + groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes) + + columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, aggOutputRow) + resultProj(joinedRow) + } + } else { + Iterator.empty } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 7b0e014f9ca48..c267e9b670f08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -126,36 +126,44 @@ case class FlatMapGroupsInPandasExec( val dedupSchema = StructType.fromAttributes(dedupAttributes) inputRDD.mapPartitionsInternal { iter => - val grouped = if (groupingAttributes.isEmpty) { - Iterator(iter) - } else { - val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) - groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dedupProj) + + // Only execute on non-empty partitions + if (iter.nonEmpty) { + + val grouped = if (groupingAttributes.isEmpty) { + Iterator(iter) + } else { + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) + groupedIter.map { + case (_, groupedRowIter) => groupedRowIter.map(dedupProj) + } } - } - val context = TaskContext.get() - - val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - argOffsets, - dedupSchema, - sessionLocalTimeZone, - pythonRunnerConf).compute(grouped, context.partitionId(), context) - - val unsafeProj = UnsafeProjection.create(output, output) - - columnarBatchIter.flatMap { batch => - // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = output.indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - flattenedBatch.rowIterator.asScala - }.map(unsafeProj) + val context = TaskContext.get() + + val columnarBatchIter = new ArrowPythonRunner( + chainedFunc, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + argOffsets, + dedupSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(grouped, context.partitionId(), context) + + val unsafeProj = UnsafeProjection.create(output, output) + + columnarBatchIter.flatMap { batch => + // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(unsafeProj) + + } else { + Iterator.empty + } } } } From 9be0110c7b7dee443d30e3188d9f51eaaf7fd3e1 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 21 Jun 2019 12:25:13 -0700 Subject: [PATCH 2/2] Added tests and cleanup --- .../sql/tests/test_pandas_udf_grouped_agg.py | 13 +++ .../sql/tests/test_pandas_udf_grouped_map.py | 12 +++ .../python/AggregateInPandasExec.scala | 89 +++++++++---------- .../python/FlatMapGroupsInPandasExec.scala | 69 +++++++------- 4 files changed, 98 insertions(+), 85 deletions(-) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index 9eda1aa610105..f5fd725b9ade3 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -18,6 +18,7 @@ import unittest from pyspark.rdd import PythonEvalType +from pyspark.sql import Row from pyspark.sql.functions import array, explode, col, lit, mean, sum, \ udf, pandas_udf, PandasUDFType from pyspark.sql.types import * @@ -461,6 +462,18 @@ def test_register_vectorized_udf_basic(self): expected = [1, 5] self.assertEqual(actual, expected) + def test_grouped_with_empty_partition(self): + data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)] + expected = [Row(id=1, sum=5), Row(id=2, x=4)] + num_parts = len(data) + 1 + df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts)) + + f = pandas_udf(lambda x: x.sum(), + 'int', PandasUDFType.GROUPED_AGG) + + result = df.groupBy('id').agg(f(df['x']).alias('sum')).collect() + self.assertEqual(result, expected) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_udf_grouped_agg import * diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index 1d87c636ab34e..32d6720b2c127 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -504,6 +504,18 @@ def test_mixed_scalar_udfs_followed_by_grouby_apply(self): self.assertEquals(result.collect()[0]['sum'], 165) + def test_grouped_with_empty_partition(self): + data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)] + expected = [Row(id=1, x=5), Row(id=1, x=5), Row(id=2, x=4)] + num_parts = len(data) + 1 + df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts)) + + f = pandas_udf(lambda pdf: pdf.assign(x=pdf['x'].sum()), + 'id long, x int', PandasUDFType.GROUPED_MAP) + + result = df.groupBy('id').apply(f).collect() + self.assertEqual(result, expected) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_udf_grouped_map import * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index a73b7f0ba15f8..fcbd0b19515b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -105,58 +105,53 @@ case class AggregateInPandasExec( StructField(s"_$i", dt) }) - inputRDD.mapPartitionsInternal { iter => + // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty + inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { + val prunedProj = UnsafeProjection.create(allInputs, child.output) - // Only execute on non-empty partitions - if (iter.nonEmpty) { - val prunedProj = UnsafeProjection.create(allInputs, child.output) - - val grouped = if (groupingExpressions.isEmpty) { - // Use an empty unsafe row as a place holder for the grouping key - Iterator((new UnsafeRow(), iter)) - } else { - GroupedIterator(iter, groupingExpressions, child.output) - }.map { case (key, rows) => - (key, rows.map(prunedProj)) - } + val grouped = if (groupingExpressions.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, groupingExpressions, child.output) + }.map { case (key, rows) => + (key, rows.map(prunedProj)) + } - val context = TaskContext.get() + val context = TaskContext.get() - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = HybridRowQueue(context.taskMemoryManager(), - new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) - context.addTaskCompletionListener[Unit] { _ => - queue.close() - } + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) + context.addTaskCompletionListener[Unit] { _ => + queue.close() + } - // Add rows to queue to join later with the result. - val projectedRowIter = grouped.map { case (groupingKey, rows) => - queue.add(groupingKey.asInstanceOf[UnsafeRow]) - rows - } + // Add rows to queue to join later with the result. + val projectedRowIter = grouped.map { case (groupingKey, rows) => + queue.add(groupingKey.asInstanceOf[UnsafeRow]) + rows + } - val columnarBatchIter = new ArrowPythonRunner( - pyFuncs, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, - argOffsets, - aggInputSchema, - sessionLocalTimeZone, - pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) - - val joinedAttributes = - groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) - val joined = new JoinedRow - val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes) - - columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow => - val leftRow = queue.remove() - val joinedRow = joined(leftRow, aggOutputRow) - resultProj(joinedRow) - } - } else { - Iterator.empty + val columnarBatchIter = new ArrowPythonRunner( + pyFuncs, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + argOffsets, + aggInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) + + val joinedAttributes = + groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes) + + columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, aggOutputRow) + resultProj(joinedRow) } - } + }} } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index c267e9b670f08..267698d1bca50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -125,45 +125,38 @@ case class FlatMapGroupsInPandasExec( val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes val dedupSchema = StructType.fromAttributes(dedupAttributes) - inputRDD.mapPartitionsInternal { iter => - - // Only execute on non-empty partitions - if (iter.nonEmpty) { - - val grouped = if (groupingAttributes.isEmpty) { - Iterator(iter) - } else { - val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) - groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dedupProj) - } - } - - val context = TaskContext.get() - - val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - argOffsets, - dedupSchema, - sessionLocalTimeZone, - pythonRunnerConf).compute(grouped, context.partitionId(), context) - - val unsafeProj = UnsafeProjection.create(output, output) - - columnarBatchIter.flatMap { batch => - // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = output.indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - flattenedBatch.rowIterator.asScala - }.map(unsafeProj) - + // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty + inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { + val grouped = if (groupingAttributes.isEmpty) { + Iterator(iter) } else { - Iterator.empty + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) + groupedIter.map { + case (_, groupedRowIter) => groupedRowIter.map(dedupProj) + } } - } + + val context = TaskContext.get() + + val columnarBatchIter = new ArrowPythonRunner( + chainedFunc, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + argOffsets, + dedupSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(grouped, context.partitionId(), context) + + val unsafeProj = UnsafeProjection.create(output, output) + + columnarBatchIter.flatMap { batch => + // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(unsafeProj) + }} } }