From c41421c6116e8c35c246b2ff79f7ba5f9f4a3731 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 19 Jul 2024 10:55:39 +0800 Subject: [PATCH 1/4] init Signed-off-by: Weichen Xu --- python/pyspark/ml/util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index b9a2829a1ca0b..a7e7eb03df9f0 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -464,7 +464,8 @@ def saveMetadata( metadataJson = DefaultParamsWriter._get_metadata_to_save( instance, sc, extraMetadata, paramMap ) - sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath) + spark = SparkSession.getActiveSession() + spark.createDataFrame([(metadataJson,)], schema=["value"]).repartition(1).write.text(metadataPath) @staticmethod def _get_metadata_to_save( @@ -577,7 +578,8 @@ def loadMetadata(path: str, sc: "SparkContext", expectedClassName: str = "") -> If non empty, this is checked against the loaded metadata. """ metadataPath = os.path.join(path, "metadata") - metadataStr = sc.textFile(metadataPath, 1).first() + spark = SparkSession.getActiveSession() + metadataStr = spark.read.text(metadataPath).first()[0] loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName) return loadedVals From ada650cb29d8aa85c68eb1491c88498f3769159b Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 19 Jul 2024 14:58:12 +0800 Subject: [PATCH 2/4] format Signed-off-by: Weichen Xu --- python/pyspark/ml/util.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a7e7eb03df9f0..e12c6bd2650aa 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -465,7 +465,9 @@ def saveMetadata( instance, sc, extraMetadata, paramMap ) spark = SparkSession.getActiveSession() - spark.createDataFrame([(metadataJson,)], schema=["value"]).repartition(1).write.text(metadataPath) + spark.createDataFrame([(metadataJson,)], schema=["value"]).repartition(1).write.text( + metadataPath + ) @staticmethod def _get_metadata_to_save( From 6e3be0503f7b5f92a701572bfccb1d08eec22063 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 19 Jul 2024 20:41:48 +0800 Subject: [PATCH 3/4] coalesce Signed-off-by: Weichen Xu --- python/pyspark/ml/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index e12c6bd2650aa..95896ddde2dc7 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -465,7 +465,7 @@ def saveMetadata( instance, sc, extraMetadata, paramMap ) spark = SparkSession.getActiveSession() - spark.createDataFrame([(metadataJson,)], schema=["value"]).repartition(1).write.text( + spark.createDataFrame([(metadataJson,)], schema=["value"]).coalesce(1).write.text( metadataPath ) From 22e5c7b5a09aaf6f4c4371d41322a0e4f19fad9c Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 22 Jul 2024 18:07:39 +0800 Subject: [PATCH 4/4] format Signed-off-by: Weichen Xu --- python/pyspark/ml/util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 95896ddde2dc7..5e7965554d825 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -465,9 +465,9 @@ def saveMetadata( instance, sc, extraMetadata, paramMap ) spark = SparkSession.getActiveSession() - spark.createDataFrame([(metadataJson,)], schema=["value"]).coalesce(1).write.text( - metadataPath - ) + spark.createDataFrame( # type: ignore[union-attr] + [(metadataJson,)], schema=["value"] + ).coalesce(1).write.text(metadataPath) @staticmethod def _get_metadata_to_save( @@ -581,7 +581,7 @@ def loadMetadata(path: str, sc: "SparkContext", expectedClassName: str = "") -> """ metadataPath = os.path.join(path, "metadata") spark = SparkSession.getActiveSession() - metadataStr = spark.read.text(metadataPath).first()[0] + metadataStr = spark.read.text(metadataPath).first()[0] # type: ignore[union-attr,index] loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName) return loadedVals