From affbf03d028eb521d187823caaf1522f8ce6d714 Mon Sep 17 00:00:00 2001 From: GideonPotok Date: Sat, 13 Apr 2024 09:07:19 +0300 Subject: [PATCH 1/5] right impl --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9c862581bfe47..c5b2ec8c06b0f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1987,10 +1987,10 @@ case class Right(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = If( IsNull(str), - Literal(null, StringType), + Literal(null, str.dataType), If( LessThanOrEqual(len, Literal(0)), - Literal(UTF8String.EMPTY_UTF8, StringType), + Literal(UTF8String.EMPTY_UTF8, str.dataType), new Substring(str, UnaryMinus(len)) ) ) From f5f0d35bfe5ff2bee51e67e41bcebccf9423ad09 Mon Sep 17 00:00:00 2001 From: GideonPotok Date: Sat, 13 Apr 2024 09:09:25 +0300 Subject: [PATCH 2/5] left impl --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c5b2ec8c06b0f..bddaf574cd3ab 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1995,7 +1995,7 @@ case class Right(str: Expression, len: Expression) extends RuntimeReplaceable ) ) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, IntegerType) override def left: Expression = str override def right: Expression = len override protected def withNewChildrenInternal( @@ -2026,7 +2026,7 @@ case class Left(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = Substring(str, Literal(1), len) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(StringType, BinaryType), IntegerType) + Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType) } override def left: Expression = str From 9e4ac4844fd2aa4a335768ca2545127e29a4ce0d Mon Sep 17 00:00:00 2001 From: GideonPotok Date: Sat, 13 Apr 2024 09:35:12 +0300 Subject: [PATCH 3/5] test impl to make more in line with refactor. next add struct test, maybe. --- .../sql/CollationStringExpressionsSuite.scala | 136 +++++++++++++++++- 1 file changed, 134 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 97dea66975410..99d6beb909168 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql import scala.collection.immutable.Seq - import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SQLConf, SqlApiConf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{BooleanType, StringType} @@ -163,6 +162,139 @@ class CollationStringExpressionsSuite }) } + test("substring check output type on explicitly collated string") { + case class SubstringTestCase[R](args: Seq[String], collation: String, result: R) + val checks = Seq( + SubstringTestCase(Seq("Spark", "2"), "UTF8_BINARY", "park"), + SubstringTestCase(Seq("Spark", "2"), "UTF8_BINARY_LCASE", "park") + ) + checks.foreach(ct => { + val query = s"SELECT substr(collate('${ct.args.head}', '${ct.collation}')," + + s" ${ct.args.tail.head})" + // Result & data type + checkAnswer(sql(query), Row(ct.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(ct.collation))) + }) + } + + test("left/right/substr on collated proper string returns proper value") { // scalastyle:ignore + case class QTestCase(query: String, collation: String, result: Row) + val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( + c => Seq( + QTestCase("select left('abc' collate " + c + ", 1)", c, Row("a")), + QTestCase("select right('def' collate " + c + ", 1)", c, Row("f")), + QTestCase("select substr('abc' collate " + c + ", 2)", c, Row("bc")), + QTestCase("select substr('example' collate " + c + ", 0, 2)", c, Row("ex")), + QTestCase("select substr('example' collate " + c + ", 1, 2)", c, Row("ex")), + QTestCase("select substr('example' collate " + c + ", 0, 7)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 1, 7)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 0, 100)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 1, 100)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 2, 2)", c, Row("xa")), + QTestCase("select substr('example' collate " + c + ", 1, 6)", c, Row("exampl")), + QTestCase("select substr('example' collate " + c + ", 2, 100)", c, Row("xample")), + QTestCase("select substr('example' collate " + c + ", 0, 0)", c, Row("")), + QTestCase("select substr('example' collate " + c + ", 100, 4)", c, Row("")), + QTestCase("select substr('example' collate " + c + ", 0, 100)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 1, 100)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 2, 100)", c, Row("xample")), + QTestCase("select substr('example' collate " + c + ", -3, 2)", c, Row("pl")), + QTestCase("select substr('example' collate " + c + ", -100, 4)", c, Row("")), + QTestCase("select substr('example' collate " + c + ", -2147483648, 6)", c, Row("")), + QTestCase("select substr(' a世a ' collate " + c + ", 2, 3)", c, Row("a世a")), // scalastyle:ignore + QTestCase("select left(' a世a ' collate " + c + ", 3)", c, Row(" a世")), // scalastyle:ignore + QTestCase("select right(' a世a ' collate " + c + ", 3)", c, Row("世a ")), // scalastyle:ignore + QTestCase("select substr('AaAaAaAa000000' collate " + c + ", 2, 3)", c, Row("aAa")), + QTestCase("select left('AaAaAaAa000000' collate " + c + ", 3)", c, Row("AaA")), + QTestCase("select right('AaAaAaAa000000' collate " + c + ", 3)", c, Row("000")), + QTestCase("select substr('' collate " + c + ", 1, 1)", c, Row("")), + QTestCase("select left('' collate " + c + ", 1)", c, Row("")), + QTestCase("select right('' collate " + c + ", 1)", c, Row("")), + QTestCase("select left('ghi' collate " + c + ", 1)", c, Row("g")) + ) + ) + + checks.foreach { check => + // Result & data type + checkAnswer(sql(check.query), Row(check.result)) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + } + } + + test("left/right/substr on collated improper string returns proper value") { + case class QTestCase(query: String, collation: String, result: Row) + val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( + c => Seq( + QTestCase("select left(null collate " + c + ", 1)", c, Row(null)), + QTestCase("select right(null collate " + c + ", 1)", c, Row(null)), + QTestCase("select substr(null collate " + c + ", 1)", c, Row(null)), + QTestCase("select substr(null collate " + c + ", 1, 1)", c, Row(null)), + QTestCase("select left('' collate " + c + ", null)", c, Row(null)), + QTestCase("select right('' collate " + c + ", null)", c, Row(null)), + QTestCase("select substr('' collate " + c + ", null)", c, Row(null)), + QTestCase("select substr('' collate " + c + ", null, null)", c, Row(null)) + ) + ) + checks.foreach(check => { + // Result & data type + checkAnswer(sql(check.query), Row(check.result)) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + }) + } + + test("left/right/substr on collated improper length and position returns proper value") { + case class QTestCase(query: String, collation: String, result: Row) + val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( + c => Seq( + QTestCase("select left(' a世a ' collate " + c + ", '3')", c, Row(" a世")), // scalastyle:ignore + QTestCase("select right(' a世a ' collate " + c + ", '3')", c, Row("世a ")), // scalastyle:ignore + QTestCase("select right('' collate " + c + ", null)", c, Row(null)), + QTestCase("select substr('' collate " + c + ", null)", c, Row(null)), + QTestCase("select substr('' collate " + c + ", null, null)", c, Row(null)), + QTestCase("select left('' collate " + c + ", null)", c, Row(null)) + ) + ) + checks.foreach(check => { + // Result & data type + checkAnswer(sql(check.query), Row(check.result)) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + }) + } + + test("left/right/substr on session-collated string returns proper type") { + Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").foreach { collationName => + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + assert(sql("select left('abc', 1)") + .schema.fields.head.dataType.sameType(StringType(collationName))) + assert(sql("select right('def', 1)") + .schema.fields.head.dataType.sameType(StringType(collationName))) + assert(sql("select substr('ghi', 1)") + .schema.fields.head.dataType.sameType(StringType(collationName))) + } + } + } + + test("left/right/substr on collated improper string returns proper type") { + Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").foreach { collationName => + assert(sql(s"select left(null collate $collationName, 1)") + .schema.fields.head.dataType.sameType(StringType(collationName))) + assert(sql(s"select right(null collate $collationName, 1)") + .schema.fields.head.dataType.sameType(StringType(collationName))) + assert(sql(s"select substr(null collate $collationName, 1)") + .schema.fields.head.dataType.sameType(StringType(collationName))) + } + } + test("left/right/substr on collated proper string returns proper type") { + Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").foreach { collationName => + assert(sql(s"select left('hij' collate $collationName, 1)") + .schema.fields.head.dataType.sameType(StringType(collationName))) + assert(sql(s"select right('klm' collate $collationName, 1)") + .schema.fields.head.dataType.sameType(StringType(collationName))) + assert(sql(s"select substr('nop' collate $collationName, 1)") + .schema.fields.head.dataType.sameType(StringType(collationName))) + } + } + // TODO: Add more tests for other string expressions } From 1911b9f56b2566874f775ae582ccded8b52d74c5 Mon Sep 17 00:00:00 2001 From: GideonPotok Date: Sat, 13 Apr 2024 09:35:46 +0300 Subject: [PATCH 4/5] format --- .../org/apache/spark/sql/CollationStringExpressionsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 99d6beb909168..4ad6b996c90e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql import scala.collection.immutable.Seq + import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper -import org.apache.spark.sql.internal.{SQLConf, SqlApiConf} +import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{BooleanType, StringType} From ea6d192e39a3e0d581b4d1e427633fd50f8f64ec Mon Sep 17 00:00:00 2001 From: GideonPotok Date: Sat, 13 Apr 2024 10:41:07 +0300 Subject: [PATCH 5/5] tests pass locally --- .../sql/CollationStringExpressionsSuite.scala | 79 ++++++++++--------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 4ad6b996c90e1..dde97751d0135 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -178,7 +178,27 @@ class CollationStringExpressionsSuite }) } - test("left/right/substr on collated proper string returns proper value") { // scalastyle:ignore + test("left/right/substr on implicitly collated string returns proper value and type") { + case class QTestCase(query: String, collation: String, result: Row) + val longString = "In the course of human events" + val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( + c => Seq( + QTestCase(s"select left(left('$longString' collate " + c + ", 5), 1)", c, Row("I")), + QTestCase(s"select right(right('$longString' collate " + c + ", 5), 1)", c, Row("s")), + QTestCase( + s"select substr(substr('$longString' collate " + c + ", 4), 2)", c, + Row("he course of human events")) + ) + ) + + checks.foreach { check => + // Result & data type + checkAnswer(sql(check.query), check.result) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + } + } + + test("left/right/substr on explicitly collated proper string returns proper value and type") { case class QTestCase(query: String, collation: String, result: Row) val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( c => Seq( @@ -217,12 +237,12 @@ class CollationStringExpressionsSuite checks.foreach { check => // Result & data type - checkAnswer(sql(check.query), Row(check.result)) + checkAnswer(sql(check.query), check.result) assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) } } - test("left/right/substr on collated improper string returns proper value") { + test("left/right/substr on explicitly collated improper string returns proper value and type") { case class QTestCase(query: String, collation: String, result: Row) val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( c => Seq( @@ -238,12 +258,12 @@ class CollationStringExpressionsSuite ) checks.foreach(check => { // Result & data type - checkAnswer(sql(check.query), Row(check.result)) + checkAnswer(sql(check.query), check.result) assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) }) } - test("left/right/substr on collated improper length and position returns proper value") { + test("left/right/substr on explicitly collated improper length & position") { case class QTestCase(query: String, collation: String, result: Row) val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( c => Seq( @@ -257,43 +277,28 @@ class CollationStringExpressionsSuite ) checks.foreach(check => { // Result & data type - checkAnswer(sql(check.query), Row(check.result)) + checkAnswer(sql(check.query), check.result) assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) }) } - test("left/right/substr on session-collated string returns proper type") { - Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").foreach { collationName => - withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { - assert(sql("select left('abc', 1)") - .schema.fields.head.dataType.sameType(StringType(collationName))) - assert(sql("select right('def', 1)") - .schema.fields.head.dataType.sameType(StringType(collationName))) - assert(sql("select substr('ghi', 1)") - .schema.fields.head.dataType.sameType(StringType(collationName))) + test("left/right/substr on session-collated string returns proper and value type") { + case class QTestCase(query: String, collation: String, result: Row) + val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci") + .flatMap { c => + Seq( + QTestCase("select left('abc', 1)", c, Row("a")), + QTestCase("select right('def', 1)", c, Row("f")), + QTestCase("select substr('ghi', 1)", c, Row("ghi")) + ) } - } - } - - test("left/right/substr on collated improper string returns proper type") { - Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").foreach { collationName => - assert(sql(s"select left(null collate $collationName, 1)") - .schema.fields.head.dataType.sameType(StringType(collationName))) - assert(sql(s"select right(null collate $collationName, 1)") - .schema.fields.head.dataType.sameType(StringType(collationName))) - assert(sql(s"select substr(null collate $collationName, 1)") - .schema.fields.head.dataType.sameType(StringType(collationName))) - } - } - test("left/right/substr on collated proper string returns proper type") { - Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").foreach { collationName => - assert(sql(s"select left('hij' collate $collationName, 1)") - .schema.fields.head.dataType.sameType(StringType(collationName))) - assert(sql(s"select right('klm' collate $collationName, 1)") - .schema.fields.head.dataType.sameType(StringType(collationName))) - assert(sql(s"select substr('nop' collate $collationName, 1)") - .schema.fields.head.dataType.sameType(StringType(collationName))) - } + checks.foreach(check => { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> check.collation) { + // Result & data type + checkAnswer(sql(check.query), check.result) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + } + }) } // TODO: Add more tests for other string expressions