diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index cf9796ef1948f..770d9695ee8bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1046,8 +1046,9 @@ trait ArraySortLike extends ExpectsInputTypes { } else { s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" } - val nonNullPrimitiveAscendingSort = - if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) { + val canPerformFastSort = + CodeGenerator.isPrimitiveType(elementType) && elementType != BooleanType && !containsNull + val nonNullPrimitiveAscendingSort = if (canPerformFastSort) { val javaType = CodeGenerator.javaType(elementType) val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7b345aabd19c8..31e0a9b6a8ca7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -326,12 +326,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val d2 = new Decimal().set(100) val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) + val a6 = Literal.create(Seq(true, false, true, false), + ArrayType(BooleanType, containsNull = false)) + val a7 = Literal.create(Seq(true, false, true, false), ArrayType(BooleanType)) + val a8 = Literal.create(Seq(true, false, true, null, false), ArrayType(BooleanType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) checkEvaluation(new SortArray(a4), Seq(d1, d2)) + checkEvaluation(new SortArray(a6), Seq(false, false, true, true)) + checkEvaluation(new SortArray(a7), Seq(false, false, true, true)) + checkEvaluation(new SortArray(a8), Seq(null, false, false, true, true)) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))