diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9c862581bfe47..bddaf574cd3ab 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1987,15 +1987,15 @@ case class Right(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = If( IsNull(str), - Literal(null, StringType), + Literal(null, str.dataType), If( LessThanOrEqual(len, Literal(0)), - Literal(UTF8String.EMPTY_UTF8, StringType), + Literal(UTF8String.EMPTY_UTF8, str.dataType), new Substring(str, UnaryMinus(len)) ) ) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, IntegerType) override def left: Expression = str override def right: Expression = len override protected def withNewChildrenInternal( @@ -2026,7 +2026,7 @@ case class Left(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = Substring(str, Literal(1), len) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(StringType, BinaryType), IntegerType) + Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType) } override def left: Expression = str diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 97dea66975410..dde97751d0135 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.Seq import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{BooleanType, StringType} @@ -163,6 +163,144 @@ class CollationStringExpressionsSuite }) } + test("substring check output type on explicitly collated string") { + case class SubstringTestCase[R](args: Seq[String], collation: String, result: R) + val checks = Seq( + SubstringTestCase(Seq("Spark", "2"), "UTF8_BINARY", "park"), + SubstringTestCase(Seq("Spark", "2"), "UTF8_BINARY_LCASE", "park") + ) + checks.foreach(ct => { + val query = s"SELECT substr(collate('${ct.args.head}', '${ct.collation}')," + + s" ${ct.args.tail.head})" + // Result & data type + checkAnswer(sql(query), Row(ct.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(ct.collation))) + }) + } + + test("left/right/substr on implicitly collated string returns proper value and type") { + case class QTestCase(query: String, collation: String, result: Row) + val longString = "In the course of human events" + val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( + c => Seq( + QTestCase(s"select left(left('$longString' collate " + c + ", 5), 1)", c, Row("I")), + QTestCase(s"select right(right('$longString' collate " + c + ", 5), 1)", c, Row("s")), + QTestCase( + s"select substr(substr('$longString' collate " + c + ", 4), 2)", c, + Row("he course of human events")) + ) + ) + + checks.foreach { check => + // Result & data type + checkAnswer(sql(check.query), check.result) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + } + } + + test("left/right/substr on explicitly collated proper string returns proper value and type") { + case class QTestCase(query: String, collation: String, result: Row) + val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( + c => Seq( + QTestCase("select left('abc' collate " + c + ", 1)", c, Row("a")), + QTestCase("select right('def' collate " + c + ", 1)", c, Row("f")), + QTestCase("select substr('abc' collate " + c + ", 2)", c, Row("bc")), + QTestCase("select substr('example' collate " + c + ", 0, 2)", c, Row("ex")), + QTestCase("select substr('example' collate " + c + ", 1, 2)", c, Row("ex")), + QTestCase("select substr('example' collate " + c + ", 0, 7)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 1, 7)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 0, 100)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 1, 100)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 2, 2)", c, Row("xa")), + QTestCase("select substr('example' collate " + c + ", 1, 6)", c, Row("exampl")), + QTestCase("select substr('example' collate " + c + ", 2, 100)", c, Row("xample")), + QTestCase("select substr('example' collate " + c + ", 0, 0)", c, Row("")), + QTestCase("select substr('example' collate " + c + ", 100, 4)", c, Row("")), + QTestCase("select substr('example' collate " + c + ", 0, 100)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 1, 100)", c, Row("example")), + QTestCase("select substr('example' collate " + c + ", 2, 100)", c, Row("xample")), + QTestCase("select substr('example' collate " + c + ", -3, 2)", c, Row("pl")), + QTestCase("select substr('example' collate " + c + ", -100, 4)", c, Row("")), + QTestCase("select substr('example' collate " + c + ", -2147483648, 6)", c, Row("")), + QTestCase("select substr(' a世a ' collate " + c + ", 2, 3)", c, Row("a世a")), // scalastyle:ignore + QTestCase("select left(' a世a ' collate " + c + ", 3)", c, Row(" a世")), // scalastyle:ignore + QTestCase("select right(' a世a ' collate " + c + ", 3)", c, Row("世a ")), // scalastyle:ignore + QTestCase("select substr('AaAaAaAa000000' collate " + c + ", 2, 3)", c, Row("aAa")), + QTestCase("select left('AaAaAaAa000000' collate " + c + ", 3)", c, Row("AaA")), + QTestCase("select right('AaAaAaAa000000' collate " + c + ", 3)", c, Row("000")), + QTestCase("select substr('' collate " + c + ", 1, 1)", c, Row("")), + QTestCase("select left('' collate " + c + ", 1)", c, Row("")), + QTestCase("select right('' collate " + c + ", 1)", c, Row("")), + QTestCase("select left('ghi' collate " + c + ", 1)", c, Row("g")) + ) + ) + + checks.foreach { check => + // Result & data type + checkAnswer(sql(check.query), check.result) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + } + } + + test("left/right/substr on explicitly collated improper string returns proper value and type") { + case class QTestCase(query: String, collation: String, result: Row) + val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( + c => Seq( + QTestCase("select left(null collate " + c + ", 1)", c, Row(null)), + QTestCase("select right(null collate " + c + ", 1)", c, Row(null)), + QTestCase("select substr(null collate " + c + ", 1)", c, Row(null)), + QTestCase("select substr(null collate " + c + ", 1, 1)", c, Row(null)), + QTestCase("select left('' collate " + c + ", null)", c, Row(null)), + QTestCase("select right('' collate " + c + ", null)", c, Row(null)), + QTestCase("select substr('' collate " + c + ", null)", c, Row(null)), + QTestCase("select substr('' collate " + c + ", null, null)", c, Row(null)) + ) + ) + checks.foreach(check => { + // Result & data type + checkAnswer(sql(check.query), check.result) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + }) + } + + test("left/right/substr on explicitly collated improper length & position") { + case class QTestCase(query: String, collation: String, result: Row) + val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci").flatMap( + c => Seq( + QTestCase("select left(' a世a ' collate " + c + ", '3')", c, Row(" a世")), // scalastyle:ignore + QTestCase("select right(' a世a ' collate " + c + ", '3')", c, Row("世a ")), // scalastyle:ignore + QTestCase("select right('' collate " + c + ", null)", c, Row(null)), + QTestCase("select substr('' collate " + c + ", null)", c, Row(null)), + QTestCase("select substr('' collate " + c + ", null, null)", c, Row(null)), + QTestCase("select left('' collate " + c + ", null)", c, Row(null)) + ) + ) + checks.foreach(check => { + // Result & data type + checkAnswer(sql(check.query), check.result) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + }) + } + + test("left/right/substr on session-collated string returns proper and value type") { + case class QTestCase(query: String, collation: String, result: Row) + val checks = Seq("utf8_binary_lcase", "utf8_binary", "unicode", "unicode_ci") + .flatMap { c => + Seq( + QTestCase("select left('abc', 1)", c, Row("a")), + QTestCase("select right('def', 1)", c, Row("f")), + QTestCase("select substr('ghi', 1)", c, Row("ghi")) + ) + } + checks.foreach(check => { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> check.collation) { + // Result & data type + checkAnswer(sql(check.query), check.result) + assert(sql(check.query).schema.fields.head.dataType.sameType(StringType(check.collation))) + } + }) + } + // TODO: Add more tests for other string expressions }