From 7306a5c3a72d5573ad22cf5efb995893fb2a921e Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Fri, 24 Apr 2020 14:30:43 +0300 Subject: [PATCH 1/8] Add a test --- .../scala/org/apache/spark/sql/ColumnExpressionSuite.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index a9ee25b10dc02..1afc86d03344a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -869,4 +869,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { df.select(typedLit(("a", 2, 1.0))), Row(Row("a", 2, 1.0)) :: Nil) } + + test("fix in set") { + val set = (0 to 20).map(_.toString).toSet + val data = Seq("1").toDF("x") + assert(set.contains("1")) + checkAnswer(data.select($"x".isInCollection(set)), Row(true)) + } } From 08d107df6e10fe7a546f07ea7bdd001fe1df3858 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Fri, 24 Apr 2020 14:31:03 +0300 Subject: [PATCH 2/8] Bug fix --- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 49c9f830fb27e..6597d3190b65c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -828,11 +828,11 @@ class Column(val expr: Expression) extends Logging { * @since 2.4.0 */ def isInCollection(values: scala.collection.Iterable[_]): Column = withExpr { - val hSet = values.toSet[Any] - if (hSet.size > SQLConf.get.optimizerInSetConversionThreshold) { - InSet(expr, hSet) + val exprValues = values.toSeq.map(lit(_).expr) + if (exprValues.size > SQLConf.get.optimizerInSetConversionThreshold) { + InSet(expr, exprValues.map(_.eval()).toSet) } else { - In(expr, values.toSeq.map(lit(_).expr)) + In(expr, exprValues) } } From 7b200bf1d1367846235b1fda423df54dbc41b0a7 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Fri, 24 Apr 2020 14:31:43 +0300 Subject: [PATCH 3/8] Fix the sql() method of InSet --- .../apache/spark/sql/catalyst/expressions/predicates.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index bd190c3e5abc7..ee616d0969e64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.immutable.TreeSet -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -519,7 +519,9 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with override def sql: String = { val valueSQL = child.sql - val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ") + val listSQL = hset.toSeq + .map(CatalystTypeConverters.convertToScala(_, child.dataType)) + .mkString(", ") s"($valueSQL IN ($listSQL))" } } From 47a1e447fa3aa35285d6beaa8038cefec626198d Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Fri, 24 Apr 2020 14:48:39 +0300 Subject: [PATCH 4/8] Add JIRA id --- .../apache/spark/sql/ColumnExpressionSuite.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 1afc86d03344a..dde8289260cab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -870,10 +870,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { Row(Row("a", 2, 1.0)) :: Nil) } - test("fix in set") { - val set = (0 to 20).map(_.toString).toSet - val data = Seq("1").toDF("x") - assert(set.contains("1")) - checkAnswer(data.select($"x".isInCollection(set)), Row(true)) + test("SPARK-31553: isInCollection for collection sizes above a threshold") { + val threshold = 100 + withSQLConf(SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> threshold.toString) { + val set = (0 until 2 * threshold).map(_.toString).toSet + val elem = "10" + val data = Seq(elem).toDF("x") + assert(set.contains(elem)) + checkAnswer(data.select($"x".isInCollection(set)), Row(true)) + } } } From 67f34a10a6bafe3a7da9480fc3f79cd041cb12b0 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Sat, 25 Apr 2020 23:49:44 +0300 Subject: [PATCH 5/8] Fix NullType --- .../spark/sql/catalyst/expressions/predicates.scala | 8 ++++++-- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ac492cf227301..d98805d9cece3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -426,7 +426,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { * Optimized version of In clause, when all filter values of In clause are * static. */ -case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate { +case class InSet( + child: Expression, + hset: Set[Any], + hsetElemType: Option[DataType] = None) extends UnaryExpression with Predicate { require(hset != null, "hset could not be null") @@ -520,8 +523,9 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with override def sql: String = { val valueSQL = child.sql + val elemType = hsetElemType.getOrElse(child.dataType) val listSQL = hset.toSeq - .map(elem => Literal(convertToScala(elem, child.dataType)).sql) + .map(elem => Literal(convertToScala(elem, elemType)).sql) .mkString(", ") s"($valueSQL IN ($listSQL))" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 6597d3190b65c..ceac13b1025b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -830,7 +830,7 @@ class Column(val expr: Expression) extends Logging { def isInCollection(values: scala.collection.Iterable[_]): Column = withExpr { val exprValues = values.toSeq.map(lit(_).expr) if (exprValues.size > SQLConf.get.optimizerInSetConversionThreshold) { - InSet(expr, exprValues.map(_.eval()).toSet) + InSet(expr, exprValues.map(_.eval()).toSet, exprValues.headOption.map(_.dataType)) } else { In(expr, exprValues) } From dd69aa6ab568ad7a48550e4b856bf10df78020f7 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Sun, 26 Apr 2020 00:19:49 +0300 Subject: [PATCH 6/8] Fix build --- .../plans/logical/statsEstimation/FilterEstimation.scala | 2 +- .../spark/sql/execution/datasources/DataSourceStrategy.scala | 2 +- .../spark/sql/execution/datasources/FileSourceStrategy.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/client/HiveShim.scala | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 2c5beef43f52a..0bdccef9654dd 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -172,7 +172,7 @@ case class FilterEstimation(plan: Filter) extends Logging { val hSet = expList.map(e => e.eval()) evaluateInSet(ar, HashSet() ++ hSet, update) - case InSet(ar: Attribute, set) => + case InSet(ar: Attribute, set, _) => evaluateInSet(ar, set, update) // In current stage, we don't have advanced statistics such as sketches or histograms. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index a58038d127818..c01cb1933cf1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -479,7 +479,7 @@ object DataSourceStrategy { case expressions.LessThanOrEqual(Literal(v, t), PushableColumn(name)) => Some(sources.GreaterThanOrEqual(name, convertToScala(v, t))) - case expressions.InSet(e @ PushableColumn(name), set) => + case expressions.InSet(e @ PushableColumn(name), set, _) => val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) Some(sources.In(name, set.toArray.map(toScala))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index f45495121a980..57ecfebbb5549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -89,7 +89,7 @@ object FileSourceStrategy extends Strategy with Logging { case expressions.In(a: Attribute, list) if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) - case expressions.InSet(a: Attribute, hset) + case expressions.InSet(a: Attribute, hset, _) if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow))) case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 2b806609426a1..c1446ec963424 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -740,7 +740,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { if useAdvanced => Some(convertInToOr(name, values)) - case InSet(ExtractAttribute(SupportedAttribute(name)), ExtractableValues(values)) + case InSet(ExtractAttribute(SupportedAttribute(name)), ExtractableValues(values), _) if useAdvanced => Some(convertInToOr(name, values)) From 05ce50ac1d9a45367975c8cb6fb153b07de48d0d Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Sun, 26 Apr 2020 11:41:20 +0300 Subject: [PATCH 7/8] Require elem type --- .../sql/catalyst/expressions/predicates.scala | 11 +++++----- .../sql/catalyst/optimizer/expressions.scala | 2 +- .../catalyst/expressions/PredicateSuite.scala | 22 ++++++++++--------- .../catalyst/optimizer/OptimizeInSuite.scala | 2 +- .../FilterEstimationSuite.scala | 13 ++++++----- .../scala/org/apache/spark/sql/Column.scala | 3 ++- .../apache/spark/sql/execution/subquery.scala | 2 +- .../spark/sql/ColumnExpressionSuite.scala | 2 +- .../datasources/DataSourceStrategySuite.scala | 4 +++- .../spark/sql/sources/BucketedReadSuite.scala | 7 ++++-- .../client/HivePartitionFilteringSuite.scala | 6 ++--- 11 files changed, 41 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d98805d9cece3..6222a0add2f41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -429,7 +429,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { case class InSet( child: Expression, hset: Set[Any], - hsetElemType: Option[DataType] = None) extends UnaryExpression with Predicate { + hsetElemType: DataType) extends UnaryExpression with Predicate { require(hset != null, "hset could not be null") @@ -449,12 +449,12 @@ case class InSet( } } - @transient lazy val set: Set[Any] = child.dataType match { + @transient lazy val set: Set[Any] = hsetElemType match { case t: AtomicType if !t.isInstanceOf[BinaryType] => hset case _: NullType => hset case _ => // for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows - TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ (hset - null) + TreeSet.empty(TypeUtils.getInterpretedOrdering(hsetElemType)) ++ (hset - null) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -465,7 +465,7 @@ case class InSet( } } - private def canBeComputedUsingSwitch: Boolean = child.dataType match { + private def canBeComputedUsingSwitch: Boolean = hsetElemType match { case ByteType | ShortType | IntegerType | DateType => true case _ => false } @@ -523,9 +523,8 @@ case class InSet( override def sql: String = { val valueSQL = child.sql - val elemType = hsetElemType.getOrElse(child.dataType) val listSQL = hset.toSeq - .map(elem => Literal(convertToScala(elem, elemType)).sql) + .map(elem => Literal(convertToScala(elem, hsetElemType)).sql) .mkString(", ") s"($valueSQL IN ($listSQL))" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index bd400f86ea2c1..96af221d3fe1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -251,7 +251,7 @@ object OptimizeIn extends Rule[LogicalPlan] { EqualTo(v, newList.head) } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) - InSet(v, HashSet() ++ hSet) + InSet(v, HashSet() ++ hSet, v.dataType) } else if (newList.length < list.length) { expr.copy(list = newList) } else { // newList.length == list.length && newList.length > 1 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1ad0a8ed758f4..8fd9825f62622 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -130,7 +130,9 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { private def checkInAndInSet(in: In, expected: Any): Unit = { // expecting all in.list are Literal or NonFoldableLiteral. checkEvaluation(in, expected) - checkEvaluation(InSet(in.value, HashSet() ++ in.list.map(_.eval())), expected) + checkEvaluation( + InSet(in.value, HashSet() ++ in.list.map(_.eval()), in.value.dataType), + expected) } test("basic IN/INSET predicate test") { @@ -154,7 +156,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(2)))), true) checkEvaluation( - And(InSet(Literal(1), HashSet(1, 2)), InSet(Literal(2), Set(1, 2))), + And(InSet(Literal(1), HashSet(1, 2), IntegerType), InSet(Literal(2), Set(1, 2), IntegerType)), true) val ns = NonFoldableLiteral.create(null, StringType) @@ -256,12 +258,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { val nullLiteral = Literal(null, presentValue.dataType) - checkEvaluation(InSet(nullLiteral, values), expected = null) - checkEvaluation(InSet(nullLiteral, values + null), expected = null) - checkEvaluation(InSet(presentValue, values), expected = true) - checkEvaluation(InSet(presentValue, values + null), expected = true) - checkEvaluation(InSet(absentValue, values), expected = false) - checkEvaluation(InSet(absentValue, values + null), expected = null) + checkEvaluation(InSet(nullLiteral, values, nullLiteral.dataType), expected = null) + checkEvaluation(InSet(nullLiteral, values + null, nullLiteral.dataType), expected = null) + checkEvaluation(InSet(presentValue, values, presentValue.dataType), expected = true) + checkEvaluation(InSet(presentValue, values + null, presentValue.dataType), expected = true) + checkEvaluation(InSet(absentValue, values, absentValue.dataType), expected = false) + checkEvaluation(InSet(absentValue, values + null, absentValue.dataType), expected = null) } def checkAllTypes(): Unit = { @@ -498,7 +500,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22693: InSet should not use global variables") { val ctx = new CodegenContext - InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) + InSet(Literal(1), Set(1, 2, 3, 4), IntegerType).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } @@ -535,7 +537,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-29100: InSet with empty input set") { val row = create_row(1) - val inSet = InSet(BoundReference(0, IntegerType, true), Set.empty) + val inSet = InSet(BoundReference(0, IntegerType, true), Set.empty, IntegerType) checkEvaluation(inSet, false, row) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index a36083b847043..0fae01d1c8cac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -85,7 +85,7 @@ class OptimizeInSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet)) + .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet, IntegerType)) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 1cf888519077a..75ee510860e4c 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{ColumnStatsM import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * In this test suite, we test predicates containing the following operators: @@ -352,7 +353,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IN (3, 4, 5)") { validateEstimatedStats( - Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), + Filter(InSet(attrInt, Set(3, 4, 5), IntegerType), childStatsTestPlan(Seq(attrInt), 10L)), Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) @@ -360,7 +361,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("evaluateInSet with all zeros") { validateEstimatedStats( - Filter(InSet(attrString, Set(3, 4, 5)), + Filter(InSet(attrString, Set(3, 4, 5), IntegerType), StatsTestPlan(Seq(attrString), 0, AttributeMap(Seq(attrString -> ColumnStat(distinctCount = Some(0), min = None, max = None, @@ -371,7 +372,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("evaluateInSet with string") { validateEstimatedStats( - Filter(InSet(attrString, Set("A0")), + Filter(InSet(attrString, Set(UTF8String.fromString("A0")), StringType), StatsTestPlan(Seq(attrString), 10, AttributeMap(Seq(attrString -> ColumnStat(distinctCount = Some(10), min = None, max = None, @@ -383,14 +384,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( - Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), + Filter(Not(InSet(attrInt, Set(3, 4, 5), IntegerType)), childStatsTestPlan(Seq(attrInt), 10L)), Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cbool IN (true)") { validateEstimatedStats( - Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Filter(InSet(attrBool, Set(true), BooleanType), childStatsTestPlan(Seq(attrBool), 10L)), Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) @@ -510,7 +511,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt)) ) validateEstimatedStats( - Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + Filter(InSet(attrInt, Set(1, 2, 3, 4, 5), IntegerType), cornerChildStatsTestplan), Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ceac13b1025b7..c4f085c21bf73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -830,7 +830,8 @@ class Column(val expr: Expression) extends Logging { def isInCollection(values: scala.collection.Iterable[_]): Column = withExpr { val exprValues = values.toSeq.map(lit(_).expr) if (exprValues.size > SQLConf.get.optimizerInSetConversionThreshold) { - InSet(expr, exprValues.map(_.eval()).toSet, exprValues.headOption.map(_.dataType)) + val elemType = exprValues.headOption.map(_.dataType).getOrElse(NullType) + InSet(expr, exprValues.map(_.eval()).toSet, elemType) } else { In(expr, exprValues) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index c2270c57eb941..d1b45a239663c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -159,7 +159,7 @@ case class InSubqueryExec( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { prepareResult() - InSet(child, result.toSet).doGenCode(ctx, ev) + InSet(child, result.toSet, child.dataType).doGenCode(ctx, ev) } override lazy val canonicalized: InSubqueryExec = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 04157e7cd3077..4a94c29b7109e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -872,7 +872,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("SPARK-31563: sql of InSet for UTF8String collection") { - val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString)) + val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString), StringType) assert(inSet.sql === "('a' IN ('a', 'b'))") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index a775a97895cfc..385bd9a401991 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -110,7 +110,9 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { testTranslateFilter(LessThanOrEqual(1, attrInt), Some(sources.GreaterThanOrEqual(intColName, 1))) - testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In(intColName, Array(1, 2, 3)))) + testTranslateFilter( + InSet(attrInt, Set(1, 2, 3), IntegerType), + Some(sources.In(intColName, Array(1, 2, 3)))) testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In(intColName, Array(1, 2, 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index e153c7168dbf2..7428b6504fe5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} +import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils import org.apache.spark.util.collection.BitSet @@ -188,8 +189,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { df) // Case 4: InSet - val inSetExpr = expressions.InSet($"j".expr, - Set(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3).map(lit(_).expr)) + val inSetExpr = expressions.InSet( + $"j".expr, + Set(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3).map(lit(_).expr), + IntegerType) checkPrunedAnswers( bucketSpec, bucketValues = Seq(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala index 2d615f6fdc261..800f8b3ffa902 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala @@ -213,7 +213,7 @@ class HivePartitionFilteringSuite(version: String) 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { case expr @ In(v, list) if expr.inSetConvertible => - InSet(v, list.map(_.eval(EmptyRow)).toSet) + InSet(v, list.map(_.eval(EmptyRow)).toSet, v.dataType) }) } @@ -225,7 +225,7 @@ class HivePartitionFilteringSuite(version: String) 0 to 4, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { case expr @ In(v, list) if expr.inSetConvertible => - InSet(v, list.map(_.eval(EmptyRow)).toSet) + InSet(v, list.map(_.eval(EmptyRow)).toSet, v.dataType) }) } @@ -244,7 +244,7 @@ class HivePartitionFilteringSuite(version: String) 0 to 4, "ab" :: "ba" :: Nil, { case expr @ In(v, list) if expr.inSetConvertible => - InSet(v, list.map(_.eval(EmptyRow)).toSet) + InSet(v, list.map(_.eval(EmptyRow)).toSet, v.dataType) }) } From 4bc0e269df9320e9bb9244afb58b6d1fbbf0e95e Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Mon, 27 Apr 2020 13:36:35 +0300 Subject: [PATCH 8/8] Check input types of InSet --- .../spark/sql/catalyst/expressions/predicates.scala | 9 +++++++++ .../org/apache/spark/sql/ColumnExpressionSuite.scala | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 6222a0add2f41..4a02ab9f2705e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -433,6 +433,15 @@ case class InSet( require(hset != null, "hset could not be null") + override def checkInputDataTypes(): TypeCheckResult = { + if (!DataType.equalsStructurally(child.dataType, hsetElemType, ignoreNullability = true)) { + TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + + s"${child.dataType.catalogString} != ${hsetElemType.catalogString}") + } else { + TypeUtils.checkForOrderingExpr(child.dataType, s"function $prettyName") + } + } + override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" @transient private[this] lazy val hasNull: Boolean = hset.contains(null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 4a94c29b7109e..5966785813ee7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -483,6 +483,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "due to data type mismatch: Arguments must be same type but were").foreach { s => assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } + val errMsg = intercept[AnalysisException] { + df.select($"a".isInCollection(Seq(0, 1).map(new java.sql.Timestamp(_)))).collect() + }.getMessage + assert(errMsg.contains("Arguments must be same type")) } } }