From 614c10856da66643b7c1ceb8bea58db2fdec9af1 Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Fri, 22 Mar 2024 20:34:16 +0100 Subject: [PATCH 1/8] Support collations in complex types operations --- .../expressions/collectionOperations.scala | 33 ++++++++++++------- .../catalyst/util/ArrayBasedMapBuilder.scala | 4 ++- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 27225b4ac74a8..d52cf105d27b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1351,7 +1351,8 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { // Input types are utilized by type coercion in ImplicitTypeCasts. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringTypeAnyCollation, ArrayType)) override def dataType: DataType = child.dataType @@ -1365,7 +1366,7 @@ case class Reverse(child: Expression) val arrayData = input.asInstanceOf[ArrayData] new GenericArrayData(arrayData.toObjectArray(elementType).reverse) } - case StringType => _.asInstanceOf[UTF8String].reverse() + case _: StringType => _.asInstanceOf[UTF8String].reverse() } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -1994,17 +1995,26 @@ case class Slice(x: Expression, start: Expression, length: Expression) case class ArrayJoin( array: Expression, delimiter: Expression, - nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes { + nullReplacement: Option[Expression]) extends ImplicitCastInputTypes with ExpectsInputTypes { def this(array: Expression, delimiter: Expression) = this(array, delimiter, None) def this(array: Expression, delimiter: Expression, nullReplacement: Expression) = this(array, delimiter, Some(nullReplacement)) + private val commonType: DataType = { + val elementType = array.dataType.asInstanceOf[ArrayType].elementType + val s = nullReplacement match { + case Some(replacement) => Seq(elementType, delimiter.dataType, replacement.dataType) + case _ => Seq(elementType, delimiter.dataType) + } + TypeCoercion.findWiderCommonType(s).getOrElse(StringType) + } + override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { - Seq(ArrayType(StringType), StringType, StringType) + Seq(ArrayType(commonType), commonType, commonType) } else { - Seq(ArrayType(StringType), StringType) + Seq(ArrayType(commonType), commonType) } override def children: Seq[Expression] = if (nullReplacement.isDefined) { @@ -2047,7 +2057,7 @@ case class ArrayJoin( } case None => (_: Boolean) => false } - arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => { + arrayEval.asInstanceOf[ArrayData].foreach(commonType, (_, item) => { if (item == null) { if (nullHandling(firstItem)) { firstItem = false @@ -2127,7 +2137,7 @@ case class ArrayJoin( | if (!$firstItem) { | $buffer.append(${delimiterGen.value}); | } - | $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)}); + | $buffer.append(${CodeGenerator.getValue(arrayGen.value, commonType, i)}); | $firstItem = false; | } |} @@ -2149,7 +2159,7 @@ case class ArrayJoin( } } - override def dataType: DataType = StringType + override def dataType: DataType = commonType override def prettyName: String = "array_join" } @@ -2724,7 +2734,8 @@ case class TryElementAt(left: Expression, right: Expression, replacement: Expres case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression with QueryErrorsBase { - private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) + private def allowedTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, BinaryType, ArrayType) final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT) @@ -2774,7 +2785,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) ByteArray.concat(inputs: _*) } - case StringType => + case _: StringType => input => { val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) UTF8String.concat(inputs: _*) @@ -2845,7 +2856,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val (concat, initCode) = dataType match { case BinaryType => (s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];") - case StringType => + case _: StringType => ("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];") case ArrayType(elementType, containsNull) => val concat = genCodeForArrays(ctx, elementType, containsNull) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index d358c92dd62c7..98e1001ea979f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -36,7 +36,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria private lazy val keyToIndex = keyType match { // Binary type data is `byte[]`, which can't use `==` to check equality. case _: AtomicType | _: CalendarIntervalType | _: NullType - if !keyType.isInstanceOf[BinaryType] => new java.util.HashMap[Any, Int]() + if !keyType.isInstanceOf[BinaryType] && (!keyType.isInstanceOf[StringType] || + keyType.asInstanceOf[StringType].isBinaryCollation) => + new java.util.HashMap[Any, Int]() case _ => // for complex types, use interpreted ordering to be able to compare unsafe data with safe // data, e.g. UnsafeRow vs GenericInternalRow. From e4c589b9422df86500e75a5d4f30c7012d9893fd Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 25 Mar 2024 09:51:46 +0100 Subject: [PATCH 2/8] Add test --- .../org/apache/spark/sql/CollationSuite.scala | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 146ba63cf402a..77bc342089aa7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -725,6 +725,47 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + test("Support operations on complex types containing collated strings") { + checkAnswer(sql("select reverse('abc' collate utf8_binary_lcase)"), Seq(Row("cba"))) + checkAnswer(sql( + """ + |select reverse(array('a' collate utf8_binary_lcase, + |'b' collate utf8_binary_lcase)) + |""".stripMargin), Seq(Row(Seq("b", "a")))) + checkAnswer(sql( + """ + |select array_join(array('a' collate utf8_binary_lcase, + |'b' collate utf8_binary_lcase), ', ' collate utf8_binary_lcase) + |""".stripMargin), Seq(Row("a, b"))) + checkAnswer(sql( + """ + |select concat('a' collate utf8_binary_lcase, 'b' collate utf8_binary_lcase) + |""".stripMargin), Seq(Row("ab"))) + checkAnswer(sql( + """ + |select concat(array('a' collate utf8_binary_lcase, 'b' collate utf8_binary_lcase)) + |""".stripMargin), Seq(Row(Seq("a", "b")))) + val ctx = "map('aaa' collate utf8_binary_lcase, 1, 'AAA' collate utf8_binary_lcase, 2)['AaA']" + val query = s"select $ctx" + checkError( + exception = intercept[AnalysisException](sql(query)), + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"map(collate(aaa), 1, collate(AAA), 2)[AaA]\"", + "paramIndex" -> "second", + "inputSql" -> "\"AaA\"", + "inputType" -> toSQLType(StringType), + "requiredType" -> toSQLType(StringType( + CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) + ), + context = ExpectedContext( + fragment = ctx, + start = query.length - ctx.length, + stop = query.length - 1 + ) + ) + } + test("window aggregates should respect collation") { val t1 = "T_NON_BINARY" val t2 = "T_BINARY" From b027d8bc683e9dd9206e75db84c4f9fe629c8deb Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 25 Mar 2024 09:59:22 +0100 Subject: [PATCH 3/8] Add checks --- .../scala/org/apache/spark/sql/CollationSuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 77bc342089aa7..4b2331da468ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -737,6 +737,12 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { |select array_join(array('a' collate utf8_binary_lcase, |'b' collate utf8_binary_lcase), ', ' collate utf8_binary_lcase) |""".stripMargin), Seq(Row("a, b"))) + checkAnswer(sql( + """ + |select array_join(array('a' collate utf8_binary_lcase, + |'b' collate utf8_binary_lcase, null), ', ' collate utf8_binary_lcase, + |'c' collate utf8_binary_lcase) + |""".stripMargin), Seq(Row("a, b, c"))) checkAnswer(sql( """ |select concat('a' collate utf8_binary_lcase, 'b' collate utf8_binary_lcase) @@ -745,6 +751,11 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { """ |select concat(array('a' collate utf8_binary_lcase, 'b' collate utf8_binary_lcase)) |""".stripMargin), Seq(Row(Seq("a", "b")))) + checkAnswer(sql( + """ + |select map('a' collate utf8_binary_lcase, 1, 'b' collate utf8_binary_lcase, 2) + |['A' collate utf8_binary_lcase] + |""".stripMargin), Seq(Row(1))) val ctx = "map('aaa' collate utf8_binary_lcase, 1, 'AAA' collate utf8_binary_lcase, 2)['AaA']" val query = s"select $ctx" checkError( From 49b7dbce68a4620f552a595e180831c9d678e25d Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 25 Mar 2024 12:14:18 +0100 Subject: [PATCH 4/8] Cast fixes in array_join --- .../expressions/collectionOperations.scala | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d52cf105d27b1..006df4e16f927 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1995,26 +1995,23 @@ case class Slice(x: Expression, start: Expression, length: Expression) case class ArrayJoin( array: Expression, delimiter: Expression, - nullReplacement: Option[Expression]) extends ImplicitCastInputTypes with ExpectsInputTypes { + nullReplacement: Option[Expression]) extends ImplicitCastInputTypes { def this(array: Expression, delimiter: Expression) = this(array, delimiter, None) def this(array: Expression, delimiter: Expression, nullReplacement: Expression) = this(array, delimiter, Some(nullReplacement)) - private val commonType: DataType = { - val elementType = array.dataType.asInstanceOf[ArrayType].elementType - val s = nullReplacement match { - case Some(replacement) => Seq(elementType, delimiter.dataType, replacement.dataType) - case _ => Seq(elementType, delimiter.dataType) + override def inputTypes: Seq[AbstractDataType] = { + val arrayType = array.dataType.asInstanceOf[ArrayType].elementType match { + case _: StringType => array.dataType + case _ => ArrayType(StringType) + } + if (nullReplacement.isDefined) { + Seq(arrayType, StringTypeAnyCollation, StringTypeAnyCollation) + } else { + Seq(arrayType, StringTypeAnyCollation) } - TypeCoercion.findWiderCommonType(s).getOrElse(StringType) - } - - override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { - Seq(ArrayType(commonType), commonType, commonType) - } else { - Seq(ArrayType(commonType), commonType) } override def children: Seq[Expression] = if (nullReplacement.isDefined) { @@ -2057,7 +2054,7 @@ case class ArrayJoin( } case None => (_: Boolean) => false } - arrayEval.asInstanceOf[ArrayData].foreach(commonType, (_, item) => { + arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => { if (item == null) { if (nullHandling(firstItem)) { firstItem = false @@ -2137,7 +2134,7 @@ case class ArrayJoin( | if (!$firstItem) { | $buffer.append(${delimiterGen.value}); | } - | $buffer.append(${CodeGenerator.getValue(arrayGen.value, commonType, i)}); + | $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)}); | $firstItem = false; | } |} @@ -2159,7 +2156,7 @@ case class ArrayJoin( } } - override def dataType: DataType = commonType + override def dataType: DataType = array.dataType.asInstanceOf[ArrayType].elementType override def prettyName: String = "array_join" } From af5b91087ebb00d5e2d2de56bef76d1be38efb86 Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 25 Mar 2024 16:08:06 +0100 Subject: [PATCH 5/8] Fix error messages and return to expected type in array_join --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- .../sql/catalyst/expressions/StringExpressionsSuite.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 006df4e16f927..231940c67a342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1995,7 +1995,7 @@ case class Slice(x: Expression, start: Expression, length: Expression) case class ArrayJoin( array: Expression, delimiter: Expression, - nullReplacement: Option[Expression]) extends ImplicitCastInputTypes { + nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes { def this(array: Expression, delimiter: Expression) = this(array, delimiter, None) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 1fbd1ac9a29fd..cda9676ca58b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -70,7 +70,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( "paramIndex" -> ordinalNumber(0), - "requiredType" -> "(\"STRING\" or \"BINARY\" or \"ARRAY\")", + "requiredType" -> "(\"STRING_ANY_COLLATION\" or \"BINARY\" or \"ARRAY\")", "inputSql" -> "\"1\"", "inputType" -> "\"INT\"" ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e42f397cbfc29..78cadfa72ec07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1987,7 +1987,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "first", "inputSql" -> "\"struct(1, a)\"", "inputType" -> "\"STRUCT\"", - "requiredType" -> "(\"STRING\" or \"ARRAY\")" + "requiredType" -> "(\"STRING_ANY_COLLATION\" or \"ARRAY\")" ), queryContext = Array(ExpectedContext("", "", 7, 29, "reverse(struct(1, 'a'))")) ) @@ -2526,7 +2526,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"concat(i1, i2, NULL)\"", "functionName" -> "`concat`", - "dataType" -> "(\"ARRAY\" or \"ARRAY\" or \"STRING\")" + "dataType" -> "(\"ARRAY\" or \"ARRAY\" or \"STRING_ANY_COLLATION\")" ), queryContext = Array(ExpectedContext("", "", 0, 19, "concat(i1, i2, null)")) ) From 8a3b32854dca66f66dfd4dc83fa7126671e69930 Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 25 Mar 2024 16:28:43 +0100 Subject: [PATCH 6/8] Fixes --- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 78cadfa72ec07..5beac33703586 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1713,7 +1713,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "second", "inputSql" -> "\"1\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRING\"" + "requiredType" -> "\"STRING_ANY_COLLATION\"" ), queryContext = Array(ExpectedContext("", "", 0, 15, "array_join(x, 1)")) ) @@ -1727,7 +1727,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "third", "inputSql" -> "\"1\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRING\"" + "requiredType" -> "\"STRING_ANY_COLLATION\"" ), queryContext = Array(ExpectedContext("", "", 0, 21, "array_join(x, ', ', 1)")) ) @@ -2002,7 +2002,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "first", "inputSql" -> "\"map(1, a)\"", "inputType" -> "\"MAP\"", - "requiredType" -> "(\"STRING\" or \"ARRAY\")" + "requiredType" -> "(\"STRING_ANY_COLLATION\" or \"ARRAY\")" ), queryContext = Array(ExpectedContext("", "", 7, 26, "reverse(map(1, 'a'))")) ) @@ -2526,7 +2526,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"concat(i1, i2, NULL)\"", "functionName" -> "`concat`", - "dataType" -> "(\"ARRAY\" or \"ARRAY\" or \"STRING_ANY_COLLATION\")" + "dataType" -> "(\"ARRAY\" or \"ARRAY\" or \"STRING\")" ), queryContext = Array(ExpectedContext("", "", 0, 19, "concat(i1, i2, null)")) ) @@ -2552,7 +2552,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"concat(map(1, 2), map(3, 4))\"", "paramIndex" -> "first", - "requiredType" -> "(\"STRING\" or \"BINARY\" or \"ARRAY\")", + "requiredType" -> "(\"STRING_ANY_COLLATION\" or \"BINARY\" or \"ARRAY\")", "inputSql" -> "\"map(1, 2)\"", "inputType" -> "\"MAP\"" ), From 7d3a5bdc9762afb943f6241ba61b0803754a82c0 Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Tue, 2 Apr 2024 13:35:31 +0200 Subject: [PATCH 7/8] Improve ArrayBasedMapBuilder keyToIndex --- .../catalyst/util/ArrayBasedMapBuilder.scala | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index 7024185379db6..136e8824569e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -34,16 +34,22 @@ import org.apache.spark.unsafe.array.ByteArrayMethods class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Serializable { assert(!keyType.existsRecursively(_.isInstanceOf[MapType]), "key of map cannot be/contain map") - private lazy val keyToIndex = keyType match { - // Binary type data is `byte[]`, which can't use `==` to check equality. - case _: AtomicType | _: CalendarIntervalType | _: NullType - if !keyType.isInstanceOf[BinaryType] && (!keyType.isInstanceOf[StringType] || - keyType.asInstanceOf[StringType].isBinaryCollation) => - new java.util.HashMap[Any, Int]() - case _ => - // for complex types, use interpreted ordering to be able to compare unsafe data with safe - // data, e.g. UnsafeRow vs GenericInternalRow. - new java.util.TreeMap[Any, Int](TypeUtils.getInterpretedOrdering(keyType)) + private lazy val keyToIndex = { + def hashMap = new java.util.HashMap[Any, Int]() + def treeMap = new java.util.TreeMap[Any, Int](TypeUtils.getInterpretedOrdering(keyType)) + + keyType match { + // StringType binary equality support implies hashing support + case s: StringType if s.supportsBinaryEquality => hashMap + case _: StringType => treeMap + // Binary type data is `byte[]`, which can't use `==` to check equality. + case _: BinaryType => treeMap + case _: AtomicType | _: CalendarIntervalType | _: NullType => hashMap + case _ => + // for complex types, use interpreted ordering to be able to compare unsafe data with safe + // data, e.g. UnsafeRow vs GenericInternalRow. + treeMap + } } // TODO: specialize it From 231d28fe1a4fcb507d9166dcda6bc27b1dbf80a5 Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Wed, 3 Apr 2024 12:51:16 +0200 Subject: [PATCH 8/8] Update ArrayJoin inputTypes --- .../expressions/collectionOperations.scala | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 231940c67a342..f331a489968aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2002,16 +2002,10 @@ case class ArrayJoin( def this(array: Expression, delimiter: Expression, nullReplacement: Expression) = this(array, delimiter, Some(nullReplacement)) - override def inputTypes: Seq[AbstractDataType] = { - val arrayType = array.dataType.asInstanceOf[ArrayType].elementType match { - case _: StringType => array.dataType - case _ => ArrayType(StringType) - } - if (nullReplacement.isDefined) { - Seq(arrayType, StringTypeAnyCollation, StringTypeAnyCollation) - } else { - Seq(arrayType, StringTypeAnyCollation) - } + override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { + Seq(ArrayType, StringTypeAnyCollation, StringTypeAnyCollation) + } else { + Seq(ArrayType, StringTypeAnyCollation) } override def children: Seq[Expression] = if (nullReplacement.isDefined) {