From a4d35923262b7dbefd34e49ae834ab6bbf0e925f Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 21 Mar 2024 15:04:07 +0100 Subject: [PATCH 01/46] Add support for instr and unit test in CollationStringExpressionsSuite.scala --- .../apache/spark/unsafe/types/UTF8String.java | 27 +++++++++++ .../expressions/stringExpressions.scala | 25 ++++++++-- .../sql/CollationStringExpressionsSuite.scala | 47 +++++++++++++++++++ 3 files changed, 95 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 6abc8385da5ab..551d0e62fe8a3 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -835,6 +835,33 @@ public int indexOf(UTF8String v, int start) { return -1; } + public int indexOf(UTF8String substring, int start, int collationId) { + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + return this.indexOf(substring, start); + } + if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) { + return this.toLowerCase().indexOf(substring.toLowerCase(), start); + } + return collatedIndexOf(substring, collationId); + } + + private int collatedIndexOf(UTF8String substring, int collationId) { + if (substring.numBytes == 0) { + return 0; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(this, substring, collationId); + + int pos = 0; + while ((pos = stringSearch.next()) != StringSearch.DONE) { + if (stringSearch.getMatchLength() == stringSearch.getPattern().length()) { + return pos; + } + } + + return 0; + } + /** * Find the `str` from left to right. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 742db0ed5a474..b16061f8f4a9c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1377,17 +1377,34 @@ case class StringInstr(str: Expression, substr: Expression) override def left: Expression = str override def right: Expression = substr override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def nullSafeEval(string: Any, sub: Any): Any = { - string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1 + val collationId = left.dataType.asInstanceOf[StringType].collationId + string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0, collationId) + 1 + } + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + return defaultCheck + } + + val collationId = left.dataType.asInstanceOf[StringType].collationId + CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType)) } override def prettyName: String = "instr" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (l, r) => - s"($l).indexOf($r, 0) + 1") + val collationId = left.dataType.asInstanceOf[StringType].collationId + + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") + } else { + defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0, $collationId) + 1") + } } override protected def withNewChildrenInternal( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 04f3781a92cf3..705dd4a184401 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.test.SharedSparkSession class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession { case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) + case class CollationTestFail[R](s1: String, s2: String, collation: String) test("Support ConcatWs string expression with Collation") { @@ -70,6 +71,52 @@ class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession }) } + case class SubstringIndexTestFail[R](s1: String, s2: String, c1: String, c2: String) + + test("Support SubstringIndex with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("The quick brown fox jumps over the dog.", "fox", "UTF8_BINARY", 17), + CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UTF8_BINARY", 0), + CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UTF8_BINARY_LCASE", 17), + CollationTestCase("The quick brown fox jumps over the dog.", "fox", "UNICODE", 17), + CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UNICODE", 0), + CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UNICODE_CI", 17) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT instr(collate('${ct.s1}', '${ct.collation}'), " + + s"collate('${ct.s2}', '${ct.collation}'))"), + Row(ct.expectedResult)) + }) + // Unsupported collation pairs + val fails = Seq( + SubstringIndexTestFail("The quick brown fox jumps over the dog.", + "Fox", "UTF8_BINARY_LCASE", "UTF8_BINARY"), + SubstringIndexTestFail("The quick brown fox jumps over the dog.", + "FOX", "UNICODE_CI", "UNICODE") + ) + fails.foreach(ct => { + val expr = s"instr(collate('${ct.s1}', '${ct.c1}'), collate('${ct.s2}', '${ct.c2}'))" + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT $expr") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"instr(collate(${ct.s1}), collate(${ct.s2}))\"", + "collationNameLeft" -> s"${ct.c1}", + "collationNameRight" -> s"${ct.c2}" + ), + context = ExpectedContext( + fragment = s"$expr", + start = 7, + stop = 45 + ct.s1.length + ct.c1.length + ct.s2.length + ct.c2.length + ) + ) + }) + } + // TODO: Add more tests for other string expressions } From eb2d7c532f6d014d1bd240d38d9533c9e8bb23a0 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 21 Mar 2024 16:55:26 +0100 Subject: [PATCH 02/46] Correct code style --- .../apache/spark/unsafe/types/UTF8String.java | 43 +++++++++++++++++- .../expressions/stringExpressions.scala | 19 ++++++-- .../sql/CollationStringExpressionsSuite.scala | 45 +++++++++++++++++++ 3 files changed, 103 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 551d0e62fe8a3..98becda285a82 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -549,7 +549,48 @@ public int findInSet(UTF8String match) { return 0; } - /** + public int findInSet(UTF8String match, int collationId) { + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + return this.findInSet(match); + } + if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) { + return this.toLowerCase().findInSet(match.toLowerCase()); + } + return collatedFindInSet(match, collationId); +} + + private int collatedFindInSet(UTF8String match, int collationId) { + if (match.contains(COMMA_UTF8)) { + return 0; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(this, match, collationId); + + String setString = this.toString(); + int wordStart = 0; + while ((wordStart = stringSearch.next()) != StringSearch.DONE) { + if (stringSearch.getMatchLength() == stringSearch.getPattern().length()) { + boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ','; + boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length() + || setString.charAt(wordStart + stringSearch.getMatchLength()) == ','; + + if(isValidStart && isValidEnd) { + int pos = 0; + for(int i = 0; i < setString.length() && i < wordStart; i++) { + if(setString.charAt(i) == ',') { + pos++; + } + } + + return pos + 1; + } + } + } + + return 0; + } + + /** * Copy the bytes from the current UTF8String, and make a new UTF8String. * @param start the start position of the current UTF8String in bytes. * @param end the end position of the current UTF8String in bytes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b16061f8f4a9c..3dbade0267803 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1002,10 +1002,13 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac case class FindInSet(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) - override protected def nullSafeEval(word: Any, set: Any): Any = - set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) + override protected def nullSafeEval(word: Any, set: Any): Any = { + val collationId = left.dataType.asInstanceOf[StringType].collationId + set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String], collationId) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (word, set) => @@ -1013,6 +1016,16 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi ) } + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + return defaultCheck + } + + val collationId = left.dataType.asInstanceOf[StringType].collationId + CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType)) + } + override def dataType: DataType = IntegerType override def prettyName: String = "find_in_set" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 705dd4a184401..83e511de4ae1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -117,6 +117,51 @@ class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession }) } + test("Support FindInSet with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("a", "abc,b,ab,c,def", "UTF8_BINARY", 0), + CollationTestCase("c", "abc,b,ab,c,def", "UTF8_BINARY", 4), + CollationTestCase("abc", "abc,b,ab,c,def", "UTF8_BINARY", 1), + CollationTestCase("ab", "abc,b,ab,c,def", "UTF8_BINARY", 3), + CollationTestCase("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0), + CollationTestCase("Ab", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 3), + CollationTestCase("ab", "abc,b,ab,c,def", "UNICODE", 3), + CollationTestCase("aB", "abc,b,ab,c,def", "UNICODE", 0), + CollationTestCase("AB", "abc,b,ab,c,def", "UNICODE_CI", 3) + ) + checks.foreach(ct => { + checkAnswer(sql(s"SELECT find_in_set(collate('${ct.s1}', '${ct.collation}'), " + + s"collate('${ct.s2}', '${ct.collation}'))"), + Row(ct.expectedResult)) + }) + // Unsupported collation pairs + val fails = Seq( + SubstringIndexTestFail("a", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", "UTF8_BINARY"), + SubstringIndexTestFail("a", "abc,b,ab,c,def", "UNICODE_CI", "UNICODE") + ) + fails.foreach(ct => { + val expr = s"find_in_set(collate('${ct.s1}', '${ct.c1}'), collate('${ct.s2}', '${ct.c2}'))" + checkError( + exception = intercept[ExtendedAnalysisException] { + sql(s"SELECT $expr") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> s"\"find_in_set(collate(${ct.s1}), collate(${ct.s2}))\"", + "collationNameLeft" -> s"${ct.c1}", + "collationNameRight" -> s"${ct.c2}" + ), + context = ExpectedContext( + fragment = s"$expr", + start = 7, + stop = 51 + ct.s1.length + ct.c1.length + ct.s2.length + ct.c2.length + ) + ) + }) + } + // TODO: Add more tests for other string expressions } From 934083192da991b467491d516e9219ed712b5ee1 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 21 Mar 2024 16:57:19 +0100 Subject: [PATCH 03/46] Remove blank line from CollationStringExpressionsSuite.scala --- .../org/apache/spark/sql/CollationStringExpressionsSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 83e511de4ae1e..cf21e7da408b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.test.SharedSparkSession class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession { case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) - case class CollationTestFail[R](s1: String, s2: String, collation: String) test("Support ConcatWs string expression with Collation") { From 465e81432f7c929e3aa937e002a6b28c11b245eb Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 21 Mar 2024 22:00:00 +0100 Subject: [PATCH 04/46] Correct comment indentation --- .../src/main/java/org/apache/spark/unsafe/types/UTF8String.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 98becda285a82..b9a487adb7c43 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -590,7 +590,7 @@ private int collatedFindInSet(UTF8String match, int collationId) { return 0; } - /** + /** * Copy the bytes from the current UTF8String, and make a new UTF8String. * @param start the start position of the current UTF8String in bytes. * @param end the end position of the current UTF8String in bytes. From f3f30d8942a5f1835884a6b7990666f74c3cfea2 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 22 Mar 2024 12:07:45 +0100 Subject: [PATCH 05/46] Add unit tests for INSTR operation --- .../sql/CollationStringExpressionsSuite.scala | 67 ++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index cf21e7da408b5..c77d5c1f55d97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -21,10 +21,14 @@ import scala.collection.immutable.Seq import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.{Collate, ExpressionEvalHelper, Literal, StringInstr} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType -class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession { +class CollationStringExpressionsSuite extends QueryTest + with SharedSparkSession with ExpressionEvalHelper { case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) case class CollationTestFail[R](s1: String, s2: String, collation: String) @@ -70,6 +74,67 @@ class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession }) } + test("INSTR check result on non-explicit default collation") { + checkEvaluation(StringInstr(Literal("aAads"), Literal("Aa")), 2) + } + + test("INSTR check result on explicitly collated strings") { + // UTF8_BINARY_LCASE + checkEvaluation(StringInstr(Literal.create("aaads", StringType(1)), + Literal.create("Aa", StringType(1))), 1) + checkEvaluation(StringInstr(Collate(Literal("aaads"), "UTF8_BINARY_LCASE"), + Collate(Literal("Aa"), "UTF8_BINARY_LCASE")), 1) + // UNICODE + checkEvaluation(StringInstr(Literal.create("aaads", StringType(2)), + Literal.create("Aa", StringType(2))), 0) + checkEvaluation(StringInstr(Collate(Literal("aaads"), "UNICODE"), + Collate(Literal("Aa"), "UNICODE")), 0) + // UNICODE_CI + checkEvaluation(StringInstr(Literal.create("aaads", StringType(3)), + Literal.create("de", StringType(3))), 0) + checkEvaluation(StringInstr(Collate(Literal("aaads"), "UNICODE_CI"), + Collate(Literal("Aa"), "UNICODE_CI")), 0) + } + + test("INSTR fail mismatched collation types") { + // UNICODE and UNICODE_CI + val expr1 = StringInstr(Collate(Literal("aaads"), "UNICODE"), + Collate(Literal("Aa"), "UNICODE_CI")) + assert(expr1.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> "UNICODE", + "collationNameRight" -> "UNICODE_CI" + ) + ) + ) + // DEFAULT(UTF8_BINARY) and UTF8_BINARY_LCASE + val expr2 = StringInstr(Literal("aaads"), + Collate(Literal("Aa"), "UTF8_BINARY_LCASE")) + assert(expr2.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> "UTF8_BINARY", + "collationNameRight" -> "UTF8_BINARY_LCASE" + ) + ) + ) + // UTF8_BINARY_LCASE and UNICODE_CI + val expr3 = StringInstr(Collate(Literal("aaads"), "UTF8_BINARY_LCASE"), + Collate(Literal("Aa"), "UNICODE_CI")) + assert(expr3.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> "UTF8_BINARY_LCASE", + "collationNameRight" -> "UNICODE_CI" + ) + ) + ) + } + case class SubstringIndexTestFail[R](s1: String, s2: String, c1: String, c2: String) test("Support SubstringIndex with Collation") { From 9cb92d3309c2a28caf39b66040d067f032b18923 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 22 Mar 2024 12:31:03 +0100 Subject: [PATCH 06/46] Add doGenCode for FindInSet --- .../sql/catalyst/expressions/stringExpressions.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 3dbade0267803..b56dfc7ff7e56 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1011,9 +1011,13 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (word, set) => - s"${ev.value} = $set.findInSet($word);" - ) + val collationId = left.dataType.asInstanceOf[StringType].collationId + + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);") + } else { + nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word, $collationId);") + } } override def checkInputDataTypes(): TypeCheckResult = { From 834be708ea8b9423f3361a2be53671c8462ef7d9 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 22 Mar 2024 12:47:30 +0100 Subject: [PATCH 07/46] Rewrite unit tests for INSTR and FIND_IN_SET --- .../sql/CollationStringExpressionsSuite.scala | 156 ++++++++---------- 1 file changed, 73 insertions(+), 83 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index c77d5c1f55d97..87bc2065f9b41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -22,7 +22,7 @@ import scala.collection.immutable.Seq import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.{Collate, ExpressionEvalHelper, Literal, StringInstr} +import org.apache.spark.sql.catalyst.expressions.{Collate, ExpressionEvalHelper, FindInSet, Literal, StringInstr} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StringType @@ -135,95 +135,85 @@ class CollationStringExpressionsSuite extends QueryTest ) } - case class SubstringIndexTestFail[R](s1: String, s2: String, c1: String, c2: String) + test("FIND_IN_SET check result on non-explicit default collation") { + checkEvaluation(FindInSet(Literal("def"), Literal("abc,b,ab,c,def")), 5) + checkEvaluation(FindInSet(Literal("defg"), Literal("abc,b,ab,c,def")), 0) + } - test("Support SubstringIndex with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("The quick brown fox jumps over the dog.", "fox", "UTF8_BINARY", 17), - CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UTF8_BINARY", 0), - CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UTF8_BINARY_LCASE", 17), - CollationTestCase("The quick brown fox jumps over the dog.", "fox", "UNICODE", 17), - CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UNICODE", 0), - CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UNICODE_CI", 17) - ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT instr(collate('${ct.s1}', '${ct.collation}'), " + - s"collate('${ct.s2}', '${ct.collation}'))"), - Row(ct.expectedResult)) - }) - // Unsupported collation pairs - val fails = Seq( - SubstringIndexTestFail("The quick brown fox jumps over the dog.", - "Fox", "UTF8_BINARY_LCASE", "UTF8_BINARY"), - SubstringIndexTestFail("The quick brown fox jumps over the dog.", - "FOX", "UNICODE_CI", "UNICODE") - ) - fails.foreach(ct => { - val expr = s"instr(collate('${ct.s1}', '${ct.c1}'), collate('${ct.s2}', '${ct.c2}'))" - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT $expr") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"instr(collate(${ct.s1}), collate(${ct.s2}))\"", - "collationNameLeft" -> s"${ct.c1}", - "collationNameRight" -> s"${ct.c2}" - ), - context = ExpectedContext( - fragment = s"$expr", - start = 7, - stop = 45 + ct.s1.length + ct.c1.length + ct.s2.length + ct.c2.length - ) - ) - }) + test("FIND_IN_SET check result on explicitly collated strings") { + // UTF8_BINARY + checkEvaluation(FindInSet(Collate(Literal("a"), "UTF8_BINARY"), + Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 0) + checkEvaluation(FindInSet(Collate(Literal("c"), "UTF8_BINARY"), + Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 4) + checkEvaluation(FindInSet(Collate(Literal("AB"), "UTF8_BINARY"), + Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 0) + checkEvaluation(FindInSet(Collate(Literal("abcd"), "UTF8_BINARY"), + Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 0) + // UTF8_BINARY_LCASE + checkEvaluation(FindInSet(Collate(Literal("aB"), "UTF8_BINARY_LCASE"), + Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 3) + checkEvaluation(FindInSet(Collate(Literal("a"), "UTF8_BINARY_LCASE"), + Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 0) + checkEvaluation(FindInSet(Collate(Literal("abc"), "UTF8_BINARY_LCASE"), + Collate(Literal("aBc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 1) + checkEvaluation(FindInSet(Collate(Literal("abcd"), "UTF8_BINARY_LCASE"), + Collate(Literal("aBc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 0) + // UNICODE + checkEvaluation(FindInSet(Collate(Literal("a"), "UNICODE"), + Collate(Literal("abc,b,ab,c,def"), "UNICODE")), 0) + checkEvaluation(FindInSet(Collate(Literal("ab"), "UNICODE"), + Collate(Literal("abc,b,ab,c,def"), "UNICODE")), 3) + checkEvaluation(FindInSet(Collate(Literal("Ab"), "UNICODE"), + Collate(Literal("abc,b,ab,c,def"), "UNICODE")), 0) + // UNICODE_CI + checkEvaluation(FindInSet(Collate(Literal("a"), "UNICODE_CI"), + Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 0) + checkEvaluation(FindInSet(Collate(Literal("C"), "UNICODE_CI"), + Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 4) + checkEvaluation(FindInSet(Collate(Literal("DeF"), "UNICODE_CI"), + Collate(Literal("abc,b,ab,c,dEf"), "UNICODE_CI")), 5) + checkEvaluation(FindInSet(Collate(Literal("DEFG"), "UNICODE_CI"), + Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 0) } - test("Support FindInSet with Collation") { - // Supported collations - val checks = Seq( - CollationTestCase("a", "abc,b,ab,c,def", "UTF8_BINARY", 0), - CollationTestCase("c", "abc,b,ab,c,def", "UTF8_BINARY", 4), - CollationTestCase("abc", "abc,b,ab,c,def", "UTF8_BINARY", 1), - CollationTestCase("ab", "abc,b,ab,c,def", "UTF8_BINARY", 3), - CollationTestCase("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0), - CollationTestCase("Ab", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 3), - CollationTestCase("ab", "abc,b,ab,c,def", "UNICODE", 3), - CollationTestCase("aB", "abc,b,ab,c,def", "UNICODE", 0), - CollationTestCase("AB", "abc,b,ab,c,def", "UNICODE_CI", 3) + test("FIND_IN_SET fail mismatched collation types") { + // UNICODE and UNICODE_CI + val expr1 = FindInSet(Collate(Literal("a"), "UNICODE"), + Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")) + assert(expr1.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> "UNICODE", + "collationNameRight" -> "UNICODE_CI" + ) + ) ) - checks.foreach(ct => { - checkAnswer(sql(s"SELECT find_in_set(collate('${ct.s1}', '${ct.collation}'), " + - s"collate('${ct.s2}', '${ct.collation}'))"), - Row(ct.expectedResult)) - }) - // Unsupported collation pairs - val fails = Seq( - SubstringIndexTestFail("a", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", "UTF8_BINARY"), - SubstringIndexTestFail("a", "abc,b,ab,c,def", "UNICODE_CI", "UNICODE") + // DEFAULT(UTF8_BINARY) and UTF8_BINARY_LCASE + val expr2 = FindInSet(Collate(Literal("a"), "UTF8_BINARY"), + Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY_LCASE")) + assert(expr2.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> "UTF8_BINARY", + "collationNameRight" -> "UTF8_BINARY_LCASE" + ) + ) ) - fails.foreach(ct => { - val expr = s"find_in_set(collate('${ct.s1}', '${ct.c1}'), collate('${ct.s2}', '${ct.c2}'))" - checkError( - exception = intercept[ExtendedAnalysisException] { - sql(s"SELECT $expr") - }, - errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", - sqlState = "42K09", - parameters = Map( - "sqlExpr" -> s"\"find_in_set(collate(${ct.s1}), collate(${ct.s2}))\"", - "collationNameLeft" -> s"${ct.c1}", - "collationNameRight" -> s"${ct.c2}" - ), - context = ExpectedContext( - fragment = s"$expr", - start = 7, - stop = 51 + ct.s1.length + ct.c1.length + ct.s2.length + ct.c2.length + // UTF8_BINARY_LCASE and UNICODE_CI + val expr3 = FindInSet(Collate(Literal("a"), "UTF8_BINARY_LCASE"), + Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")) + assert(expr3.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> "UTF8_BINARY_LCASE", + "collationNameRight" -> "UNICODE_CI" ) ) - }) + ) } // TODO: Add more tests for other string expressions From db2453af5839db9a1de41316b23e5da5bf592919 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 22 Mar 2024 14:57:14 +0100 Subject: [PATCH 08/46] Correct return value when substr is not found in INSTR method --- .../java/org/apache/spark/unsafe/types/UTF8String.java | 2 +- .../spark/sql/CollationStringExpressionsSuite.scala | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index b9a487adb7c43..69ee0588a207c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -900,7 +900,7 @@ private int collatedIndexOf(UTF8String substring, int collationId) { } } - return 0; + return -1; } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 87bc2065f9b41..fbbb7d3a86877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -84,16 +84,18 @@ class CollationStringExpressionsSuite extends QueryTest Literal.create("Aa", StringType(1))), 1) checkEvaluation(StringInstr(Collate(Literal("aaads"), "UTF8_BINARY_LCASE"), Collate(Literal("Aa"), "UTF8_BINARY_LCASE")), 1) + checkEvaluation(StringInstr(Collate(Literal("aaads"), "UTF8_BINARY_LCASE"), + Collate(Literal("de"), "UTF8_BINARY_LCASE")), 0) // UNICODE checkEvaluation(StringInstr(Literal.create("aaads", StringType(2)), Literal.create("Aa", StringType(2))), 0) checkEvaluation(StringInstr(Collate(Literal("aaads"), "UNICODE"), - Collate(Literal("Aa"), "UNICODE")), 0) + Collate(Literal("de"), "UNICODE")), 0) // UNICODE_CI checkEvaluation(StringInstr(Literal.create("aaads", StringType(3)), Literal.create("de", StringType(3))), 0) checkEvaluation(StringInstr(Collate(Literal("aaads"), "UNICODE_CI"), - Collate(Literal("Aa"), "UNICODE_CI")), 0) + Collate(Literal("AD"), "UNICODE_CI")), 3) } test("INSTR fail mismatched collation types") { @@ -175,6 +177,8 @@ class CollationStringExpressionsSuite extends QueryTest Collate(Literal("abc,b,ab,c,dEf"), "UNICODE_CI")), 5) checkEvaluation(FindInSet(Collate(Literal("DEFG"), "UNICODE_CI"), Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 0) + checkEvaluation(FindInSet(Collate(Literal("dsf"), "UNICODE_CI"), + Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 0) } test("FIND_IN_SET fail mismatched collation types") { From 91b648a6bbf25314ad56f64bc6dba903e13fa8f0 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 25 Mar 2024 12:16:30 +0100 Subject: [PATCH 09/46] Update unit tests for StringInStr and FindInSet --- .../sql/CollationStringExpressionsSuite.scala | 83 ++++++++----------- 1 file changed, 33 insertions(+), 50 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index fbbb7d3a86877..87bdc1660294e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -74,28 +74,23 @@ class CollationStringExpressionsSuite extends QueryTest }) } - test("INSTR check result on non-explicit default collation") { - checkEvaluation(StringInstr(Literal("aAads"), Literal("Aa")), 2) - } - test("INSTR check result on explicitly collated strings") { + def testInStr(expected: Integer, stringType: Integer, str: String, substr: String): Unit = { + val string = Literal.create(str, StringType(stringType)) + val substring = Literal.create(substr, StringType(stringType)) + + checkEvaluation(StringInstr(string, substring), expected) + } + // UTF8_BINARY_LCASE - checkEvaluation(StringInstr(Literal.create("aaads", StringType(1)), - Literal.create("Aa", StringType(1))), 1) - checkEvaluation(StringInstr(Collate(Literal("aaads"), "UTF8_BINARY_LCASE"), - Collate(Literal("Aa"), "UTF8_BINARY_LCASE")), 1) - checkEvaluation(StringInstr(Collate(Literal("aaads"), "UTF8_BINARY_LCASE"), - Collate(Literal("de"), "UTF8_BINARY_LCASE")), 0) + testInStr(1, 1, "aaads", "Aa") + testInStr(0, 1, "aaaDs", "de") // UNICODE - checkEvaluation(StringInstr(Literal.create("aaads", StringType(2)), - Literal.create("Aa", StringType(2))), 0) - checkEvaluation(StringInstr(Collate(Literal("aaads"), "UNICODE"), - Collate(Literal("de"), "UNICODE")), 0) + testInStr(0, 2, "aaads", "Aa") + testInStr(0, 2, "aaads", "de") // UNICODE_CI - checkEvaluation(StringInstr(Literal.create("aaads", StringType(3)), - Literal.create("de", StringType(3))), 0) - checkEvaluation(StringInstr(Collate(Literal("aaads"), "UNICODE_CI"), - Collate(Literal("AD"), "UNICODE_CI")), 3) + testInStr(3, 3, "aaads", "AD") + testInStr(4, 3, "aaads", "dS") } test("INSTR fail mismatched collation types") { @@ -143,42 +138,30 @@ class CollationStringExpressionsSuite extends QueryTest } test("FIND_IN_SET check result on explicitly collated strings") { + def testFindInSet(expected: Integer, stringType: Integer, word: String, set: String): Unit = { + val w = Literal.create(word, StringType(stringType)) + val s = Literal.create(set, StringType(stringType)) + + checkEvaluation(FindInSet(w, s), expected) + } + // UTF8_BINARY - checkEvaluation(FindInSet(Collate(Literal("a"), "UTF8_BINARY"), - Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 0) - checkEvaluation(FindInSet(Collate(Literal("c"), "UTF8_BINARY"), - Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 4) - checkEvaluation(FindInSet(Collate(Literal("AB"), "UTF8_BINARY"), - Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 0) - checkEvaluation(FindInSet(Collate(Literal("abcd"), "UTF8_BINARY"), - Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY")), 0) + testFindInSet(0, 0, "AB", "abc,b,ab,c,def") // UTF8_BINARY_LCASE - checkEvaluation(FindInSet(Collate(Literal("aB"), "UTF8_BINARY_LCASE"), - Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 3) - checkEvaluation(FindInSet(Collate(Literal("a"), "UTF8_BINARY_LCASE"), - Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 0) - checkEvaluation(FindInSet(Collate(Literal("abc"), "UTF8_BINARY_LCASE"), - Collate(Literal("aBc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 1) - checkEvaluation(FindInSet(Collate(Literal("abcd"), "UTF8_BINARY_LCASE"), - Collate(Literal("aBc,b,ab,c,def"), "UTF8_BINARY_LCASE")), 0) + testFindInSet(0, 1, "a", "abc,b,ab,c,def") + testFindInSet(4, 1, "c", "abc,b,ab,c,def") + testFindInSet(3, 1, "AB", "abc,b,ab,c,def") + testFindInSet(1, 1, "AbC", "abc,b,ab,c,def") + testFindInSet(0, 1, "abcd", "abc,b,ab,c,def") // UNICODE - checkEvaluation(FindInSet(Collate(Literal("a"), "UNICODE"), - Collate(Literal("abc,b,ab,c,def"), "UNICODE")), 0) - checkEvaluation(FindInSet(Collate(Literal("ab"), "UNICODE"), - Collate(Literal("abc,b,ab,c,def"), "UNICODE")), 3) - checkEvaluation(FindInSet(Collate(Literal("Ab"), "UNICODE"), - Collate(Literal("abc,b,ab,c,def"), "UNICODE")), 0) + testFindInSet(0, 2, "a", "abc,b,ab,c,def") + testFindInSet(3, 2, "ab", "abc,b,ab,c,def") + testFindInSet(0, 2, "Ab", "abc,b,ab,c,def") // UNICODE_CI - checkEvaluation(FindInSet(Collate(Literal("a"), "UNICODE_CI"), - Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 0) - checkEvaluation(FindInSet(Collate(Literal("C"), "UNICODE_CI"), - Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 4) - checkEvaluation(FindInSet(Collate(Literal("DeF"), "UNICODE_CI"), - Collate(Literal("abc,b,ab,c,dEf"), "UNICODE_CI")), 5) - checkEvaluation(FindInSet(Collate(Literal("DEFG"), "UNICODE_CI"), - Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 0) - checkEvaluation(FindInSet(Collate(Literal("dsf"), "UNICODE_CI"), - Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")), 0) + testFindInSet(0, 3, "a", "abc,b,ab,c,def") + testFindInSet(4, 3, "C", "abc,b,ab,c,def") + testFindInSet(5, 3, "DeF", "abc,b,ab,c,dEf") + testFindInSet(0, 3, "DEFG", "abc,b,ab,c,def") } test("FIND_IN_SET fail mismatched collation types") { From 1062521ac0b240a4bd926ba4abd7b8eab3780049 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 25 Mar 2024 12:29:53 +0100 Subject: [PATCH 10/46] Remove tests on non-explicit default collation --- .../apache/spark/sql/CollationStringExpressionsSuite.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 87bdc1660294e..1d8207637e62f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -132,11 +132,6 @@ class CollationStringExpressionsSuite extends QueryTest ) } - test("FIND_IN_SET check result on non-explicit default collation") { - checkEvaluation(FindInSet(Literal("def"), Literal("abc,b,ab,c,def")), 5) - checkEvaluation(FindInSet(Literal("defg"), Literal("abc,b,ab,c,def")), 0) - } - test("FIND_IN_SET check result on explicitly collated strings") { def testFindInSet(expected: Integer, stringType: Integer, word: String, set: String): Unit = { val w = Literal.create(word, StringType(stringType)) From 427ea255acfc39836bb177eaafc54befd8762701 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 26 Mar 2024 16:13:22 +0100 Subject: [PATCH 11/46] Improve signature of testInStr --- .../sql/CollationStringExpressionsSuite.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 1d8207637e62f..aebc56a8b6b43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -75,22 +75,22 @@ class CollationStringExpressionsSuite extends QueryTest } test("INSTR check result on explicitly collated strings") { - def testInStr(expected: Integer, stringType: Integer, str: String, substr: String): Unit = { - val string = Literal.create(str, StringType(stringType)) - val substring = Literal.create(substr, StringType(stringType)) + def testInStr(str: String, substr: String, collationId: Integer, expected: Integer): Unit = { + val string = Literal.create(str, StringType(collationId)) + val substring = Literal.create(substr, StringType(collationId)) checkEvaluation(StringInstr(string, substring), expected) } // UTF8_BINARY_LCASE - testInStr(1, 1, "aaads", "Aa") - testInStr(0, 1, "aaaDs", "de") + testInStr("aaads", "Aa", 1, 1) + testInStr("aaaDs", "de", 1, 0) // UNICODE - testInStr(0, 2, "aaads", "Aa") - testInStr(0, 2, "aaads", "de") + testInStr("aaads", "Aa", 2, 0) + testInStr("aaads", "de", 2, 0) // UNICODE_CI - testInStr(3, 3, "aaads", "AD") - testInStr(4, 3, "aaads", "dS") + testInStr("aaads", "AD", 3, 3) + testInStr("aaads", "dS", 3, 4) } test("INSTR fail mismatched collation types") { From 108d707fa55b7e293576d4a6710e6d93948e3f86 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 26 Mar 2024 16:35:27 +0100 Subject: [PATCH 12/46] Remove E2E test for collation mismatch. This will be added in Implicit Casting --- .../sql/CollationStringExpressionsSuite.scala | 39 ------------------- 1 file changed, 39 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 604d5b052ad33..cd5fc84d305d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -93,45 +93,6 @@ class CollationStringExpressionsSuite extends QueryTest testInStr("aaads", "dS", 3, 4) } - test("INSTR fail mismatched collation types") { - // UNICODE and UNICODE_CI - val expr1 = StringInstr(Collate(Literal("aaads"), "UNICODE"), - Collate(Literal("Aa"), "UNICODE_CI")) - assert(expr1.checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> "UNICODE", - "collationNameRight" -> "UNICODE_CI" - ) - ) - ) - // DEFAULT(UTF8_BINARY) and UTF8_BINARY_LCASE - val expr2 = StringInstr(Literal("aaads"), - Collate(Literal("Aa"), "UTF8_BINARY_LCASE")) - assert(expr2.checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> "UTF8_BINARY", - "collationNameRight" -> "UTF8_BINARY_LCASE" - ) - ) - ) - // UTF8_BINARY_LCASE and UNICODE_CI - val expr3 = StringInstr(Collate(Literal("aaads"), "UTF8_BINARY_LCASE"), - Collate(Literal("Aa"), "UNICODE_CI")) - assert(expr3.checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> "UTF8_BINARY_LCASE", - "collationNameRight" -> "UNICODE_CI" - ) - ) - ) - } - test("FIND_IN_SET check result on explicitly collated strings") { def testFindInSet(expected: Integer, stringType: Integer, word: String, set: String): Unit = { val w = Literal.create(word, StringType(stringType)) From 822ecd2f6d6b212db39c1aa94bcd71c6ba53dae6 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 26 Mar 2024 16:38:45 +0100 Subject: [PATCH 13/46] Resolve merge problems with master --- .../apache/spark/unsafe/types/UTF8String.java | 8 ++-- .../expressions/stringExpressions.scala | 4 +- .../sql/CollationStringExpressionsSuite.scala | 40 ------------------- 3 files changed, 6 insertions(+), 46 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 7e2a958602a64..d14e108ef4cd3 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -550,10 +550,10 @@ public int findInSet(UTF8String match) { } public int findInSet(UTF8String match, int collationId) { - if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { return this.findInSet(match); } - if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) { + if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { return this.toLowerCase().findInSet(match.toLowerCase()); } return collatedFindInSet(match, collationId); @@ -877,10 +877,10 @@ public int indexOf(UTF8String v, int start) { } public int indexOf(UTF8String substring, int start, int collationId) { - if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { return this.indexOf(substring, start); } - if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) { + if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { return this.toLowerCase().indexOf(substring.toLowerCase(), start); } return collatedIndexOf(substring, collationId); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 43f2c76853152..32694c66afd23 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1013,7 +1013,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val collationId = left.dataType.asInstanceOf[StringType].collationId - if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);") } else { nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word, $collationId);") @@ -1417,7 +1417,7 @@ case class StringInstr(str: Expression, substr: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val collationId = left.dataType.asInstanceOf[StringType].collationId - if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } else { defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0, $collationId) + 1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index cd5fc84d305d2..9e5a5579e127b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -21,7 +21,6 @@ import scala.collection.immutable.Seq import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, FindInSet, Literal, StringInstr, StringRepeat} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -119,45 +118,6 @@ class CollationStringExpressionsSuite extends QueryTest testFindInSet(5, 3, "DeF", "abc,b,ab,c,dEf") testFindInSet(0, 3, "DEFG", "abc,b,ab,c,def") } - - test("FIND_IN_SET fail mismatched collation types") { - // UNICODE and UNICODE_CI - val expr1 = FindInSet(Collate(Literal("a"), "UNICODE"), - Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")) - assert(expr1.checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> "UNICODE", - "collationNameRight" -> "UNICODE_CI" - ) - ) - ) - // DEFAULT(UTF8_BINARY) and UTF8_BINARY_LCASE - val expr2 = FindInSet(Collate(Literal("a"), "UTF8_BINARY"), - Collate(Literal("abc,b,ab,c,def"), "UTF8_BINARY_LCASE")) - assert(expr2.checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> "UTF8_BINARY", - "collationNameRight" -> "UTF8_BINARY_LCASE" - ) - ) - ) - // UTF8_BINARY_LCASE and UNICODE_CI - val expr3 = FindInSet(Collate(Literal("a"), "UTF8_BINARY_LCASE"), - Collate(Literal("abc,b,ab,c,def"), "UNICODE_CI")) - assert(expr3.checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "COLLATION_MISMATCH", - messageParameters = Map( - "collationNameLeft" -> "UTF8_BINARY_LCASE", - "collationNameRight" -> "UNICODE_CI" - ) - ) - ) - } test("REPEAT check output type on explicitly collated string") { def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = { From f730d054d359ce0a222ccea68a2f25efc5168b75 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 26 Mar 2024 17:37:56 +0100 Subject: [PATCH 14/46] Improve scala style --- .../org/apache/spark/sql/CollationStringExpressionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 9e5a5579e127b..74bbce242c6c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -118,7 +118,7 @@ class CollationStringExpressionsSuite extends QueryTest testFindInSet(5, 3, "DeF", "abc,b,ab,c,dEf") testFindInSet(0, 3, "DEFG", "abc,b,ab,c,def") } - + test("REPEAT check output type on explicitly collated string") { def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = { val s = Literal.create(input, StringType(collationId)) From de7b59182701c7b0db3d1c4027ef768a322c292b Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 27 Mar 2024 10:03:30 +0100 Subject: [PATCH 15/46] Solve whitespace scala style problem --- .../org/apache/spark/sql/CollationStringExpressionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 74bbce242c6c8..7abbbaffa71e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -118,7 +118,7 @@ class CollationStringExpressionsSuite extends QueryTest testFindInSet(5, 3, "DeF", "abc,b,ab,c,dEf") testFindInSet(0, 3, "DEFG", "abc,b,ab,c,def") } - + test("REPEAT check output type on explicitly collated string") { def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = { val s = Literal.create(input, StringType(collationId)) From f0ee8fd83f91747d96bfb0de7998e71c8aacc69a Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 1 Apr 2024 15:29:49 +0200 Subject: [PATCH 16/46] Add lazy val collationId --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 32694c66afd23..6fcd61adbcbff 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1002,17 +1002,16 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac case class FindInSet(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, StringTypeAnyCollation) override protected def nullSafeEval(word: Any, set: Any): Any = { - val collationId = left.dataType.asInstanceOf[StringType].collationId set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String], collationId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val collationId = left.dataType.asInstanceOf[StringType].collationId - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);") } else { From 4ac688517a5d1721af2dbd20381ca6fe92861cd0 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 1 Apr 2024 15:31:21 +0200 Subject: [PATCH 17/46] Remove repeated code --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 6fcd61adbcbff..b28d0e3d432c2 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1025,7 +1025,6 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi return defaultCheck } - val collationId = left.dataType.asInstanceOf[StringType].collationId CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType)) } @@ -1390,6 +1389,8 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non case class StringInstr(str: Expression, substr: Expression) extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { + final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId + override def left: Expression = str override def right: Expression = substr override def dataType: DataType = IntegerType @@ -1397,7 +1398,6 @@ case class StringInstr(str: Expression, substr: Expression) Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def nullSafeEval(string: Any, sub: Any): Any = { - val collationId = left.dataType.asInstanceOf[StringType].collationId string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0, collationId) + 1 } @@ -1407,15 +1407,12 @@ case class StringInstr(str: Expression, substr: Expression) return defaultCheck } - val collationId = left.dataType.asInstanceOf[StringType].collationId CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType)) } override def prettyName: String = "instr" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val collationId = left.dataType.asInstanceOf[StringType].collationId - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } else { From b931333f4bbef2f16c7fe693a7982b31a96b57b2 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 1 Apr 2024 15:40:33 +0200 Subject: [PATCH 18/46] Improve test format --- .../sql/CollationStringExpressionsSuite.scala | 84 +++++++++++-------- 1 file changed, 50 insertions(+), 34 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 7abbbaffa71e7..537b84ab79c68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.Seq import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, FindInSet, Literal, StringInstr, StringRepeat} +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StringType @@ -81,55 +82,70 @@ class CollationStringExpressionsSuite extends QueryTest checkEvaluation(StringInstr(string, substring), expected) } - // UTF8_BINARY_LCASE - testInStr("aaads", "Aa", 1, 1) - testInStr("aaaDs", "de", 1, 0) - // UNICODE - testInStr("aaads", "Aa", 2, 0) - testInStr("aaads", "de", 2, 0) - // UNICODE_CI - testInStr("aaads", "AD", 3, 3) - testInStr("aaads", "dS", 3, 4) + var collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") + testInStr("aaads", "Aa", collationId, 1) + testInStr("aaaDs", "de", collationId, 0) + + collationId = CollationFactory.collationNameToId("UNICODE") + testInStr("aaads", "Aa", collationId, 0) + testInStr("aaads", "de", collationId, 0) + + collationId = CollationFactory.collationNameToId("UNICODE_CI") + testInStr("aaads", "AD", collationId, 3) + testInStr("aaads", "dS", collationId, 4) } test("FIND_IN_SET check result on explicitly collated strings") { - def testFindInSet(expected: Integer, stringType: Integer, word: String, set: String): Unit = { - val w = Literal.create(word, StringType(stringType)) - val s = Literal.create(set, StringType(stringType)) + def testFindInSet(word: String, set: String, collationId: Integer, expected: Integer): Unit = { + val w = Literal.create(word, StringType(collationId)) + val s = Literal.create(set, StringType(collationId)) checkEvaluation(FindInSet(w, s), expected) } - // UTF8_BINARY - testFindInSet(0, 0, "AB", "abc,b,ab,c,def") - // UTF8_BINARY_LCASE - testFindInSet(0, 1, "a", "abc,b,ab,c,def") - testFindInSet(4, 1, "c", "abc,b,ab,c,def") - testFindInSet(3, 1, "AB", "abc,b,ab,c,def") - testFindInSet(1, 1, "AbC", "abc,b,ab,c,def") - testFindInSet(0, 1, "abcd", "abc,b,ab,c,def") - // UNICODE - testFindInSet(0, 2, "a", "abc,b,ab,c,def") - testFindInSet(3, 2, "ab", "abc,b,ab,c,def") - testFindInSet(0, 2, "Ab", "abc,b,ab,c,def") - // UNICODE_CI - testFindInSet(0, 3, "a", "abc,b,ab,c,def") - testFindInSet(4, 3, "C", "abc,b,ab,c,def") - testFindInSet(5, 3, "DeF", "abc,b,ab,c,dEf") - testFindInSet(0, 3, "DEFG", "abc,b,ab,c,def") + var collationId = CollationFactory.collationNameToId("UTF8_BINARY") + testFindInSet("AB", "abc,b,ab,c,def", collationId, 0) + + collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") + testFindInSet("a", "abc,b,ab,c,def", collationId, 0) + testFindInSet("c", "abc,b,ab,c,def", collationId, 4) + testFindInSet("AB", "abc,b,ab,c,def", collationId, 3) + testFindInSet("AbC", "abc,b,ab,c,def", collationId, 1) + testFindInSet("abcd", "abc,b,ab,c,def", collationId, 0) + + collationId = CollationFactory.collationNameToId("UNICODE") + testFindInSet("a", "abc,b,ab,c,def", collationId, 0) + testFindInSet("ab", "abc,b,ab,c,def", collationId, 3) + testFindInSet("Ab", "abc,b,ab,c,def", collationId, 0) + + collationId = CollationFactory.collationNameToId("UNICODE_CI") + testFindInSet("a", "abc,b,ab,c,def", collationId, 0) + testFindInSet("C", "abc,b,ab,c,def", collationId, 4) + testFindInSet("DeF", "abc,b,ab,c,dEf", collationId, 5) + testFindInSet("DEFG", "abc,b,ab,c,def", collationId, 0) } test("REPEAT check output type on explicitly collated string") { - def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = { + def testRepeat(input: String, n: Int, collationId: Int, expected: String): Unit = { val s = Literal.create(input, StringType(collationId)) checkEvaluation(Collation(StringRepeat(s, Literal.create(n))).replacement, expected) } - testRepeat("UTF8_BINARY", 0, "abc", 2) - testRepeat("UTF8_BINARY_LCASE", 1, "abc", 2) - testRepeat("UNICODE", 2, "abc", 2) - testRepeat("UNICODE_CI", 3, "abc", 2) + // Not important for this test + val repeatNum = 2; + + var collationId = CollationFactory.collationNameToId("UTF8_BINARY") + testRepeat("abc", repeatNum, collationId, "UTF8_BINARY") + + collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") + testRepeat("abc", repeatNum, collationId, "UTF8_BINARY_LCASE") + + collationId = CollationFactory.collationNameToId("UNICODE") + testRepeat("abc", repeatNum, collationId, "UNICODE") + + collationId = CollationFactory.collationNameToId("UNICODE_CI") + testRepeat("abc", repeatNum, collationId, "UNICODE_CI") } // TODO: Add more tests for other string expressions From 0fd51d58300180a082007c769c88ba4aeed09169 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 1 Apr 2024 15:50:21 +0200 Subject: [PATCH 19/46] Improve indexOf method --- .../apache/spark/unsafe/types/UTF8String.java | 39 +++++++++---------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index d14e108ef4cd3..36b7441058e31 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -569,21 +569,19 @@ private int collatedFindInSet(UTF8String match, int collationId) { String setString = this.toString(); int wordStart = 0; while ((wordStart = stringSearch.next()) != StringSearch.DONE) { - if (stringSearch.getMatchLength() == stringSearch.getPattern().length()) { - boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ','; - boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length() - || setString.charAt(wordStart + stringSearch.getMatchLength()) == ','; - - if(isValidStart && isValidEnd) { - int pos = 0; - for(int i = 0; i < setString.length() && i < wordStart; i++) { - if(setString.charAt(i) == ',') { - pos++; - } + boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ','; + boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length() + || setString.charAt(wordStart + stringSearch.getMatchLength()) == ','; + + if(isValidStart && isValidEnd) { + int pos = 0; + for(int i = 0; i < setString.length() && i < wordStart; i++) { + if(setString.charAt(i) == ',') { + pos++; } - - return pos + 1; } + + return pos + 1; } } @@ -883,24 +881,23 @@ public int indexOf(UTF8String substring, int start, int collationId) { if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { return this.toLowerCase().indexOf(substring.toLowerCase(), start); } - return collatedIndexOf(substring, collationId); + return collatedIndexOf(substring, start, collationId); } - private int collatedIndexOf(UTF8String substring, int collationId) { + private int collatedIndexOf(UTF8String substring, int start, int collationId) { if (substring.numBytes == 0) { return 0; } StringSearch stringSearch = CollationFactory.getStringSearch(this, substring, collationId); + stringSearch.setOverlapping(true); - int pos = 0; - while ((pos = stringSearch.next()) != StringSearch.DONE) { - if (stringSearch.getMatchLength() == stringSearch.getPattern().length()) { - return pos; - } + int pos = stringSearch.next(); + while(pos != StringSearch.DONE && pos < start) { + pos = stringSearch.next(); } - return -1; + return pos >= start ? pos : -1; } /** From 28fa7f041e784f4bf59197d66530af3eefc5386f Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 1 Apr 2024 15:52:10 +0200 Subject: [PATCH 20/46] Remove checks in return statement of collatedIndexOf method --- .../src/main/java/org/apache/spark/unsafe/types/UTF8String.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 36b7441058e31..2e923c4d79eaf 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -897,7 +897,7 @@ private int collatedIndexOf(UTF8String substring, int start, int collationId) { pos = stringSearch.next(); } - return pos >= start ? pos : -1; + return pos; } /** From 4666affdb83d24ebc71ca6117f74afb6366737b5 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 2 Apr 2024 11:08:43 +0200 Subject: [PATCH 21/46] Add branch for collated findInSet --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b28d0e3d432c2..dcec7df606d37 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1008,7 +1008,11 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi Seq(StringTypeAnyCollation, StringTypeAnyCollation) override protected def nullSafeEval(word: Any, set: Any): Any = { - set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String], collationId) + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) + } else { + set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String], collationId) + } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { From ca8a37cd60d7cc48f916c844c8b552273e2e1e72 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 2 Apr 2024 11:09:49 +0200 Subject: [PATCH 22/46] Add branch for collation check in StringInstr --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index dcec7df606d37..982468d7d088b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1402,7 +1402,11 @@ case class StringInstr(str: Expression, substr: Expression) Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def nullSafeEval(string: Any, sub: Any): Any = { - string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0, collationId) + 1 + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1 + } else { + string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0, collationId) + 1 + } } override def checkInputDataTypes(): TypeCheckResult = { From 4ffab78bb71cabe8d42bcf32772442851c440bf9 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 2 Apr 2024 12:31:00 +0200 Subject: [PATCH 23/46] Improve naming of collation aware methods --- .../java/org/apache/spark/unsafe/types/UTF8String.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 2e923c4d79eaf..4573a264abbe2 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -556,10 +556,10 @@ public int findInSet(UTF8String match, int collationId) { if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { return this.toLowerCase().findInSet(match.toLowerCase()); } - return collatedFindInSet(match, collationId); + return collationAwareFindInSet(match, collationId); } - private int collatedFindInSet(UTF8String match, int collationId) { + private int collationAwareFindInSet(UTF8String match, int collationId) { if (match.contains(COMMA_UTF8)) { return 0; } @@ -881,10 +881,10 @@ public int indexOf(UTF8String substring, int start, int collationId) { if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { return this.toLowerCase().indexOf(substring.toLowerCase(), start); } - return collatedIndexOf(substring, start, collationId); + return collationAwareIndexOf(substring, start, collationId); } - private int collatedIndexOf(UTF8String substring, int start, int collationId) { + private int collationAwareIndexOf(UTF8String substring, int start, int collationId) { if (substring.numBytes == 0) { return 0; } From 037d6beb129089903035c4baa4c15aacb0a34de1 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 3 Apr 2024 10:12:04 +0200 Subject: [PATCH 24/46] Improve java style --- .../java/org/apache/spark/unsafe/types/UTF8String.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 4573a264abbe2..85d05639b3a99 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -573,10 +573,10 @@ private int collationAwareFindInSet(UTF8String match, int collationId) { boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length() || setString.charAt(wordStart + stringSearch.getMatchLength()) == ','; - if(isValidStart && isValidEnd) { + if (isValidStart && isValidEnd) { int pos = 0; - for(int i = 0; i < setString.length() && i < wordStart; i++) { - if(setString.charAt(i) == ',') { + for (int i = 0; i < setString.length() && i < wordStart; i++) { + if (setString.charAt(i) == ',') { pos++; } } @@ -893,7 +893,7 @@ private int collationAwareIndexOf(UTF8String substring, int start, int collation stringSearch.setOverlapping(true); int pos = stringSearch.next(); - while(pos != StringSearch.DONE && pos < start) { + while (pos != StringSearch.DONE && pos < start) { pos = stringSearch.next(); } From 0a22909106562a363fc0b6b0cdf93e1558dcdc5e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 3 Apr 2024 10:19:41 +0200 Subject: [PATCH 25/46] Improve collationAwareIndexOf performance --- .../java/org/apache/spark/unsafe/types/UTF8String.java | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 85d05639b3a99..41f80a14027cc 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -890,14 +890,9 @@ private int collationAwareIndexOf(UTF8String substring, int start, int collation } StringSearch stringSearch = CollationFactory.getStringSearch(this, substring, collationId); - stringSearch.setOverlapping(true); + stringSearch.setIndex(start); - int pos = stringSearch.next(); - while (pos != StringSearch.DONE && pos < start) { - pos = stringSearch.next(); - } - - return pos; + return stringSearch.next(); } /** From b3be85d6fe6db524b6985309bda629e2101a792e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 3 Apr 2024 11:51:35 +0200 Subject: [PATCH 26/46] Fix indentation --- .../src/main/java/org/apache/spark/unsafe/types/UTF8String.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 41f80a14027cc..10a8acd48070d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -557,7 +557,7 @@ public int findInSet(UTF8String match, int collationId) { return this.toLowerCase().findInSet(match.toLowerCase()); } return collationAwareFindInSet(match, collationId); -} + } private int collationAwareFindInSet(UTF8String match, int collationId) { if (match.contains(COMMA_UTF8)) { From c4c0fe75d075c9a898200f6d89929e8f9ba7409c Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 3 Apr 2024 13:53:30 +0200 Subject: [PATCH 27/46] Add more tests for instr --- .../apache/spark/sql/CollationStringExpressionsSuite.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 537b84ab79c68..bde20c99ed78d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -105,6 +105,9 @@ class CollationStringExpressionsSuite extends QueryTest var collationId = CollationFactory.collationNameToId("UTF8_BINARY") testFindInSet("AB", "abc,b,ab,c,def", collationId, 0) + testFindInSet("abc", "abc,b,ab,c,def", collationId, 1) + testFindInSet("def", "abc,b,ab,c,def", collationId, 5) + testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0) collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") testFindInSet("a", "abc,b,ab,c,def", collationId, 0) @@ -112,11 +115,13 @@ class CollationStringExpressionsSuite extends QueryTest testFindInSet("AB", "abc,b,ab,c,def", collationId, 3) testFindInSet("AbC", "abc,b,ab,c,def", collationId, 1) testFindInSet("abcd", "abc,b,ab,c,def", collationId, 0) + testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0) collationId = CollationFactory.collationNameToId("UNICODE") testFindInSet("a", "abc,b,ab,c,def", collationId, 0) testFindInSet("ab", "abc,b,ab,c,def", collationId, 3) testFindInSet("Ab", "abc,b,ab,c,def", collationId, 0) + testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0) collationId = CollationFactory.collationNameToId("UNICODE_CI") testFindInSet("a", "abc,b,ab,c,def", collationId, 0) From 5b29f7601d7062bb232f6e7a8fdcba0e551a749c Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 3 Apr 2024 14:25:56 +0200 Subject: [PATCH 28/46] Add more tests --- .../apache/spark/unsafe/types/UTF8String.java | 6 +++ .../sql/CollationStringExpressionsSuite.scala | 47 ++++++++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 10a8acd48070d..5932a476ed9f2 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -559,6 +559,12 @@ public int findInSet(UTF8String match, int collationId) { return collationAwareFindInSet(match, collationId); } + /* + * Works on Strings with collationId other than UTF8_BINARY_COLLATION_ID. Returns the index + * of the string `match` in this String. This string has to be a comma separated + * list. If `match` contains a comma 0 will be returned. If the `match` isn't part of this String, + * 0 will be returned, else the index of match (1-based index) + */ private int collationAwareFindInSet(UTF8String match, int collationId) { if (match.contains(COMMA_UTF8)) { return 0; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index bde20c99ed78d..6dcd8d7414040 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -82,17 +82,45 @@ class CollationStringExpressionsSuite extends QueryTest checkEvaluation(StringInstr(string, substring), expected) } - var collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") + var collationId = CollationFactory.collationNameToId("UTF8_BINARY") + testInStr("aaads", "Aa", collationId, 0) + testInStr("aaaDs", "de", collationId, 0) + testInStr("aaads", "ds", collationId, 4) + testInStr("xxxx", "", collationId, 1) + testInStr("", "xxxx", collationId, 0) + // scalastyle:off + testInStr("test大千世界X大千世界", "大千", collationId, 5) + testInStr("test大千世界X大千世界", "界X", collationId, 8) + // scalastyle:on + + collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") testInStr("aaads", "Aa", collationId, 1) testInStr("aaaDs", "de", collationId, 0) + testInStr("aaaDs", "ds", collationId, 4) + testInStr("xxxx", "", collationId, 1) + testInStr("", "xxxx", collationId, 0) + // scalastyle:off + testInStr("test大千世界X大千世界", "大千", collationId, 5) + testInStr("test大千世界X大千世界", "界x", collationId, 8) + // scalastyle:on collationId = CollationFactory.collationNameToId("UNICODE") testInStr("aaads", "Aa", collationId, 0) + testInStr("aaads", "aa", collationId, 1) testInStr("aaads", "de", collationId, 0) + testInStr("xxxx", "", collationId, 1) + testInStr("", "xxxx", collationId, 0) + // scalastyle:off + testInStr("test大千世界X大千世界", "界x", collationId, 0) + testInStr("test大千世界X大千世界", "界X", collationId, 8) + // scalastyle:on collationId = CollationFactory.collationNameToId("UNICODE_CI") testInStr("aaads", "AD", collationId, 3) testInStr("aaads", "dS", collationId, 4) + // scalastyle:off + testInStr("test大千世界X大千世界", "界x", collationId, 8) + // scalastyle:on } test("FIND_IN_SET check result on explicitly collated strings") { @@ -108,6 +136,7 @@ class CollationStringExpressionsSuite extends QueryTest testFindInSet("abc", "abc,b,ab,c,def", collationId, 1) testFindInSet("def", "abc,b,ab,c,def", collationId, 5) testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0) + testFindInSet("", "abc,b,ab,c,def", collationId, 0) collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") testFindInSet("a", "abc,b,ab,c,def", collationId, 0) @@ -116,18 +145,34 @@ class CollationStringExpressionsSuite extends QueryTest testFindInSet("AbC", "abc,b,ab,c,def", collationId, 1) testFindInSet("abcd", "abc,b,ab,c,def", collationId, 0) testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0) + testFindInSet("XX", "xx", collationId, 1) + testFindInSet("", "abc,b,ab,c,def", collationId, 0) + // scalastyle:off + testFindInSet("界x", "test,大千,世,界X,大,千,世界", collationId, 4) + // scalastyle:on collationId = CollationFactory.collationNameToId("UNICODE") testFindInSet("a", "abc,b,ab,c,def", collationId, 0) testFindInSet("ab", "abc,b,ab,c,def", collationId, 3) testFindInSet("Ab", "abc,b,ab,c,def", collationId, 0) testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0) + testFindInSet("xx", "xx", collationId, 1) + // scalastyle:off + testFindInSet("界x", "test,大千,世,界X,大,千,世界", collationId, 0) + testFindInSet("大", "test,大千,世,界X,大,千,世界", collationId, 5) + // scalastyle:on collationId = CollationFactory.collationNameToId("UNICODE_CI") testFindInSet("a", "abc,b,ab,c,def", collationId, 0) testFindInSet("C", "abc,b,ab,c,def", collationId, 4) testFindInSet("DeF", "abc,b,ab,c,dEf", collationId, 5) testFindInSet("DEFG", "abc,b,ab,c,def", collationId, 0) + testFindInSet("XX", "xx", collationId, 1) + // scalastyle:off + testFindInSet("界x", "test,大千,世,界X,大,千,世界", collationId, 4) + testFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", collationId, 5) + testFindInSet("大", "test,大千,世,界X,大,千,世界", collationId, 5) + // scalastyle:on } test("REPEAT check output type on explicitly collated string") { From 877828e30e4427ab5b616f496ceb76841bec3e99 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 4 Apr 2024 10:14:13 +0200 Subject: [PATCH 29/46] Remove collation match type checks --- .../catalyst/expressions/stringExpressions.scala | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 982468d7d088b..d768a00946d31 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1024,12 +1024,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } override def checkInputDataTypes(): TypeCheckResult = { - val defaultCheck = super.checkInputDataTypes() - if (defaultCheck.isFailure) { - return defaultCheck - } - - CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType)) + super.checkInputDataTypes() } override def dataType: DataType = IntegerType @@ -1410,12 +1405,7 @@ case class StringInstr(str: Expression, substr: Expression) } override def checkInputDataTypes(): TypeCheckResult = { - val defaultCheck = super.checkInputDataTypes() - if (defaultCheck.isFailure) { - return defaultCheck - } - - CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType)) + super.checkInputDataTypes() } override def prettyName: String = "instr" From b35a8aca503331d6ea187f5ea8234793efab7c87 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 4 Apr 2024 10:19:57 +0200 Subject: [PATCH 30/46] Merge with the latest master --- .../apache/spark/sql/CollationStringExpressionsSuite.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 6d4df280192d1..4643742cc27be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -21,10 +21,7 @@ import scala.collection.immutable.Seq import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs, ExpressionEvalHelper, Literal, StringRepeat} -import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, FindInSet, Literal, StringInstr, StringRepeat} +import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs, ExpressionEvalHelper, FindInSet, Literal, StringInstr, StringRepeat} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession From 8b0601477f37fc39b8a534bb72dbc869c3c91fb1 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 4 Apr 2024 10:36:06 +0200 Subject: [PATCH 31/46] Remove checkInputDataTypes --- .../sql/catalyst/expressions/stringExpressions.scala | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 78eb58c65debd..fd44240e28681 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1015,10 +1015,6 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } } - override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() - } - override def dataType: DataType = IntegerType override def prettyName: String = "find_in_set" @@ -1396,10 +1392,6 @@ case class StringInstr(str: Expression, substr: Expression) } } - override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() - } - override def prettyName: String = "instr" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { From fbd1c001d6b8bccb88771d79e7e4d63bc4a44e5f Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 12 Apr 2024 14:48:38 +0200 Subject: [PATCH 32/46] Refactor code and move it to CollationSupport --- .../sql/catalyst/util/CollationSupport.java | 108 ++++++++++++++++++ .../apache/spark/unsafe/types/UTF8String.java | 66 ----------- .../expressions/stringExpressions.scala | 21 +--- .../sql/CollationStringExpressionsSuite.scala | 17 --- 4 files changed, 114 insertions(+), 98 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index fe1952921b7fb..ba826e6435521 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -137,6 +137,74 @@ public static boolean execICU(final UTF8String l, final UTF8String r, } } + public static class FindInSet { + public static int exec(final UTF8String l, final UTF8String r, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(l, r); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(l, r); + } else { + return execICU(l, r, collationId); + } + } + public static String genCode(final String l, final String r, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.FindInSet.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", l, r); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s)", l, r); + } else { + return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); + } + } + public static int execBinary(final UTF8String l, final UTF8String r) { + return l.findInSet(r); + } + public static int execLowercase(final UTF8String l, final UTF8String r) { + return l.toLowerCase().findInSet(r.toLowerCase()); + } + public static int execICU(final UTF8String l, final UTF8String r, + final int collationId) { + return CollationAwareUTF8String.findInSet(l, r, collationId); + } + } + + public static class IndexOf { + public static int exec(final UTF8String l, final UTF8String r, final int start, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(l, r, start); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(l, r, start); + } else { + return execICU(l, r, start, collationId); + } + } + public static String genCode(final String l, final String r, final int start, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.IndexOf.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s, %d)", l, r, start); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s, %d)", l, r, start); + } else { + return String.format(expr + "ICU(%s, %s, %d, %d)", l, r, start, collationId); + } + } + public static int execBinary(final UTF8String l, final UTF8String r, final int start) { + return l.indexOf(r, start); + } + public static int execLowercase(final UTF8String l, final UTF8String r, final int start) { + return l.toLowerCase().indexOf(r.toLowerCase(), start); + } + public static int execICU(final UTF8String l, final UTF8String r, final int start, + final int collationId) { + return CollationAwareUTF8String.indexOf(l, r, start, collationId); + } + } + // TODO: Add more collation-aware string expressions. /** @@ -169,6 +237,46 @@ private static boolean matchAt(final UTF8String target, final UTF8String pattern pos, pos + pattern.numChars()), pattern, collationId).last() == 0; } + private static int findInSet(UTF8String match, UTF8String set, int collationId) { + if (match.contains(UTF8String.fromString(","))) { + return 0; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(set, match, collationId); + + String setString = set.toString(); + int wordStart = 0; + while ((wordStart = stringSearch.next()) != StringSearch.DONE) { + boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ','; + boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length() + || setString.charAt(wordStart + stringSearch.getMatchLength()) == ','; + + if (isValidStart && isValidEnd) { + int pos = 0; + for (int i = 0; i < setString.length() && i < wordStart; i++) { + if (setString.charAt(i) == ',') { + pos++; + } + } + + return pos + 1; + } + } + + return 0; + } + + private static int indexOf(UTF8String target, UTF8String pattern, int start, int collationId) { + if (pattern.numBytes() == 0) { + return 0; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); + stringSearch.setIndex(start); + + return stringSearch.next(); + } + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index dbd35689f16e4..2009f1d20442c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -515,51 +515,6 @@ public int findInSet(UTF8String match) { return 0; } - public int findInSet(UTF8String match, int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - return this.findInSet(match); - } - if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { - return this.toLowerCase().findInSet(match.toLowerCase()); - } - return collationAwareFindInSet(match, collationId); - } - - /* - * Works on Strings with collationId other than UTF8_BINARY_COLLATION_ID. Returns the index - * of the string `match` in this String. This string has to be a comma separated - * list. If `match` contains a comma 0 will be returned. If the `match` isn't part of this String, - * 0 will be returned, else the index of match (1-based index) - */ - private int collationAwareFindInSet(UTF8String match, int collationId) { - if (match.contains(COMMA_UTF8)) { - return 0; - } - - StringSearch stringSearch = CollationFactory.getStringSearch(this, match, collationId); - - String setString = this.toString(); - int wordStart = 0; - while ((wordStart = stringSearch.next()) != StringSearch.DONE) { - boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ','; - boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length() - || setString.charAt(wordStart + stringSearch.getMatchLength()) == ','; - - if (isValidStart && isValidEnd) { - int pos = 0; - for (int i = 0; i < setString.length() && i < wordStart; i++) { - if (setString.charAt(i) == ',') { - pos++; - } - } - - return pos + 1; - } - } - - return 0; - } - /** * Copy the bytes from the current UTF8String, and make a new UTF8String. * @param start the start position of the current UTF8String in bytes. @@ -846,27 +801,6 @@ public int indexOf(UTF8String v, int start) { return -1; } - public int indexOf(UTF8String substring, int start, int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - return this.indexOf(substring, start); - } - if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { - return this.toLowerCase().indexOf(substring.toLowerCase(), start); - } - return collationAwareIndexOf(substring, start, collationId); - } - - private int collationAwareIndexOf(UTF8String substring, int start, int collationId) { - if (substring.numBytes == 0) { - return 0; - } - - StringSearch stringSearch = CollationFactory.getStringSearch(this, substring, collationId); - stringSearch.setIndex(start); - - return stringSearch.next(); - } - /** * Find the `str` from left to right. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 1120ec1d75de5..b7feb7612a562 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -980,19 +980,13 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi Seq(StringTypeAnyCollation, StringTypeAnyCollation) override protected def nullSafeEval(word: Any, set: Any): Any = { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) - } else { - set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String], collationId) - } + CollationSupport.FindInSet. + exec(word.asInstanceOf[UTF8String], set.asInstanceOf[UTF8String], collationId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);") - } else { - nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word, $collationId);") - } + nullSafeCodeGen(ctx, ev, (word, set) => + CollationSupport.FindInSet.genCode(word, set, collationId)) } override def dataType: DataType = IntegerType @@ -1365,11 +1359,8 @@ case class StringInstr(str: Expression, substr: Expression) Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def nullSafeEval(string: Any, sub: Any): Any = { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1 - } else { - string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0, collationId) + 1 - } + CollationSupport.IndexOf. + exec(string.asInstanceOf[UTF8String], sub.asInstanceOf[UTF8String], collationId) } override def prettyName: String = "instr" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index e621ddc98a046..de273d8a6496d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -199,9 +199,6 @@ class CollationStringExpressionsSuite // scalastyle:on } - test("REPEAT check output type on explicitly collated string") { - def testRepeat(input: String, n: Int, collationId: Int, expected: String): Unit = { - val s = Literal.create(input, StringType(collationId)) test("Support StartsWith string expression with collation") { // Supported collations case class StartsWithTestCase[R](l: String, r: String, c: String, result: R) @@ -252,20 +249,6 @@ class CollationStringExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } - // Not important for this test - val repeatNum = 2; - - var collationId = CollationFactory.collationNameToId("UTF8_BINARY") - testRepeat("abc", repeatNum, collationId, "UTF8_BINARY") - - collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") - testRepeat("abc", repeatNum, collationId, "UTF8_BINARY_LCASE") - - collationId = CollationFactory.collationNameToId("UNICODE") - testRepeat("abc", repeatNum, collationId, "UNICODE") - - collationId = CollationFactory.collationNameToId("UNICODE_CI") - testRepeat("abc", repeatNum, collationId, "UNICODE_CI") test("Support StringRepeat string expression with collation") { // Supported collations case class StringRepeatTestCase[R](s: String, n: Int, c: String, result: R) From 960af54e98d2b74aaa5c17cf4e279d2b4ba6d0e1 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 12 Apr 2024 15:33:57 +0200 Subject: [PATCH 33/46] Improve codegen and run tests --- .../spark/sql/catalyst/util/CollationSupport.java | 8 +++++--- .../catalyst/expressions/stringExpressions.scala | 13 +++++-------- .../spark/sql/CollationStringExpressionsSuite.scala | 5 +---- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index ba826e6435521..5b861e83d3e5b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -160,10 +160,10 @@ public static String genCode(final String l, final String r, final int collation } } public static int execBinary(final UTF8String l, final UTF8String r) { - return l.findInSet(r); + return r.findInSet(l); } public static int execLowercase(final UTF8String l, final UTF8String r) { - return l.toLowerCase().findInSet(r.toLowerCase()); + return r.toLowerCase().findInSet(l.toLowerCase()); } public static int execICU(final UTF8String l, final UTF8String r, final int collationId) { @@ -274,7 +274,9 @@ private static int indexOf(UTF8String target, UTF8String pattern, int start, int StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); stringSearch.setIndex(start); - return stringSearch.next(); + int result = stringSearch.next(); + + return Math.max(result, 0); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b7feb7612a562..30e287e6846aa 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -985,8 +985,8 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (word, set) => - CollationSupport.FindInSet.genCode(word, set, collationId)) + nullSafeCodeGen(ctx, ev, (l, r) => + s"${ev.value} = " + CollationSupport.FindInSet.genCode(l, r, collationId) + ";") } override def dataType: DataType = IntegerType @@ -1360,17 +1360,14 @@ case class StringInstr(str: Expression, substr: Expression) override def nullSafeEval(string: Any, sub: Any): Any = { CollationSupport.IndexOf. - exec(string.asInstanceOf[UTF8String], sub.asInstanceOf[UTF8String], collationId) + exec(string.asInstanceOf[UTF8String], sub.asInstanceOf[UTF8String], 0, collationId) + 1 } override def prettyName: String = "instr" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") - } else { - defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0, $collationId) + 1") - } + defineCodeGen(ctx, ev, (l, r) => + CollationSupport.IndexOf.genCode(l, r, 0, collationId) + " + 1") } override protected def withNewChildrenInternal( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index de273d8a6496d..8846a729b2069 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql -import scala.collection.immutable.Seq - import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs, ExpressionEvalHelper, FindInSet, Literal, StringInstr, StringRepeat} +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, FindInSet, Literal, StringInstr} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession From 05cd6c459e2fb7a351e6da65935e2adcadda646d Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 12 Apr 2024 17:56:00 +0200 Subject: [PATCH 34/46] Unify collationAwareIndexOf for return value to have same semantics as indexOf --- .../apache/spark/sql/catalyst/util/CollationSupport.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 5b861e83d3e5b..f8f1aaa8b98f8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -201,7 +201,7 @@ public static int execLowercase(final UTF8String l, final UTF8String r, final in } public static int execICU(final UTF8String l, final UTF8String r, final int start, final int collationId) { - return CollationAwareUTF8String.indexOf(l, r, start, collationId); + return Math.max(CollationAwareUTF8String.indexOf(l, r, start, collationId), 0); } } @@ -274,9 +274,7 @@ private static int indexOf(UTF8String target, UTF8String pattern, int start, int StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); stringSearch.setIndex(start); - int result = stringSearch.next(); - - return Math.max(result, 0); + return stringSearch.next(); } } From 053efa05802b412516b2f12f599dd82d46ff67c0 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Sat, 13 Apr 2024 12:48:36 +0200 Subject: [PATCH 35/46] Break line at 100 chars --- .../apache/spark/sql/catalyst/util/CollationSupport.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index f8f1aaa8b98f8..481daf63d3e57 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -172,7 +172,8 @@ public static int execICU(final UTF8String l, final UTF8String r, } public static class IndexOf { - public static int exec(final UTF8String l, final UTF8String r, final int start, final int collationId) { + public static int exec(final UTF8String l, final UTF8String r, final int start, + final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality) { return execBinary(l, r, start); @@ -182,7 +183,8 @@ public static int exec(final UTF8String l, final UTF8String r, final int start, return execICU(l, r, start, collationId); } } - public static String genCode(final String l, final String r, final int start, final int collationId) { + public static String genCode(final String l, final String r, final int start, + final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.IndexOf.exec"; if (collation.supportsBinaryEquality) { From c65d68ebdcbc6e47dbe597098049c2f39e812276 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Apr 2024 12:19:36 +0200 Subject: [PATCH 36/46] Add new version of getStringSearch --- .../spark/sql/catalyst/util/CollationFactory.java | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index ff7bc450f851e..79981caabb834 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -202,6 +202,20 @@ public static StringSearch getStringSearch( return new StringSearch(pattern, target, (RuleBasedCollator) collator); } + /** + * Returns a StringSearch object for the given pattern and target strings, under collation + * rules corresponding to the given collationId. The external ICU library StringSearch object can + * be used to find occurrences of the pattern in the target string, while respecting collation. + */ + public static StringSearch getStringSearch( + final String targetString, + final String patternString, + final int collationId) { + CharacterIterator target = new StringCharacterIterator(targetString); + Collator collator = CollationFactory.fetchCollation(collationId).collator; + return new StringSearch(patternString, target, (RuleBasedCollator) collator); + } + /** * Returns a collation-unaware StringSearch object for the given pattern and target strings. * While this object does not respect collation, it can be used to find occurrences of the pattern From ae33a38dabe150682fd07c8c2650e40e4b3caea0 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Apr 2024 12:24:54 +0200 Subject: [PATCH 37/46] Rename StringInstr params and class in CollationSupport --- .../sql/catalyst/util/CollationSupport.java | 30 +++++++++---------- .../expressions/stringExpressions.scala | 6 ++-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 481daf63d3e57..45321d95b650d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -171,39 +171,39 @@ public static int execICU(final UTF8String l, final UTF8String r, } } - public static class IndexOf { - public static int exec(final UTF8String l, final UTF8String r, final int start, + public static class StringInstr { + public static int exec(final UTF8String string, final UTF8String substring, final int start, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality) { - return execBinary(l, r, start); + return execBinary(string, substring, start); } else if (collation.supportsLowercaseEquality) { - return execLowercase(l, r, start); + return execLowercase(string, substring, start); } else { - return execICU(l, r, start, collationId); + return execICU(string, substring, start, collationId); } } - public static String genCode(final String l, final String r, final int start, + public static String genCode(final String string, final String substring, final int start, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.IndexOf.exec"; if (collation.supportsBinaryEquality) { - return String.format(expr + "Binary(%s, %s, %d)", l, r, start); + return String.format(expr + "Binary(%s, %s, %d)", string, substring, start); } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s, %d)", l, r, start); + return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start); } else { - return String.format(expr + "ICU(%s, %s, %d, %d)", l, r, start, collationId); + return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId); } } - public static int execBinary(final UTF8String l, final UTF8String r, final int start) { - return l.indexOf(r, start); + public static int execBinary(final UTF8String string, final UTF8String substring, final int start) { + return string.indexOf(substring, start); } - public static int execLowercase(final UTF8String l, final UTF8String r, final int start) { - return l.toLowerCase().indexOf(r.toLowerCase(), start); + public static int execLowercase(final UTF8String string, final UTF8String substring, final int start) { + return string.toLowerCase().indexOf(substring.toLowerCase(), start); } - public static int execICU(final UTF8String l, final UTF8String r, final int start, + public static int execICU(final UTF8String string, final UTF8String substring, final int start, final int collationId) { - return Math.max(CollationAwareUTF8String.indexOf(l, r, start, collationId), 0); + return Math.max(CollationAwareUTF8String.indexOf(string, substring, start, collationId), 0); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 30e287e6846aa..e52156ff2abf1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1359,15 +1359,15 @@ case class StringInstr(str: Expression, substr: Expression) Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def nullSafeEval(string: Any, sub: Any): Any = { - CollationSupport.IndexOf. + CollationSupport.StringInstr. exec(string.asInstanceOf[UTF8String], sub.asInstanceOf[UTF8String], 0, collationId) + 1 } override def prettyName: String = "instr" override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (l, r) => - CollationSupport.IndexOf.genCode(l, r, 0, collationId) + " + 1") + defineCodeGen(ctx, ev, (string, substring) => + CollationSupport.StringInstr.genCode(string, substring, 0, collationId) + " + 1") } override protected def withNewChildrenInternal( From be1b52cb6bacd892b97b7f4c3a804ff496d70498 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Apr 2024 12:31:43 +0200 Subject: [PATCH 38/46] Go from nullSafeCodeGen to defineCodeGen --- .../spark/sql/catalyst/util/CollationSupport.java | 12 +++++++----- .../sql/catalyst/expressions/stringExpressions.scala | 6 ++++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 45321d95b650d..5172c5b959a44 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -186,7 +186,7 @@ public static int exec(final UTF8String string, final UTF8String substring, fina public static String genCode(final String string, final String substring, final int start, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - String expr = "CollationSupport.IndexOf.exec"; + String expr = "CollationSupport.StringInstr.exec"; if (collation.supportsBinaryEquality) { return String.format(expr + "Binary(%s, %s, %d)", string, substring, start); } else if (collation.supportsLowercaseEquality) { @@ -239,14 +239,15 @@ private static boolean matchAt(final UTF8String target, final UTF8String pattern pos, pos + pattern.numChars()), pattern, collationId).last() == 0; } - private static int findInSet(UTF8String match, UTF8String set, int collationId) { + private static int findInSet(final UTF8String match, final UTF8String set, int collationId) { if (match.contains(UTF8String.fromString(","))) { return 0; } - StringSearch stringSearch = CollationFactory.getStringSearch(set, match, collationId); - String setString = set.toString(); + StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(), + collationId); + int wordStart = 0; while ((wordStart = stringSearch.next()) != StringSearch.DONE) { boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ','; @@ -268,7 +269,8 @@ private static int findInSet(UTF8String match, UTF8String set, int collationId) return 0; } - private static int indexOf(UTF8String target, UTF8String pattern, int start, int collationId) { + private static int indexOf(final UTF8String target, final UTF8String pattern, + final int start, final int collationId) { if (pattern.numBytes() == 0) { return 0; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e52156ff2abf1..5b54eaa3698aa 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -985,8 +985,10 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (l, r) => - s"${ev.value} = " + CollationSupport.FindInSet.genCode(l, r, collationId) + ";") + defineCodeGen(ctx, ev, (word, set) => CollationSupport.FindInSet. + genCode(word, set, collationId)) +// nullSafeCodeGen(ctx, ev, (word, set) => +// s"${ev.value} = " + CollationSupport.FindInSet.genCode(word, set, collationId) + ";") } override def dataType: DataType = IntegerType From 5894d2f11195f42146d977f465a1bc475ab14368 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Apr 2024 14:46:03 +0200 Subject: [PATCH 39/46] Refactor testing --- .../sql/catalyst/util/CollationFactory.java | 5 +- .../sql/catalyst/util/CollationSupport.java | 50 +++--- .../unsafe/types/CollationSupportSuite.java | 80 ++++++++++ .../expressions/stringExpressions.scala | 4 +- .../sql/CollationStringExpressionsSuite.scala | 151 ++++++------------ 5 files changed, 160 insertions(+), 130 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 79981caabb834..d6cb20b5e89cd 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -196,10 +196,7 @@ public static StringSearch getStringSearch( final UTF8String targetUTF8String, final UTF8String patternUTF8String, final int collationId) { - String pattern = patternUTF8String.toString(); - CharacterIterator target = new StringCharacterIterator(targetUTF8String.toString()); - Collator collator = CollationFactory.fetchCollation(collationId).collator; - return new StringSearch(pattern, target, (RuleBasedCollator) collator); + return getStringSearch(targetUTF8String.toString(), patternUTF8String.toString(), collationId); } /** diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 5172c5b959a44..3c716eb5be32b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -138,49 +138,49 @@ public static boolean execICU(final UTF8String l, final UTF8String r, } public static class FindInSet { - public static int exec(final UTF8String l, final UTF8String r, final int collationId) { + public static int exec(final UTF8String word, final UTF8String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality) { - return execBinary(l, r); + return execBinary(word, set); } else if (collation.supportsLowercaseEquality) { - return execLowercase(l, r); + return execLowercase(word, set); } else { - return execICU(l, r, collationId); + return execICU(word, set, collationId); } } - public static String genCode(final String l, final String r, final int collationId) { + public static String genCode(final String word, final String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.FindInSet.exec"; if (collation.supportsBinaryEquality) { - return String.format(expr + "Binary(%s, %s)", l, r); + return String.format(expr + "Binary(%s, %s)", word, set); } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s)", l, r); + return String.format(expr + "Lowercase(%s, %s)", word, set); } else { - return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); + return String.format(expr + "ICU(%s, %s, %d)", word, set, collationId); } } - public static int execBinary(final UTF8String l, final UTF8String r) { - return r.findInSet(l); + public static int execBinary(final UTF8String word, final UTF8String set) { + return set.findInSet(word); } - public static int execLowercase(final UTF8String l, final UTF8String r) { - return r.toLowerCase().findInSet(l.toLowerCase()); + public static int execLowercase(final UTF8String word, final UTF8String set) { + return set.toLowerCase().findInSet(word.toLowerCase()); } - public static int execICU(final UTF8String l, final UTF8String r, + public static int execICU(final UTF8String word, final UTF8String set, final int collationId) { - return CollationAwareUTF8String.findInSet(l, r, collationId); + return CollationAwareUTF8String.findInSet(word, set, collationId); } } public static class StringInstr { - public static int exec(final UTF8String string, final UTF8String substring, final int start, + public static int exec(final UTF8String string, final UTF8String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality) { - return execBinary(string, substring, start); + return execBinary(string, substring); } else if (collation.supportsLowercaseEquality) { - return execLowercase(string, substring, start); + return execLowercase(string, substring); } else { - return execICU(string, substring, start, collationId); + return execICU(string, substring, collationId); } } public static String genCode(final String string, final String substring, final int start, @@ -195,15 +195,15 @@ public static String genCode(final String string, final String substring, final return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId); } } - public static int execBinary(final UTF8String string, final UTF8String substring, final int start) { - return string.indexOf(substring, start); + public static int execBinary(final UTF8String string, final UTF8String substring) { + return string.indexOf(substring, 0); } - public static int execLowercase(final UTF8String string, final UTF8String substring, final int start) { - return string.toLowerCase().indexOf(substring.toLowerCase(), start); + public static int execLowercase(final UTF8String string, final UTF8String substring) { + return string.toLowerCase().indexOf(substring.toLowerCase(), 0); } - public static int execICU(final UTF8String string, final UTF8String substring, final int start, - final int collationId) { - return Math.max(CollationAwareUTF8String.indexOf(string, substring, start, collationId), 0); + public static int execICU(final UTF8String string, final UTF8String substring, + final int collationId) { + return Math.max(CollationAwareUTF8String.indexOf(string, substring, 0, collationId), 0); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index bfb696c35fff6..4967183c65cba 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -249,6 +249,86 @@ public void testEndsWith() throws SparkException { assertEndsWith("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false); } + private void assertStringInstr(String string, String substring, String collationName, + Integer value) throws SparkException { + UTF8String str = UTF8String.fromString(string); + UTF8String substr = UTF8String.fromString(substring); + int collationId = CollationFactory.collationNameToId(collationName); + + assertEquals(CollationSupport.StringInstr.exec(str, substr, collationId) + 1, value); + } + + @Test + public void testStringInstr() throws SparkException { + assertStringInstr("aaads", "Aa", "UTF8_BINARY", 0); + assertStringInstr("aaaDs", "de", "UTF8_BINARY", 0); + assertStringInstr("aaads", "ds", "UTF8_BINARY", 4); + assertStringInstr("xxxx", "", "UTF8_BINARY", 1); + assertStringInstr("", "xxxx", "UTF8_BINARY", 0); + assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY", 5); + assertStringInstr("test大千世界X大千世界", "界X", "UTF8_BINARY", 8); + assertStringInstr("aaads", "Aa", "UTF8_BINARY_LCASE", 1); + assertStringInstr("aaaDs", "de", "UTF8_BINARY_LCASE", 0); + assertStringInstr("aaaDs", "ds", "UTF8_BINARY_LCASE", 4); + assertStringInstr("xxxx", "", "UTF8_BINARY_LCASE", 1); + assertStringInstr("", "xxxx", "UTF8_BINARY_LCASE", 0); + assertStringInstr("test大千世界X大千世界", "大千", "UTF8_BINARY_LCASE", 5); + assertStringInstr("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8); + assertStringInstr("aaads", "Aa", "UNICODE", 0); + assertStringInstr("aaads", "aa", "UNICODE", 1); + assertStringInstr("aaads", "de", "UNICODE", 0); + assertStringInstr("xxxx", "", "UNICODE", 1); + assertStringInstr("", "xxxx", "UNICODE", 0); + assertStringInstr("test大千世界X大千世界", "界x", "UNICODE", 0); + assertStringInstr("test大千世界X大千世界", "界X", "UNICODE", 8); + assertStringInstr("aaads", "AD", "UNICODE_CI", 3); + assertStringInstr("aaads", "dS", "UNICODE_CI", 4); + assertStringInstr("test大千世界X大千世界", "界x", "UNICODE_CI", 8); + } + + //word: String, set: String, collationId: Integer, expected: Integer + private void assertFindInSet(String word, String set, String collationName, + Integer value) throws SparkException { + UTF8String w = UTF8String.fromString(word); + UTF8String s = UTF8String.fromString(set); + int collationId = CollationFactory.collationNameToId(collationName); + + assertEquals(CollationSupport.FindInSet.exec(w, s, collationId), value); + } + + @Test + public void testFindInSet() throws SparkException { + assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0); + assertFindInSet("abc", "abc,b,ab,c,def", "UTF8_BINARY", 1); + assertFindInSet("def", "abc,b,ab,c,def", "UTF8_BINARY", 5); + assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY", 0); + assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY", 0); + assertFindInSet("a", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0); + assertFindInSet("c", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4); + assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 3); + assertFindInSet("AbC", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 1); + assertFindInSet("abcd", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0); + assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0); + assertFindInSet("XX", "xx", "UTF8_BINARY_LCASE", 1); + assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 0); + assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UTF8_BINARY_LCASE", 4); + assertFindInSet("a", "abc,b,ab,c,def", "UNICODE", 0); + assertFindInSet("ab", "abc,b,ab,c,def", "UNICODE", 3); + assertFindInSet("Ab", "abc,b,ab,c,def", "UNICODE", 0); + assertFindInSet("d,ef", "abc,b,ab,c,def", "UNICODE", 0); + assertFindInSet("xx", "xx", "UNICODE", 1); + assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE", 0); + assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE", 5); + assertFindInSet("a", "abc,b,ab,c,def", "UNICODE_CI", 0); + assertFindInSet("C", "abc,b,ab,c,def", "UNICODE_CI", 4); + assertFindInSet("DeF", "abc,b,ab,c,dEf", "UNICODE_CI", 5); + assertFindInSet("DEFG", "abc,b,ab,c,def", "UNICODE_CI", 0); + assertFindInSet("XX", "xx", "UNICODE_CI", 1); + assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 4); + assertFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", "UNICODE_CI", 5); + assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 5); + } + // TODO: Test more collation-aware string expressions. /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5b54eaa3698aa..35ace1e23b019 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -987,8 +987,6 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (word, set) => CollationSupport.FindInSet. genCode(word, set, collationId)) -// nullSafeCodeGen(ctx, ev, (word, set) => -// s"${ev.value} = " + CollationSupport.FindInSet.genCode(word, set, collationId) + ";") } override def dataType: DataType = IntegerType @@ -1362,7 +1360,7 @@ case class StringInstr(str: Expression, substr: Expression) override def nullSafeEval(string: Any, sub: Any): Any = { CollationSupport.StringInstr. - exec(string.asInstanceOf[UTF8String], sub.asInstanceOf[UTF8String], 0, collationId) + 1 + exec(string.asInstanceOf[UTF8String], sub.asInstanceOf[UTF8String], collationId) + 1 } override def prettyName: String = "instr" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 8846a729b2069..38244a0e46054 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, FindInSet, Literal, StringInstr} -import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BooleanType, StringType} +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType} class CollationStringExpressionsSuite extends QueryTest @@ -95,105 +94,61 @@ class CollationStringExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } - test("INSTR check result on explicitly collated strings") { - def testInStr(str: String, substr: String, collationId: Integer, expected: Integer): Unit = { - val string = Literal.create(str, StringType(collationId)) - val substring = Literal.create(substr, StringType(collationId)) - - checkEvaluation(StringInstr(string, substring), expected) + test("Support StringInStr string expression with collation") { + case class StringInStrTestCase[R](string: String, substring: String, c: String, result: R) + val testCases = Seq( + // scalastyle:off + StringInStrTestCase("test大千世界X大千世界", "大千", "UTF8_BINARY", 5), + StringInStrTestCase("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8), + StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE", 0), + StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8) + // scalastyle:on + ) + testCases.foreach(t => { + val query = s"SELECT instr(collate('${t.string}','${t.c}')," + + s"collate('${t.substring}','${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) + // Implicit casting + checkAnswer(sql(s"SELECT instr(collate('${t.string}','${t.c}')," + + s"'${t.substring}')"), Row(t.result)) + checkAnswer(sql(s"SELECT instr('${t.string}'," + + s"collate('${t.substring}','${t.c}'))"), Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql(s"SELECT instr(collate('aaads','UTF8_BINARY'), collate('Aa','UTF8_BINARY_LCASE'))") } - - var collationId = CollationFactory.collationNameToId("UTF8_BINARY") - testInStr("aaads", "Aa", collationId, 0) - testInStr("aaaDs", "de", collationId, 0) - testInStr("aaads", "ds", collationId, 4) - testInStr("xxxx", "", collationId, 1) - testInStr("", "xxxx", collationId, 0) - // scalastyle:off - testInStr("test大千世界X大千世界", "大千", collationId, 5) - testInStr("test大千世界X大千世界", "界X", collationId, 8) - // scalastyle:on - - collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") - testInStr("aaads", "Aa", collationId, 1) - testInStr("aaaDs", "de", collationId, 0) - testInStr("aaaDs", "ds", collationId, 4) - testInStr("xxxx", "", collationId, 1) - testInStr("", "xxxx", collationId, 0) - // scalastyle:off - testInStr("test大千世界X大千世界", "大千", collationId, 5) - testInStr("test大千世界X大千世界", "界x", collationId, 8) - // scalastyle:on - - collationId = CollationFactory.collationNameToId("UNICODE") - testInStr("aaads", "Aa", collationId, 0) - testInStr("aaads", "aa", collationId, 1) - testInStr("aaads", "de", collationId, 0) - testInStr("xxxx", "", collationId, 1) - testInStr("", "xxxx", collationId, 0) - // scalastyle:off - testInStr("test大千世界X大千世界", "界x", collationId, 0) - testInStr("test大千世界X大千世界", "界X", collationId, 8) - // scalastyle:on - - collationId = CollationFactory.collationNameToId("UNICODE_CI") - testInStr("aaads", "AD", collationId, 3) - testInStr("aaads", "dS", collationId, 4) - // scalastyle:off - testInStr("test大千世界X大千世界", "界x", collationId, 8) - // scalastyle:on + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } - test("FIND_IN_SET check result on explicitly collated strings") { - def testFindInSet(word: String, set: String, collationId: Integer, expected: Integer): Unit = { - val w = Literal.create(word, StringType(collationId)) - val s = Literal.create(set, StringType(collationId)) - - checkEvaluation(FindInSet(w, s), expected) + test("Support FindInSet string expression with collation") { + case class FindInSetTestCase[R](word: String, set: String, c: String, result: R) + val testCases = Seq( + FindInSetTestCase("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0), + FindInSetTestCase("C", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4), + FindInSetTestCase("d,ef", "abc,b,ab,c,def", "UNICODE", 0), + FindInSetTestCase("DeF", "abc,b,ab,c,dEf", "UNICODE_CI", 5) + ) + testCases.foreach(t => { + val query = s"SELECT find_in_set(collate('${t.word}', '${t.c}')," + + s"collate('${t.set}', '${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) + // Implicit casting + checkAnswer(sql(s"SELECT find_in_set(collate('${t.word}', '${t.c}')," + + s"'${t.set}')"), Row(t.result)) + checkAnswer(sql(s"SELECT find_in_set('${t.word}'," + + s"collate('${t.set}', '${t.c}'))"), Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql(s"SELECT find_in_set(collate('AB','UTF8_BINARY')," + + s"collate('ab,xyz,fgh','UTF8_BINARY_LCASE'))") } - - var collationId = CollationFactory.collationNameToId("UTF8_BINARY") - testFindInSet("AB", "abc,b,ab,c,def", collationId, 0) - testFindInSet("abc", "abc,b,ab,c,def", collationId, 1) - testFindInSet("def", "abc,b,ab,c,def", collationId, 5) - testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0) - testFindInSet("", "abc,b,ab,c,def", collationId, 0) - - collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") - testFindInSet("a", "abc,b,ab,c,def", collationId, 0) - testFindInSet("c", "abc,b,ab,c,def", collationId, 4) - testFindInSet("AB", "abc,b,ab,c,def", collationId, 3) - testFindInSet("AbC", "abc,b,ab,c,def", collationId, 1) - testFindInSet("abcd", "abc,b,ab,c,def", collationId, 0) - testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0) - testFindInSet("XX", "xx", collationId, 1) - testFindInSet("", "abc,b,ab,c,def", collationId, 0) - // scalastyle:off - testFindInSet("界x", "test,大千,世,界X,大,千,世界", collationId, 4) - // scalastyle:on - - collationId = CollationFactory.collationNameToId("UNICODE") - testFindInSet("a", "abc,b,ab,c,def", collationId, 0) - testFindInSet("ab", "abc,b,ab,c,def", collationId, 3) - testFindInSet("Ab", "abc,b,ab,c,def", collationId, 0) - testFindInSet("d,ef", "abc,b,ab,c,def", collationId, 0) - testFindInSet("xx", "xx", collationId, 1) - // scalastyle:off - testFindInSet("界x", "test,大千,世,界X,大,千,世界", collationId, 0) - testFindInSet("大", "test,大千,世,界X,大,千,世界", collationId, 5) - // scalastyle:on - - collationId = CollationFactory.collationNameToId("UNICODE_CI") - testFindInSet("a", "abc,b,ab,c,def", collationId, 0) - testFindInSet("C", "abc,b,ab,c,def", collationId, 4) - testFindInSet("DeF", "abc,b,ab,c,dEf", collationId, 5) - testFindInSet("DEFG", "abc,b,ab,c,def", collationId, 0) - testFindInSet("XX", "xx", collationId, 1) - // scalastyle:off - testFindInSet("界x", "test,大千,世,界X,大,千,世界", collationId, 4) - testFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", collationId, 5) - testFindInSet("大", "test,大千,世,界X,大,千,世界", collationId, 5) - // scalastyle:on + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } test("Support StartsWith string expression with collation") { From 75dc0bda8eb4ae4e77e2716bcd6a2953636f4561 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Apr 2024 15:25:20 +0200 Subject: [PATCH 40/46] Remove empty lines --- .../spark/unsafe/types/CollationSupportSuite.java | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 4967183c65cba..4602ce6d19593 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -250,12 +250,11 @@ public void testEndsWith() throws SparkException { } private void assertStringInstr(String string, String substring, String collationName, - Integer value) throws SparkException { + Integer expected) throws SparkException { UTF8String str = UTF8String.fromString(string); UTF8String substr = UTF8String.fromString(substring); int collationId = CollationFactory.collationNameToId(collationName); - - assertEquals(CollationSupport.StringInstr.exec(str, substr, collationId) + 1, value); + assertEquals(expected, CollationSupport.StringInstr.exec(str, substr, collationId) + 1); } @Test @@ -286,14 +285,12 @@ public void testStringInstr() throws SparkException { assertStringInstr("test大千世界X大千世界", "界x", "UNICODE_CI", 8); } - //word: String, set: String, collationId: Integer, expected: Integer private void assertFindInSet(String word, String set, String collationName, - Integer value) throws SparkException { + Integer expected) throws SparkException { UTF8String w = UTF8String.fromString(word); UTF8String s = UTF8String.fromString(set); int collationId = CollationFactory.collationNameToId(collationName); - - assertEquals(CollationSupport.FindInSet.exec(w, s, collationId), value); + assertEquals(expected, CollationSupport.FindInSet.exec(w, s, collationId)); } @Test From c712b4b54707fa7a259fd63cdbdc0090695e6d42 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Apr 2024 15:36:18 +0200 Subject: [PATCH 41/46] Improve CollationAware indexOf to have the same semantics as UTF8String indexOf --- .../org/apache/spark/sql/catalyst/util/CollationSupport.java | 2 +- .../org/apache/spark/unsafe/types/CollationSupportSuite.java | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 3c716eb5be32b..ddb768b0f7a1f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -203,7 +203,7 @@ public static int execLowercase(final UTF8String string, final UTF8String substr } public static int execICU(final UTF8String string, final UTF8String substring, final int collationId) { - return Math.max(CollationAwareUTF8String.indexOf(string, substring, 0, collationId), 0); + return CollationAwareUTF8String.indexOf(string, substring, 0, collationId); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 4602ce6d19593..acdd73ce28b00 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -282,6 +282,7 @@ public void testStringInstr() throws SparkException { assertStringInstr("test大千世界X大千世界", "界X", "UNICODE", 8); assertStringInstr("aaads", "AD", "UNICODE_CI", 3); assertStringInstr("aaads", "dS", "UNICODE_CI", 4); + assertStringInstr("test大千世界X大千世界", "界y", "UNICODE_CI", 0); assertStringInstr("test大千世界X大千世界", "界x", "UNICODE_CI", 8); } From 3c37f35a1db6be77021eb6008f7c41c306514515 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Apr 2024 15:37:29 +0200 Subject: [PATCH 42/46] Add new e2e test --- .../org/apache/spark/sql/CollationStringExpressionsSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 38244a0e46054..4584ee5baf65a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -101,6 +101,7 @@ class CollationStringExpressionsSuite StringInStrTestCase("test大千世界X大千世界", "大千", "UTF8_BINARY", 5), StringInStrTestCase("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8), StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE", 0), + StringInStrTestCase("test大千世界X大千世界", "界y", "UNICODE_CI", 0), StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8) // scalastyle:on ) From cd860b96c0f64859dfd17b8eab299600a6c31285 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Apr 2024 16:10:36 +0200 Subject: [PATCH 43/46] Revert unused import deletion --- .../org/apache/spark/sql/CollationStringExpressionsSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 4584ee5baf65a..8c56406dad7b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.collection.immutable.Seq + import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper import org.apache.spark.sql.internal.SQLConf From 3fa25024812d2c8befe5abe94a78fcadbebec64f Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 16 Apr 2024 11:06:37 +0200 Subject: [PATCH 44/46] Fix codegen --- .../apache/spark/sql/catalyst/util/CollationSupport.java | 8 ++++---- .../apache/spark/unsafe/types/CollationSupportSuite.java | 2 +- .../sql/catalyst/expressions/stringExpressions.scala | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index ddb768b0f7a1f..fe00b7760e005 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -183,16 +183,16 @@ public static int exec(final UTF8String string, final UTF8String substring, return execICU(string, substring, collationId); } } - public static String genCode(final String string, final String substring, final int start, + public static String genCode(final String string, final String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringInstr.exec"; if (collation.supportsBinaryEquality) { - return String.format(expr + "Binary(%s, %s, %d)", string, substring, start); + return String.format(expr + "Binary(%s, %s)", string, substring); } else if (collation.supportsLowercaseEquality) { - return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start); + return String.format(expr + "Lowercase(%s, %s)", string, substring); } else { - return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId); + return String.format(expr + "ICU(%s, %s, %d)", string, substring, collationId); } } public static int execBinary(final UTF8String string, final UTF8String substring) { diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index acdd73ce28b00..977cb4eda9384 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -256,7 +256,7 @@ private void assertStringInstr(String string, String substring, String collation int collationId = CollationFactory.collationNameToId(collationName); assertEquals(expected, CollationSupport.StringInstr.exec(str, substr, collationId) + 1); } - + @Test public void testStringInstr() throws SparkException { assertStringInstr("aaads", "Aa", "UTF8_BINARY", 0); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 35ace1e23b019..77da9703c4c65 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1367,7 +1367,7 @@ case class StringInstr(str: Expression, substr: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (string, substring) => - CollationSupport.StringInstr.genCode(string, substring, 0, collationId) + " + 1") + CollationSupport.StringInstr.genCode(string, substring, collationId) + " + 1") } override protected def withNewChildrenInternal( From b35d7181d581e3508fd0f46fba184b2d9071c54c Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 17 Apr 2024 17:37:47 +0200 Subject: [PATCH 45/46] Remove unused import --- .../apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala | 2 -- .../org/apache/spark/sql/CollationStringExpressionsSuite.scala | 1 - 2 files changed, 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 795e8a696b017..9868d38c00a41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,9 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable - import scala.annotation.tailrec - import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least} import org.apache.spark.sql.errors.QueryCompilationErrors diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 30d645e429558..4721476f374bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import scala.collection.immutable.Seq import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, IntegerType, StringType} From 1ee5ad65c9463e5b8391888e43345562d0640753 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 18 Apr 2024 17:24:08 +0200 Subject: [PATCH 46/46] Add new tests --- .../apache/spark/unsafe/types/CollationSupportSuite.java | 4 ++++ .../spark/sql/catalyst/analysis/CollationTypeCasts.scala | 2 ++ .../spark/sql/CollationStringExpressionsSuite.scala | 8 ++++++-- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index bab27f24f6279..36acf1c9b7a66 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -296,6 +296,8 @@ public void testStringInstr() throws SparkException { assertStringInstr("aaads", "dS", "UNICODE_CI", 4); assertStringInstr("test大千世界X大千世界", "界y", "UNICODE_CI", 0); assertStringInstr("test大千世界X大千世界", "界x", "UNICODE_CI", 8); + assertStringInstr("abİo12", "i̇o", "UNICODE_CI", 3); + assertStringInstr("abi̇o12", "İo", "UNICODE_CI", 3); } private void assertFindInSet(String word, String set, String collationName, @@ -337,6 +339,8 @@ public void testFindInSet() throws SparkException { assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 4); assertFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", "UNICODE_CI", 5); assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 5); + assertFindInSet("i̇o", "ab,İo,12", "UNICODE_CI", 2); + assertFindInSet("İo", "ab,i̇o,12", "UNICODE_CI", 2); } // TODO: Test more collation-aware string expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 9868d38c00a41..795e8a696b017 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable + import scala.annotation.tailrec + import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least} import org.apache.spark.sql.errors.QueryCompilationErrors diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 4721476f374bf..64fcc0b2ae0df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -122,7 +122,8 @@ class CollationStringExpressionsSuite StringInStrTestCase("test大千世界X大千世界", "界x", "UTF8_BINARY_LCASE", 8), StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE", 0), StringInStrTestCase("test大千世界X大千世界", "界y", "UNICODE_CI", 0), - StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8) + StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8), + StringInStrTestCase("abİo12", "i̇o", "UNICODE_CI", 3) // scalastyle:on ) testCases.foreach(t => { @@ -150,7 +151,10 @@ class CollationStringExpressionsSuite FindInSetTestCase("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0), FindInSetTestCase("C", "abc,b,ab,c,def", "UTF8_BINARY_LCASE", 4), FindInSetTestCase("d,ef", "abc,b,ab,c,def", "UNICODE", 0), - FindInSetTestCase("DeF", "abc,b,ab,c,dEf", "UNICODE_CI", 5) + // scalastyle:off + FindInSetTestCase("i̇o", "ab,İo,12", "UNICODE_CI", 2), + FindInSetTestCase("İo", "ab,i̇o,12", "UNICODE_CI", 2) + // scalastyle:on ) testCases.foreach(t => { val query = s"SELECT find_in_set(collate('${t.word}', '${t.c}')," +