From 0a32eac6db7b390ad40a7202041faf1a73a0624d Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Wed, 27 Mar 2024 16:57:30 +0100 Subject: [PATCH 1/8] Add all code paths for collation support to TRIM functions --- .../apache/spark/unsafe/types/UTF8String.java | 86 +++++++++ .../expressions/stringExpressions.scala | 168 +++++++++++++----- 2 files changed, 214 insertions(+), 40 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index c5dfb91f06c63..87018a850ad42 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -585,6 +585,16 @@ public UTF8String trim() { return copyUTF8String(s, e); } + public UTF8String trim(int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality + || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return trim(); + } + else { + return trim(UTF8String.fromString(" "), collationId); + } + } + /** * Trims whitespace ASCII characters from both ends of this string. * @@ -628,6 +638,18 @@ public UTF8String trim(UTF8String trimString) { } } + public UTF8String trim(UTF8String trimString, int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + return trim(trimString); + } + + if (CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return lowercaseTrimLeft(trimString).lowercaseTrimRight(trimString); + } + + return trimLeft(trimString, collationId).trimRight(trimString, collationId); + } + /** * Trims space characters (ASCII 32) from the start of this string. * @@ -648,6 +670,16 @@ public UTF8String trimLeft() { return copyUTF8String(s, this.numBytes - 1); } + public UTF8String trimLeft(int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality + || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return trimLeft(); + } + else { + return trimLeft(UTF8String.fromString(" "), collationId); + } + } + /** * Trims instances of the given trim string from the start of this string. * @@ -686,6 +718,28 @@ public UTF8String trimLeft(UTF8String trimString) { return copyUTF8String(trimIdx, numBytes - 1); } + public UTF8String trimLeft(UTF8String trimString, int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + return trimLeft(trimString); + } + + if (CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return lowercaseTrimLeft(trimString); + } + + return collatedTrimLeft(trimString, collationId); + } + + public UTF8String lowercaseTrimLeft(UTF8String trimString) { + // TODO + return EMPTY_UTF8; + } + + public UTF8String collatedTrimLeft(UTF8String trimString, int collationId) { + // TODO + return EMPTY_UTF8; + } + /** * Trims space characters (ASCII 32) from the end of this string. * @@ -706,6 +760,16 @@ public UTF8String trimRight() { return copyUTF8String(0, e); } + public UTF8String trimRight(int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality + || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return trimRight(); + } + else { + return trimRight(UTF8String.fromString(" "), collationId); + } + } + /** * Trims at most `numSpaces` space characters (ASCII 32) from the end of this string. */ @@ -767,6 +831,28 @@ public UTF8String trimRight(UTF8String trimString) { return copyUTF8String(0, trimEnd); } + public UTF8String trimRight(UTF8String trimString, int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + return trimRight(trimString); + } + + if (CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return lowercaseTrimRight(trimString); + } + + return collatedTrimRight(trimString, collationId); + } + + public UTF8String lowercaseTrimRight(UTF8String trimString) { + // TODO + return EMPTY_UTF8; + } + + public UTF8String collatedTrimRight(UTF8String trimString, int collationId) { + // TODO + return EMPTY_UTF8; + } + public UTF8String reverse() { byte[] result = new byte[this.numBytes]; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e73dc5f2ee1b4..eb66403e4e56d 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1028,8 +1028,11 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { protected def direction: String override def children: Seq[Expression] = srcStr +: trimStr.toSeq - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = srcStr.dataType + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) + + final lazy val collationId: Int = srcStr.dataType.asInstanceOf[StringType].collationId override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -1037,6 +1040,20 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { protected def doEval(srcString: UTF8String): UTF8String protected def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheckResult = super.checkInputDataTypes() + if (defaultCheckResult.isFailure) { + return defaultCheckResult + } + + trimStr match { + case None => TypeCheckResult.TypeCheckSuccess + case Some(trimChars) => + val collationId = srcStr.dataType.asInstanceOf[StringType].collationId + CollationTypeConstraints.checkCollationCompatibility(collationId, Seq(trimChars.dataType)) + } + } + override def eval(input: InternalRow): Any = { val srcString = srcStr.eval(input).asInstanceOf[UTF8String] if (srcString == null) { @@ -1054,32 +1071,64 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { val evals = children.map(_.genCode(ctx)) val srcString = evals(0) - if (evals.length == 1) { - ev.copy(code = code""" - |${srcString.code} - |boolean ${ev.isNull} = false; - |UTF8String ${ev.value} = null; - |if (${srcString.isNull}) { - | ${ev.isNull} = true; - |} else { - | ${ev.value} = ${srcString.value}.$trimMethod(); - |}""".stripMargin) - } else { - val trimString = evals(1) - ev.copy(code = code""" - |${srcString.code} - |boolean ${ev.isNull} = false; - |UTF8String ${ev.value} = null; - |if (${srcString.isNull}) { - | ${ev.isNull} = true; - |} else { - | ${trimString.code} - | if (${trimString.isNull}) { - | ${ev.isNull} = true; - | } else { - | ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value}); - | } - |}""".stripMargin) + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + if (evals.length == 1) { + ev.copy(code = code""" + |${srcString.code} + |boolean ${ev.isNull} = false; + |UTF8String ${ev.value} = null; + |if (${srcString.isNull}) { + | ${ev.isNull} = true; + |} else { + | ${ev.value} = ${srcString.value}.$trimMethod(); + |}""".stripMargin) + } else { + val trimString = evals(1) + ev.copy(code = code""" + |${srcString.code} + |boolean ${ev.isNull} = false; + |UTF8String ${ev.value} = null; + |if (${srcString.isNull}) { + | ${ev.isNull} = true; + |} else { + | ${trimString.code} + | if (${trimString.isNull}) { + | ${ev.isNull} = true; + | } else { + | ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value}); + | } + |}""".stripMargin) + } + } + else { + if (evals.length == 1) { + ev.copy(code = code""" + |${srcString.code} + |boolean ${ev.isNull} = false; + |UTF8String ${ev.value} = null; + |if (${srcString.isNull}) { + | ${ev.isNull} = true; + |} else { + | ${ev.value} = ${srcString.value}.$trimMethod($collationId); + |}""".stripMargin) + } else { + val trimString = evals(1) + ev.copy(code = code""" + |${srcString.code} + |boolean ${ev.isNull} = false; + |UTF8String ${ev.value} = null; + |if (${srcString.isNull}) { + | ${ev.isNull} = true; + |} else { + | ${trimString.code} + | if (${trimString.isNull}) { + | ${ev.isNull} = true; + | } else { + | ${ev.value} = + | ${srcString.value}.$trimMethod(${trimString.value}, $collationId); + | } + |}""".stripMargin) + } } } @@ -1170,12 +1219,25 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) override protected def direction: String = "BOTH" - override def doEval(srcString: UTF8String): UTF8String = srcString.trim() + override val trimMethod: String = "trim" - override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trim(trimString) + override def doEval(srcString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trim() + } + else { + srcString.trim(collationId) + } + } - override val trimMethod: String = "trim" + override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trim(trimString) + } + else { + srcString.trim(trimString, collationId) + } + } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy( @@ -1278,12 +1340,25 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None override protected def direction: String = "LEADING" - override def doEval(srcString: UTF8String): UTF8String = srcString.trimLeft() + override val trimMethod: String = "trimLeft" - override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trimLeft(trimString) + override def doEval(srcString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trimLeft() + } + else { + srcString.trimLeft(collationId) + } + } - override val trimMethod: String = "trimLeft" + override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trimLeft(trimString) + } + else { + srcString.trimLeft(trimString, collationId) + } + } override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimLeft = @@ -1339,12 +1414,25 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non override protected def direction: String = "TRAILING" - override def doEval(srcString: UTF8String): UTF8String = srcString.trimRight() + override val trimMethod: String = "trimRight" - override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trimRight(trimString) + override def doEval(srcString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trimRight() + } + else { + srcString.trimRight(collationId) + } + } - override val trimMethod: String = "trimRight" + override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trimRight(trimString) + } + else { + srcString.trimRight(trimString, collationId) + } + } override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimRight = From ce8253ef74eba9e0929b1b52452b2f1fbc7a238d Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Wed, 27 Mar 2024 22:00:35 +0100 Subject: [PATCH 2/8] Implement trim functions that support collations --- .../apache/spark/unsafe/types/UTF8String.java | 209 ++++++++++++++++-- 1 file changed, 195 insertions(+), 14 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 87018a850ad42..92db3ce6bee15 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -698,6 +698,7 @@ public UTF8String trimLeft(UTF8String trimString) { UTF8String searchChar = copyUTF8String( searchIdx, searchIdx + numBytesForFirstByte(this.getByte(searchIdx)) - 1); int searchCharBytes = searchChar.numBytes; + // try to find the matching for the searchChar in the trimString set if (trimString.find(searchChar, 0) >= 0) { trimIdx += searchCharBytes; @@ -707,6 +708,7 @@ public UTF8String trimLeft(UTF8String trimString) { } searchIdx += searchCharBytes; } + if (searchIdx == 0) { // Nothing trimmed return this; @@ -730,14 +732,88 @@ public UTF8String trimLeft(UTF8String trimString, int collationId) { return collatedTrimLeft(trimString, collationId); } - public UTF8String lowercaseTrimLeft(UTF8String trimString) { - // TODO - return EMPTY_UTF8; + private UTF8String lowercaseTrimLeft(UTF8String trimString) { + if (trimString == null) { + return null; + } + + // The searching byte position in the lowercase source string + int searchIdx = 0; + // The byte position of a first non-matching character in the lowercase source string + int trimByteIdx = 0; + + // Convert trimString to lowercase so it can be searched properly + trimString = trimString.toLowerCase(); + + while (searchIdx < numBytes) { + UTF8String searchChar = copyUTF8String( + searchIdx, + searchIdx + numBytesForFirstByte(getByte(searchIdx)) - 1); + int searchCharBytes = searchChar.numBytes; + + // Try to find the matching for the lowercase searchChar in the trimString + if (trimString.find(searchChar.toLowerCase(), 0) >= 0) { + trimByteIdx += searchCharBytes; + searchIdx += searchCharBytes; + } + else { + // No matching, exit the search + break; + } + } + + if (searchIdx == 0) { + // Nothing trimmed - return original string (not converted to lowercase) + return this; + } + if (trimByteIdx >= numBytes) { + // Everything trimmed + return EMPTY_UTF8; + } + return copyUTF8String(trimByteIdx, numBytes - 1); } - public UTF8String collatedTrimLeft(UTF8String trimString, int collationId) { - // TODO - return EMPTY_UTF8; + private UTF8String collatedTrimLeft(UTF8String trimString, int collationId) { + if (trimString == null) { + return null; + } + + // The searching byte position in the source string + int searchIdx = 0; + // The byte position of a first non-matching character in the source string + int trimByteIdx = 0; + + while (searchIdx < numBytes) { + UTF8String searchChar = copyUTF8String( + searchIdx, + searchIdx + numBytesForFirstByte(getByte(searchIdx)) - 1); + int searchCharBytes = searchChar.numBytes; + + // Try to find the matching for the searchChar in the trimString + StringSearch stringSearch = CollationFactory.getStringSearch( + trimString, searchChar, collationId); + int searchCharIdx = stringSearch.next(); + + if (searchCharIdx != StringSearch.DONE + && stringSearch.getMatchLength() == stringSearch.getPattern().length()) { + trimByteIdx += searchCharBytes; + searchIdx += searchCharBytes; + } + else { + // No matching, exit the search + break; + } + } + + if (searchIdx == 0) { + // Nothing trimmed - return original string (not converted to lowercase) + return this; + } + if (trimByteIdx >= numBytes) { + // Everything trimmed + return EMPTY_UTF8; + } + return copyUTF8String(trimByteIdx, numBytes - 1); } /** @@ -797,27 +873,30 @@ public UTF8String trimRight(UTF8String trimString) { int[] stringCharLen = new int[numBytes]; // array of the first byte position for each character in the source string int[] stringCharPos = new int[numBytes]; + // build the position and length array while (charIdx < numBytes) { stringCharPos[numChars] = charIdx; stringCharLen[numChars] = numBytesForFirstByte(getByte(charIdx)); charIdx += stringCharLen[numChars]; - numChars ++; + numChars++; } // index trimEnd points to the first no matching byte position from the right side of // the source string. int trimEnd = numBytes - 1; + while (numChars > 0) { UTF8String searchChar = copyUTF8String( stringCharPos[numChars - 1], stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); + if (trimString.find(searchChar, 0) >= 0) { trimEnd -= stringCharLen[numChars - 1]; } else { break; } - numChars --; + numChars--; } if (trimEnd == numBytes - 1) { @@ -843,14 +922,116 @@ public UTF8String trimRight(UTF8String trimString, int collationId) { return collatedTrimRight(trimString, collationId); } - public UTF8String lowercaseTrimRight(UTF8String trimString) { - // TODO - return EMPTY_UTF8; + private UTF8String lowercaseTrimRight(UTF8String trimString) { + if (trimString == null) { + return null; + } + + // Convert trimString to lowercase so it can be searched properly + trimString = trimString.toLowerCase(); + + // Number of bytes iterated from the source string + int byteIdx = 0; + // Number of characters iterated from the source string + int numChars = 0; + // Array of character length for the source string + int[] stringCharLen = new int[numBytes]; + // Array of the first byte position for each character in the source string + int[] stringCharPos = new int[numBytes]; + + // Build the position and length array + while (byteIdx < numBytes) { + stringCharPos[numChars] = byteIdx; + stringCharLen[numChars] = numBytesForFirstByte(getByte(byteIdx)); + byteIdx += stringCharLen[numChars]; + numChars++; + } + + // Index trimEnd points to the first no matching byte position from the right side of + // the source string. + int trimByteIdx = numBytes - 1; + + while (numChars > 0) { + UTF8String searchChar = copyUTF8String( + stringCharPos[numChars - 1], + stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); + + // Try to find the matching for the lowercase searchChar in the trimString + if (trimString.find(searchChar.toLowerCase(), 0) >= 0) { + trimByteIdx -= stringCharLen[numChars - 1]; + numChars--; + } + else { + break; + } + } + + if (trimByteIdx == numBytes - 1) { + // Nothing trimmed + return this; + } + if (trimByteIdx < 0) { + // Everything trimmed + return EMPTY_UTF8; + } + return copyUTF8String(0, trimByteIdx); } - public UTF8String collatedTrimRight(UTF8String trimString, int collationId) { - // TODO - return EMPTY_UTF8; + private UTF8String collatedTrimRight(UTF8String trimString, int collationId) { + if (trimString == null) { + return null; + } + + // Number of bytes iterated from the source string + int byteIdx = 0; + // Number of characters iterated from the source string + int numChars = 0; + // Array of character length for the source string + int[] stringCharLen = new int[numBytes]; + // Array of the first byte position for each character in the source string + int[] stringCharPos = new int[numBytes]; + + // Build the position and length array + while (byteIdx < numBytes) { + stringCharPos[numChars] = byteIdx; + stringCharLen[numChars] = numBytesForFirstByte(getByte(byteIdx)); + byteIdx += stringCharLen[numChars]; + numChars++; + } + + // Index trimEnd points to the first no matching byte position from the right side of + // the source string. + int trimByteIdx = numBytes - 1; + + while (numChars > 0) { + UTF8String searchChar = copyUTF8String( + stringCharPos[numChars - 1], + stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); + + // Try to find the matching for the searchChar in the trimString + StringSearch stringSearch = CollationFactory.getStringSearch( + trimString, searchChar, collationId); + int searchCharIdx = stringSearch.next(); + + if (searchCharIdx != StringSearch.DONE + && stringSearch.getMatchLength() == stringSearch.getPattern().length()) { + trimByteIdx -= stringCharLen[numChars - 1]; + numChars--; + } + else { + break; + } + } + + if (trimByteIdx == numBytes - 1) { + // Nothing trimmed + return this; + } + if (trimByteIdx < 0) { + // Everything trimmed + return EMPTY_UTF8; + } + return copyUTF8String(0, trimByteIdx); } public UTF8String reverse() { From 095748f6e84cdded2705e608477d5fe21e0232af Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 28 Mar 2024 10:22:57 +0100 Subject: [PATCH 3/8] Change param names for getStringSearch to a more descriptive ones --- .../apache/spark/sql/catalyst/util/CollationFactory.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 119508a37e717..09f74a131fc6c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -176,11 +176,11 @@ public Collation( */ public static StringSearch getStringSearch( - final UTF8String left, - final UTF8String right, + final UTF8String targetString, + final UTF8String patternString, final int collationId) { - String pattern = right.toString(); - CharacterIterator target = new StringCharacterIterator(left.toString()); + String pattern = patternString.toString(); + CharacterIterator target = new StringCharacterIterator(targetString.toString()); Collator collator = CollationFactory.fetchCollation(collationId).collator; return new StringSearch(pattern, target, (RuleBasedCollator) collator); } From 804493a5068cb9ecec0ae4ab8b68ed984f916834 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 28 Mar 2024 18:33:18 +0100 Subject: [PATCH 4/8] Tests + revert file reformatting --- .../apache/spark/unsafe/types/UTF8String.java | 9 +- .../sql/CollationStringExpressionsSuite.scala | 320 +++++++++++++++++- 2 files changed, 321 insertions(+), 8 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 92db3ce6bee15..f7601bf19e1d8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -698,7 +698,6 @@ public UTF8String trimLeft(UTF8String trimString) { UTF8String searchChar = copyUTF8String( searchIdx, searchIdx + numBytesForFirstByte(this.getByte(searchIdx)) - 1); int searchCharBytes = searchChar.numBytes; - // try to find the matching for the searchChar in the trimString set if (trimString.find(searchChar, 0) >= 0) { trimIdx += searchCharBytes; @@ -708,7 +707,6 @@ public UTF8String trimLeft(UTF8String trimString) { } searchIdx += searchCharBytes; } - if (searchIdx == 0) { // Nothing trimmed return this; @@ -873,30 +871,27 @@ public UTF8String trimRight(UTF8String trimString) { int[] stringCharLen = new int[numBytes]; // array of the first byte position for each character in the source string int[] stringCharPos = new int[numBytes]; - // build the position and length array while (charIdx < numBytes) { stringCharPos[numChars] = charIdx; stringCharLen[numChars] = numBytesForFirstByte(getByte(charIdx)); charIdx += stringCharLen[numChars]; - numChars++; + numChars ++; } // index trimEnd points to the first no matching byte position from the right side of // the source string. int trimEnd = numBytes - 1; - while (numChars > 0) { UTF8String searchChar = copyUTF8String( stringCharPos[numChars - 1], stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); - if (trimString.find(searchChar, 0) >= 0) { trimEnd -= stringCharLen[numChars - 1]; } else { break; } - numChars--; + numChars --; } if (trimEnd == numBytes - 1) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index ab2d768256c14..13459c24ce827 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -86,8 +86,326 @@ class CollationStringExpressionsSuite extends QueryTest testRepeat("UNICODE_CI", 3, "abc", 2) } - // TODO: Add more tests for other string expressions + case class StringTrimTestCase( + collation: String, + trimFunc: String, + sourceString: String, + trimString: String, + expectedResultString: String) + + test("string trim functions with collation - success") { + // scalastyle:off + + val testCases = Seq( + // Basic test cases + StringTrimTestCase("UTF8_BINARY", "TRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "TRIM", " asd ", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "TRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UTF8_BINARY", "TRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "TRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "TRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UTF8_BINARY", "BTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", " asd ", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UTF8_BINARY", "LTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", " asd ", null, "asd "), + StringTrimTestCase("UTF8_BINARY", "LTRIM", " a世a ", null, "a世a "), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "xxasdxx", "x", "asdxx"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "xa世ax", "x", "a世ax"), + + StringTrimTestCase("UTF8_BINARY", "RTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", " asd ", null, " asd"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", " a世a ", null, " a世a"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "xxasdxx", "x", "xxasd"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "xa世ax", "x", "xa世a"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", " asd ", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", " asd ", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", " asd ", null, "asd "), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", " a世a ", null, "a世a "), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "xxasdxx", "x", "asdxx"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "xa世ax", "x", "a世ax"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", " asd ", null, " asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", " a世a ", null, " a世a"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "xxasdxx", "x", "xxasd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "xa世ax", "x", "xa世a"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", " asd ", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UNICODE", "BTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE", "BTRIM", " asd ", null, "asd"), + StringTrimTestCase("UNICODE", "BTRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UNICODE", "BTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE", "BTRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UNICODE", "BTRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UNICODE", "LTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE", "LTRIM", " asd ", null, "asd "), + StringTrimTestCase("UNICODE", "LTRIM", " a世a ", null, "a世a "), + StringTrimTestCase("UNICODE", "LTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE", "LTRIM", "xxasdxx", "x", "asdxx"), + StringTrimTestCase("UNICODE", "LTRIM", "xa世ax", "x", "a世ax"), + + StringTrimTestCase("UNICODE", "RTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE", "RTRIM", " asd ", null, " asd"), + StringTrimTestCase("UNICODE", "RTRIM", " a世a ", null, " a世a"), + StringTrimTestCase("UNICODE", "RTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE", "RTRIM", "xxasdxx", "x", "xxasd"), + StringTrimTestCase("UNICODE", "RTRIM", "xa世ax", "x", "xa世a"), + + StringTrimTestCase("UNICODE_CI", "TRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE_CI", "TRIM", " asd ", null, "asd"), + StringTrimTestCase("UNICODE_CI", "TRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UNICODE_CI", "TRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "TRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "TRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UNICODE_CI", "BTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE_CI", "BTRIM", " asd ", null, "asd"), + StringTrimTestCase("UNICODE_CI", "BTRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UNICODE_CI", "LTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE_CI", "LTRIM", " asd ", null, "asd "), + StringTrimTestCase("UNICODE_CI", "LTRIM", " a世a ", null, "a世a "), + StringTrimTestCase("UNICODE_CI", "LTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "xxasdxx", "x", "asdxx"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "xa世ax", "x", "a世ax"), + + StringTrimTestCase("UNICODE_CI", "RTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE_CI", "RTRIM", " asd ", null, " asd"), + StringTrimTestCase("UNICODE_CI", "RTRIM", " a世a ", null, " a世a"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "xxasdxx", "x", "xxasd"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "xa世ax", "x", "xa世a"), + + // Test cases where trimString has more than one character + StringTrimTestCase("UTF8_BINARY", "TRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "ddsXXXaa", "asd", "XXXaa"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "ddsXXXaa", "asd", "ddsXXX"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "ddsXXXaa", "asd", "XXXaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "ddsXXXaa", "asd", "ddsXXX"), + + StringTrimTestCase("UNICODE", "TRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UNICODE", "BTRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UNICODE", "LTRIM", "ddsXXXaa", "asd", "XXXaa"), + StringTrimTestCase("UNICODE", "RTRIM", "ddsXXXaa", "asd", "ddsXXX"), + + StringTrimTestCase("UNICODE_CI", "TRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "ddsXXXaa", "asd", "XXXaa"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "ddsXXXaa", "asd", "ddsXXX"), + + // Test cases specific to collation type + // uppercase trim, lowercase src + StringTrimTestCase("UTF8_BINARY", "TRIM", "asd", "A", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "asd", "A", "sd"), + StringTrimTestCase("UNICODE", "TRIM", "asd", "A", "asd"), + StringTrimTestCase("UNICODE_CI", "TRIM", "asd", "A", "sd"), + + // lowercase trim, uppercase src + StringTrimTestCase("UTF8_BINARY", "TRIM", "ASD", "a", "ASD"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "ASD", "a", "SD"), + StringTrimTestCase("UNICODE", "TRIM", "ASD", "a", "ASD"), + StringTrimTestCase("UNICODE_CI", "TRIM", "ASD", "a", "SD"), + + // uppercase and lowercase chars of different byte-length (utf8) + StringTrimTestCase("UTF8_BINARY", "TRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "ẞaaaẞ", "ß", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "ẞaaaẞ", "ß", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "ẞaaaẞ", "ß", "aaaẞ"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "ẞaaaẞ", "ß", "ẞaaa"), + + StringTrimTestCase("UNICODE", "TRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UNICODE", "BTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UNICODE", "LTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UNICODE", "RTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + + StringTrimTestCase("UNICODE_CI", "TRIM", "ẞaaaẞ", "ß", "aaa"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "ẞaaaẞ", "ß", "aaa"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "ẞaaaẞ", "ß", "aaaẞ"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "ẞaaaẞ", "ß", "ẞaaa"), + StringTrimTestCase("UTF8_BINARY", "TRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "ßaaaß", "ẞ", "ßaaaß"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "ßaaaß", "ẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "ßaaaß", "ẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "ßaaaß", "ẞ", "aaaß"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "ßaaaß", "ẞ", "ßaaa"), + + StringTrimTestCase("UNICODE", "TRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UNICODE", "BTRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UNICODE", "LTRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UNICODE", "RTRIM", "ßaaaß", "ẞ", "ßaaaß"), + + StringTrimTestCase("UNICODE_CI", "TRIM", "ßaaaß", "ẞ", "aaa"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "ßaaaß", "ẞ", "aaa"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "ßaaaß", "ẞ", "aaaß"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "ßaaaß", "ẞ", "ßaaa"), + + // different byte-length (utf8) chars trimmed + StringTrimTestCase("UTF8_BINARY", "TRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "Ëaaaẞ", "Ëẞ", "aaaẞ"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "Ëaaaẞ", "Ëẞ", "Ëaaa"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "Ëaaaẞ", "Ëẞ", "aaaẞ"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "Ëaaaẞ", "Ëẞ", "Ëaaa"), + + StringTrimTestCase("UNICODE", "TRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UNICODE", "BTRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UNICODE", "LTRIM", "Ëaaaẞ", "Ëẞ", "aaaẞ"), + StringTrimTestCase("UNICODE", "RTRIM", "Ëaaaẞ", "Ëẞ", "Ëaaa"), + + StringTrimTestCase("UNICODE_CI", "TRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "Ëaaaẞ", "Ëẞ", "aaaẞ"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "Ëaaaẞ", "Ëẞ", "Ëaaa") + ) + + testCases.foreach(testCase => { + var df: DataFrame = null + + if (testCase.trimFunc.equalsIgnoreCase("BTRIM")) { + // BTRIM has arguments in (srcStr, trimStr) order + df = sql(s"SELECT ${testCase.trimFunc}(" + + s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" + + (if (testCase.trimString == null) "" else s", COLLATE('${testCase.trimString}', '${testCase.collation}')") + + ")") + } + else { + // While other functions have arguments in (trimStr, srcStr) order + df = sql(s"SELECT ${testCase.trimFunc}(" + + (if (testCase.trimString == null) "" else s"COLLATE('${testCase.trimString}', '${testCase.collation}'), ") + + s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" + + ")") + } + + checkAnswer(df = df, expectedAnswer = Row(testCase.expectedResultString)) + }) + + // scalastyle:on + } + + test("string trim functions with collation - exceptions") { + // scalastyle:off + + // TRIM + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT TRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> "UNICODE", + "collationNameRight" -> "UTF8_BINARY_LCASE", + "sqlExpr" -> "\"TRIM(BOTH collate(trimStr) FROM collate(sourceStr))\"" + ), + context = ExpectedContext(fragment = + "TRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))", + start = 7, stop = 84) + ) + + // BTRIM + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT BTRIM(COLLATE('sourceStr', 'UTF8_BINARY_LCASE'), COLLATE('trimStr', 'UNICODE'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> "UTF8_BINARY_LCASE", + "collationNameRight" -> "UNICODE", + "sqlExpr" -> "\"TRIM(BOTH collate(trimStr) FROM collate(sourceStr))\"" + ), + context = ExpectedContext(fragment = + "BTRIM(COLLATE('sourceStr', 'UTF8_BINARY_LCASE'), COLLATE('trimStr', 'UNICODE'))", + start = 7, stop = 85) + ) + + // LTRIM + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT LTRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> "UNICODE", + "collationNameRight" -> "UTF8_BINARY_LCASE", + "sqlExpr" -> "\"TRIM(LEADING collate(trimStr) FROM collate(sourceStr))\"" + ), + context = ExpectedContext(fragment = + "LTRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))", + start = 7, stop = 85) + ) + + // RTRIM + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT RTRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> "UNICODE", + "collationNameRight" -> "UTF8_BINARY_LCASE", + "sqlExpr" -> "\"TRIM(TRAILING collate(trimStr) FROM collate(sourceStr))\"" + ), + context = ExpectedContext(fragment = + "RTRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))", + start = 7, stop = 85) + ) + + // scalastyle:on + } + + // TODO: Add more tests for other string expressions } class CollationStringExpressionsANSISuite extends CollationRegexpExpressionsSuite { From 4c03363cf2dbe76a48fc2e085a78c82ed8164237 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Thu, 28 Mar 2024 19:11:32 +0100 Subject: [PATCH 5/8] Add doc comments --- .../sql/catalyst/util/CollationFactory.java | 15 ++++-- .../apache/spark/unsafe/types/UTF8String.java | 51 +++++++++++++++++++ 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 09f74a131fc6c..8533e107db831 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -175,12 +175,19 @@ public Collation( * Auxiliary methods for collation aware string operations. */ + /** + * Creates an instance of ICU's StringSearch with provided parameters. + * @param targetUTF8String UTF8String representation of the string to be searched. + * @param patternUTF8String UTF8String representation of the string to search for. + * @param collationId ID of the collation to use. + * @return Created instance of StringSearch. + */ public static StringSearch getStringSearch( - final UTF8String targetString, - final UTF8String patternString, + final UTF8String targetUTF8String, + final UTF8String patternUTF8String, final int collationId) { - String pattern = patternString.toString(); - CharacterIterator target = new StringCharacterIterator(targetString.toString()); + String pattern = patternUTF8String.toString(); + CharacterIterator target = new StringCharacterIterator(targetUTF8String.toString()); Collator collator = CollationFactory.fetchCollation(collationId).collator; return new StringSearch(pattern, target, (RuleBasedCollator) collator); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f7601bf19e1d8..f729f452eb4db 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -585,6 +585,14 @@ public UTF8String trim() { return copyUTF8String(s, e); } + /** + * Trims space characters from both ends of this string - same as {@link UTF8String#trim()}. + * This variant of the method additionally applies provided collation to this string + * and space character before searching. + * + * @param collationId Id of the collation to use. + * @return this string with no spaces at the start or end. + */ public UTF8String trim(int collationId) { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { @@ -638,6 +646,15 @@ public UTF8String trim(UTF8String trimString) { } } + /** + * Trims characters of the given trim string from both ends of this string. + * This variant of the method additionally applies provided collation to this string + * and trim characters before searching. + * + * @param trimString The trim characters string. + * @param collationId Id of the collation to use. + * @return this string with no occurrences of the characters from trim string. + */ public UTF8String trim(UTF8String trimString, int collationId) { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { return trim(trimString); @@ -670,6 +687,14 @@ public UTF8String trimLeft() { return copyUTF8String(s, this.numBytes - 1); } + /** + * Trims space characters from the start of this string - same as {@link UTF8String#trimLeft()}. + * This variant of the method additionally applies provided collation to this string + * and space character before searching. + * + * @param collationId Id of the collation to use. + * @return this string with no spaces at the start. + */ public UTF8String trimLeft(int collationId) { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { @@ -718,6 +743,15 @@ public UTF8String trimLeft(UTF8String trimString) { return copyUTF8String(trimIdx, numBytes - 1); } + /** + * Trims characters of the given trim string from the start of this string. + * This variant of the method additionally applies provided collation to this string + * and trim characters before searching. + * + * @param trimString The trim characters string. + * @param collationId Id of the collation to use. + * @return this string with no occurrences of the trim characters at the start. + */ public UTF8String trimLeft(UTF8String trimString, int collationId) { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { return trimLeft(trimString); @@ -834,6 +868,14 @@ public UTF8String trimRight() { return copyUTF8String(0, e); } + /** + * Trims space characters from the end of this string - same as {@link UTF8String#trimRight()}. + * This variant of the method additionally applies provided collation to this string + * and space character before searching. + * + * @param collationId Id of the collation to use. + * @return this string with no spaces at the end. + */ public UTF8String trimRight(int collationId) { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { @@ -905,6 +947,15 @@ public UTF8String trimRight(UTF8String trimString) { return copyUTF8String(0, trimEnd); } + /** + * Trims characters of the given trim string from the end of this string. + * This variant of the method additionally applies provided collation to this string + * and trim characters before searching. + * + * @param trimString The trim characters string. + * @param collationId Id of the collation to use. + * @return this string with no occurrences of the trim characters at the end. + */ public UTF8String trimRight(UTF8String trimString, int collationId) { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { return trimRight(trimString); From edfa729997b411f53f3769c69e4def43302c02c7 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Mon, 1 Apr 2024 16:36:46 +0200 Subject: [PATCH 6/8] Revert unnecessary changes --- .../catalyst/expressions/stringExpressions.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index eb66403e4e56d..cfb00dedbafba 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1049,7 +1049,6 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { trimStr match { case None => TypeCheckResult.TypeCheckSuccess case Some(trimChars) => - val collationId = srcStr.dataType.asInstanceOf[StringType].collationId CollationTypeConstraints.checkCollationCompatibility(collationId, Seq(trimChars.dataType)) } } @@ -1219,8 +1218,6 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) override protected def direction: String = "BOTH" - override val trimMethod: String = "trim" - override def doEval(srcString: UTF8String): UTF8String = { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { srcString.trim() @@ -1239,6 +1236,8 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) } } + override val trimMethod: String = "trim" + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy( srcStr = newChildren.head, @@ -1340,8 +1339,6 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None override protected def direction: String = "LEADING" - override val trimMethod: String = "trimLeft" - override def doEval(srcString: UTF8String): UTF8String = { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { srcString.trimLeft() @@ -1360,6 +1357,8 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None } } + override val trimMethod: String = "trimLeft" + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimLeft = copy( @@ -1414,8 +1413,6 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non override protected def direction: String = "TRAILING" - override val trimMethod: String = "trimRight" - override def doEval(srcString: UTF8String): UTF8String = { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { srcString.trimRight() @@ -1434,6 +1431,8 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non } } + override val trimMethod: String = "trimRight" + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimRight = copy( From 9dc8f39dba5bfbe57f6e2512111b8dc354e358a9 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Mon, 1 Apr 2024 16:46:18 +0200 Subject: [PATCH 7/8] Simplify code gen --- .../expressions/stringExpressions.scala | 92 +++++++------------ 1 file changed, 34 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index cfb00dedbafba..6454d3f5c0a08 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1070,64 +1070,40 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { val evals = children.map(_.genCode(ctx)) val srcString = evals(0) - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { - if (evals.length == 1) { - ev.copy(code = code""" - |${srcString.code} - |boolean ${ev.isNull} = false; - |UTF8String ${ev.value} = null; - |if (${srcString.isNull}) { - | ${ev.isNull} = true; - |} else { - | ${ev.value} = ${srcString.value}.$trimMethod(); - |}""".stripMargin) - } else { - val trimString = evals(1) - ev.copy(code = code""" - |${srcString.code} - |boolean ${ev.isNull} = false; - |UTF8String ${ev.value} = null; - |if (${srcString.isNull}) { - | ${ev.isNull} = true; - |} else { - | ${trimString.code} - | if (${trimString.isNull}) { - | ${ev.isNull} = true; - | } else { - | ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value}); - | } - |}""".stripMargin) - } - } - else { - if (evals.length == 1) { - ev.copy(code = code""" - |${srcString.code} - |boolean ${ev.isNull} = false; - |UTF8String ${ev.value} = null; - |if (${srcString.isNull}) { - | ${ev.isNull} = true; - |} else { - | ${ev.value} = ${srcString.value}.$trimMethod($collationId); - |}""".stripMargin) - } else { - val trimString = evals(1) - ev.copy(code = code""" - |${srcString.code} - |boolean ${ev.isNull} = false; - |UTF8String ${ev.value} = null; - |if (${srcString.isNull}) { - | ${ev.isNull} = true; - |} else { - | ${trimString.code} - | if (${trimString.isNull}) { - | ${ev.isNull} = true; - | } else { - | ${ev.value} = - | ${srcString.value}.$trimMethod(${trimString.value}, $collationId); - | } - |}""".stripMargin) - } + if (evals.length == 1) { + val collationIdStr = + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) "" + else collationId + + ev.copy(code = code""" + |${srcString.code} + |boolean ${ev.isNull} = false; + |UTF8String ${ev.value} = null; + |if (${srcString.isNull}) { + | ${ev.isNull} = true; + |} else { + | ${ev.value} = ${srcString.value}.$trimMethod($collationIdStr); + |}""".stripMargin) + } else { + val trimString = evals(1) + val collationIdStr = + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) "" + else ", " + collationId + + ev.copy(code = code""" + |${srcString.code} + |boolean ${ev.isNull} = false; + |UTF8String ${ev.value} = null; + |if (${srcString.isNull}) { + | ${ev.isNull} = true; + |} else { + | ${trimString.code} + | if (${trimString.isNull}) { + | ${ev.isNull} = true; + | } else { + | ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value}$collationIdStr); + | } + |}""".stripMargin) } } From 337878f41b4b012d46a7488b570f87e89af06f68 Mon Sep 17 00:00:00 2001 From: David Milicevic Date: Tue, 2 Apr 2024 09:25:05 +0200 Subject: [PATCH 8/8] Addressing comments --- .../apache/spark/unsafe/types/UTF8String.java | 29 +++++++------------ .../expressions/stringExpressions.scala | 18 ++++-------- 2 files changed, 17 insertions(+), 30 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f729f452eb4db..42000c07aad43 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -597,8 +597,7 @@ public UTF8String trim(int collationId) { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { return trim(); - } - else { + } else { return trim(UTF8String.fromString(" "), collationId); } } @@ -699,8 +698,7 @@ public UTF8String trimLeft(int collationId) { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { return trimLeft(); - } - else { + } else { return trimLeft(UTF8String.fromString(" "), collationId); } } @@ -761,7 +759,7 @@ public UTF8String trimLeft(UTF8String trimString, int collationId) { return lowercaseTrimLeft(trimString); } - return collatedTrimLeft(trimString, collationId); + return collationAwareTrimLeft(trimString, collationId); } private UTF8String lowercaseTrimLeft(UTF8String trimString) { @@ -787,8 +785,7 @@ private UTF8String lowercaseTrimLeft(UTF8String trimString) { if (trimString.find(searchChar.toLowerCase(), 0) >= 0) { trimByteIdx += searchCharBytes; searchIdx += searchCharBytes; - } - else { + } else { // No matching, exit the search break; } @@ -805,7 +802,7 @@ private UTF8String lowercaseTrimLeft(UTF8String trimString) { return copyUTF8String(trimByteIdx, numBytes - 1); } - private UTF8String collatedTrimLeft(UTF8String trimString, int collationId) { + private UTF8String collationAwareTrimLeft(UTF8String trimString, int collationId) { if (trimString == null) { return null; } @@ -830,8 +827,7 @@ private UTF8String collatedTrimLeft(UTF8String trimString, int collationId) { && stringSearch.getMatchLength() == stringSearch.getPattern().length()) { trimByteIdx += searchCharBytes; searchIdx += searchCharBytes; - } - else { + } else { // No matching, exit the search break; } @@ -880,8 +876,7 @@ public UTF8String trimRight(int collationId) { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { return trimRight(); - } - else { + } else { return trimRight(UTF8String.fromString(" "), collationId); } } @@ -965,7 +960,7 @@ public UTF8String trimRight(UTF8String trimString, int collationId) { return lowercaseTrimRight(trimString); } - return collatedTrimRight(trimString, collationId); + return collationAwareTrimRight(trimString, collationId); } private UTF8String lowercaseTrimRight(UTF8String trimString) { @@ -1006,8 +1001,7 @@ private UTF8String lowercaseTrimRight(UTF8String trimString) { if (trimString.find(searchChar.toLowerCase(), 0) >= 0) { trimByteIdx -= stringCharLen[numChars - 1]; numChars--; - } - else { + } else { break; } } @@ -1023,7 +1017,7 @@ private UTF8String lowercaseTrimRight(UTF8String trimString) { return copyUTF8String(0, trimByteIdx); } - private UTF8String collatedTrimRight(UTF8String trimString, int collationId) { + private UTF8String collationAwareTrimRight(UTF8String trimString, int collationId) { if (trimString == null) { return null; } @@ -1063,8 +1057,7 @@ private UTF8String collatedTrimRight(UTF8String trimString, int collationId) { && stringSearch.getMatchLength() == stringSearch.getPattern().length()) { trimByteIdx -= stringCharLen[numChars - 1]; numChars--; - } - else { + } else { break; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 6454d3f5c0a08..1415affaf9c86 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1197,8 +1197,7 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) override def doEval(srcString: UTF8String): UTF8String = { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { srcString.trim() - } - else { + } else { srcString.trim(collationId) } } @@ -1206,8 +1205,7 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { srcString.trim(trimString) - } - else { + } else { srcString.trim(trimString, collationId) } } @@ -1318,8 +1316,7 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None override def doEval(srcString: UTF8String): UTF8String = { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { srcString.trimLeft() - } - else { + } else { srcString.trimLeft(collationId) } } @@ -1327,8 +1324,7 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { srcString.trimLeft(trimString) - } - else { + } else { srcString.trimLeft(trimString, collationId) } } @@ -1392,8 +1388,7 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non override def doEval(srcString: UTF8String): UTF8String = { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { srcString.trimRight() - } - else { + } else { srcString.trimRight(collationId) } } @@ -1401,8 +1396,7 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = { if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { srcString.trimRight(trimString) - } - else { + } else { srcString.trimRight(trimString, collationId) } }