From 2653750762c12cf5d17b8a2ac2a7ee9f8d55bfec Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Thu, 8 Dec 2016 21:22:31 -0800 Subject: [PATCH 01/17] [SPARK-14932][SQL] Allow DataFrame.replace() to replace values with None --- python/pyspark/sql/dataframe.py | 22 ++++++++++++++----- .../spark/sql/DataFrameNaFunctions.scala | 13 ++++++----- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b9d90384e3e2c..b6997127d4fea 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1277,10 +1277,10 @@ def replace(self, to_replace, value, subset=None): If the value is a dict, then `value` is ignored and `to_replace` must be a mapping from column name (string) to replacement value. The value to be replaced must be an int, long, float, or string. - :param value: int, long, float, string, or list. + :param value: int, long, float, string, list or None. Value to use to replace holes. - The replacement value must be an int, long, float, or string. If `value` is a - list or tuple, `value` should be of the same length with `to_replace`. + The replacement value must be an int, long, float, string or None. If `value` + is a list or tuple, `value` should be of the same length with `to_replace`. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, @@ -1296,6 +1296,16 @@ def replace(self, to_replace, value, subset=None): |null| null| null| +----+------+-----+ + >>> df4.na.replace('Alice', None).show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80|null| + | 5| null| Bob| + |null| null| Tom| + |null| null|null| + +----+------+----+ + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -1310,8 +1320,8 @@ def replace(self, to_replace, value, subset=None): raise ValueError( "to_replace should be a float, int, long, string, list, tuple, or dict") - if not isinstance(value, (float, int, long, basestring, list, tuple)): - raise ValueError("value should be a float, int, long, string, list, or tuple") + if value is not None and not isinstance(value, (float, int, long, basestring, list, tuple)): + raise ValueError("value should be a float, int, long, string, list, tuple or None") rep_dict = dict() @@ -1328,7 +1338,7 @@ def replace(self, to_replace, value, subset=None): if len(to_replace) != len(value): raise ValueError("to_replace and value lists should be of the same length") rep_dict = dict(zip(to_replace, value)) - elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): + elif isinstance(to_replace, list) and (value is None or isinstance(value, (float, int, long, basestring))): rep_dict = dict([(tr, value) for tr in to_replace]) elif isinstance(to_replace, dict): rep_dict = to_replace diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 28820681cd3a6..777e0ca49dd83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -342,11 +342,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] - val replacementMap: Map[_, _] = replacement.head._2 match { - case v: String => replacement - case v: Boolean => replacement - case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } - } + val replacementMap: Map[_, _] = + if (replacement.head._2 == null) + replacement + else replacement.head._2 match { + case v: String => replacement + case v: Boolean => replacement + case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } + } // targetColumnType is either DoubleType or StringType or BooleanType val targetColumnType = replacement.head._1 match { From 2eac8b9070f12fd7dade103857682669d3017587 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Fri, 9 Dec 2016 09:35:35 -0800 Subject: [PATCH 02/17] Scala test for df.replace with null --- .../spark/sql/DataFrameNaFunctionsSuite.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index fd829846ac332..79a5aaca140f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -208,16 +208,16 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out(4) === Row("Amy", null, null)) assert(out(5) === Row(null, null, null)) - // Replace only the age column - val out1 = input.na.replace("age", Map( - 16 -> 61, - 60 -> 6, - 164.3 -> 461.3 // Alice is really tall + // Replace only the name column + val out1 = input.na.replace("name", Map( + "Bob" -> "Bravo", + "Alice" -> "Jessie", + "David" -> null )).collect() - assert(out1(0) === Row("Bob", 61, 176.5)) - assert(out1(1) === Row("Alice", null, 164.3)) - assert(out1(2) === Row("David", 6, null)) + assert(out1(0) === Row("Bravo", 16, 176.5)) + assert(out1(1) === Row("Jessie", null, 164.3)) + assert(out1(2) === Row(null, 60, null)) assert(out1(3).get(2).asInstanceOf[Double].isNaN) assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null)) From 79492924bfad40f36aa132e58376da98b0727eac Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Fri, 9 Dec 2016 09:44:39 -0800 Subject: [PATCH 03/17] Use pattern matching for null case --- .../scala/org/apache/spark/sql/DataFrameNaFunctions.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 777e0ca49dd83..687e588f8450a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -342,10 +342,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] - val replacementMap: Map[_, _] = - if (replacement.head._2 == null) - replacement - else replacement.head._2 match { + val replacementMap: Map[_, _] = replacement.head._2 match { + case null => replacement case v: String => replacement case v: Boolean => replacement case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } From 0b15c8f1d64754432c6585bb88e20d17738d4bbd Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Fri, 9 Dec 2016 09:45:51 -0800 Subject: [PATCH 04/17] Fix indentation --- .../org/apache/spark/sql/DataFrameNaFunctions.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 687e588f8450a..46e14bdbf9493 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -343,11 +343,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] val replacementMap: Map[_, _] = replacement.head._2 match { - case null => replacement - case v: String => replacement - case v: Boolean => replacement - case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } - } + case null => replacement + case v: String => replacement + case v: Boolean => replacement + case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } + } // targetColumnType is either DoubleType or StringType or BooleanType val targetColumnType = replacement.head._1 match { From 2c532c3781087ec8f0c36d5176837c9de568d7ec Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Sat, 10 Dec 2016 09:42:44 -0800 Subject: [PATCH 05/17] Fix Python style check --- python/pyspark/sql/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b6997127d4fea..7fcf13700c16a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1338,7 +1338,8 @@ def replace(self, to_replace, value, subset=None): if len(to_replace) != len(value): raise ValueError("to_replace and value lists should be of the same length") rep_dict = dict(zip(to_replace, value)) - elif isinstance(to_replace, list) and (value is None or isinstance(value, (float, int, long, basestring))): + elif (isinstance(to_replace, list) and + (value is None or isinstance(value, (float, int, long, basestring)))): rep_dict = dict([(tr, value) for tr in to_replace]) elif isinstance(to_replace, dict): rep_dict = to_replace From 43fb6bd56802f2c20cdd28f7ea384e472183cdc4 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Tue, 23 May 2017 17:46:19 -0700 Subject: [PATCH 06/17] Improve scala doc and pyspark test --- python/pyspark/sql/tests.py | 5 +++++ .../org/apache/spark/sql/DataFrameNaFunctions.scala | 9 ++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index acea9113ee858..509463837a7a1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1851,6 +1851,11 @@ def test_replace(self): .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) + # replace with None + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80), None).first() + self.assertTupleEqual(row, (u'Alice', None, None)) + # should fail if subset is not list, tuple or None with self.assertRaises(ValueError): self.spark.createDataFrame( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 6e7e1cd17bca6..8e2b01417a6ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -319,8 +319,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and - * can only be doubles , strings or booleans. + * Key and value of `replacement` map must satisfy one of: + * 1. keys are String, values are mix of String and null + * 2. keys are Boolean, values are mix of Boolean and null + * 3. keys are Double, values are either all Double or all null * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". @@ -342,7 +344,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { return df } - // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] + // replacementMap is either Map[String, String], Map[Double, Double], Map[Boolean,Boolean] + // or value being null val replacementMap: Map[_, _] = replacement.head._2 match { case null => replacement case v: String => replacement From b5424d9fea56d2e0fb57ebc27d3d35054da6d22b Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Tue, 23 May 2017 19:56:07 -0700 Subject: [PATCH 07/17] Fix python3 dict.values() syntax --- python/pyspark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f2675fd7339fb..b57bb97e4e84a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1436,7 +1436,7 @@ def all_of_(xs): # Verify we were not passed in mixed type generics." if not any(all_of_type(rep_dict.keys()) and (all_of_type(rep_dict.values()) - or rep_dict.values().count(None) == len(rep_dict)) + or list(rep_dict.values()).count(None) == len(rep_dict)) for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): raise ValueError("Mixed type replacements are not supported") From a3939ba6e80d26f9d4283da8bea2c36244d876e1 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Wed, 2 Aug 2017 08:51:47 -0700 Subject: [PATCH 08/17] Unify allowed null in Python and Scala --- python/pyspark/sql/dataframe.py | 23 +++++++++--------- python/pyspark/sql/tests.py | 9 +++++-- .../spark/sql/DataFrameNaFunctions.scala | 24 +++++++++++-------- .../spark/sql/DataFrameNaFunctionsSuite.scala | 16 ++++++------- 4 files changed, 40 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b57bb97e4e84a..8c54c5657cd46 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1325,8 +1325,8 @@ def replace(self, to_replace, value=None, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. - Values to_replace and value should contain either all numerics, all booleans, - or all strings. When replacing, the new value will be cast + Values to_replace and value must have the same type and can only be numerics, booleans, + or strings. Value can have None. When replacing, the new value will be cast to the type of the existing column. For numeric replacements all values to be replaced should have unique floating point representation. In case of conflicts (for example with `{42: -1, 42.0: 1}`) @@ -1390,9 +1390,10 @@ def all_of_(xs): return all(isinstance(x, types) for x in xs) return all_of_ - all_of_bool = all_of(bool) - all_of_str = all_of(basestring) - all_of_numeric = all_of((float, int, long)) + # Replacement key and value must have the same type while value can have None + all_of_bool = (all_of(bool), all_of((bool, type(None)))) + all_of_str = (all_of(basestring), all_of((basestring, type(None)))) + all_of_numeric = (all_of((float, int, long)), all_of((float, int, long, type(None)))) # Validate input types valid_types = (bool, float, int, long, basestring, list, tuple) @@ -1401,8 +1402,7 @@ def all_of_(xs): "to_replace should be a float, int, long, string, list, tuple, or dict. " "Got {0}".format(type(to_replace))) - if not isinstance(value, valid_types) and value is not None \ - and not isinstance(to_replace, dict): + if not isinstance(value, valid_types + (type(None), )) and not isinstance(to_replace, dict): raise ValueError("If to_replace is not a dict, value should be " "a float, int, long, string, list, tuple or None. " "Got {0}".format(type(value))) @@ -1420,7 +1420,7 @@ def all_of_(xs): if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(value, (float, int, long, basestring)) or value is None: + if isinstance(value, (float, int, long, basestring, type(None))): value = [value for _ in range(len(to_replace))] if isinstance(to_replace, dict): @@ -1434,10 +1434,9 @@ def all_of_(xs): subset = [subset] # Verify we were not passed in mixed type generics." - if not any(all_of_type(rep_dict.keys()) - and (all_of_type(rep_dict.values()) - or list(rep_dict.values()).count(None) == len(rep_dict)) - for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): + if not any(key_all_of_type(rep_dict.keys()) and value_all_of_type(rep_dict.values()) + for (key_all_of_type, value_all_of_type) + in [all_of_bool, all_of_str, all_of_numeric]): raise ValueError("Mixed type replacements are not supported") if subset is None: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 509463837a7a1..5f4f7444ccfa9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1853,8 +1853,13 @@ def test_replace(self): # replace with None row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace((10, 80), None).first() - self.assertTupleEqual(row, (u'Alice', None, None)) + [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).first() + self.assertTupleEqual(row, (None, 10, 80.0)) + + # replace with numerics and None + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first() + self.assertTupleEqual(row, (u'Alice', 20, None)) # should fail if subset is not list, tuple or None with self.assertRaises(ValueError): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 8e2b01417a6ec..bd1021e38fab6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -238,6 +238,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Replaces values matching keys in `replacement` map with the corresponding values. * Key and value of `replacement` map must have the same type, and * can only be doubles, strings or booleans. + * `replacement` map value can have null. * If `col` is "*", then the replacement is applied on all string columns or numeric columns. * * {{{ @@ -266,6 +267,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Replaces values matching keys in `replacement` map with the corresponding values. * Key and value of `replacement` map must have the same type, and * can only be doubles, strings or booleans. + * `replacement` map value can have null. * * {{{ * import com.google.common.collect.ImmutableMap; @@ -290,6 +292,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * (Scala-specific) Replaces values matching keys in `replacement` map. * Key and value of `replacement` map must have the same type, and * can only be doubles, strings or booleans. + * `replacement` map value can have null. * If `col` is "*", * then the replacement is applied on all string columns , numeric columns or boolean columns. * @@ -319,10 +322,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must satisfy one of: - * 1. keys are String, values are mix of String and null - * 2. keys are Boolean, values are mix of Boolean and null - * 3. keys are Double, values are either all Double or all null + * Key and value of `replacement` map must have the same type, and + * can only be doubles, strings or booleans. + * `replacement` map value can have null. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". @@ -345,12 +347,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } // replacementMap is either Map[String, String], Map[Double, Double], Map[Boolean,Boolean] - // or value being null - val replacementMap: Map[_, _] = replacement.head._2 match { - case null => replacement - case v: String => replacement - case v: Boolean => replacement - case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } + // while value can have null + val replacementMap: Map[_, _] = replacement.map { + case (k, v: String) => (k, v) + case (k, v: Boolean) => (k, v) + case (k: String, null) => (k, null) + case (k: Boolean, null) => (k, null) + case (k, null) => (convertToDouble(k), null) + case _ @(k, v) => (convertToDouble(k), convertToDouble(v)) } // targetColumnType is either DoubleType or StringType or BooleanType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index f1f498b2b9e6d..103ffc77659bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -222,16 +222,16 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out(4) === Row("Amy", null, null)) assert(out(5) === Row(null, null, null)) - // Replace only the name column - val out1 = input.na.replace("name", Map( - "Bob" -> "Bravo", - "Alice" -> "Jessie", - "David" -> null + // Replace only the age column + val out1 = input.na.replace("age", Map[Any, Any]( + 16 -> 61, + 60 -> null, + 164.3 -> 461.3 // Alice is really tall )).collect() - assert(out1(0) === Row("Bravo", 16, 176.5)) - assert(out1(1) === Row("Jessie", null, 164.3)) - assert(out1(2) === Row(null, 60, null)) + assert(out1(0) === Row("Bob", 61, 176.5)) + assert(out1(1) === Row("Alice", null, 164.3)) + assert(out1(2) === Row("David", null, null)) assert(out1(3).get(2).asInstanceOf[Double].isNaN) assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null)) From 37dfaa7e6ead619d6e3ca721327cc8a1abb7e3d0 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Thu, 3 Aug 2017 12:16:21 -0700 Subject: [PATCH 09/17] Simplify all_of_type logic --- python/pyspark/sql/dataframe.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8c54c5657cd46..26da8885bec82 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1390,10 +1390,9 @@ def all_of_(xs): return all(isinstance(x, types) for x in xs) return all_of_ - # Replacement key and value must have the same type while value can have None - all_of_bool = (all_of(bool), all_of((bool, type(None)))) - all_of_str = (all_of(basestring), all_of((basestring, type(None)))) - all_of_numeric = (all_of((float, int, long)), all_of((float, int, long, type(None)))) + all_of_bool = all_of(bool) + all_of_str = all_of(basestring) + all_of_numeric = all_of((float, int, long)) # Validate input types valid_types = (bool, float, int, long, basestring, list, tuple) @@ -1434,9 +1433,9 @@ def all_of_(xs): subset = [subset] # Verify we were not passed in mixed type generics." - if not any(key_all_of_type(rep_dict.keys()) and value_all_of_type(rep_dict.values()) - for (key_all_of_type, value_all_of_type) - in [all_of_bool, all_of_str, all_of_numeric]): + if not any(all_of_type(rep_dict.keys()) + and all_of_type(x for x in rep_dict.values() if x is not None) + for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): raise ValueError("Mixed type replacements are not supported") if subset is None: From fcb617e50e909d41f4c6458e7e2b5d85f8a33832 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Thu, 3 Aug 2017 15:00:40 -0700 Subject: [PATCH 10/17] Throw exception when column is not nullable --- .../spark/sql/DataFrameNaFunctions.scala | 4 ++++ .../spark/sql/DataFrameNaFunctionsSuite.scala | 23 +++++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index bd1021e38fab6..47f68a7116773 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -429,6 +429,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * TODO: This can be optimized to use broadcast join when replacementMap is large. */ private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { + if (!col.nullable && !replacementMap.values.forall(_ != null)) { + throw new IllegalArgumentException(s"Column '${col.name}' is not nullable " + + s"and can not be replaced to null.") + } val keyExpr = df.col(col.name).expr def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) val branches = replacementMap.flatMap { case (source, target) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 103ffc77659bd..59cb2639158e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { @@ -206,10 +207,10 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { } test("replace") { - val input = createDF() + val input1 = createDF() // Replace two numeric columns: age and height - val out = input.na.replace(Seq("age", "height"), Map( + val out = input1.na.replace(Seq("age", "height"), Map( 16 -> 61, 60 -> 6, 164.3 -> 461.3 // Alice is really tall @@ -222,8 +223,8 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out(4) === Row("Amy", null, null)) assert(out(5) === Row(null, null, null)) - // Replace only the age column - val out1 = input.na.replace("age", Map[Any, Any]( + // Replace only the age column and with null + val out1 = input1.na.replace("age", Map[Any, Any]( 16 -> 61, 60 -> null, 164.3 -> 461.3 // Alice is really tall @@ -235,5 +236,19 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out1(3).get(2).asInstanceOf[Double].isNaN) assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null)) + + // Replace with null on a column that is not nullable + val rows = spark.sparkContext.parallelize(Seq( + Row("Bravo", 28, 183.5), + Row("Jessie", 18, 165.8))) + val schema = StructType(Seq( + StructField("name", StringType, nullable = false), + StructField("age", IntegerType, nullable = true), + StructField("height", DoubleType, nullable = true))) + val input2 = spark.createDataFrame(rows, schema) + val message = intercept[IllegalArgumentException] { + input2.na.replace("name", Map("Bravo" -> null)) + }.getMessage + assert(message === "Column 'name' is not nullable and can not be replaced to null.") } } From dfbcaf3e47126eaf6fd3a0276054cad01dbff71a Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Thu, 3 Aug 2017 15:42:29 -0700 Subject: [PATCH 11/17] Piggybacking a minor improvement on code I recently pushed --- .../test/scala/org/apache/spark/sql/types/DataTypeSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 193826d66be26..1e272338aa659 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -145,8 +145,8 @@ class DataTypeSuite extends SparkFunSuite { val message = intercept[SparkException] { left.merge(right) }.getMessage - assert(message.equals("Failed to merge fields 'b' and 'b'. " + - "Failed to merge incompatible data types FloatType and LongType")) + assert(message === "Failed to merge fields 'b' and 'b'. " + + "Failed to merge incompatible data types FloatType and LongType") } test("existsRecursively") { From 8f7953bd756e30cc3d8a4b0bf56c9ba71ad169a0 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Fri, 4 Aug 2017 08:47:37 -0700 Subject: [PATCH 12/17] Revert "Throw exception when column is not nullable" This reverts commit fcb617e50e909d41f4c6458e7e2b5d85f8a33832. --- .../spark/sql/DataFrameNaFunctions.scala | 4 ---- .../spark/sql/DataFrameNaFunctionsSuite.scala | 23 ++++--------------- 2 files changed, 4 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 04c44432ac9fb..9bfe3311e2888 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -453,10 +453,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * TODO: This can be optimized to use broadcast join when replacementMap is large. */ private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = { - if (!col.nullable && !replacementMap.values.forall(_ != null)) { - throw new IllegalArgumentException(s"Column '${col.name}' is not nullable " + - s"and can not be replaced to null.") - } val keyExpr = df.col(col.name).expr def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType) val branches = replacementMap.flatMap { case (source, target) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index d6d28c2ae40df..300b5f57a16eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -233,10 +232,10 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { } test("replace") { - val input1 = createDF() + val input = createDF() // Replace two numeric columns: age and height - val out = input1.na.replace(Seq("age", "height"), Map( + val out = input.na.replace(Seq("age", "height"), Map( 16 -> 61, 60 -> 6, 164.3 -> 461.3 // Alice is really tall @@ -249,8 +248,8 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out(4) === Row("Amy", null, null)) assert(out(5) === Row(null, null, null)) - // Replace only the age column and with null - val out1 = input1.na.replace("age", Map[Any, Any]( + // Replace only the age column + val out1 = input.na.replace("age", Map[Any, Any]( 16 -> 61, 60 -> null, 164.3 -> 461.3 // Alice is really tall @@ -262,19 +261,5 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out1(3).get(2).asInstanceOf[Double].isNaN) assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null)) - - // Replace with null on a column that is not nullable - val rows = spark.sparkContext.parallelize(Seq( - Row("Bravo", 28, 183.5), - Row("Jessie", 18, 165.8))) - val schema = StructType(Seq( - StructField("name", StringType, nullable = false), - StructField("age", IntegerType, nullable = true), - StructField("height", DoubleType, nullable = true))) - val input2 = spark.createDataFrame(rows, schema) - val message = intercept[IllegalArgumentException] { - input2.na.replace("name", Map("Bravo" -> null)) - }.getMessage - assert(message === "Column 'name' is not nullable and can not be replaced to null.") } } From 3e3823fb493dff2c7e4513f01e6827d52857edee Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Fri, 4 Aug 2017 08:50:36 -0700 Subject: [PATCH 13/17] Improve a comment --- .../scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 300b5f57a16eb..590af91271a04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -248,7 +248,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out(4) === Row("Amy", null, null)) assert(out(5) === Row(null, null, null)) - // Replace only the age column + // Replace only the age column and with null val out1 = input.na.replace("age", Map[Any, Any]( 16 -> 61, 60 -> null, From 2946659887c3c3ad14f8f106f98aa994607a6e11 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Sun, 6 Aug 2017 09:47:44 -0700 Subject: [PATCH 14/17] Check value is None and new scala test --- python/pyspark/sql/dataframe.py | 5 +++-- .../spark/sql/DataFrameNaFunctions.scala | 2 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 21 +++++++++++++++---- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e441f91771fc3..0c55d6a6864b4 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1438,7 +1438,8 @@ def all_of_(xs): "to_replace should be a float, int, long, string, list, tuple, or dict. " "Got {0}".format(type(to_replace))) - if not isinstance(value, valid_types + (type(None), )) and not isinstance(to_replace, dict): + if not isinstance(value, valid_types) and value is not None \ + and not isinstance(to_replace, dict): raise ValueError("If to_replace is not a dict, value should be " "a float, int, long, string, list, tuple or None. " "Got {0}".format(type(value))) @@ -1456,7 +1457,7 @@ def all_of_(xs): if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(value, (float, int, long, basestring, type(None))): + if isinstance(value, (float, int, long, basestring)) or value is None: value = [value for _ in range(len(to_replace))] if isinstance(to_replace, dict): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 9bfe3311e2888..c9554334309e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -378,7 +378,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case (k: String, null) => (k, null) case (k: Boolean, null) => (k, null) case (k, null) => (convertToDouble(k), null) - case _ @(k, v) => (convertToDouble(k), convertToDouble(v)) + case (k, v) => (convertToDouble(k), convertToDouble(v)) } // targetColumnType is either DoubleType or StringType or BooleanType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 590af91271a04..d47338d180d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -248,18 +248,31 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out(4) === Row("Amy", null, null)) assert(out(5) === Row(null, null, null)) - // Replace only the age column and with null - val out1 = input.na.replace("age", Map[Any, Any]( + // Replace only the age column + val out1 = input.na.replace("age", Map( 16 -> 61, - 60 -> null, + 60 -> 6, 164.3 -> 461.3 // Alice is really tall )).collect() assert(out1(0) === Row("Bob", 61, 176.5)) assert(out1(1) === Row("Alice", null, 164.3)) - assert(out1(2) === Row("David", null, null)) + assert(out1(2) === Row("David", 6, null)) assert(out1(3).get(2).asInstanceOf[Double].isNaN) assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null)) + + // Replace with null + val out2 = input.na.replace("name", Map( + "Bob" -> "Bravo", + "Alice" -> null + )).collect() + + assert(out2(0) === Row("Bravo", 16, 176.5)) + assert(out2(1) === Row(null, null, 164.3)) + assert(out2(2) === Row("David", 60, null)) + assert(out2(3).get(2).asInstanceOf[Double].isNaN) + assert(out2(4) === Row("Amy", null, null)) + assert(out2(5) === Row(null, null, null)) } } From 351be99432218545307fe93f2400bdf2d6fe76e2 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Mon, 7 Aug 2017 21:49:04 -0700 Subject: [PATCH 15/17] More tests, better comments --- python/pyspark/sql/dataframe.py | 7 ++- python/pyspark/sql/tests.py | 11 ++-- .../spark/sql/types/DataTypeSuite.scala | 4 +- .../spark/sql/DataFrameNaFunctions.scala | 51 +++++++++---------- .../spark/sql/DataFrameNaFunctionsSuite.scala | 21 +++++++- 5 files changed, 58 insertions(+), 36 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0c55d6a6864b4..bcab7ec7b1fb9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1457,20 +1457,19 @@ def all_of_(xs): if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(value, (float, int, long, basestring)) or value is None: - value = [value for _ in range(len(to_replace))] - if isinstance(to_replace, dict): rep_dict = to_replace if value is not None: warnings.warn("to_replace is a dict and value is not None. value will be ignored.") else: + if isinstance(value, (float, int, long, basestring)) or value is None: + value = [value for _ in range(len(to_replace))] rep_dict = dict(zip(to_replace, value)) if isinstance(subset, basestring): subset = [subset] - # Verify we were not passed in mixed type generics." + # Verify we were not passed in mixed type generics. if not any(all_of_type(rep_dict.keys()) and all_of_type(x for x in rep_dict.values() if x is not None) for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d1af3df14b6ff..cf2c473a1645c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1964,12 +1964,17 @@ def test_replace(self): .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) - # replace with None + # replace list while value is not given (default to None) row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).first() + [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() self.assertTupleEqual(row, (None, 10, 80.0)) - # replace with numerics and None + # replace string with None and then drop None rows + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() + self.assertEqual(row.count(), 0) + + # replace with number and None row = self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first() self.assertTupleEqual(row, (u'Alice', 20, None)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 1e272338aa659..193826d66be26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -145,8 +145,8 @@ class DataTypeSuite extends SparkFunSuite { val message = intercept[SparkException] { left.merge(right) }.getMessage - assert(message === "Failed to merge fields 'b' and 'b'. " + - "Failed to merge incompatible data types FloatType and LongType") + assert(message.equals("Failed to merge fields 'b' and 'b'. " + + "Failed to merge incompatible data types FloatType and LongType")) } test("existsRecursively") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index c9554334309e6..e068df3586f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -260,10 +260,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and - * can only be doubles, strings or booleans. - * `replacement` map value can have null. - * If `col` is "*", then the replacement is applied on all string columns or numeric columns. * * {{{ * import com.google.common.collect.ImmutableMap; @@ -278,8 +274,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * - * @param col name of the column to apply the value replacement - * @param replacement value replacement map, as explained above + * @param col name of the column to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -289,9 +288,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and - * can only be doubles, strings or booleans. - * `replacement` map value can have null. * * {{{ * import com.google.common.collect.ImmutableMap; @@ -303,8 +299,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * - * @param cols list of columns to apply the value replacement - * @param replacement value replacement map, as explained above + * @param cols list of columns to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -314,11 +313,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and - * can only be doubles, strings or booleans. - * `replacement` map value can have null. - * If `col` is "*", - * then the replacement is applied on all string columns , numeric columns or boolean columns. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height". @@ -331,8 +325,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace("*", Map("UNKNOWN" -> "unnamed")); * }}} * - * @param col name of the column to apply the value replacement - * @param replacement value replacement map, as explained above + * @param col name of the column to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -346,9 +343,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and - * can only be doubles, strings or booleans. - * `replacement` map value can have null. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". @@ -358,8 +352,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); * }}} * - * @param cols list of columns to apply the value replacement - * @param replacement value replacement map, as explained above + * @param cols list of columns to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -370,8 +367,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { return df } - // replacementMap is either Map[String, String], Map[Double, Double], Map[Boolean,Boolean] - // while value can have null + // Convert the NumericType in replacement map to DoubleType, + // while leaving StringType, BooleanType and null untouched. val replacementMap: Map[_, _] = replacement.map { case (k, v: String) => (k, v) case (k, v: Boolean) => (k, v) @@ -381,7 +378,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case (k, v) => (convertToDouble(k), convertToDouble(v)) } - // targetColumnType is either DoubleType or StringType or BooleanType + // targetColumnType is either DoubleType, StringType or BooleanType, + // depending on the type of first key in replacement map. + // Only fields of targetColumnType will perform replacement. val targetColumnType = replacement.head._1 match { case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType case _: jl.Boolean => BooleanType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index d47338d180d6c..5a44b6d8c1103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -262,7 +262,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null)) - // Replace with null + // Replace String with String and null val out2 = input.na.replace("name", Map( "Bob" -> "Bravo", "Alice" -> null @@ -274,5 +274,24 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out2(3).get(2).asInstanceOf[Double].isNaN) assert(out2(4) === Row("Amy", null, null)) assert(out2(5) === Row(null, null, null)) + + // Replace Double with null + val out3 = input.na.replace("age", Map[Any, Any]( + 16 -> null + )).collect() + + assert(out3(0) === Row("Bob", null, 176.5)) + assert(out3(1) === Row("Alice", null, 164.3)) + assert(out3(2) === Row("David", 60, null)) + assert(out3(3).get(2).asInstanceOf[Double].isNaN) + assert(out3(4) === Row("Amy", null, null)) + assert(out3(5) === Row(null, null, null)) + + // Replace String with null and then drop rows containing null + checkAnswer( + input.na.replace("name", Map( + "Bob" -> null + )).na.drop("name" :: Nil).select("name"), + Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil) } } From a09d3e987dda4fe5b97a13e76cf9855f346b3eb8 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Tue, 8 Aug 2017 00:33:54 -0700 Subject: [PATCH 16/17] Add bool to comment and Error text --- python/pyspark/sql/dataframe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bcab7ec7b1fb9..edc7ca6f5146f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1373,8 +1373,8 @@ def replace(self, to_replace, value=None, subset=None): Value to be replaced. If the value is a dict, then `value` is ignored and `to_replace` must be a mapping between a value and a replacement. - :param value: int, long, float, string, list or None. - The replacement value must be an int, long, float, string or None. If `value` is a + :param value: bool, int, long, float, string, list or None. + The replacement value must be a bool, int, long, float, string or None. If `value` is a list, `value` should be of the same length and type as `to_replace`. If `value` is a scalar and `to_replace` is a sequence, then `value` is used as a replacement for each item in `to_replace`. @@ -1435,13 +1435,13 @@ def all_of_(xs): valid_types = (bool, float, int, long, basestring, list, tuple) if not isinstance(to_replace, valid_types + (dict, )): raise ValueError( - "to_replace should be a float, int, long, string, list, tuple, or dict. " + "to_replace should be a bool, float, int, long, string, list, tuple, or dict. " "Got {0}".format(type(to_replace))) if not isinstance(value, valid_types) and value is not None \ and not isinstance(to_replace, dict): raise ValueError("If to_replace is not a dict, value should be " - "a float, int, long, string, list, tuple or None. " + "a bool, float, int, long, string, list, tuple or None. " "Got {0}".format(type(value))) if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): From bc7a231c6aef2ca419d7929bf75e02f529b27da4 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Wed, 9 Aug 2017 08:48:04 -0700 Subject: [PATCH 17/17] Separate test and boolean test --- .../spark/sql/DataFrameNaFunctionsSuite.scala | 55 +++++++++++-------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 5a44b6d8c1103..e6983b6be555a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -261,37 +261,48 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out1(3).get(2).asInstanceOf[Double].isNaN) assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null)) + } - // Replace String with String and null - val out2 = input.na.replace("name", Map( - "Bob" -> "Bravo", - "Alice" -> null - )).collect() + test("replace with null") { + val input = Seq[(String, java.lang.Double, java.lang.Boolean)]( + ("Bob", 176.5, true), + ("Alice", 164.3, false), + ("David", null, true) + ).toDF("name", "height", "married") - assert(out2(0) === Row("Bravo", 16, 176.5)) - assert(out2(1) === Row(null, null, 164.3)) - assert(out2(2) === Row("David", 60, null)) - assert(out2(3).get(2).asInstanceOf[Double].isNaN) - assert(out2(4) === Row("Amy", null, null)) - assert(out2(5) === Row(null, null, null)) + // Replace String with String and null + checkAnswer( + input.na.replace("name", Map( + "Bob" -> "Bravo", + "Alice" -> null + )), + Row("Bravo", 176.5, true) :: + Row(null, 164.3, false) :: + Row("David", null, true) :: Nil) // Replace Double with null - val out3 = input.na.replace("age", Map[Any, Any]( - 16 -> null - )).collect() - - assert(out3(0) === Row("Bob", null, 176.5)) - assert(out3(1) === Row("Alice", null, 164.3)) - assert(out3(2) === Row("David", 60, null)) - assert(out3(3).get(2).asInstanceOf[Double].isNaN) - assert(out3(4) === Row("Amy", null, null)) - assert(out3(5) === Row(null, null, null)) + checkAnswer( + input.na.replace("height", Map[Any, Any]( + 164.3 -> null + )), + Row("Bob", 176.5, true) :: + Row("Alice", null, false) :: + Row("David", null, true) :: Nil) + + // Replace Boolean with null + checkAnswer( + input.na.replace("*", Map[Any, Any]( + false -> null + )), + Row("Bob", 176.5, true) :: + Row("Alice", 164.3, null) :: + Row("David", null, true) :: Nil) // Replace String with null and then drop rows containing null checkAnswer( input.na.replace("name", Map( "Bob" -> null )).na.drop("name" :: Nil).select("name"), - Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil) + Row("Alice") :: Row("David") :: Nil) } }