diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 8bcc9fe5d1b85..a329b8c4e8825 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -24,16 +24,11 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} -import org.apache.spark.ml.feature.{IndexToString, RFormula} -import org.apache.spark.ml.regression._ -import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.feature.{IndexToString, RFormula, SQLTransformer} import org.apache.spark.ml.r.RWrapperUtils._ +import org.apache.spark.ml.regression._ import org.apache.spark.ml.util._ import org.apache.spark.sql._ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ private[r] class GeneralizedLinearRegressionWrapper private ( val pipeline: PipelineModel, @@ -114,9 +109,9 @@ private[r] object GeneralizedLinearRegressionWrapper .setLabelCol(rFormula.getLabelCol) val pipeline = if (family == "binomial") { // Convert prediction from probability to label index. - val probToPred = new ProbabilityToPrediction() - .setInputCol(PREDICTED_LABEL_PROB_COL) - .setOutputCol(PREDICTED_LABEL_INDEX_COL) + val statement = + s"SELECT *, ROUND($PREDICTED_LABEL_PROB_COL) AS $PREDICTED_LABEL_INDEX_COL FROM __THIS__" + val probToPred = new SQLTransformer().setStatement(statement) // Convert prediction from label index to original label. val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) .asInstanceOf[NominalAttribute] @@ -248,27 +243,3 @@ private[r] object GeneralizedLinearRegressionWrapper } } } - -/** - * This utility transformer converts the predicted value of GeneralizedLinearRegressionModel - * with "binomial" family from probability to prediction according to threshold 0.5. - */ -private[r] class ProbabilityToPrediction private[r] (override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { - - def this() = this(Identifiable.randomUID("probToPred")) - - def setInputCol(value: String): this.type = set(inputCol, value) - - def setOutputCol(value: String): this.type = set(outputCol, value) - - override def transformSchema(schema: StructType): StructType = { - StructType(schema.fields :+ StructField($(outputCol), DoubleType)) - } - - override def transform(dataset: Dataset[_]): DataFrame = { - dataset.withColumn($(outputCol), round(col($(inputCol)))) - } - - override def copy(extra: ParamMap): ProbabilityToPrediction = defaultCopy(extra) -}