From 93130102651cf7cf46b1a4490dcbc8188f0ab1e7 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Wed, 29 Apr 2020 15:06:44 +0300 Subject: [PATCH 1/2] Add a test for isInCollection - collection element types --- .../spark/sql/ColumnExpressionSuite.scala | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 8d3b56242ec5a..1ec5a24ba1a25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.sql.Date +import java.sql.{Date, Timestamp} import java.util.Locale import scala.collection.JavaConverters._ @@ -476,6 +476,48 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-31553: isInCollection - collection element types") { + val expected = Seq(Row(true), Row(false)) + Seq(0, 1, 10).foreach { threshold => + withSQLConf(SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> threshold.toString) { + checkAnswer(Seq(0).toDS.select($"value".isInCollection(Seq(null))), Seq(Row(null))) + checkAnswer( + Seq(true).toDS.select($"value".isInCollection(Seq(true, false))), + Seq(Row(true))) + checkAnswer( + Seq(0.toByte, 1.toByte).toDS.select($"value".isInCollection(Seq(0.toByte, 2.toByte))), + expected) + checkAnswer( + Seq(0.toShort, 1.toShort).toDS.select($"value".isInCollection(Seq(0.toShort, 2.toShort))), + expected) + checkAnswer(Seq(0, 1).toDS.select($"value".isInCollection(Seq(0, 2))), expected) + checkAnswer(Seq(0L, 1L).toDS.select($"value".isInCollection(Seq(0L, 2L))), expected) + checkAnswer(Seq(0.0f, 1.0f).toDS.select($"value".isInCollection(Seq(0.0f, 2.0f))), expected) + checkAnswer(Seq(0.0D, 1.0D).toDS.select($"value".isInCollection(Seq(0.0D, 2.0D))), expected) + checkAnswer( + Seq(BigDecimal(0), BigDecimal(2)).toDS + .select($"value".isInCollection(Seq(BigDecimal(0), BigDecimal(1)))), + expected) + checkAnswer( + Seq("abc", "def").toDS.select($"value".isInCollection(Seq("abc", "xyz"))), + expected) + checkAnswer( + Seq(Date.valueOf("2020-04-29"), Date.valueOf("2020-05-01")).toDS + .select($"value".isInCollection( + Seq(Date.valueOf("2020-04-29"), Date.valueOf("2020-04-30")))), + expected) + checkAnswer( + Seq(new Timestamp(0), new Timestamp(2)).toDS + .select($"value".isInCollection(Seq(new Timestamp(0), new Timestamp(1)))), + expected) + checkAnswer( + Seq(Array("a", "b"), Array("c", "d")).toDS + .select($"value".isInCollection(Seq(Array("a", "b"), Array("x", "z")))), + expected) + } + } + } + test("&&") { checkAnswer( booleanData.filter($"a" && true), From 5bed4ede5d94f214123c1fc7a8bc57924efb1efa Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Wed, 29 Apr 2020 15:14:53 +0300 Subject: [PATCH 2/2] Test switch thresholds --- .../spark/sql/ColumnExpressionSuite.scala | 123 ++++++++++-------- 1 file changed, 69 insertions(+), 54 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 1ec5a24ba1a25..4bf19532edd94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -454,66 +454,81 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("isInCollection: Scala Collection") { - val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - // Test with different types of collections - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + Seq(0, 1, 10).foreach { optThreshold => + Seq(0, 1, 10).foreach { switchThreshold => + withSQLConf( + SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> optThreshold.toString, + SQLConf.OPTIMIZER_INSET_SWITCH_THRESHOLD.key -> switchThreshold.toString) { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - - val e = intercept[AnalysisException] { - df2.filter($"a".isInCollection(Seq($"b"))) - } - Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") - .foreach { s => - assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b"))) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } } + } } test("SPARK-31553: isInCollection - collection element types") { val expected = Seq(Row(true), Row(false)) - Seq(0, 1, 10).foreach { threshold => - withSQLConf(SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> threshold.toString) { - checkAnswer(Seq(0).toDS.select($"value".isInCollection(Seq(null))), Seq(Row(null))) - checkAnswer( - Seq(true).toDS.select($"value".isInCollection(Seq(true, false))), - Seq(Row(true))) - checkAnswer( - Seq(0.toByte, 1.toByte).toDS.select($"value".isInCollection(Seq(0.toByte, 2.toByte))), - expected) - checkAnswer( - Seq(0.toShort, 1.toShort).toDS.select($"value".isInCollection(Seq(0.toShort, 2.toShort))), - expected) - checkAnswer(Seq(0, 1).toDS.select($"value".isInCollection(Seq(0, 2))), expected) - checkAnswer(Seq(0L, 1L).toDS.select($"value".isInCollection(Seq(0L, 2L))), expected) - checkAnswer(Seq(0.0f, 1.0f).toDS.select($"value".isInCollection(Seq(0.0f, 2.0f))), expected) - checkAnswer(Seq(0.0D, 1.0D).toDS.select($"value".isInCollection(Seq(0.0D, 2.0D))), expected) - checkAnswer( - Seq(BigDecimal(0), BigDecimal(2)).toDS - .select($"value".isInCollection(Seq(BigDecimal(0), BigDecimal(1)))), - expected) - checkAnswer( - Seq("abc", "def").toDS.select($"value".isInCollection(Seq("abc", "xyz"))), - expected) - checkAnswer( - Seq(Date.valueOf("2020-04-29"), Date.valueOf("2020-05-01")).toDS - .select($"value".isInCollection( - Seq(Date.valueOf("2020-04-29"), Date.valueOf("2020-04-30")))), - expected) - checkAnswer( - Seq(new Timestamp(0), new Timestamp(2)).toDS - .select($"value".isInCollection(Seq(new Timestamp(0), new Timestamp(1)))), - expected) - checkAnswer( - Seq(Array("a", "b"), Array("c", "d")).toDS - .select($"value".isInCollection(Seq(Array("a", "b"), Array("x", "z")))), - expected) + Seq(0, 1, 10).foreach { optThreshold => + Seq(0, 1, 10).foreach { switchThreshold => + withSQLConf( + SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> optThreshold.toString, + SQLConf.OPTIMIZER_INSET_SWITCH_THRESHOLD.key -> switchThreshold.toString) { + checkAnswer(Seq(0).toDS.select($"value".isInCollection(Seq(null))), Seq(Row(null))) + checkAnswer( + Seq(true).toDS.select($"value".isInCollection(Seq(true, false))), + Seq(Row(true))) + checkAnswer( + Seq(0.toByte, 1.toByte).toDS.select($"value".isInCollection(Seq(0.toByte, 2.toByte))), + expected) + checkAnswer( + Seq(0.toShort, 1.toShort).toDS + .select($"value".isInCollection(Seq(0.toShort, 2.toShort))), + expected) + checkAnswer(Seq(0, 1).toDS.select($"value".isInCollection(Seq(0, 2))), expected) + checkAnswer(Seq(0L, 1L).toDS.select($"value".isInCollection(Seq(0L, 2L))), expected) + checkAnswer(Seq(0.0f, 1.0f).toDS + .select($"value".isInCollection(Seq(0.0f, 2.0f))), expected) + checkAnswer(Seq(0.0D, 1.0D).toDS + .select($"value".isInCollection(Seq(0.0D, 2.0D))), expected) + checkAnswer( + Seq(BigDecimal(0), BigDecimal(2)).toDS + .select($"value".isInCollection(Seq(BigDecimal(0), BigDecimal(1)))), + expected) + checkAnswer( + Seq("abc", "def").toDS.select($"value".isInCollection(Seq("abc", "xyz"))), + expected) + checkAnswer( + Seq(Date.valueOf("2020-04-29"), Date.valueOf("2020-05-01")).toDS + .select($"value".isInCollection( + Seq(Date.valueOf("2020-04-29"), Date.valueOf("2020-04-30")))), + expected) + checkAnswer( + Seq(new Timestamp(0), new Timestamp(2)).toDS + .select($"value".isInCollection(Seq(new Timestamp(0), new Timestamp(1)))), + expected) + checkAnswer( + Seq(Array("a", "b"), Array("c", "d")).toDS + .select($"value".isInCollection(Seq(Array("a", "b"), Array("x", "z")))), + expected) + } } } }