diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 4be7aa0507b6e..d471d754e7f8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, NamedExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} object TableOutputResolver { def resolveOutputColumns( @@ -41,16 +42,7 @@ object TableOutputResolver { val errors = new mutable.ArrayBuffer[String]() val resolved: Seq[NamedExpression] = if (byName) { - expected.flatMap { tableAttr => - query.resolve(Seq(tableAttr.name), conf.resolver) match { - case Some(queryExpr) => - checkField(tableAttr, queryExpr, byName, conf, err => errors += err) - case None => - errors += s"Cannot find data for output column '${tableAttr.name}'" - None - } - } - + reorderColumnsByName(query.output, expected, conf, errors += _) } else { if (expected.size > query.output.size) { throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( @@ -74,6 +66,160 @@ object TableOutputResolver { } } + private def reorderColumnsByName( + inputCols: Seq[NamedExpression], + expectedCols: Seq[Attribute], + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String] = Nil): Seq[NamedExpression] = { + val matchedCols = mutable.HashSet.empty[String] + val reordered = expectedCols.flatMap { expectedCol => + val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name)) + val newColPath = colPath :+ expectedCol.name + if (matched.isEmpty) { + addError(s"Cannot find data for output column '${newColPath.quoted}'") + None + } else if (matched.length > 1) { + addError(s"Ambiguous column name in the input data: '${newColPath.quoted}'") + None + } else { + matchedCols += matched.head.name + val expectedName = expectedCol.name + val matchedCol = matched.head match { + // Save an Alias if we can change the name directly. + case a: Attribute => a.withName(expectedName) + case a: Alias => a.withName(expectedName) + case other => other + } + (matchedCol.dataType, expectedCol.dataType) match { + case (matchedType: StructType, expectedType: StructType) => + checkNullability(matchedCol, expectedCol, conf, addError, newColPath) + resolveStructType( + matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) + case (matchedType: ArrayType, expectedType: ArrayType) => + checkNullability(matchedCol, expectedCol, conf, addError, newColPath) + resolveArrayType( + matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) + case (matchedType: MapType, expectedType: MapType) => + checkNullability(matchedCol, expectedCol, conf, addError, newColPath) + resolveMapType( + matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) + case _ => + checkField(expectedCol, matchedCol, byName = true, conf, addError) + } + } + } + + if (reordered.length == expectedCols.length) { + if (matchedCols.size < inputCols.length) { + val extraCols = inputCols.filterNot(col => matchedCols.contains(col.name)) + .map(col => s"'${col.name}'").mkString(", ") + addError(s"Cannot write extra fields to struct '${colPath.quoted}': $extraCols") + Nil + } else { + reordered + } + } else { + Nil + } + } + + private def checkNullability( + input: Expression, + expected: Attribute, + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String]): Unit = { + if (input.nullable && !expected.nullable && + conf.storeAssignmentPolicy != StoreAssignmentPolicy.LEGACY) { + addError(s"Cannot write nullable values to non-null column '${colPath.quoted}'") + } + } + + private def resolveStructType( + input: NamedExpression, + inputType: StructType, + expectedType: StructType, + expectedName: String, + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String]): Option[NamedExpression] = { + val fields = inputType.zipWithIndex.map { case (f, i) => + Alias(GetStructField(input, i, Some(f.name)), f.name)() + } + val reordered = reorderColumnsByName(fields, expectedType.toAttributes, conf, addError, colPath) + if (reordered.length == expectedType.length) { + val struct = CreateStruct(reordered) + val res = if (input.nullable) { + If(IsNull(input), Literal(null, struct.dataType), struct) + } else { + struct + } + Some(Alias(res, expectedName)()) + } else { + None + } + } + + private def resolveArrayType( + input: NamedExpression, + inputType: ArrayType, + expectedType: ArrayType, + expectedName: String, + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String]): Option[NamedExpression] = { + if (inputType.containsNull && !expectedType.containsNull) { + addError(s"Cannot write nullable elements to array of non-nulls: '${colPath.quoted}'") + None + } else { + val param = NamedLambdaVariable("x", inputType.elementType, inputType.containsNull) + val fakeAttr = AttributeReference("x", expectedType.elementType, expectedType.containsNull)() + val res = reorderColumnsByName(Seq(param), Seq(fakeAttr), conf, addError, colPath) + if (res.length == 1) { + val func = LambdaFunction(res.head, Seq(param)) + Some(Alias(ArrayTransform(input, func), expectedName)()) + } else { + None + } + } + } + + private def resolveMapType( + input: NamedExpression, + inputType: MapType, + expectedType: MapType, + expectedName: String, + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String]): Option[NamedExpression] = { + if (inputType.valueContainsNull && !expectedType.valueContainsNull) { + addError(s"Cannot write nullable values to map of non-nulls: '${colPath.quoted}'") + None + } else { + val keyParam = NamedLambdaVariable("k", inputType.keyType, nullable = false) + val fakeKeyAttr = AttributeReference("k", expectedType.keyType, nullable = false)() + val resKey = reorderColumnsByName( + Seq(keyParam), Seq(fakeKeyAttr), conf, addError, colPath :+ "key") + + val valueParam = NamedLambdaVariable("v", inputType.valueType, inputType.valueContainsNull) + val fakeValueAttr = + AttributeReference("v", expectedType.valueType, expectedType.valueContainsNull)() + val resValue = reorderColumnsByName( + Seq(valueParam), Seq(fakeValueAttr), conf, addError, colPath :+ "value") + + if (resKey.length == 1 && resValue.length == 1) { + val keyFunc = LambdaFunction(resKey.head, Seq(keyParam)) + val valueFunc = LambdaFunction(resValue.head, Seq(valueParam)) + val newKeys = ArrayTransform(MapKeys(input), keyFunc) + val newValues = ArrayTransform(MapValues(input), valueFunc) + Some(Alias(MapFromArrays(newKeys, newValues), expectedName)()) + } else { + None + } + } + } + private def checkField( tableAttr: Attribute, queryExpr: NamedExpression, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala similarity index 88% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index 5065276747ad3..81043cd2dd9c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.catalyst.analysis -import java.net.URI import java.util.Locale -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, AttributeReference, Cast, LessThanOrEqual, Literal} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.types._ -class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { +class V2AppendDataANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { AppendData.byName(table, query) } @@ -37,7 +38,7 @@ class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { } } -class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { +class V2AppendDataStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { AppendData.byName(table, query) } @@ -47,7 +48,7 @@ class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { } } -class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { +class V2OverwritePartitionsDynamicANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwritePartitionsDynamic.byName(table, query) } @@ -57,7 +58,7 @@ class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnal } } -class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { +class V2OverwritePartitionsDynamicStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwritePartitionsDynamic.byName(table, query) } @@ -67,7 +68,7 @@ class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2Strict } } -class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { +class V2OverwriteByExpressionANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwriteByExpression.byName(table, query, Literal(true)) } @@ -85,7 +86,7 @@ class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisS } } -class V2OverwriteByExpressionStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { +class V2OverwriteByExpressionStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwriteByExpression.byName(table, query, Literal(true)) } @@ -113,7 +114,7 @@ case class TestRelationAcceptAnySchema(output: Seq[AttributeReference]) override def skipSchemaResolution: Boolean = true } -abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSuite { +abstract class V2ANSIWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { // For Ansi store assignment policy, expression `AnsiCast` is used instead of `Cast`. override def checkAnalysis( @@ -140,7 +141,7 @@ abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSui } } -abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseSuite { +abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { override def checkAnalysis( inputPlan: LogicalPlan, expectedPlan: LogicalPlan, @@ -187,7 +188,6 @@ abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseS "Cannot find data for output column", "'y'")) } - test("byPosition: fail canWrite check") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), @@ -220,29 +220,15 @@ abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseS } } -abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { +abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { - override def getAnalyzer: Analyzer = { - val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin) - catalog.createDatabase( - CatalogDatabase("default", "", new URI("loc"), Map.empty), - ignoreIfExists = false) - new Analyzer(catalog) { - override val extendedResolutionRules = EliminateSubqueryAliases :: Nil - } - } + override def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Seq(EliminateSubqueryAliases) - val table = TestRelation(StructType(Seq( - StructField("x", FloatType), - StructField("y", FloatType))).toAttributes) + val table = TestRelation(Seq('x.float, 'y.float)) - val requiredTable = TestRelation(StructType(Seq( - StructField("x", FloatType, nullable = false), - StructField("y", FloatType, nullable = false))).toAttributes) + val requiredTable = TestRelation(Seq('x.float.notNull, 'y.float.notNull)) - val widerTable = TestRelation(StructType(Seq( - StructField("x", DoubleType), - StructField("y", DoubleType))).toAttributes) + val widerTable = TestRelation(Seq('x.double, 'y.double)) def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan @@ -459,6 +445,16 @@ abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { "Data columns: 'x', 'y', 'z'")) } + test("byName: fail extra data fields in struct") { + val table = TestRelation(Seq('a.int, 'b.struct('x.int, 'y.int))) + val query = TestRelation(Seq('b.struct('y.int, 'x.int, 'z.int), 'a.int)) + + val writePlan = byName(table, query) + assertAnalysisError(writePlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write extra fields to struct 'b': 'z'")) + } + test("byPosition: basic behavior") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), @@ -616,8 +612,8 @@ abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", - "Struct 'col' 0-th field name does not match", "expected 'a', found 'x'", - "Struct 'col' 1-th field name does not match", "expected 'b', found 'y'")) + "Cannot find data for output column 'col.a'", + "Cannot find data for output column 'col.b'")) } withClue("byPosition") { @@ -700,4 +696,40 @@ abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { assertNotResolved(parsedPlan2) assertAnalysisError(parsedPlan2, Seq("cannot resolve", "a", "given input columns", "x, y")) } + + test("SPARK-36498: reorder inner fields with byName mode") { + val table = TestRelation(Seq('a.int, 'b.struct('x.int, 'y.int))) + val query = TestRelation(Seq('b.struct('y.int, 'x.byte), 'a.int)) + + val writePlan = byName(table, query).analyze + assert(writePlan.children.head.schema == table.schema) + } + + test("SPARK-36498: reorder inner fields in array of struct with byName mode") { + val table = TestRelation(Seq( + 'a.int, + 'arr.array(new StructType().add("x", "int").add("y", "int")))) + val query = TestRelation(Seq( + 'arr.array(new StructType().add("y", "int").add("x", "byte")), + 'a.int)) + + val writePlan = byName(table, query).analyze + assert(writePlan.children.head.schema == table.schema) + } + + test("SPARK-36498: reorder inner fields in map of struct with byName mode") { + val table = TestRelation(Seq( + 'a.int, + 'm.map( + new StructType().add("x", "int").add("y", "int"), + new StructType().add("x", "int").add("y", "int")))) + val query = TestRelation(Seq( + 'm.map( + new StructType().add("y", "int").add("x", "byte"), + new StructType().add("y", "int").add("x", "byte")), + 'a.int)) + + val writePlan = byName(table, query).analyze + assert(writePlan.children.head.schema == table.schema) + } }