diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 27225b4ac74a8..f331a489968aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1351,7 +1351,8 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { // Input types are utilized by type coercion in ImplicitTypeCasts. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringTypeAnyCollation, ArrayType)) override def dataType: DataType = child.dataType @@ -1365,7 +1366,7 @@ case class Reverse(child: Expression) val arrayData = input.asInstanceOf[ArrayData] new GenericArrayData(arrayData.toObjectArray(elementType).reverse) } - case StringType => _.asInstanceOf[UTF8String].reverse() + case _: StringType => _.asInstanceOf[UTF8String].reverse() } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -2002,9 +2003,9 @@ case class ArrayJoin( this(array, delimiter, Some(nullReplacement)) override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { - Seq(ArrayType(StringType), StringType, StringType) + Seq(ArrayType, StringTypeAnyCollation, StringTypeAnyCollation) } else { - Seq(ArrayType(StringType), StringType) + Seq(ArrayType, StringTypeAnyCollation) } override def children: Seq[Expression] = if (nullReplacement.isDefined) { @@ -2149,7 +2150,7 @@ case class ArrayJoin( } } - override def dataType: DataType = StringType + override def dataType: DataType = array.dataType.asInstanceOf[ArrayType].elementType override def prettyName: String = "array_join" } @@ -2724,7 +2725,8 @@ case class TryElementAt(left: Expression, right: Expression, replacement: Expres case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression with QueryErrorsBase { - private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) + private def allowedTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, BinaryType, ArrayType) final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT) @@ -2774,7 +2776,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) ByteArray.concat(inputs: _*) } - case StringType => + case _: StringType => input => { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) UTF8String.concat(inputs: _*) @@ -2845,7 +2847,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val (concat, initCode) = dataType match { case BinaryType => (s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];") - case StringType => + case _: StringType => ("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];") case ArrayType(elementType, containsNull) => val concat = genCodeForArrays(ctx, elementType, containsNull) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index a2d41ebf04e17..136e8824569e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -34,14 +34,22 @@ import org.apache.spark.unsafe.array.ByteArrayMethods class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Serializable { assert(!keyType.existsRecursively(_.isInstanceOf[MapType]), "key of map cannot be/contain map") - private lazy val keyToIndex = keyType match { - // Binary type data is `byte[]`, which can't use `==` to check equality. - case _: AtomicType | _: CalendarIntervalType | _: NullType - if !keyType.isInstanceOf[BinaryType] => new java.util.HashMap[Any, Int]() - case _ => - // for complex types, use interpreted ordering to be able to compare unsafe data with safe - // data, e.g. UnsafeRow vs GenericInternalRow. - new java.util.TreeMap[Any, Int](TypeUtils.getInterpretedOrdering(keyType)) + private lazy val keyToIndex = { + def hashMap = new java.util.HashMap[Any, Int]() + def treeMap = new java.util.TreeMap[Any, Int](TypeUtils.getInterpretedOrdering(keyType)) + + keyType match { + // StringType binary equality support implies hashing support + case s: StringType if s.supportsBinaryEquality => hashMap + case _: StringType => treeMap + // Binary type data is `byte[]`, which can't use `==` to check equality. + case _: BinaryType => treeMap + case _: AtomicType | _: CalendarIntervalType | _: NullType => hashMap + case _ => + // for complex types, use interpreted ordering to be able to compare unsafe data with safe + // data, e.g. UnsafeRow vs GenericInternalRow. + treeMap + } } // TODO: specialize it diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 1fbd1ac9a29fd..cda9676ca58b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -70,7 +70,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( "paramIndex" -> ordinalNumber(0), - "requiredType" -> "(\"STRING\" or \"BINARY\" or \"ARRAY\")", + "requiredType" -> "(\"STRING_ANY_COLLATION\" or \"BINARY\" or \"ARRAY\")", "inputSql" -> "\"1\"", "inputType" -> "\"INT\"" ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 99c0dbfcb1448..f135d8d0234fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -800,6 +800,58 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + test("Support operations on complex types containing collated strings") { + checkAnswer(sql("select reverse('abc' collate utf8_binary_lcase)"), Seq(Row("cba"))) + checkAnswer(sql( + """ + |select reverse(array('a' collate utf8_binary_lcase, + |'b' collate utf8_binary_lcase)) + |""".stripMargin), Seq(Row(Seq("b", "a")))) + checkAnswer(sql( + """ + |select array_join(array('a' collate utf8_binary_lcase, + |'b' collate utf8_binary_lcase), ', ' collate utf8_binary_lcase) + |""".stripMargin), Seq(Row("a, b"))) + checkAnswer(sql( + """ + |select array_join(array('a' collate utf8_binary_lcase, + |'b' collate utf8_binary_lcase, null), ', ' collate utf8_binary_lcase, + |'c' collate utf8_binary_lcase) + |""".stripMargin), Seq(Row("a, b, c"))) + checkAnswer(sql( + """ + |select concat('a' collate utf8_binary_lcase, 'b' collate utf8_binary_lcase) + |""".stripMargin), Seq(Row("ab"))) + checkAnswer(sql( + """ + |select concat(array('a' collate utf8_binary_lcase, 'b' collate utf8_binary_lcase)) + |""".stripMargin), Seq(Row(Seq("a", "b")))) + checkAnswer(sql( + """ + |select map('a' collate utf8_binary_lcase, 1, 'b' collate utf8_binary_lcase, 2) + |['A' collate utf8_binary_lcase] + |""".stripMargin), Seq(Row(1))) + val ctx = "map('aaa' collate utf8_binary_lcase, 1, 'AAA' collate utf8_binary_lcase, 2)['AaA']" + val query = s"select $ctx" + checkError( + exception = intercept[AnalysisException](sql(query)), + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"map(collate(aaa), 1, collate(AAA), 2)[AaA]\"", + "paramIndex" -> "second", + "inputSql" -> "\"AaA\"", + "inputType" -> toSQLType(StringType), + "requiredType" -> toSQLType(StringType( + CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) + ), + context = ExpectedContext( + fragment = ctx, + start = query.length - ctx.length, + stop = query.length - 1 + ) + ) + } + test("window aggregates should respect collation") { val t1 = "T_NON_BINARY" val t2 = "T_BINARY" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e42f397cbfc29..5beac33703586 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1713,7 +1713,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "second", "inputSql" -> "\"1\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRING\"" + "requiredType" -> "\"STRING_ANY_COLLATION\"" ), queryContext = Array(ExpectedContext("", "", 0, 15, "array_join(x, 1)")) ) @@ -1727,7 +1727,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "third", "inputSql" -> "\"1\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRING\"" + "requiredType" -> "\"STRING_ANY_COLLATION\"" ), queryContext = Array(ExpectedContext("", "", 0, 21, "array_join(x, ', ', 1)")) ) @@ -1987,7 +1987,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "first", "inputSql" -> "\"struct(1, a)\"", "inputType" -> "\"STRUCT\"", - "requiredType" -> "(\"STRING\" or \"ARRAY\")" + "requiredType" -> "(\"STRING_ANY_COLLATION\" or \"ARRAY\")" ), queryContext = Array(ExpectedContext("", "", 7, 29, "reverse(struct(1, 'a'))")) ) @@ -2002,7 +2002,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "first", "inputSql" -> "\"map(1, a)\"", "inputType" -> "\"MAP\"", - "requiredType" -> "(\"STRING\" or \"ARRAY\")" + "requiredType" -> "(\"STRING_ANY_COLLATION\" or \"ARRAY\")" ), queryContext = Array(ExpectedContext("", "", 7, 26, "reverse(map(1, 'a'))")) ) @@ -2552,7 +2552,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"concat(map(1, 2), map(3, 4))\"", "paramIndex" -> "first", - "requiredType" -> "(\"STRING\" or \"BINARY\" or \"ARRAY\")", + "requiredType" -> "(\"STRING_ANY_COLLATION\" or \"BINARY\" or \"ARRAY\")", "inputSql" -> "\"map(1, 2)\"", "inputType" -> "\"MAP\"" ),