diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b0de3c85aaef8..7b903a3f7f148 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ private[sql] object Column { @@ -808,7 +809,14 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.4.0 */ - def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) + def isInCollection(values: scala.collection.Iterable[_]): Column = withExpr { + val hSet = values.toSet[Any] + if (hSet.size > SQLConf.get.optimizerInSetConversionThreshold) { + InSet(expr, hSet) + } else { + In(expr, values.toSeq.map(lit(_).expr)) + } + } /** * A boolean expression that is evaluated to true if the value of this expression is contained diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index a52c6d503d147..c346c8946a972 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ -import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.expressions.{In, InSet, NamedExpression} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -454,25 +454,36 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("isInCollection: Scala Collection") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - // Test with different types of collections - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + Seq(1, 2).foreach { conf => + withSQLConf(SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> conf.toString) { + if (conf <= 1) { + assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[InSet], "Expect expr to be InSet") + } else { + assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[In], "Expect expr to be In") + } - val e = intercept[AnalysisException] { - df2.filter($"a".isInCollection(Seq($"b"))) - } - Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") - .foreach { s => - assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b"))) + } + Seq("cannot resolve", + "due to data type mismatch: Arguments must be same type but were").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } } + } } test("&&") {