diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 8ffff63445b69..45994fbf58a68 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -143,6 +143,7 @@ public Collation( private static final HashMap collationNameToIdMap = new HashMap<>(); public static final int UTF8_BINARY_COLLATION_ID = 0; + public static final int INDETERMINATE_COLLATION_ID = -1; public static final int UTF8_BINARY_LCASE_COLLATION_ID = 1; static { diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fbd4987713e26..f10589cf00f4a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -266,6 +266,8 @@ def fromCollationId(self, collationId: int) -> "StringType": def collationIdToName(self) -> str: if self.collationId == 0: return "" + elif self.collationId == -1: + return " collate INDETERMINATE" else: return " collate %s" % StringType.collationNames[self.collationId] diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 74c714ff63f41..990717abc6de2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -59,6 +59,9 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa */ override def typeName: String = if (isUTF8BinaryCollation) "string" + else if (collationId == CollationFactory.INDETERMINATE_COLLATION_ID) { + "string collate INDETERMINATE" + } else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}" override def equals(obj: Any): Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index a50dad7c8cdb8..01bb58d9e2109 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -23,6 +23,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, DataType, StringType} @@ -37,7 +38,9 @@ object CollationTypeCasts extends TypeCoercionRule { case caseWhenExpr: CaseWhen if !haveSameType(caseWhenExpr.inputTypesForMerging) => val outputStringType = - getOutputCollation(caseWhenExpr.branches.map(_._2) ++ caseWhenExpr.elseValue) + getOutputCollation( + caseWhenExpr.branches.map(_._2) ++ caseWhenExpr.elseValue, + failOnIndeterminate = true) val newBranches = caseWhenExpr.branches.map { case (condition, value) => (condition, castStringType(value, outputStringType).getOrElse(value)) } @@ -71,10 +74,14 @@ object CollationTypeCasts extends TypeCoercionRule { val Seq(newStr, newPad) = collateToSingleType(Seq(str, pad)) stringPadExpr.withNewChildren(Seq(newStr, len, newPad)) + case concatExprs @ (_: Concat | _: ConcatWs) => + val newChildren = collateToSingleType(concatExprs.children, failOnIndeterminate = false) + concatExprs.withNewChildren(newChildren) + case otherExpr @ ( - _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | - _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: StringReplace | - _: StringTranslate | _: StringTrim | _: StringTrimLeft | _: StringTrimRight) => + _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Greatest | _: Least | _: Coalesce | + _: BinaryExpression | _: Mask | _: StringReplace | _: StringTranslate | _: StringTrim | + _: StringTrimLeft | _: StringTrimRight) => val newChildren = collateToSingleType(otherExpr.children) otherExpr.withNewChildren(newChildren) } @@ -110,8 +117,10 @@ object CollationTypeCasts extends TypeCoercionRule { /** * Collates input expressions to a single collation. */ - def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = { - val st = getOutputCollation(exprs) + def collateToSingleType( + exprs: Seq[Expression], + failOnIndeterminate: Boolean = true): Seq[Expression] = { + val st = getOutputCollation(exprs, failOnIndeterminate) exprs.map(e => castStringType(e, st).getOrElse(e)) } @@ -122,7 +131,7 @@ object CollationTypeCasts extends TypeCoercionRule { * any expressions, but will only be affected by collated StringTypes or * complex DataTypes with collated StringTypes (e.g. ArrayType) */ - def getOutputCollation(expr: Seq[Expression]): StringType = { + def getOutputCollation(expr: Seq[Expression], failOnIndeterminate: Boolean): StringType = { val explicitTypes = expr.filter { case _: Collate => true case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined => @@ -155,7 +164,20 @@ object CollationTypeCasts extends TypeCoercionRule { .distinct if (implicitTypes.length > 1) { - throw QueryCompilationErrors.implicitCollationMismatchError() + if (failOnIndeterminate) { + throw QueryCompilationErrors.implicitCollationMismatchError() + } + else { + StringType(CollationFactory.INDETERMINATE_COLLATION_ID) + } + } + else if (implicitTypes.contains(-1)) { + if (failOnIndeterminate) { + throw QueryCompilationErrors.indeterminateCollationError() + } + else { + StringType(CollationFactory.INDETERMINATE_COLLATION_ID) + } } else { implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala index 83e50aa33c70d..7140b84253a9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala @@ -21,14 +21,14 @@ import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, UnaryExpression, Unevaluable} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.util.GeneratedColumn +import org.apache.spark.sql.catalyst.util.{CollationFactory, GeneratedColumn} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.validateDefaultValueExpr import org.apache.spark.sql.catalyst.util.ResolveDefaultColumnsUtils.{CURRENT_DEFAULT_COLUMN_METADATA_KEY, EXISTS_DEFAULT_COLUMN_METADATA_KEY} import org.apache.spark.sql.connector.catalog.{Column => V2Column, ColumnDefaultValue} import org.apache.spark.sql.connector.expressions.LiteralValue import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.connector.ColumnImpl -import org.apache.spark.sql.types.{DataType, Metadata, MetadataBuilder, StructField} +import org.apache.spark.sql.types.{DataType, Metadata, MetadataBuilder, StringType, StructField} /** * Column definition for tables. This is an expression so that analyzer can resolve the default @@ -43,6 +43,10 @@ case class ColumnDefinition( generationExpression: Option[String] = None, metadata: Metadata = Metadata.empty) extends Expression with Unevaluable { + if (dataType == StringType(CollationFactory.INDETERMINATE_COLLATION_ID)) { + throw QueryCompilationErrors.indeterminateCollationError() + } + override def children: Seq[Expression] = defaultValue.toSeq override protected def withNewChildrenInternal( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 2f39a1962d2c0..f53852d74b7f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -29,13 +29,13 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.connector.expressions.{FieldReference, RewritableTransform} import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, CreateViewCommand, DDLUtils} import org.apache.spark.sql.execution.command.ViewHelper.generateViewProperties import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.InsertableRelation -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.ArrayImplicits._ @@ -683,3 +683,23 @@ object ViewSyncSchemaToMetaStore extends (LogicalPlan => Unit) { } } } + +object IndeterminateCheck extends (LogicalPlan => Unit) { + def apply(plan: LogicalPlan): Unit = { + plan match { + case CreateDataSourceTableAsSelectCommand(_, _, query, _) if query.resolved => + if (query.schema.exists(sf => sf.dataType == StringType(-1))) { + throw QueryCompilationErrors.indeterminateCollationError() + } + case CreateViewCommand(_, _, _, _, _, plan, _, _, _, _, _, _) if plan.resolved => + if (plan.schema.exists(sf => sf.dataType == StringType(-1))) { + throw QueryCompilationErrors.indeterminateCollationError() + } + case Sort(order, _, child) if child.resolved => + if (order.exists(sf => sf.dataType == StringType(-1))) { + throw QueryCompilationErrors.indeterminateCollationError() + } + case _ => () + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 4660970814e21..f767613ac02d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -225,6 +225,7 @@ abstract class BaseSessionStateBuilder( CommandCheck +: CollationCheck +: ViewSyncSchemaToMetaStore +: + IndeterminateCheck +: customCheckRules } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index fce9ad3cc184b..87a01e0904c1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -460,13 +460,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { // concat of columns of different collations is allowed // as long as we don't use the result in an unsupported function - // TODO(SPARK-47210): Add indeterminate support - checkError( - exception = intercept[AnalysisException] { - sql(s"SELECT c1 || c2 FROM $tableName") - }, - errorClass = "COLLATION_MISMATCH.IMPLICIT" - ) + checkAnswer(sql(s"SELECT c1 || c2 FROM $tableName"), Seq(Row("aa"), Row("AA"))) // concat + in @@ -526,7 +520,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT c1 FROM $tableName WHERE c1 || c3 = 'aa'") }, - errorClass = "COLLATION_MISMATCH.IMPLICIT" + errorClass = "INDETERMINATE_COLLATION" ) // concat on different implicit collations should succeed, @@ -535,7 +529,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT * FROM $tableName ORDER BY c1 || c3") }, - errorClass = "COLLATION_MISMATCH.IMPLICIT" + errorClass = "INDETERMINATE_COLLATION" ) // concat + in @@ -552,7 +546,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[AnalysisException] { sql(s"SELECT * FROM $tableName WHERE contains(c1||c3, 'a')") }, - errorClass = "COLLATION_MISMATCH.IMPLICIT" + errorClass = "INDETERMINATE_COLLATION" ) checkError( @@ -650,7 +644,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - // TODO(SPARK-47210): Add indeterminate support test("SPARK-47210: Indeterminate collation checks") { val tableName = "t1" val newTableName = "t2" @@ -671,9 +664,16 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { withTable(newTableName) { checkError( exception = intercept[AnalysisException] { - sql(s"CREATE TABLE $newTableName AS SELECT c1 || c2 FROM $tableName") + sql(s"CREATE TABLE $newTableName AS SELECT c1 || c2 FROM $tableName").explain(true) + }, + errorClass = "INDETERMINATE_COLLATION") + } + withView("v") { + checkError( + exception = intercept[AnalysisException] { + sql(s"CREATE VIEW v AS SELECT c1 || c2 as con FROM $tableName").explain(true) }, - errorClass = "COLLATION_MISMATCH.IMPLICIT") + errorClass = "INDETERMINATE_COLLATION") } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 979ff1e24ef5c..b2c57521007b1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -115,6 +115,7 @@ class HiveSessionStateBuilder( CommandCheck +: CollationCheck +: ViewSyncSchemaToMetaStore +: + IndeterminateCheck +: customCheckRules }