From db0a58350482545996b171c7ccc6603bade938f6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 4 Oct 2018 23:40:45 +0000 Subject: [PATCH 1/2] Avoid overwriting deserialized accumulator. --- python/pyspark/accumulators.py | 12 ++++++++---- python/pyspark/sql/tests.py | 25 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 30ad04297c68..00ec094e7e3b 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -109,10 +109,14 @@ def _deserialize_accumulator(aid, zero_value, accum_param): from pyspark.accumulators import _accumulatorRegistry - accum = Accumulator(aid, zero_value, accum_param) - accum._deserialized = True - _accumulatorRegistry[aid] = accum - return accum + # If this certain accumulator was deserialized, don't overwrite it. + if aid in _accumulatorRegistry: + return _accumulatorRegistry[aid] + else: + accum = Accumulator(aid, zero_value, accum_param) + accum._deserialized = True + _accumulatorRegistry[aid] = accum + return accum class Accumulator(object): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b88a6551f8ae..4dca193ff4a8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3603,6 +3603,31 @@ def test_repr_behaviors(self): self.assertEquals(None, df._repr_html_()) self.assertEquals(expected, df.__repr__()) + # SPARK-25591 + def test_same_accumulator_in_udfs(self): + from pyspark.sql.functions import udf + + data_schema = StructType([StructField("a", DoubleType(), True), + StructField("b", DoubleType(), True)]) + data = self.spark.createDataFrame([[1.0, 2.0]], schema=data_schema) + + test_accum = self.sc.accumulator(0.0) + + def first_udf(x): + test_accum.add(1.0) + return x + + def second_udf(x): + test_accum.add(100.0) + return x + + func_udf = udf(first_udf, DoubleType()) + func_udf2 = udf(second_udf, DoubleType()) + data = data.withColumn("out1", func_udf(data["a"])) + data = data.withColumn("out2", func_udf2(data["b"])) + data.collect() + self.assertEqual(test_accum.value, 101) + class HiveSparkSubmitTests(SparkSubmitTests): From 08c7223c57d6c2b9536ba311ea4f81b20f37d973 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 8 Oct 2018 05:37:41 +0000 Subject: [PATCH 2/2] Address comment. --- python/pyspark/sql/tests.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4dca193ff4a8..c23331112d8e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3607,22 +3607,22 @@ def test_repr_behaviors(self): def test_same_accumulator_in_udfs(self): from pyspark.sql.functions import udf - data_schema = StructType([StructField("a", DoubleType(), True), - StructField("b", DoubleType(), True)]) - data = self.spark.createDataFrame([[1.0, 2.0]], schema=data_schema) + data_schema = StructType([StructField("a", IntegerType(), True), + StructField("b", IntegerType(), True)]) + data = self.spark.createDataFrame([[1, 2]], schema=data_schema) - test_accum = self.sc.accumulator(0.0) + test_accum = self.sc.accumulator(0) def first_udf(x): - test_accum.add(1.0) + test_accum.add(1) return x def second_udf(x): - test_accum.add(100.0) + test_accum.add(100) return x - func_udf = udf(first_udf, DoubleType()) - func_udf2 = udf(second_udf, DoubleType()) + func_udf = udf(first_udf, IntegerType()) + func_udf2 = udf(second_udf, IntegerType()) data = data.withColumn("out1", func_udf(data["a"])) data = data.withColumn("out2", func_udf2(data["b"])) data.collect()