From e4b886732d5bb844c66116a4751dc9a92173a928 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 2 May 2018 05:17:22 +0000 Subject: [PATCH 1/2] Add tests and throw exception when we can't parse input. --- python/pyspark/tests.py | 9 +++++++++ python/pyspark/util.py | 34 +++++++++++++++++++--------------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8392d7f29af53..65a829d8cb1cc 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2303,6 +2303,15 @@ def test_py4j_exception_message(self): self.assertTrue('NullPointerException' in _exception_message(context.exception)) + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + + (major, minor) = VersionUtils.majorMinorVersion("2.4.0") + self.assertEqual(major, 2) + self.assertEqual(minor, 4) + + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 04df835bf6717..60c0d8ad2b2e0 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -62,24 +62,28 @@ def _get_argspec(f): return argspec -def majorMinorVersion(version): +class VersionUtils(object): """ - Get major and minor version numbers for given Spark version string. - - >>> version = "2.4.0" - >>> majorMinorVersion(version) - (2, 4) + Provides utility method to determine Spark versions with given input string. + """ + @staticmethod + def majorMinorVersion(version): + """ + Get major and minor version numbers for given Spark version string. - >>> version = "abc" - >>> majorMinorVersion(version) is None - True + >>> version = "2.4.0" + >>> majorMinorVersion(version) + (2, 4) + >>> version = "2.3.0-SNAPSHOT" + >>> majorMinorVersion(version) + (2, 3) - """ - m = re.search('^(\d+)\.(\d+)(\..*)?$', version) - if m is None: - return None - else: - return (int(m.group(1)), int(m.group(2))) + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', version) + if m is None: + raise ValueError("invalid version string: " + version) + else: + return (int(m.group(1)), int(m.group(2))) if __name__ == "__main__": From d2bcfe2a892d627002a4cac678712693986778c2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 2 May 2018 15:17:16 +0000 Subject: [PATCH 2/2] Address comments to make it consistent with Scala API. --- python/pyspark/tests.py | 5 ----- python/pyspark/util.py | 23 +++++++++++++---------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 65a829d8cb1cc..5066b3c4a2252 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2305,11 +2305,6 @@ def test_py4j_exception_message(self): def test_parsing_version_string(self): from pyspark.util import VersionUtils - - (major, minor) = VersionUtils.majorMinorVersion("2.4.0") - self.assertEqual(major, 2) - self.assertEqual(minor, 4) - self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 60c0d8ad2b2e0..59cc2a6329350 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -67,23 +67,26 @@ class VersionUtils(object): Provides utility method to determine Spark versions with given input string. """ @staticmethod - def majorMinorVersion(version): + def majorMinorVersion(sparkVersion): """ - Get major and minor version numbers for given Spark version string. + Given a Spark version string, return the (major version number, minor version number). + E.g., for 2.0.1-SNAPSHOT, return (2, 0). - >>> version = "2.4.0" - >>> majorMinorVersion(version) + >>> sparkVersion = "2.4.0" + >>> VersionUtils.majorMinorVersion(sparkVersion) (2, 4) - >>> version = "2.3.0-SNAPSHOT" - >>> majorMinorVersion(version) + >>> sparkVersion = "2.3.0-SNAPSHOT" + >>> VersionUtils.majorMinorVersion(sparkVersion) (2, 3) """ - m = re.search('^(\d+)\.(\d+)(\..*)?$', version) - if m is None: - raise ValueError("invalid version string: " + version) - else: + m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion) + if m is not None: return (int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Spark tried to parse '%s' as a Spark" % sparkVersion + + " version string, but it could not find the major and minor" + + " version numbers.") if __name__ == "__main__":