From 86abf481949d61cbcc726bcadfa91d12686846d6 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 19 Sep 2017 21:12:42 +0800 Subject: [PATCH 1/3] init pr --- .../spark/ml/tuning/CrossValidator.scala | 15 ++++++------- .../ml/tuning/TrainValidationSplit.scala | 16 ++++++-------- .../spark/ml/tuning/ValidatorParams.scala | 22 +++++++------------ .../org/apache/spark/ml/util/ReadWrite.scala | 11 ++++++---- .../spark/ml/tuning/CrossValidatorSuite.scala | 3 +++ .../ml/tuning/TrainValidationSplitSuite.scala | 4 +++- 6 files changed, 35 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index ce2a3a2e40411..18cf978138080 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tuning +import java.io.IOException import java.util.{List => JList} import scala.collection.JavaConverters._ @@ -212,14 +213,12 @@ object CrossValidator extends MLReadable[CrossValidator] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val numFolds = (metadata.params \ "numFolds").extract[Int] - val seed = (metadata.params \ "seed").extract[Long] - new CrossValidator(metadata.uid) + val cv = new CrossValidator(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - .setNumFolds(numFolds) - .setSeed(seed) + DefaultParamsReader.getAndSetParams(cv, metadata, skipParams = List("estimatorParamMaps")) + cv } } } @@ -303,16 +302,16 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) val numFolds = (metadata.params \ "numFolds").extract[Int] - val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - .set(model.numFolds, numFolds) - .set(model.seed, seed) + DefaultParamsReader.getAndSetParams(model, metadata, skipParams = List("estimatorParamMaps")) + model } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 16db0f5f12c77..811ee4da59e15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tuning +import java.io.IOException import java.util.{List => JList} import scala.collection.JavaConverters._ @@ -207,14 +208,12 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val trainRatio = (metadata.params \ "trainRatio").extract[Double] - val seed = (metadata.params \ "seed").extract[Long] - new TrainValidationSplit(metadata.uid) + val tvs = new TrainValidationSplit(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - .setTrainRatio(trainRatio) - .setSeed(seed) + DefaultParamsReader.getAndSetParams(tvs, metadata, skipParams = List("estimatorParamMaps")) + tvs } } } @@ -295,17 +294,16 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val trainRatio = (metadata.params \ "trainRatio").extract[Double] - val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray + val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - .set(model.trainRatio, trainRatio) - .set(model.seed, seed) + DefaultParamsReader.getAndSetParams(model, metadata, skipParams = List("estimatorParamMaps")) + model } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 0ab6eed959381..363304ef10147 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -150,20 +150,14 @@ private[ml] object ValidatorParams { }.toSeq )) - val validatorSpecificParams = instance match { - case cv: CrossValidatorParams => - List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds))) - case tvs: TrainValidationSplitParams => - List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio))) - case _ => - // This should not happen. - throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " + - instance.getClass.getCanonicalName) - } - - val jsonParams = validatorSpecificParams ++ List( - "estimatorParamMaps" -> parse(estimatorParamMapsJson), - "seed" -> parse(instance.seed.jsonEncode(instance.getSeed))) + val params = instance.extractParamMap().toSeq + val skipParams = List("estimator", "evaluator", "estimatorParamMaps") + val jsonParams = render(params + .filter { case ParamPair(p, v) => !skipParams.contains(p.name)} + .map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson)) + ) DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 65f142cfbbcb6..dcbba18ec26d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -399,14 +399,17 @@ private[ml] object DefaultParamsReader { * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. * TODO: Move to [[Metadata]] method */ - def getAndSetParams(instance: Params, metadata: Metadata): Unit = { + def getAndSetParams(instance: Params, metadata: Metadata, + skipParams: List[String] = null): Unit = { implicit val format = DefaultFormats metadata.params match { case JObject(pairs) => pairs.foreach { case (paramName, jsonValue) => - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) + if (skipParams == null || !skipParams.contains(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } } case _ => throw new IllegalArgumentException( diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index a8d4377cff2d1..a01744f7b67fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -159,12 +159,15 @@ class CrossValidatorSuite .setEvaluator(evaluator) .setNumFolds(20) .setEstimatorParamMaps(paramMaps) + .setSeed(42L) + .setParallelism(2) val cv2 = testDefaultReadWrite(cv, testParams = false) assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) assert(cv.getSeed === cv2.getSeed) + assert(cv.getParallelism === cv2.getParallelism) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 74801733381c1..2ed4fbb601b61 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.param.{ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -160,11 +160,13 @@ class TrainValidationSplitSuite .setTrainRatio(0.5) .setEstimatorParamMaps(paramMaps) .setSeed(42L) + .setParallelism(2) val tvs2 = testDefaultReadWrite(tvs, testParams = false) assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.getSeed === tvs2.getSeed) + assert(tvs.getParallelism === tvs2.getParallelism) ValidatorParamsSuiteHelpers .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) From cc30578d2d25d3345821793fcf2ce030cf991a92 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 20 Sep 2017 09:44:58 +0800 Subject: [PATCH 2/3] update --- .../org/apache/spark/ml/tuning/CrossValidator.scala | 7 ++++--- .../apache/spark/ml/tuning/TrainValidationSplit.scala | 6 ++++-- .../scala/org/apache/spark/ml/util/ReadWrite.scala | 10 +++++++--- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 18cf978138080..1363d4d849953 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.tuning -import java.io.IOException import java.util.{List => JList} import scala.collection.JavaConverters._ @@ -217,7 +216,8 @@ object CrossValidator extends MLReadable[CrossValidator] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(cv, metadata, skipParams = List("estimatorParamMaps")) + DefaultParamsReader.getAndSetParams(cv, metadata, + skipParams = Option(List("estimatorParamMaps"))) cv } } @@ -310,7 +310,8 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, skipParams = List("estimatorParamMaps")) + DefaultParamsReader.getAndSetParams(model, metadata, + skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 811ee4da59e15..6e3ad40706803 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -212,7 +212,8 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(tvs, metadata, skipParams = List("estimatorParamMaps")) + DefaultParamsReader.getAndSetParams(tvs, metadata, + skipParams = Option(List("estimatorParamMaps"))) tvs } } @@ -302,7 +303,8 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, skipParams = List("estimatorParamMaps")) + DefaultParamsReader.getAndSetParams(model, metadata, + skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index dcbba18ec26d4..571a1d9469115 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -396,16 +396,20 @@ private[ml] object DefaultParamsReader { /** * Extract Params from metadata, and set them in the instance. - * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * This works if all Params (except params included by `skipParams` list) implement + * [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * + * The params included in `skipParams` won't be set. This is useful if some params don't + * implement [[org.apache.spark.ml.param.Param.jsonDecode()]] and need special handling. * TODO: Move to [[Metadata]] method */ def getAndSetParams(instance: Params, metadata: Metadata, - skipParams: List[String] = null): Unit = { + skipParams: Option[List[String]] = None): Unit = { implicit val format = DefaultFormats metadata.params match { case JObject(pairs) => pairs.foreach { case (paramName, jsonValue) => - if (skipParams == null || !skipParams.contains(paramName)) { + if (skipParams == None || !skipParams.get.contains(paramName)) { val param = instance.getParam(paramName) val value = param.jsonDecode(compact(render(jsonValue))) instance.set(param, value) From 8f78f596473877f3e8a0169f998f16a6bf1a8f5a Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 22 Sep 2017 10:12:45 +0800 Subject: [PATCH 3/3] update --- .../org/apache/spark/ml/tuning/CrossValidator.scala | 1 - .../main/scala/org/apache/spark/ml/util/ReadWrite.scala | 9 ++++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 1363d4d849953..7c81cb96e07f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -301,7 +301,6 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val numFolds = (metadata.params \ "numFolds").extract[Int] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 571a1d9469115..7188da3531267 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -399,11 +399,14 @@ private[ml] object DefaultParamsReader { * This works if all Params (except params included by `skipParams` list) implement * [[org.apache.spark.ml.param.Param.jsonDecode()]]. * - * The params included in `skipParams` won't be set. This is useful if some params don't - * implement [[org.apache.spark.ml.param.Param.jsonDecode()]] and need special handling. + * @param skipParams The params included in `skipParams` won't be set. This is useful if some + * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] + * and need special handling. * TODO: Move to [[Metadata]] method */ - def getAndSetParams(instance: Params, metadata: Metadata, + def getAndSetParams( + instance: Params, + metadata: Metadata, skipParams: Option[List[String]] = None): Unit = { implicit val format = DefaultFormats metadata.params match {