From 5d79149d19fa945daed5d19918cfe6068e76acd3 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 11 Sep 2019 16:09:46 +0800 Subject: [PATCH 1/3] init pr --- .../src/main/scala/org/apache/spark/sql/Column.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b0de3c85aaef8..4419b316f338e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ private[sql] object Column { @@ -808,7 +809,14 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.4.0 */ - def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) + def isInCollection(values: scala.collection.Iterable[_]): Column = withExpr { + val hSet = values.toSet[Any] + if (hSet.size > SQLConf.get.optimizerInSetConversionThreshold) { + InSet(expr, hSet) + } else { + In(expr, hSet.toSeq.map(lit(_).expr)) + } + } /** * A boolean expression that is evaluated to true if the value of this expression is contained From ab3e5d4e5119ec05553c9a2f8cdf3b6544f699ed Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 12 Sep 2019 17:42:47 +0800 Subject: [PATCH 2/3] address comments --- .../main/scala/org/apache/spark/sql/Column.scala | 2 +- .../apache/spark/sql/ColumnExpressionSuite.scala | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 4419b316f338e..7b903a3f7f148 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -814,7 +814,7 @@ class Column(val expr: Expression) extends Logging { if (hSet.size > SQLConf.get.optimizerInSetConversionThreshold) { InSet(expr, hSet) } else { - In(expr, hSet.toSeq.map(lit(_).expr)) + In(expr, values.toSeq.map(lit(_).expr)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index a52c6d503d147..0cc3f36d770cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ -import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.expressions.{In, InSet, NamedExpression} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -454,6 +454,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("isInCollection: Scala Collection") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + + // Test with different types of collections checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) @@ -464,6 +466,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[In], "Expect expr to be In") + + withSQLConf(SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "1") { + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + + assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[InSet], "Expect expr to be InSet") + } + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") val e = intercept[AnalysisException] { From b60cd94c0f650732f7f79cbbe02a083b56d6bd14 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 12 Sep 2019 11:29:02 -0700 Subject: [PATCH 3/3] Update ColumnExpressionSuite.scala --- .../spark/sql/ColumnExpressionSuite.scala | 52 +++++++++---------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 0cc3f36d770cf..c346c8946a972 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -455,37 +455,35 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("isInCollection: Scala Collection") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + Seq(1, 2).foreach { conf => + withSQLConf(SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> conf.toString) { + if (conf <= 1) { + assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[InSet], "Expect expr to be InSet") + } else { + assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[In], "Expect expr to be In") + } - // Test with different types of collections - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - - assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[In], "Expect expr to be In") - - withSQLConf(SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "1") { - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - - assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[InSet], "Expect expr to be InSet") - } + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - val e = intercept[AnalysisException] { - df2.filter($"a".isInCollection(Seq($"b"))) - } - Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") - .foreach { s => - assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b"))) + } + Seq("cannot resolve", + "due to data type mismatch: Arguments must be same type but were").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } } + } } test("&&") {