Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ final class ChiSqSelector(override val uid: String)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(labelCol), $(featuresCol)).rdd.map {
val input = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}
Expand All @@ -90,7 +90,7 @@ final class ChiSqSelector(override val uid: String)

override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(labelCol))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ class RFormulaModel private[feature](
val columnNames = schema.map(_.name)
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
require(
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
"Label column already exists and is not of type DoubleType.")
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the || not be &&?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so no. What do you think @yanboliang ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 @BenFradet It should be ||

@MLnick MLnick Apr 22, 2016

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. before this PR this works (and I don't believe it's supposed to?).

scala> val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
original: org.apache.spark.sql.DataFrame = [x: int, y: int]

scala> formula.fit(original).transform(original).show
+---+---+--------+-----+
|  x|  y|features|label|
+---+---+--------+-----+
|  0|  1|   [0.0]|  1.0|
|  2|  2|   [2.0]|  2.0|
+---+---+--------+-----+

And to make it clear that this check is not actually being performed:

scala> val original = sqlContext.createDataFrame(Seq((0, Seq(1)), (2, Seq(2)))).toDF("x", "y")
original: org.apache.spark.sql.DataFrame = [x: int, y: array<int>]

scala> formula.fit(original).transform(original).show
java.lang.IllegalArgumentException: Unsupported type for label: ArrayType(IntegerType,false)
  at org.apache.spark.ml.feature.RFormulaModel.transformLabel(RFormula.scala:246)
  at org.apache.spark.ml.feature.RFormulaModel.transform(RFormula.scala:211)
  ... 48 elided

... so it's catching it, but at L244 not here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see now, never mind.

"Label column already exists and is not of type NumericType.")
}

@Since("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class DecisionTreeClassifierSuite
test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
dt, isClassification = true, spark) { (expected, actual) =>
dt, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") {
val gbt = new GBTClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
gbt, isClassification = true, spark) { (expected, actual) =>
gbt, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ class LogisticRegressionSuite
test("should support all NumericType labels and not support other types") {
val lr = new LogisticRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
lr, isClassification = true, spark) { (expected, actual) =>
lr, spark) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients.toArray === actual.coefficients.toArray)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class MultilayerPerceptronClassifierSuite
val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
MLTestingUtils.checkNumericTypes[
MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
mpc, isClassification = true, spark) { (expected, actual) =>
mpc, spark) { (expected, actual) =>
assert(expected.layers === actual.layers)
assert(expected.weights === actual.weights)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
test("should support all NumericType labels and not support other types") {
val nb = new NaiveBayes()
MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
nb, isClassification = true, spark) { (expected, actual) =>
nb, spark) { (expected, actual) =>
assert(expected.pi === actual.pi)
assert(expected.theta === actual.theta)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("should support all NumericType labels and not support other types") {
val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
ovr, isClassification = true, spark) { (expected, actual) =>
ovr, spark) { (expected, actual) =>
val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
assert(expectedModels.length === actualModels.length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class RandomForestClassifierSuite
test("should support all NumericType labels and not support other types") {
val rf = new RandomForestClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
rf, isClassification = true, spark) { (expected, actual) =>
rf, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -81,4 +81,12 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.selectedFeatures === instance.selectedFeatures)
}

test("should support all NumericType labels and not support other types") {
val css = new ChiSqSelector()
MLTestingUtils.checkNumericTypes[ChiSqSelectorModel, ChiSqSelector](
css, spark) { (expected, actual) =>
assert(expected.selectedFeatures === actual.selectedFeatures)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.DoubleType

class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
Expand Down Expand Up @@ -68,9 +68,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(resultSchema.toString == model.transform(original).schema.toString)
}

test("label column already exists but is not double type") {
test("label column already exists but is not numeric type") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
val original = spark.createDataFrame(Seq((0, true), (2, false))).toDF("x", "y")
val model = formula.fit(original)
intercept[IllegalArgumentException] {
model.transformSchema(original.schema)
Expand Down Expand Up @@ -134,7 +134,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = spark.createDataFrame(
Seq(
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
Expand Down Expand Up @@ -188,7 +187,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
"vec2",
Array[Attribute](
NumericAttribute.defaultAttr,
NumericAttribute.defaultAttr)).toMetadata
NumericAttribute.defaultAttr)).toMetadata()
val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
val model = formula.fit(original)
val result = model.transform(original)
Expand Down Expand Up @@ -309,4 +308,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val newModel = testDefaultReadWrite(model)
checkModelData(model, newModel)
}

test("should support all NumericType labels") {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use MLTestingUtils.checkNumericTypes to test this? It will eliminate some redundant code.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd work expect for the expected exception when dealing with a dataframe containing string labels because the label column gets indexed by RFormula's fit.
Consequently, an exception is thrown by StringIndexer.

What I could do is add a validateSchema to RFormula (called at the beginiing of the fit method) checking that the label column is of numeric type, then I could use MLTestingUtils.checkNumericTypes.

What do you think?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or simply just:

val schema = dataset.schema
SchemaUtils.checkNumericType(schema, $(labelCol))

at the beginning of RFormula's fit method.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reviewing the suite, I don't think the same tests apply since RFormula also accepts string labels.

Consequently, I think it's best as is.

@yanboliang yanboliang Apr 28, 2016

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenFradet Sorry for late response.
I'm OK with what you have done here for the issue mentioned above. Could you add more tests for RFormulaModel equality check? Here you have checked resolvedFormula which is produced by RFormulaParser rather than the entire RFormula. It's better also check the equality of pipelineModel of RFormulaModel.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thanks for your input.

val formula = new RFormula().setFormula("label ~ features")
.setLabelCol("x")
.setFeaturesCol("y")
val dfs = MLTestingUtils.genRegressionDFWithNumericLabelCol(spark)
val expected = formula.fit(dfs(DoubleType))
val actuals = dfs.keys.filter(_ != DoubleType).map(t => formula.fit(dfs(t)))
actuals.foreach { actual =>
assert(expected.pipelineModel.stages.length === actual.pipelineModel.stages.length)
expected.pipelineModel.stages.zip(actual.pipelineModel.stages).foreach {
case (exTransformer, acTransformer) =>
assert(exTransformer.params === acTransformer.params)
}
assert(expected.resolvedFormula.label === actual.resolvedFormula.label)
assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms)
assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept)
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanboliang is this what you had in mind?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ class AFTSurvivalRegressionSuite
test("should support all NumericType labels") {
val aft = new AFTSurvivalRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
aft, isClassification = false, spark) { (expected, actual) =>
aft, spark, isClassification = false) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class DecisionTreeRegressorSuite
test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](
dt, isClassification = false, spark) { (expected, actual) =>
dt, spark, isClassification = false) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") {
val gbt = new GBTRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor](
gbt, isClassification = false, spark) { (expected, actual) =>
gbt, spark, isClassification = false) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ class GeneralizedLinearRegressionSuite
val glr = new GeneralizedLinearRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[
GeneralizedLinearRegressionModel, GeneralizedLinearRegression](
glr, isClassification = false, spark) { (expected, actual) =>
glr, spark, isClassification = false) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class IsotonicRegressionSuite
test("should support all NumericType labels and not support other types") {
val ir = new IsotonicRegression()
MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression](
ir, isClassification = false, spark) { (expected, actual) =>
ir, spark, isClassification = false) { (expected, actual) =>
assert(expected.boundaries === actual.boundaries)
assert(expected.predictions === actual.predictions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ class LinearRegressionSuite
test("should support all NumericType labels and not support other types") {
val lr = new LinearRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
lr, isClassification = false, spark) { (expected, actual) =>
lr, spark, isClassification = false) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
test("should support all NumericType labels and not support other types") {
val rf = new RandomForestRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor](
rf, isClassification = false, spark) { (expected, actual) =>
rf, spark, isClassification = false) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ object MLTestingUtils extends SparkFunSuite {

def checkNumericTypes[M <: Model[M], T <: Estimator[M]](
estimator: T,
isClassification: Boolean,
spark: SparkSession)(check: (M, M) => Unit): Unit = {
spark: SparkSession,
isClassification: Boolean = true)(check: (M, M) => Unit): Unit = {
val dfs = if (isClassification) {
genClassifDFWithNumericLabelCol(spark)
} else {
Expand Down