From 6ca93fa977dbe1940c54a58e3bcc451c564359a1 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Tue, 28 Apr 2020 10:12:40 +0300 Subject: [PATCH] Revert "[SPARK-29048] Improve performance on Column.isInCollection() with a large size collection" This reverts commit 5631a96367d2576e1e0f95d7ae529468da8f5fa8. --- .../scala/org/apache/spark/sql/Column.scala | 10 +---- .../spark/sql/ColumnExpressionSuite.scala | 45 +++++++------------ 2 files changed, 18 insertions(+), 37 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 50bc7ec5f2afa..6913d4e676d66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ private[sql] object Column { @@ -827,14 +826,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.4.0 */ - def isInCollection(values: scala.collection.Iterable[_]): Column = withExpr { - val hSet = values.toSet[Any] - if (hSet.size > SQLConf.get.optimizerInSetConversionThreshold) { - InSet(expr, hSet) - } else { - In(expr, values.toSeq.map(lit(_).expr)) - } - } + def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) /** * A boolean expression that is evaluated to true if the value of this expression is contained diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b72d92b9e2a83..8d3b56242ec5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ -import org.apache.spark.sql.catalyst.expressions.{In, InSet, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{InSet, Literal, NamedExpression} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -455,36 +455,25 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("isInCollection: Scala Collection") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - Seq(1, 2).foreach { conf => - withSQLConf(SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> conf.toString) { - if (conf <= 1) { - assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[InSet], "Expect expr to be InSet") - } else { - assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[In], "Expect expr to be In") - } - - // Test with different types of collections - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - - val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - val e = intercept[AnalysisException] { - df2.filter($"a".isInCollection(Seq($"b"))) - } - Seq("cannot resolve", - "due to data type mismatch: Arguments must be same type but were").foreach { s => - assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) - } - } + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b"))) } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } } test("&&") {