-
Notifications
You must be signed in to change notification settings - Fork 29.3k
[SPARK-13961][ML] spark.ml ChiSqSelector and RFormula should support other numeric types for label #12467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-13961][ML] spark.ml ChiSqSelector and RFormula should support other numeric types for label #12467
Changes from all commits
fc494d3
69b6470
8c843ef
a8e5aa0
ce19549
3786ef9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,10 +20,10 @@ package org.apache.spark.ml.feature | |
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.attribute._ | ||
| import org.apache.spark.ml.param.ParamsSuite | ||
| import org.apache.spark.ml.util.DefaultReadWriteTest | ||
| import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} | ||
| import org.apache.spark.mllib.linalg.Vectors | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.types.DoubleType | ||
|
|
||
| class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { | ||
| test("params") { | ||
|
|
@@ -68,9 +68,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul | |
| assert(resultSchema.toString == model.transform(original).schema.toString) | ||
| } | ||
|
|
||
| test("label column already exists but is not double type") { | ||
| test("label column already exists but is not numeric type") { | ||
| val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") | ||
| val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") | ||
| val original = spark.createDataFrame(Seq((0, true), (2, false))).toDF("x", "y") | ||
| val model = formula.fit(original) | ||
| intercept[IllegalArgumentException] { | ||
| model.transformSchema(original.schema) | ||
|
|
@@ -134,7 +134,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul | |
| ).toDF("id", "a", "b") | ||
| val model = formula.fit(original) | ||
| val result = model.transform(original) | ||
| val resultSchema = model.transformSchema(original.schema) | ||
| val expected = spark.createDataFrame( | ||
| Seq( | ||
| ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), | ||
|
|
@@ -188,7 +187,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul | |
| "vec2", | ||
| Array[Attribute]( | ||
| NumericAttribute.defaultAttr, | ||
| NumericAttribute.defaultAttr)).toMetadata | ||
| NumericAttribute.defaultAttr)).toMetadata() | ||
| val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) | ||
| val model = formula.fit(original) | ||
| val result = model.transform(original) | ||
|
|
@@ -309,4 +308,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul | |
| val newModel = testDefaultReadWrite(model) | ||
| checkModelData(model, newModel) | ||
| } | ||
|
|
||
| test("should support all NumericType labels") { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It'd work expect for the expected exception when dealing with a dataframe containing string labels because the label column gets indexed by RFormula's What I could do is add a What do you think?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or simply just: val schema = dataset.schema
SchemaUtils.checkNumericType(schema, $(labelCol))at the beginning of RFormula's
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After reviewing the suite, I don't think the same tests apply since RFormula also accepts string labels. Consequently, I think it's best as is.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BenFradet Sorry for late response.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, thanks for your input. |
||
| val formula = new RFormula().setFormula("label ~ features") | ||
| .setLabelCol("x") | ||
| .setFeaturesCol("y") | ||
| val dfs = MLTestingUtils.genRegressionDFWithNumericLabelCol(spark) | ||
| val expected = formula.fit(dfs(DoubleType)) | ||
| val actuals = dfs.keys.filter(_ != DoubleType).map(t => formula.fit(dfs(t))) | ||
| actuals.foreach { actual => | ||
| assert(expected.pipelineModel.stages.length === actual.pipelineModel.stages.length) | ||
| expected.pipelineModel.stages.zip(actual.pipelineModel.stages).foreach { | ||
| case (exTransformer, acTransformer) => | ||
| assert(exTransformer.params === acTransformer.params) | ||
| } | ||
| assert(expected.resolvedFormula.label === actual.resolvedFormula.label) | ||
| assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms) | ||
| assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept) | ||
| } | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yanboliang is this what you had in mind?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the
||not be&&?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so no. What do you think @yanboliang ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 @BenFradet It should be
||Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e.g. before this PR this works (and I don't believe it's supposed to?).
And to make it clear that this check is not actually being performed:
... so it's catching it, but at L244 not here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I see now, never mind.