From 2a5fce76215630a2b507fa18e32ed2bc1d999f93 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 25 Mar 2024 12:51:57 +0100 Subject: [PATCH 01/28] Update StringReplace class --- .../expressions/stringExpressions.scala | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 742db0ed5a474..d98db114b39a3 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -740,18 +740,39 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp } override def nullSafeEval(srcEval: Any, searchEval: Any, replaceEval: Any): Any = { + val collationId = first.dataType.asInstanceOf[StringType].collationId srcEval.asInstanceOf[UTF8String].replace( - searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String]) + searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String], collationId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (src, search, replace) => { - s"""${ev.value} = $src.replace($search, $replace);""" - }) + val collationId = first.dataType.asInstanceOf[StringType].collationId + + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + nullSafeCodeGen(ctx, ev, (src, search, replace) => { + s"""${ev.value} = $src.replace($search, $replace);""" + }) + } else { + nullSafeCodeGen(ctx, ev, (src, search, replace) => { + s"""${ev.value} = $src.replace($search, $replace, $collationId);""" + }) + } } - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + return defaultCheck + } + + // Only srcExpr and searchExpr are checked for collation compatibility. + val collationId = first.dataType.asInstanceOf[StringType].collationId + CollationTypeConstraints.checkCollationCompatibility(collationId, Seq(second.dataType)) + } + + override def dataType: DataType = srcExpr.dataType + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) override def first: Expression = srcExpr override def second: Expression = searchExpr override def third: Expression = replaceExpr From e0ce699867a95495f35765ae1103dc22309d7c92 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 25 Mar 2024 15:54:07 +0100 Subject: [PATCH 02/28] Add UTF8_BINARY_LCASE collation support using custom function --- .../apache/spark/unsafe/types/UTF8String.java | 98 +++++++++++++++++++ .../sql/CollationStringExpressionsSuite.scala | 45 ++++++++- 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 6abc8385da5ab..ac89a4b4d0b75 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1136,6 +1136,104 @@ public UTF8String replace(UTF8String search, UTF8String replace) { return buf.build(); } + public UTF8String replace(UTF8String search, UTF8String replace, int collationId) { + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + return this.replace(search, replace); + } + if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) { + return lowercaseReplace(search, replace); + } + return collatedReplace(search, replace, collationId); + } + + public UTF8String lowercaseReplace(UTF8String search, UTF8String replace) { + if (numBytes == 0 || search.numBytes == 0) { + return this; + } + UTF8String lowercaseString = this.toLowerCase(); + UTF8String lowercaseSearch = search.toLowerCase(); + + int start = 0; + int end = lowercaseString.indexOf(lowercaseSearch, 0); + if (end == -1) { + // Search string was not found, so string is unchanged. + return this; + } + + // Initialize byte positions + int c = 0; + int byteStart = 0; // position in byte + int byteEnd = 0; // position in byte + while (byteEnd < numBytes && c < end) { + byteEnd += numBytesForFirstByte(getByte(byteEnd)); + c += 1; + } + + // At least one match was found. Estimate space needed for result. + // The 16x multiplier here is chosen to match commons-lang3's implementation. + int increase = Math.max(0, replace.numBytes - search.numBytes) * 16; + final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase); + while (end != -1) { + buf.appendBytes(this.base, this.offset + byteStart, byteEnd - byteStart); + buf.append(replace); + // Update character positions + start = end + lowercaseSearch.numChars(); + end = lowercaseString.indexOf(lowercaseSearch, start); + // Update byte positions + byteStart = byteEnd + search.numBytes; + while (byteEnd < numBytes && c < end) { + byteEnd += numBytesForFirstByte(getByte(byteEnd)); + c += 1; + } + } + buf.appendBytes(this.base, this.offset + byteStart, numBytes - byteStart); + return buf.build(); + } + + private UTF8String collatedReplace(UTF8String search, UTF8String replace, int collationId) { + if (numBytes == 0 || search.numBytes == 0) { + return this; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(this, search, collationId); + + // Find the first occurrence of the search string. + int end = stringSearch.next(); + if (end == StringSearch.DONE) { + // Search string was not found, so string is unchanged. + return this; + } + + // Initialize byte positions + int c = 0; + int byteStart = 0; // position in byte + int byteEnd = 0; // position in byte + while (byteEnd < numBytes && c < end) { + byteEnd += numBytesForFirstByte(getByte(byteEnd)); + c += 1; + } + + // At least one match was found. Estimate space needed for result. + // The 16x multiplier here is chosen to match commons-lang3's implementation. + int increase = Math.max(0, Math.abs(replace.numBytes - search.numBytes)) * 16; + final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase); + while (end != StringSearch.DONE) { + if(stringSearch.getMatchLength() == stringSearch.getPattern().length()) { + buf.appendBytes(this.base, this.offset + byteStart, byteEnd - byteStart); + buf.append(replace); + byteStart = byteEnd + search.numBytes; + } + end = stringSearch.next(); + // Update byte positions + while (byteEnd < numBytes && c < end) { + byteEnd += numBytesForFirstByte(getByte(byteEnd)); + c += 1; + } + } + buf.appendBytes(this.base, this.offset + byteStart, numBytes - byteStart); + return buf.build(); + } + public UTF8String translate(Map dict) { String srcStr = this.toString(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 04f3781a92cf3..af80856d9870f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -21,10 +21,13 @@ import scala.collection.immutable.Seq import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal, StringReplace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType -class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession { +class CollationStringExpressionsSuite extends QueryTest + with SharedSparkSession with ExpressionEvalHelper { case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R) case class CollationTestFail[R](s1: String, s2: String, collation: String) @@ -70,6 +73,46 @@ class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession }) } + test("REPLACE check result on explicitly collated strings") { + def testReplace(expected: String, collationId: Integer, + source: String, search: String, replace: String): Unit = { + val sourceLiteral = Literal.create(source, StringType(collationId)) + val searchLiteral = Literal.create(search, StringType(collationId)) + val replaceLiteral = Literal.create(replace, StringType(collationId)) + + checkEvaluation(StringReplace(sourceLiteral, searchLiteral, replaceLiteral), expected) + } + + // scalastyle:off + // UTF8_BINARY + testReplace("r世e123ace", 0, "r世eplace", "pl", "123") + testReplace("reace", 0, "replace", "pl", "") + testReplace("repl世ace", 0, "repl世ace", "Pl", "") + testReplace("replace", 0, "replace", "", "123") + testReplace("a12ca12c", 0, "abcabc", "b", "12") + testReplace("adad", 0, "abcdabcd", "bc", "") + // UTF8_BINARY_LCASE + testReplace("r世exxace", 1, "r世eplace", "pl", "xx") + testReplace("reAB世ace", 1, "repl世ace", "PL", "AB") + testReplace("Replace", 1, "Replace", "", "123") + testReplace("rexplace", 1, "re世place", "世", "x") + testReplace("a12ca12c", 1, "abcaBc", "B", "12") + testReplace("Adad", 1, "AbcdabCd", "Bc", "") + // UNICODE + testReplace("re世place", 2, "re世place", "plx", "123") + testReplace("世Replace", 2, "世Replace", "re", "") + testReplace("replace世", 2, "replace世", "", "123") + testReplace("aBc世a12c", 2, "aBc世abc", "b", "12") + testReplace("adad", 2, "abcdabcd", "bc", "") + // UNICODE_CI + testReplace("replace", 3, "replace", "plx", "123") + testReplace("place", 3, "Replace", "re", "") + testReplace("replace", 3, "replace", "", "123") + testReplace("a12c世a12c", 3, "aBc世abc", "b", "12") + testReplace("a世dad", 3, "a世Bcdabcd", "bC", "") + // scalastyle:on + } + // TODO: Add more tests for other string expressions } From d2e90f856d4007799754564d8508ffd066355d0e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 26 Mar 2024 16:17:02 +0100 Subject: [PATCH 03/28] Improve testReplace signature --- .../sql/CollationStringExpressionsSuite.scala | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index af80856d9870f..f247a79323b91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -74,8 +74,8 @@ class CollationStringExpressionsSuite extends QueryTest } test("REPLACE check result on explicitly collated strings") { - def testReplace(expected: String, collationId: Integer, - source: String, search: String, replace: String): Unit = { + def testReplace(source: String, search: String, replace: String, + collationId: Integer, expected: String): Unit = { val sourceLiteral = Literal.create(source, StringType(collationId)) val searchLiteral = Literal.create(search, StringType(collationId)) val replaceLiteral = Literal.create(replace, StringType(collationId)) @@ -85,31 +85,31 @@ class CollationStringExpressionsSuite extends QueryTest // scalastyle:off // UTF8_BINARY - testReplace("r世e123ace", 0, "r世eplace", "pl", "123") - testReplace("reace", 0, "replace", "pl", "") - testReplace("repl世ace", 0, "repl世ace", "Pl", "") - testReplace("replace", 0, "replace", "", "123") - testReplace("a12ca12c", 0, "abcabc", "b", "12") - testReplace("adad", 0, "abcdabcd", "bc", "") + testReplace("r世eplace", "pl", "123", 0, "r世e123ace") + testReplace("replace", "pl", "", 0, "reace") + testReplace("repl世ace", "Pl", "", 0, "repl世ace") + testReplace("replace", "", "123", 0, "replace") + testReplace("abcabc", "b", "12", 0, "a12ca12c") + testReplace("abcdabcd", "bc", "", 0, "adad") // UTF8_BINARY_LCASE - testReplace("r世exxace", 1, "r世eplace", "pl", "xx") - testReplace("reAB世ace", 1, "repl世ace", "PL", "AB") - testReplace("Replace", 1, "Replace", "", "123") - testReplace("rexplace", 1, "re世place", "世", "x") - testReplace("a12ca12c", 1, "abcaBc", "B", "12") - testReplace("Adad", 1, "AbcdabCd", "Bc", "") + testReplace("r世eplace", "pl", "xx", 1, "r世exxace") + testReplace("repl世ace", "PL", "AB", 1, "reAB世ace") + testReplace("Replace", "", "123", 1, "Replace") + testReplace("re世place", "世", "x", 1, "rexplace") + testReplace("abcaBc", "B", "12", 1, "a12ca12c") + testReplace("AbcdabCd", "Bc", "", 1, "Adad") // UNICODE - testReplace("re世place", 2, "re世place", "plx", "123") - testReplace("世Replace", 2, "世Replace", "re", "") - testReplace("replace世", 2, "replace世", "", "123") - testReplace("aBc世a12c", 2, "aBc世abc", "b", "12") - testReplace("adad", 2, "abcdabcd", "bc", "") + testReplace("re世place", "plx", "123", 2, "re世place") + testReplace("世Replace", "re", "", 2, "世Replace") + testReplace("replace世", "", "123", 2, "replace世") + testReplace("aBc世abc", "b", "12", 2, "aBc世a12c") + testReplace("abcdabcd", "bc", "", 2, "adad") // UNICODE_CI - testReplace("replace", 3, "replace", "plx", "123") - testReplace("place", 3, "Replace", "re", "") - testReplace("replace", 3, "replace", "", "123") - testReplace("a12c世a12c", 3, "aBc世abc", "b", "12") - testReplace("a世dad", 3, "a世Bcdabcd", "bC", "") + testReplace("replace", "plx", "123", 3, "replace") + testReplace("Replace", "re", "", 3, "place") + testReplace("replace", "", "123", 3, "replace") + testReplace("aBc世abc", "b", "12", 3, "a12c世a12c") + testReplace("a世Bcdabcd", "bC", "", 3, "a世dad") // scalastyle:on } From 93c6eb7583c2ed3ebcca2a3b933426029074ab7e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 26 Mar 2024 16:31:51 +0100 Subject: [PATCH 04/28] Resolve merge problems with master --- .../main/java/org/apache/spark/unsafe/types/UTF8String.java | 4 ++-- .../spark/sql/catalyst/expressions/stringExpressions.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 9f02da564ec3e..e8c0b322f447f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1137,10 +1137,10 @@ public UTF8String replace(UTF8String search, UTF8String replace) { } public UTF8String replace(UTF8String search, UTF8String replace, int collationId) { - if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { return this.replace(search, replace); } - if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) { + if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { return lowercaseReplace(search, replace); } return collatedReplace(search, replace, collationId); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index f371eec96dc17..177b10fa74398 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -748,7 +748,7 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val collationId = first.dataType.asInstanceOf[StringType].collationId - if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { nullSafeCodeGen(ctx, ev, (src, search, replace) => { s"""${ev.value} = $src.replace($search, $replace);""" }) From 7a1b24079c280d3940c8854e313d04313ba0c867 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 26 Mar 2024 17:32:55 +0100 Subject: [PATCH 05/28] Improve scala style --- .../org/apache/spark/sql/CollationStringExpressionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 444c6df78ffe9..5f703a4774fe8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -112,7 +112,7 @@ class CollationStringExpressionsSuite extends QueryTest testReplace("a世Bcdabcd", "bC", "", 3, "a世dad") // scalastyle:on } - + test("REPEAT check output type on explicitly collated string") { def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = { val s = Literal.create(input, StringType(collationId)) From c59d71e6a0552fccde9ebd2a907e414415bab364 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 27 Mar 2024 10:06:39 +0100 Subject: [PATCH 06/28] Solve whitespace scala style problem --- .../apache/spark/sql/CollationStringExpressionsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 5f703a4774fe8..b7c48355802b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -75,7 +75,7 @@ class CollationStringExpressionsSuite extends QueryTest test("REPLACE check result on explicitly collated strings") { def testReplace(source: String, search: String, replace: String, - collationId: Integer, expected: String): Unit = { + collationId: Integer, expected: String): Unit = { val sourceLiteral = Literal.create(source, StringType(collationId)) val searchLiteral = Literal.create(search, StringType(collationId)) val replaceLiteral = Literal.create(replace, StringType(collationId)) @@ -112,7 +112,7 @@ class CollationStringExpressionsSuite extends QueryTest testReplace("a世Bcdabcd", "bC", "", 3, "a世dad") // scalastyle:on } - + test("REPEAT check output type on explicitly collated string") { def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = { val s = Literal.create(input, StringType(collationId)) From a5c75b38bb53c5d6fd039afab0489d158e9f6da9 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 1 Apr 2024 15:03:08 +0200 Subject: [PATCH 07/28] Add lowercase StringSearch and remove lowercaseReplace --- .../sql/catalyst/util/CollationFactory.java | 14 ++++++ .../apache/spark/unsafe/types/UTF8String.java | 47 ------------------- 2 files changed, 14 insertions(+), 47 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 119508a37e717..10fe2e50f6c69 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -179,12 +179,26 @@ public static StringSearch getStringSearch( final UTF8String left, final UTF8String right, final int collationId) { + + if(collationId == UTF8_BINARY_LCASE_COLLATION_ID) { + return getStringSearchUTF8LCase(left, right); + } + String pattern = right.toString(); CharacterIterator target = new StringCharacterIterator(left.toString()); Collator collator = CollationFactory.fetchCollation(collationId).collator; return new StringSearch(pattern, target, (RuleBasedCollator) collator); } + private static StringSearch getStringSearchUTF8LCase( + final UTF8String left, + final UTF8String right) { + String pattern = right.toLowerCase().toString(); + String target = left.toLowerCase().toString(); + + return new StringSearch(pattern, target); + } + /** * Returns the collation id for the given collation name. */ diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e8c0b322f447f..2812773eba6f2 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1140,56 +1140,9 @@ public UTF8String replace(UTF8String search, UTF8String replace, int collationId if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { return this.replace(search, replace); } - if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { - return lowercaseReplace(search, replace); - } return collatedReplace(search, replace, collationId); } - public UTF8String lowercaseReplace(UTF8String search, UTF8String replace) { - if (numBytes == 0 || search.numBytes == 0) { - return this; - } - UTF8String lowercaseString = this.toLowerCase(); - UTF8String lowercaseSearch = search.toLowerCase(); - - int start = 0; - int end = lowercaseString.indexOf(lowercaseSearch, 0); - if (end == -1) { - // Search string was not found, so string is unchanged. - return this; - } - - // Initialize byte positions - int c = 0; - int byteStart = 0; // position in byte - int byteEnd = 0; // position in byte - while (byteEnd < numBytes && c < end) { - byteEnd += numBytesForFirstByte(getByte(byteEnd)); - c += 1; - } - - // At least one match was found. Estimate space needed for result. - // The 16x multiplier here is chosen to match commons-lang3's implementation. - int increase = Math.max(0, replace.numBytes - search.numBytes) * 16; - final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase); - while (end != -1) { - buf.appendBytes(this.base, this.offset + byteStart, byteEnd - byteStart); - buf.append(replace); - // Update character positions - start = end + lowercaseSearch.numChars(); - end = lowercaseString.indexOf(lowercaseSearch, start); - // Update byte positions - byteStart = byteEnd + search.numBytes; - while (byteEnd < numBytes && c < end) { - byteEnd += numBytesForFirstByte(getByte(byteEnd)); - c += 1; - } - } - buf.appendBytes(this.base, this.offset + byteStart, numBytes - byteStart); - return buf.build(); - } - private UTF8String collatedReplace(UTF8String search, UTF8String replace, int collationId) { if (numBytes == 0 || search.numBytes == 0) { return this; From 76878b9097ac5db885660fa6246e4296866dc2e1 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 1 Apr 2024 15:22:07 +0200 Subject: [PATCH 08/28] Remove repeated code --- .../expressions/stringExpressions.scala | 5 +- .../sql/CollationStringExpressionsSuite.scala | 56 ++++++++++--------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 177b10fa74398..64afd6d2a3576 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -735,19 +735,18 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId + def this(srcExpr: Expression, searchExpr: Expression) = { this(srcExpr, searchExpr, Literal("")) } override def nullSafeEval(srcEval: Any, searchEval: Any, replaceEval: Any): Any = { - val collationId = first.dataType.asInstanceOf[StringType].collationId srcEval.asInstanceOf[UTF8String].replace( searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String], collationId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val collationId = first.dataType.asInstanceOf[StringType].collationId - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { nullSafeCodeGen(ctx, ev, (src, search, replace) => { s"""${ev.value} = $src.replace($search, $replace);""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index b7c48355802b5..6d785465b0d05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.Seq import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, Literal, StringRepeat, StringReplace} +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StringType @@ -84,32 +85,35 @@ class CollationStringExpressionsSuite extends QueryTest } // scalastyle:off - // UTF8_BINARY - testReplace("r世eplace", "pl", "123", 0, "r世e123ace") - testReplace("replace", "pl", "", 0, "reace") - testReplace("repl世ace", "Pl", "", 0, "repl世ace") - testReplace("replace", "", "123", 0, "replace") - testReplace("abcabc", "b", "12", 0, "a12ca12c") - testReplace("abcdabcd", "bc", "", 0, "adad") - // UTF8_BINARY_LCASE - testReplace("r世eplace", "pl", "xx", 1, "r世exxace") - testReplace("repl世ace", "PL", "AB", 1, "reAB世ace") - testReplace("Replace", "", "123", 1, "Replace") - testReplace("re世place", "世", "x", 1, "rexplace") - testReplace("abcaBc", "B", "12", 1, "a12ca12c") - testReplace("AbcdabCd", "Bc", "", 1, "Adad") - // UNICODE - testReplace("re世place", "plx", "123", 2, "re世place") - testReplace("世Replace", "re", "", 2, "世Replace") - testReplace("replace世", "", "123", 2, "replace世") - testReplace("aBc世abc", "b", "12", 2, "aBc世a12c") - testReplace("abcdabcd", "bc", "", 2, "adad") - // UNICODE_CI - testReplace("replace", "plx", "123", 3, "replace") - testReplace("Replace", "re", "", 3, "place") - testReplace("replace", "", "123", 3, "replace") - testReplace("aBc世abc", "b", "12", 3, "a12c世a12c") - testReplace("a世Bcdabcd", "bC", "", 3, "a世dad") + var collationId = CollationFactory.collationNameToId("UTF8_BINARY") + testReplace("r世eplace", "pl", "123", collationId, "r世e123ace") + testReplace("replace", "pl", "", collationId, "reace") + testReplace("repl世ace", "Pl", "", collationId, "repl世ace") + testReplace("replace", "", "123", collationId, "replace") + testReplace("abcabc", "b", "12", collationId, "a12ca12c") + testReplace("abcdabcd", "bc", "", collationId, "adad") + + collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") + testReplace("r世eplace", "pl", "xx", collationId, "r世exxace") + testReplace("repl世ace", "PL", "AB", collationId, "reAB世ace") + testReplace("Replace", "", "123", collationId, "Replace") + testReplace("re世place", "世", "x", collationId, "rexplace") + testReplace("abcaBc", "B", "12", collationId, "a12ca12c") + testReplace("AbcdabCd", "Bc", "", collationId, "Adad") + + collationId = CollationFactory.collationNameToId("UNICODE") + testReplace("re世place", "plx", "123", collationId, "re世place") + testReplace("世Replace", "re", "", collationId, "世Replace") + testReplace("replace世", "", "123", collationId, "replace世") + testReplace("aBc世abc", "b", "12", collationId, "aBc世a12c") + testReplace("abcdabcd", "bc", "", collationId, "adad") + + collationId = CollationFactory.collationNameToId("UNICODE_CI") + testReplace("replace", "plx", "123", collationId, "replace") + testReplace("Replace", "re", "", collationId, "place") + testReplace("replace", "", "123", collationId, "replace") + testReplace("aBc世abc", "b", "12", collationId, "a12c世a12c") + testReplace("a世Bcdabcd", "bC", "", collationId, "a世dad") // scalastyle:on } From 572bd541275b4463da5d970efa17f86e231fa382 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 2 Apr 2024 12:35:11 +0200 Subject: [PATCH 09/28] Improve naming of collation aware methods --- .../org/apache/spark/sql/catalyst/util/CollationFactory.java | 4 ++-- .../main/java/org/apache/spark/unsafe/types/UTF8String.java | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 10fe2e50f6c69..feda48f3d4369 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -181,7 +181,7 @@ public static StringSearch getStringSearch( final int collationId) { if(collationId == UTF8_BINARY_LCASE_COLLATION_ID) { - return getStringSearchUTF8LCase(left, right); + return getStringSearch(left, right); } String pattern = right.toString(); @@ -190,7 +190,7 @@ public static StringSearch getStringSearch( return new StringSearch(pattern, target, (RuleBasedCollator) collator); } - private static StringSearch getStringSearchUTF8LCase( + private static StringSearch getStringSearch( final UTF8String left, final UTF8String right) { String pattern = right.toLowerCase().toString(); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 2812773eba6f2..f3c13252640b6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1140,10 +1140,10 @@ public UTF8String replace(UTF8String search, UTF8String replace, int collationId if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { return this.replace(search, replace); } - return collatedReplace(search, replace, collationId); + return collationAwareReplace(search, replace, collationId); } - private UTF8String collatedReplace(UTF8String search, UTF8String replace, int collationId) { + private UTF8String collationAwareReplace(UTF8String search, UTF8String replace, int collationId) { if (numBytes == 0 || search.numBytes == 0) { return this; } From e2bea13b560180cbe45ee442890aa5d2f120509f Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 3 Apr 2024 09:58:04 +0200 Subject: [PATCH 10/28] Improve java style --- .../org/apache/spark/sql/catalyst/util/CollationFactory.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index feda48f3d4369..088f95ecc2475 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -191,8 +191,8 @@ public static StringSearch getStringSearch( } private static StringSearch getStringSearch( - final UTF8String left, - final UTF8String right) { + final UTF8String left, + final UTF8String right) { String pattern = right.toLowerCase().toString(); String target = left.toLowerCase().toString(); From a194292b19a118371edb7d9b92ef43850a53a9e5 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 3 Apr 2024 10:04:28 +0200 Subject: [PATCH 11/28] Remove unnecessary check for mathced length --- .../java/org/apache/spark/unsafe/types/UTF8String.java | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f3c13252640b6..4cd556d5b7a64 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1171,11 +1171,10 @@ private UTF8String collationAwareReplace(UTF8String search, UTF8String replace, int increase = Math.max(0, Math.abs(replace.numBytes - search.numBytes)) * 16; final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase); while (end != StringSearch.DONE) { - if(stringSearch.getMatchLength() == stringSearch.getPattern().length()) { - buf.appendBytes(this.base, this.offset + byteStart, byteEnd - byteStart); - buf.append(replace); - byteStart = byteEnd + search.numBytes; - } + buf.appendBytes(this.base, this.offset + byteStart, byteEnd - byteStart); + buf.append(replace); + byteStart = byteEnd + search.numBytes; + // Go to next match end = stringSearch.next(); // Update byte positions while (byteEnd < numBytes && c < end) { From 7b6720b24a865a24acc7d96f32bcba67d93e2b31 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 3 Apr 2024 10:08:06 +0200 Subject: [PATCH 12/28] Improve style in CollationFactory --- .../org/apache/spark/sql/catalyst/util/CollationFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 088f95ecc2475..e734a7400c668 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -180,7 +180,7 @@ public static StringSearch getStringSearch( final UTF8String right, final int collationId) { - if(collationId == UTF8_BINARY_LCASE_COLLATION_ID) { + if (collationId == UTF8_BINARY_LCASE_COLLATION_ID) { return getStringSearch(left, right); } From 3042d7e791f1c410ade3b6fe0758988ced63685e Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 3 Apr 2024 14:33:14 +0200 Subject: [PATCH 13/28] Add doc comment --- .../java/org/apache/spark/unsafe/types/UTF8String.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 4cd556d5b7a64..7388e5877db67 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1136,6 +1136,14 @@ public UTF8String replace(UTF8String search, UTF8String replace) { return buf.build(); } + /** + * Replace all occurrences of search in this with replace respecting collation with id = collationId. + * + * @param search the string to be searched + * @param replace the start position of the current string for searching + * @param collationId the id of applied collation + * @return the string with replace instead of search in all places + */ public UTF8String replace(UTF8String search, UTF8String replace, int collationId) { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { return this.replace(search, replace); @@ -1144,6 +1152,7 @@ public UTF8String replace(UTF8String search, UTF8String replace, int collationId } private UTF8String collationAwareReplace(UTF8String search, UTF8String replace, int collationId) { + // This collation aware implementation is based on existing implementation on UTF8String with default collation if (numBytes == 0 || search.numBytes == 0) { return this; } From 84e41a31b9981182532b211cc173a70042eb9712 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 3 Apr 2024 14:33:50 +0200 Subject: [PATCH 14/28] Improve comment style --- .../src/main/java/org/apache/spark/unsafe/types/UTF8String.java | 1 - 1 file changed, 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 7388e5877db67..a7019bd7e8ffb 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1138,7 +1138,6 @@ public UTF8String replace(UTF8String search, UTF8String replace) { /** * Replace all occurrences of search in this with replace respecting collation with id = collationId. - * * @param search the string to be searched * @param replace the start position of the current string for searching * @param collationId the id of applied collation From cc940cbad164e6d3d54de9afb450e3a4f7174c22 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 3 Apr 2024 14:39:45 +0200 Subject: [PATCH 15/28] Improve naming in getStringSearch --- .../sql/catalyst/util/CollationFactory.java | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index e734a7400c668..fc9da458b0977 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -176,25 +176,25 @@ public Collation( */ public static StringSearch getStringSearch( - final UTF8String left, - final UTF8String right, + final UTF8String targetUTF8String, + final UTF8String patternUTF8String, final int collationId) { if (collationId == UTF8_BINARY_LCASE_COLLATION_ID) { - return getStringSearch(left, right); + return getStringSearch(targetUTF8String.toLowerCase(), patternUTF8String.toLowerCase()); } - String pattern = right.toString(); - CharacterIterator target = new StringCharacterIterator(left.toString()); + String pattern = patternUTF8String.toString(); + CharacterIterator target = new StringCharacterIterator(targetUTF8String.toString()); Collator collator = CollationFactory.fetchCollation(collationId).collator; return new StringSearch(pattern, target, (RuleBasedCollator) collator); } - private static StringSearch getStringSearch( - final UTF8String left, - final UTF8String right) { - String pattern = right.toLowerCase().toString(); - String target = left.toLowerCase().toString(); + public static StringSearch getStringSearch( + final UTF8String targetUTF8String, + final UTF8String patternUTF8String) { + String pattern = patternUTF8String.toString(); + String target = targetUTF8String.toString(); return new StringSearch(pattern, target); } From 4e93874b3c561e54a4969cb7918369707355d98f Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 4 Apr 2024 10:32:48 +0200 Subject: [PATCH 16/28] Remove type checks for collation missmatch --- .../spark/sql/catalyst/analysis/CollationTypeCasts.scala | 6 ++---- .../sql/catalyst/expressions/stringExpressions.scala | 9 +-------- .../spark/sql/CollationStringExpressionsSuite.scala | 5 +---- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 1a14b4227de8f..1ee4d1c8fd7bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable - import scala.annotation.tailrec - import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least, StringReplace} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, StringType} @@ -47,7 +45,7 @@ object CollationTypeCasts extends TypeCoercionRule { case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | - _: Coalesce | _: BinaryExpression | _: ConcatWs) => + _: Coalesce | _: BinaryExpression | _: ConcatWs | _: StringReplace) => val newChildren = collateToSingleType(otherExpr.children) otherExpr.withNewChildren(newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index f4221f5952718..57b4e7ff8da26 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -751,14 +751,7 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp } override def checkInputDataTypes(): TypeCheckResult = { - val defaultCheck = super.checkInputDataTypes() - if (defaultCheck.isFailure) { - return defaultCheck - } - - // Only srcExpr and searchExpr are checked for collation compatibility. - val collationId = first.dataType.asInstanceOf[StringType].collationId - CollationTypeConstraints.checkCollationCompatibility(collationId, Seq(second.dataType)) + super.checkInputDataTypes() } override def dataType: DataType = srcExpr.dataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 0547b184d658e..d628598c4902e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -21,10 +21,7 @@ import scala.collection.immutable.Seq import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs, ExpressionEvalHelper, Literal, StringRepeat} -import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.{Collation, ExpressionEvalHelper, Literal, StringRepeat, StringReplace} +import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs, ExpressionEvalHelper, Literal, StringRepeat, StringReplace} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession From 9cb0944f18b806c8dfd09cc629a10da337953a3c Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 4 Apr 2024 10:34:22 +0200 Subject: [PATCH 17/28] Remove checkInputDataTypes --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 57b4e7ff8da26..bc607388cab4f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -750,10 +750,6 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp } } - override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() - } - override def dataType: DataType = srcExpr.dataType override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) From 41c3872f9cf7d03cb4e22f3abd726214715ef873 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 4 Apr 2024 12:32:02 +0200 Subject: [PATCH 18/28] Add empty lines between imports --- .../apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 1ee4d1c8fd7bb..87682b1603d79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable + import scala.annotation.tailrec + import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Expression, Greatest, If, In, InSubquery, Least, StringReplace} import org.apache.spark.sql.errors.QueryCompilationErrors From 74f69b9c2395723fa8ce4d9fcc2c5d090c199ad5 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 4 Apr 2024 13:31:39 +0200 Subject: [PATCH 19/28] Handle all collationIds in getStringSearch --- .../org/apache/spark/sql/catalyst/util/CollationFactory.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index fc9da458b0977..edcf0863e93a6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -180,7 +180,9 @@ public static StringSearch getStringSearch( final UTF8String patternUTF8String, final int collationId) { - if (collationId == UTF8_BINARY_LCASE_COLLATION_ID) { + if (collationId == UTF8_BINARY_COLLATION_ID) { + return getStringSearch(targetUTF8String, patternUTF8String); + } else if (collationId == UTF8_BINARY_LCASE_COLLATION_ID) { return getStringSearch(targetUTF8String.toLowerCase(), patternUTF8String.toLowerCase()); } From ea3730c37b2af9591040a24bebf20db2dae65b0d Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 5 Apr 2024 10:04:18 +0200 Subject: [PATCH 20/28] Improve Java style --- .../main/java/org/apache/spark/unsafe/types/UTF8String.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index a7019bd7e8ffb..b608cfd6ee309 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1137,7 +1137,7 @@ public UTF8String replace(UTF8String search, UTF8String replace) { } /** - * Replace all occurrences of search in this with replace respecting collation with id = collationId. + * Replace all occurrences of search in this with replace respecting given collation. * @param search the string to be searched * @param replace the start position of the current string for searching * @param collationId the id of applied collation @@ -1151,7 +1151,7 @@ public UTF8String replace(UTF8String search, UTF8String replace, int collationId } private UTF8String collationAwareReplace(UTF8String search, UTF8String replace, int collationId) { - // This collation aware implementation is based on existing implementation on UTF8String with default collation + // This collation aware implementation is based on existing implementation on UTF8String if (numBytes == 0 || search.numBytes == 0) { return this; } From bc5c256607e28a88e1a280cae9308610035c85d9 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 12 Apr 2024 16:50:42 +0200 Subject: [PATCH 21/28] Refactor StringReplace --- .../sql/catalyst/util/CollationSupport.java | 122 ++++++++++++++++++ .../apache/spark/unsafe/types/UTF8String.java | 62 +-------- .../expressions/stringExpressions.scala | 17 +-- .../sql/CollationStringExpressionsSuite.scala | 10 +- 4 files changed, 133 insertions(+), 78 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index fe1952921b7fb..7f5d8932396e7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -18,6 +18,7 @@ import com.ibm.icu.text.StringSearch; +import org.apache.spark.unsafe.UTF8StringBuilder; import org.apache.spark.unsafe.types.UTF8String; /** @@ -137,6 +138,39 @@ public static boolean execICU(final UTF8String l, final UTF8String r, } } + public static class StringReplace { + public static UTF8String exec(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(src, search, replace); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(src, search, replace); + } else { + return execICU(src, search, replace, collationId); + } + } + public static String genCode(final String src, final String search, final String replace, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringReplace.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s, %s)", src, search, replace); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s, %s)", src, search, replace); + } else { + return String.format(expr + "ICU(%s, %s, %s, %d)", src, search, replace, collationId); + } + } + public static UTF8String execBinary(final UTF8String src, final UTF8String search, final UTF8String replace) { + return src.replace(search, replace); + } + public static UTF8String execLowercase(final UTF8String src, final UTF8String search, final UTF8String replace) { + return CollationAwareUTF8String.lowercaseReplace(src, search, replace); + } + public static UTF8String execICU(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { + return CollationAwareUTF8String.replace(src, search, replace, collationId); + } + } + // TODO: Add more collation-aware string expressions. /** @@ -169,6 +203,94 @@ private static boolean matchAt(final UTF8String target, final UTF8String pattern pos, pos + pattern.numChars()), pattern, collationId).last() == 0; } + private static UTF8String replace(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { + // This collation aware implementation is based on existing implementation on UTF8String + if (src.numBytes() == 0 || search.numBytes() == 0) { + return src; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(src, search, collationId); + + // Find the first occurrence of the search string. + int end = stringSearch.next(); + if (end == StringSearch.DONE) { + // Search string was not found, so string is unchanged. + return src; + } + + // Initialize byte positions + int c = 0; + int byteStart = 0; // position in byte + int byteEnd = 0; // position in byte + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + + // At least one match was found. Estimate space needed for result. + // The 16x multiplier here is chosen to match commons-lang3's implementation. + int increase = Math.max(0, Math.abs(replace.numBytes() - search.numBytes())) * 16; + final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); + while (end != StringSearch.DONE) { + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); + buf.append(replace); + byteStart = byteEnd + search.numBytes(); + // Go to next match + end = stringSearch.next(); + // Update byte positions + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + } + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, src.numBytes() - byteStart); + return buf.build(); + } + + private static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search, final UTF8String replace) { + if (src.numBytes() == 0 || search.numBytes() == 0) { + return src; + } + UTF8String lowercaseString = src.toLowerCase(); + UTF8String lowercaseSearch = search.toLowerCase(); + + int start = 0; + int end = lowercaseString.indexOf(lowercaseSearch, 0); + if (end == -1) { + // Search string was not found, so string is unchanged. + return src; + } + + // Initialize byte positions + int c = 0; + int byteStart = 0; // position in byte + int byteEnd = 0; // position in byte + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + + // At least one match was found. Estimate space needed for result. + // The 16x multiplier here is chosen to match commons-lang3's implementation. + int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16; + final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); + while (end != -1) { + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); + buf.append(replace); + // Update character positions + start = end + lowercaseSearch.numChars(); + end = lowercaseString.indexOf(lowercaseSearch, start); + // Update byte positions + byteStart = byteEnd + search.numBytes(); + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + } + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, src.numBytes() - byteStart); + return buf.build(); + } + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index b4a8f182b49ce..f79fafb7e9ec8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -224,7 +224,7 @@ public void writeTo(OutputStream out) throws IOException { * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point */ - private static int numBytesForFirstByte(final byte b) { + public static int numBytesForFirstByte(final byte b) { final int offset = b & 0xFF; byte numBytes = bytesOfCodePointInUTF8[offset]; return (numBytes == 0) ? 1: numBytes; // Skip the first byte disallowed in UTF-8 @@ -344,7 +344,7 @@ public boolean contains(final UTF8String substring) { /** * Returns the byte at position `i`. */ - private byte getByte(int i) { + public byte getByte(int i) { return Platform.getByte(base, offset + i); } @@ -1102,64 +1102,6 @@ public UTF8String replace(UTF8String search, UTF8String replace) { return buf.build(); } - /** - * Replace all occurrences of search in this with replace respecting given collation. - * @param search the string to be searched - * @param replace the start position of the current string for searching - * @param collationId the id of applied collation - * @return the string with replace instead of search in all places - */ - public UTF8String replace(UTF8String search, UTF8String replace, int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - return this.replace(search, replace); - } - return collationAwareReplace(search, replace, collationId); - } - - private UTF8String collationAwareReplace(UTF8String search, UTF8String replace, int collationId) { - // This collation aware implementation is based on existing implementation on UTF8String - if (numBytes == 0 || search.numBytes == 0) { - return this; - } - - StringSearch stringSearch = CollationFactory.getStringSearch(this, search, collationId); - - // Find the first occurrence of the search string. - int end = stringSearch.next(); - if (end == StringSearch.DONE) { - // Search string was not found, so string is unchanged. - return this; - } - - // Initialize byte positions - int c = 0; - int byteStart = 0; // position in byte - int byteEnd = 0; // position in byte - while (byteEnd < numBytes && c < end) { - byteEnd += numBytesForFirstByte(getByte(byteEnd)); - c += 1; - } - - // At least one match was found. Estimate space needed for result. - // The 16x multiplier here is chosen to match commons-lang3's implementation. - int increase = Math.max(0, Math.abs(replace.numBytes - search.numBytes)) * 16; - final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase); - while (end != StringSearch.DONE) { - buf.appendBytes(this.base, this.offset + byteStart, byteEnd - byteStart); - buf.append(replace); - byteStart = byteEnd + search.numBytes; - // Go to next match - end = stringSearch.next(); - // Update byte positions - while (byteEnd < numBytes && c < end) { - byteEnd += numBytesForFirstByte(getByte(byteEnd)); - c += 1; - } - } - buf.appendBytes(this.base, this.offset + byteStart, numBytes - byteStart); - return buf.build(); - } - public UTF8String translate(Map dict) { String srcStr = this.toString(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 7b813b952c635..f2762863e503d 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -714,20 +714,15 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp } override def nullSafeEval(srcEval: Any, searchEval: Any, replaceEval: Any): Any = { - srcEval.asInstanceOf[UTF8String].replace( - searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String], collationId) + CollationSupport.StringReplace.exec(srcEval.asInstanceOf[UTF8String], + searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String], collationId); } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - nullSafeCodeGen(ctx, ev, (src, search, replace) => { - s"""${ev.value} = $src.replace($search, $replace);""" - }) - } else { - nullSafeCodeGen(ctx, ev, (src, search, replace) => { - s"""${ev.value} = $src.replace($search, $replace, $collationId);""" - }) - } + nullSafeCodeGen(ctx, ev, (src, search, replace) => { + s"${ev.value} = " + + CollationSupport.StringReplace.genCode(src, search, replace, collationId) + ";" + }) } override def dataType: DataType = srcExpr.dataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 0c4403ab3d791..7582797a8e199 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql -import scala.collection.immutable.Seq - import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal, StringReplace} +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{BooleanType, StringType} @@ -120,6 +119,7 @@ class CollationStringExpressionsSuite } assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } + test("REPLACE check result on explicitly collated strings") { def testReplace(source: String, search: String, replace: String, collationId: Integer, expected: String): Unit = { @@ -163,10 +163,6 @@ class CollationStringExpressionsSuite // scalastyle:on } - test("REPEAT check output type on explicitly collated string") { - def testRepeat(expected: String, collationId: Int, input: String, n: Int): Unit = { - val s = Literal.create(input, StringType(collationId)) - test("Support EndsWith string expression with collation") { // Supported collations case class EndsWithTestCase[R](l: String, r: String, c: String, result: R) From 8a81536328f9219e215164fc5897df231d62a1a2 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 12 Apr 2024 20:50:27 +0200 Subject: [PATCH 22/28] Break lines to 100 characters --- .../sql/catalyst/util/CollationSupport.java | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 7f5d8932396e7..1a265fae67ef0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -139,7 +139,8 @@ public static boolean execICU(final UTF8String l, final UTF8String r, } public static class StringReplace { - public static UTF8String exec(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { + public static UTF8String exec(final UTF8String src, final UTF8String search, + final UTF8String replace, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality) { return execBinary(src, search, replace); @@ -149,7 +150,8 @@ public static UTF8String exec(final UTF8String src, final UTF8String search, fin return execICU(src, search, replace, collationId); } } - public static String genCode(final String src, final String search, final String replace, final int collationId) { + public static String genCode(final String src, final String search, final String replace, + final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringReplace.exec"; if (collation.supportsBinaryEquality) { @@ -160,13 +162,16 @@ public static String genCode(final String src, final String search, final String return String.format(expr + "ICU(%s, %s, %s, %d)", src, search, replace, collationId); } } - public static UTF8String execBinary(final UTF8String src, final UTF8String search, final UTF8String replace) { + public static UTF8String execBinary(final UTF8String src, final UTF8String search, + final UTF8String replace) { return src.replace(search, replace); } - public static UTF8String execLowercase(final UTF8String src, final UTF8String search, final UTF8String replace) { + public static UTF8String execLowercase(final UTF8String src, final UTF8String search, + final UTF8String replace) { return CollationAwareUTF8String.lowercaseReplace(src, search, replace); } - public static UTF8String execICU(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { + public static UTF8String execICU(final UTF8String src, final UTF8String search, + final UTF8String replace, final int collationId) { return CollationAwareUTF8String.replace(src, search, replace, collationId); } } @@ -203,7 +208,8 @@ private static boolean matchAt(final UTF8String target, final UTF8String pattern pos, pos + pattern.numChars()), pattern, collationId).last() == 0; } - private static UTF8String replace(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { + private static UTF8String replace(final UTF8String src, final UTF8String search, + final UTF8String replace, final int collationId) { // This collation aware implementation is based on existing implementation on UTF8String if (src.numBytes() == 0 || search.numBytes() == 0) { return src; @@ -243,11 +249,13 @@ private static UTF8String replace(final UTF8String src, final UTF8String search, c += 1; } } - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, src.numBytes() - byteStart); + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, + src.numBytes() - byteStart); return buf.build(); } - private static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search, final UTF8String replace) { + private static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search, + final UTF8String replace) { if (src.numBytes() == 0 || search.numBytes() == 0) { return src; } @@ -287,7 +295,8 @@ private static UTF8String lowercaseReplace(final UTF8String src, final UTF8Strin c += 1; } } - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, src.numBytes() - byteStart); + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, + src.numBytes() - byteStart); return buf.build(); } From c456325e8c7e55dca47a31f0a248af3ebc70f5c5 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 15 Apr 2024 15:22:01 +0200 Subject: [PATCH 23/28] Refactor tests --- .../unsafe/types/CollationSupportSuite.java | 36 ++++++++++ .../expressions/stringExpressions.scala | 6 +- .../sql/CollationStringExpressionsSuite.scala | 72 ++++++++----------- 3 files changed, 69 insertions(+), 45 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index bfb696c35fff6..cb464de8f8637 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -249,6 +249,42 @@ public void testEndsWith() throws SparkException { assertEndsWith("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false); } + private void assertReplace(String source, String search, String replace, String collationName, + String expected) throws SparkException { + UTF8String src = UTF8String.fromString(source); + UTF8String sear = UTF8String.fromString(search); + UTF8String repl = UTF8String.fromString(replace); + int collationId = CollationFactory.collationNameToId(collationName); + assertEquals(expected, CollationSupport.StringReplace. + exec(src, sear, repl, collationId).toString()); + } + + @Test + public void testReplace() throws SparkException { + assertReplace("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace"); + assertReplace("replace", "pl", "", "UTF8_BINARY", "reace"); + assertReplace("repl世ace", "Pl", "", "UTF8_BINARY", "repl世ace"); + assertReplace("replace", "", "123", "UTF8_BINARY", "replace"); + assertReplace("abcabc", "b", "12", "UTF8_BINARY", "a12ca12c"); + assertReplace("abcdabcd", "bc", "", "UTF8_BINARY", "adad"); + assertReplace("r世eplace", "pl", "xx", "UTF8_BINARY_LCASE", "r世exxace"); + assertReplace("repl世ace", "PL", "AB", "UTF8_BINARY_LCASE", "reAB世ace"); + assertReplace("Replace", "", "123", "UTF8_BINARY_LCASE", "Replace"); + assertReplace("re世place", "世", "x", "UTF8_BINARY_LCASE", "rexplace"); + assertReplace("abcaBc", "B", "12", "UTF8_BINARY_LCASE", "a12ca12c"); + assertReplace("AbcdabCd", "Bc", "", "UTF8_BINARY_LCASE", "Adad"); + assertReplace("re世place", "plx", "123", "UNICODE", "re世place"); + assertReplace("世Replace", "re", "", "UNICODE", "世Replace"); + assertReplace("replace世", "", "123", "UNICODE", "replace世"); + assertReplace("aBc世abc", "b", "12", "UNICODE", "aBc世a12c"); + assertReplace("abcdabcd", "bc", "", "UNICODE", "adad"); + assertReplace("replace", "plx", "123", "UNICODE_CI", "replace"); + assertReplace("Replace", "re", "", "UNICODE_CI", "place"); + assertReplace("replace", "", "123", "UNICODE_CI", "replace"); + assertReplace("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c"); + assertReplace("a世Bcdabcd", "bC", "", "UNICODE_CI", "a世dad"); + } + // TODO: Test more collation-aware string expressions. /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index f2762863e503d..b08e81970adfa 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -719,10 +719,8 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (src, search, replace) => { - s"${ev.value} = " + - CollationSupport.StringReplace.genCode(src, search, replace, collationId) + ";" - }) + defineCodeGen(ctx, ev, (src, search, replace) => + CollationSupport.StringReplace.genCode(src, search, replace, collationId)) } override def dataType: DataType = srcExpr.dataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 7582797a8e199..f3130180ec56b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal, StringReplace} +import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -120,47 +120,37 @@ class CollationStringExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } - test("REPLACE check result on explicitly collated strings") { - def testReplace(source: String, search: String, replace: String, - collationId: Integer, expected: String): Unit = { - val sourceLiteral = Literal.create(source, StringType(collationId)) - val searchLiteral = Literal.create(search, StringType(collationId)) - val replaceLiteral = Literal.create(replace, StringType(collationId)) - - checkEvaluation(StringReplace(sourceLiteral, searchLiteral, replaceLiteral), expected) + test("Support Replace string expression with collation") { + case class StartsWithTestCase[R](source: String, search: String, replace: String, + c: String, result: R) + val testCases = Seq( + // scalastyle:off + StartsWithTestCase("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace"), + StartsWithTestCase("repl世ace", "PL", "AB", "UTF8_BINARY_LCASE", "reAB世ace"), + StartsWithTestCase("abcdabcd", "bc", "", "UNICODE", "adad"), + StartsWithTestCase("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c") + // scalastyle:on + ) + testCases.foreach(t => { + val query = s"SELECT replace(collate('${t.source}','${t.c}'),collate('${t.search}'," + + s"'${t.c}'),collate('${t.replace}','${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType( + StringType(CollationFactory.collationNameToId(t.c)))) + // Implicit casting + checkAnswer(sql(s"SELECT replace(collate('${t.source}','${t.c}'),'${t.search}'," + + s"'${t.replace}')"), Row(t.result)) + checkAnswer(sql(s"SELECT replace('${t.source}',collate('${t.search}','${t.c}')," + + s"'${t.replace}')"), Row(t.result)) + checkAnswer(sql(s"SELECT replace('${t.source}','${t.search}'," + + s"collate('${t.replace}','${t.c}'))"), Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT startswith(collate('abcde', 'UTF8_BINARY_LCASE'),collate('C', 'UNICODE_CI'))") } - - // scalastyle:off - var collationId = CollationFactory.collationNameToId("UTF8_BINARY") - testReplace("r世eplace", "pl", "123", collationId, "r世e123ace") - testReplace("replace", "pl", "", collationId, "reace") - testReplace("repl世ace", "Pl", "", collationId, "repl世ace") - testReplace("replace", "", "123", collationId, "replace") - testReplace("abcabc", "b", "12", collationId, "a12ca12c") - testReplace("abcdabcd", "bc", "", collationId, "adad") - - collationId = CollationFactory.collationNameToId("UTF8_BINARY_LCASE") - testReplace("r世eplace", "pl", "xx", collationId, "r世exxace") - testReplace("repl世ace", "PL", "AB", collationId, "reAB世ace") - testReplace("Replace", "", "123", collationId, "Replace") - testReplace("re世place", "世", "x", collationId, "rexplace") - testReplace("abcaBc", "B", "12", collationId, "a12ca12c") - testReplace("AbcdabCd", "Bc", "", collationId, "Adad") - - collationId = CollationFactory.collationNameToId("UNICODE") - testReplace("re世place", "plx", "123", collationId, "re世place") - testReplace("世Replace", "re", "", collationId, "世Replace") - testReplace("replace世", "", "123", collationId, "replace世") - testReplace("aBc世abc", "b", "12", collationId, "aBc世a12c") - testReplace("abcdabcd", "bc", "", collationId, "adad") - - collationId = CollationFactory.collationNameToId("UNICODE_CI") - testReplace("replace", "plx", "123", collationId, "replace") - testReplace("Replace", "re", "", collationId, "place") - testReplace("replace", "", "123", collationId, "replace") - testReplace("aBc世abc", "b", "12", collationId, "a12c世a12c") - testReplace("a世Bcdabcd", "bC", "", collationId, "a世dad") - // scalastyle:on + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } test("Support EndsWith string expression with collation") { From 09f13d8218657abb7825322bebca7737a86fa925 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 16 Apr 2024 11:19:44 +0200 Subject: [PATCH 24/28] Sync with the latest master --- .../org/apache/spark/sql/CollationStringExpressionsSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index c227e1c328aa8..c1288d6fd6a75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import scala.collection.immutable.Seq import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession From d9f56d6664bc16a65e363da3bc1d80daccdf6458 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 18 Apr 2024 17:41:16 +0200 Subject: [PATCH 25/28] Added new tests (2 failing) --- .../org/apache/spark/unsafe/types/CollationSupportSuite.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index d23744d597067..3e4b816d845d5 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -295,6 +295,8 @@ public void testReplace() throws SparkException { assertReplace("replace", "", "123", "UNICODE_CI", "replace"); assertReplace("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c"); assertReplace("a世Bcdabcd", "bC", "", "UNICODE_CI", "a世dad"); + assertReplace("abİo12", "i̇o", "xx", "UNICODE_CI", "abxx12"); // FAILING + assertReplace("abi̇o12", "İo", "yy", "UNICODE_CI", "abyy12"); // FAILING } // TODO: Test more collation-aware string expressions. From 0c725f989a52f81f667ed71c1de3edfe12f0ef8c Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 24 Apr 2024 13:46:59 +0200 Subject: [PATCH 26/28] Fix bug with case-variable lenght characters --- .../spark/sql/catalyst/util/CollationSupport.java | 11 ++++++++++- .../spark/unsafe/types/CollationSupportSuite.java | 4 ++-- .../spark/sql/CollationStringExpressionsSuite.scala | 12 +++++++----- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 0fbafe7b9e19b..a1404e5b271f6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -395,7 +395,16 @@ private static UTF8String replace(final UTF8String src, final UTF8String search, while (end != StringSearch.DONE) { buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); buf.append(replace); - byteStart = byteEnd + search.numBytes(); + + // Move byteStart to the beginning of the current match + byteStart = byteEnd; + int cs = c; + // Move cs to the end of the current match + // This is necessary because the search string may contain 'multi-character' characters + while (byteStart < src.numBytes() && cs < c + stringSearch.getMatchLength()) { + byteStart += UTF8String.numBytesForFirstByte(src.getByte(byteStart)); + cs += 1; + } // Go to next match end = stringSearch.next(); // Update byte positions diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 1b854bc3ebce6..bbf64fab9e005 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -562,8 +562,8 @@ public void testReplace() throws SparkException { assertReplace("replace", "", "123", "UNICODE_CI", "replace"); assertReplace("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c"); assertReplace("a世Bcdabcd", "bC", "", "UNICODE_CI", "a世dad"); - assertReplace("abİo12", "i̇o", "xx", "UNICODE_CI", "abxx12"); // FAILING - assertReplace("abi̇o12", "İo", "yy", "UNICODE_CI", "abyy12"); // FAILING + assertReplace("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx"); + assertReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"); } // TODO: Test more collation-aware string expressions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 805b072890d9e..78e6c318b1f44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -202,14 +202,16 @@ class CollationStringExpressionsSuite } test("Support Replace string expression with collation") { - case class StartsWithTestCase[R](source: String, search: String, replace: String, + case class ReplaceTestCase[R](source: String, search: String, replace: String, c: String, result: R) val testCases = Seq( // scalastyle:off - StartsWithTestCase("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace"), - StartsWithTestCase("repl世ace", "PL", "AB", "UTF8_BINARY_LCASE", "reAB世ace"), - StartsWithTestCase("abcdabcd", "bc", "", "UNICODE", "adad"), - StartsWithTestCase("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c") + ReplaceTestCase("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace"), + ReplaceTestCase("repl世ace", "PL", "AB", "UTF8_BINARY_LCASE", "reAB世ace"), + ReplaceTestCase("abcdabcd", "bc", "", "UNICODE", "adad"), + ReplaceTestCase("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c"), + ReplaceTestCase("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"), + ReplaceTestCase("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx") // scalastyle:on ) testCases.foreach(t => { From 816a49ab6f0fc3eae9dcab5220cd33a73a013880 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 24 Apr 2024 15:51:36 +0200 Subject: [PATCH 27/28] Fix java linter errors --- .../org/apache/spark/unsafe/types/CollationSupportSuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index bbf64fab9e005..a4eb43270cc67 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -534,8 +534,8 @@ private void assertReplace(String source, String search, String replace, String UTF8String sear = UTF8String.fromString(search); UTF8String repl = UTF8String.fromString(replace); int collationId = CollationFactory.collationNameToId(collationName); - assertEquals(expected, CollationSupport.StringReplace. - exec(src, sear, repl, collationId).toString()); + assertEquals(expected, CollationSupport.StringReplace + .exec(src, sear, repl, collationId).toString()); } @Test From 0ef49d09a7f93929a26ff25380e82f05b656f89a Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 25 Apr 2024 13:33:54 +0200 Subject: [PATCH 28/28] Fix import scalastyle --- .../apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 094ab7916fada..5556e0d1ee9b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -22,7 +22,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Literal, Overlay, RegExpReplace, StringReplace, StringLPad, StringRPad} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Literal, Overlay, RegExpReplace, StringLPad, StringReplace, StringRPad} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, StringType}