From c6ecbfbdf63a44e7133fc4dea0c06b967e497e2e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Aug 2021 21:35:12 +0800 Subject: [PATCH 1/4] tmp --- .../analysis/TableOutputResolver.scala | 47 ++++++++++++---- ...Suite.scala => V2WriteAnalysisSuite.scala} | 55 +++++++++---------- 2 files changed, 61 insertions(+), 41 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/{DataSourceV2AnalysisSuite.scala => V2WriteAnalysisSuite.scala} (93%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 4be7aa0507b6e..e392c524ab224 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, CreateStruct, GetStructField, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} object TableOutputResolver { def resolveOutputColumns( @@ -41,16 +42,7 @@ object TableOutputResolver { val errors = new mutable.ArrayBuffer[String]() val resolved: Seq[NamedExpression] = if (byName) { - expected.flatMap { tableAttr => - query.resolve(Seq(tableAttr.name), conf.resolver) match { - case Some(queryExpr) => - checkField(tableAttr, queryExpr, byName, conf, err => errors += err) - case None => - errors += s"Cannot find data for output column '${tableAttr.name}'" - None - } - } - + reorderColumnsByName(query.output, expected, conf, errors += _) } else { if (expected.size > query.output.size) { throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( @@ -74,6 +66,37 @@ object TableOutputResolver { } } + private def reorderColumnsByName( + inputCols: Seq[NamedExpression], + expectedCols: Seq[Attribute], + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String] = Nil): Seq[NamedExpression] = { + expectedCols.flatMap { expectedCol => + val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name)) + val newColPath = colPath :+ expectedCol.name + if (matched.isEmpty) { + addError(s"Cannot find data for output column ${newColPath.quoted}") + None + } else if (matched.length > 1) { + addError(s"Ambiguous column name in the input data: ${newColPath.quoted}") + None + } else { + (matched.head.dataType, expectedCol.dataType) match { + case (input: StructType, expected: StructType) => + val fields = input.zipWithIndex.map { case (f, i) => + Alias(GetStructField(matched.head, i), f.name)() + } + val reordered = reorderColumnsByName( + fields, expected.toAttributes, conf, addError, newColPath) + Some(Alias(CreateStruct(reordered), expectedCol.name)()) + case _ => + checkField(expectedCol, matched.head, byName = true, conf, addError) + } + } + } + } + private def checkField( tableAttr: Attribute, queryExpr: NamedExpression, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala similarity index 93% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index 5065276747ad3..9da728a42b5e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.catalyst.analysis -import java.net.URI import java.util.Locale -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, AttributeReference, Cast, LessThanOrEqual, Literal} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.types._ -class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { +class V2AppendDataANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { AppendData.byName(table, query) } @@ -37,7 +38,7 @@ class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { } } -class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { +class V2AppendDataStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { AppendData.byName(table, query) } @@ -47,7 +48,7 @@ class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { } } -class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { +class V2OverwritePartitionsDynamicANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwritePartitionsDynamic.byName(table, query) } @@ -57,7 +58,7 @@ class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnal } } -class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { +class V2OverwritePartitionsDynamicStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwritePartitionsDynamic.byName(table, query) } @@ -67,7 +68,7 @@ class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2Strict } } -class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { +class V2OverwriteByExpressionANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwriteByExpression.byName(table, query, Literal(true)) } @@ -85,7 +86,7 @@ class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisS } } -class V2OverwriteByExpressionStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { +class V2OverwriteByExpressionStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { OverwriteByExpression.byName(table, query, Literal(true)) } @@ -113,7 +114,7 @@ case class TestRelationAcceptAnySchema(output: Seq[AttributeReference]) override def skipSchemaResolution: Boolean = true } -abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSuite { +abstract class V2ANSIWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { // For Ansi store assignment policy, expression `AnsiCast` is used instead of `Cast`. override def checkAnalysis( @@ -140,7 +141,7 @@ abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSui } } -abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseSuite { +abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { override def checkAnalysis( inputPlan: LogicalPlan, expectedPlan: LogicalPlan, @@ -220,29 +221,15 @@ abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseS } } -abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { +abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { - override def getAnalyzer: Analyzer = { - val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin) - catalog.createDatabase( - CatalogDatabase("default", "", new URI("loc"), Map.empty), - ignoreIfExists = false) - new Analyzer(catalog) { - override val extendedResolutionRules = EliminateSubqueryAliases :: Nil - } - } + override def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Seq(EliminateSubqueryAliases) - val table = TestRelation(StructType(Seq( - StructField("x", FloatType), - StructField("y", FloatType))).toAttributes) + val table = TestRelation(Seq('x.float, 'y.float)) - val requiredTable = TestRelation(StructType(Seq( - StructField("x", FloatType, nullable = false), - StructField("y", FloatType, nullable = false))).toAttributes) + val requiredTable = TestRelation(Seq('x.float.notNull, 'y.float.notNull)) - val widerTable = TestRelation(StructType(Seq( - StructField("x", DoubleType), - StructField("y", DoubleType))).toAttributes) + val widerTable = TestRelation(Seq('x.double, 'y.double)) def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan @@ -700,4 +687,14 @@ abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { assertNotResolved(parsedPlan2) assertAnalysisError(parsedPlan2, Seq("cannot resolve", "a", "given input columns", "x, y")) } + + test("byName: reorder inner fields") { + val table = TestRelation(Seq('a.int, 'b.struct('x.int, 'y.int))) + val query = TestRelation(Seq('b.struct('y.int, 'x.byte), 'a.int)) + + val writePlan = byName(table, query) + val expectedPlan = byName(table, query.select( + + )) + } } From 81101d08c54ead4a2bb37aae576770005b2fc696 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Aug 2021 23:05:27 +0800 Subject: [PATCH 2/4] reorder inner fields in byName V2 write --- .../catalyst/analysis/TableOutputResolver.scala | 15 ++++++++++----- .../catalyst/analysis/V2WriteAnalysisSuite.scala | 15 ++++++++++----- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index e392c524ab224..cb76a1280c849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -76,22 +76,27 @@ object TableOutputResolver { val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name)) val newColPath = colPath :+ expectedCol.name if (matched.isEmpty) { - addError(s"Cannot find data for output column ${newColPath.quoted}") + addError(s"Cannot find data for output column '${newColPath.quoted}'") None } else if (matched.length > 1) { - addError(s"Ambiguous column name in the input data: ${newColPath.quoted}") + addError(s"Ambiguous column name in the input data: '${newColPath.quoted}'") None } else { - (matched.head.dataType, expectedCol.dataType) match { + val matchedCol = matched.head match { + case a: Attribute => a.withName(expectedCol.name) + case a: Alias => a.withName(expectedCol.name) + case other => other + } + (matchedCol.dataType, expectedCol.dataType) match { case (input: StructType, expected: StructType) => val fields = input.zipWithIndex.map { case (f, i) => - Alias(GetStructField(matched.head, i), f.name)() + Alias(GetStructField(matchedCol, i, Some(f.name)), f.name)() } val reordered = reorderColumnsByName( fields, expected.toAttributes, conf, addError, newColPath) Some(Alias(CreateStruct(reordered), expectedCol.name)()) case _ => - checkField(expectedCol, matched.head, byName = true, conf, addError) + checkField(expectedCol, matchedCol, byName = true, conf, addError) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index 9da728a42b5e4..19724818666fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -603,8 +603,8 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", - "Struct 'col' 0-th field name does not match", "expected 'a', found 'x'", - "Struct 'col' 1-th field name does not match", "expected 'b', found 'y'")) + "Cannot find data for output column 'col.a'", + "Cannot find data for output column 'col.b'")) } withClue("byPosition") { @@ -688,13 +688,18 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertAnalysisError(parsedPlan2, Seq("cannot resolve", "a", "given input columns", "x, y")) } - test("byName: reorder inner fields") { + test("SPARK-36498: reorder inner fields in byName mode") { val table = TestRelation(Seq('a.int, 'b.struct('x.int, 'y.int))) val query = TestRelation(Seq('b.struct('y.int, 'x.byte), 'a.int)) val writePlan = byName(table, query) val expectedPlan = byName(table, query.select( - - )) + "a".attr, + namedStruct( + "x".expr, "b".attr.getField("x").as("x").cast(IntegerType), + "y".expr, "b".attr.getField("y").as("y") + ).as("b") + )).analyze + checkAnalysis(writePlan, expectedPlan) } } From 3693ce6241701e6451224395c68e716f02b65424 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Aug 2021 16:36:39 +0800 Subject: [PATCH 3/4] support array/map --- .../analysis/TableOutputResolver.scala | 130 ++++++++++++++++-- .../analysis/V2WriteAnalysisSuite.scala | 51 +++++-- 2 files changed, 159 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index cb76a1280c849..d67b0d28a87ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, CreateStruct, GetStructField, NamedExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} object TableOutputResolver { def resolveOutputColumns( @@ -72,7 +72,8 @@ object TableOutputResolver { conf: SQLConf, addError: String => Unit, colPath: Seq[String] = Nil): Seq[NamedExpression] = { - expectedCols.flatMap { expectedCol => + val matchedCols = mutable.HashSet.empty[String] + val reordered = expectedCols.flatMap { expectedCol => val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name)) val newColPath = colPath :+ expectedCol.name if (matched.isEmpty) { @@ -81,25 +82,130 @@ object TableOutputResolver { } else if (matched.length > 1) { addError(s"Ambiguous column name in the input data: '${newColPath.quoted}'") None + } else if (matched.head.nullable && !expectedCol.nullable) { + addError(s"Cannot write nullable values to non-null column '${newColPath.quoted}'") + None } else { + matchedCols += matched.head.name + val expectedName = expectedCol.name val matchedCol = matched.head match { - case a: Attribute => a.withName(expectedCol.name) - case a: Alias => a.withName(expectedCol.name) + // Save an Alias if we can change the name directly. + case a: Attribute => a.withName(expectedName) + case a: Alias => a.withName(expectedName) case other => other } (matchedCol.dataType, expectedCol.dataType) match { - case (input: StructType, expected: StructType) => - val fields = input.zipWithIndex.map { case (f, i) => - Alias(GetStructField(matchedCol, i, Some(f.name)), f.name)() - } - val reordered = reorderColumnsByName( - fields, expected.toAttributes, conf, addError, newColPath) - Some(Alias(CreateStruct(reordered), expectedCol.name)()) + case (matchedType: StructType, expectedType: StructType) => + resolveStructType( + matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) + case (matchedType: ArrayType, expectedType: ArrayType) => + resolveArrayType( + matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) + case (matchedType: MapType, expectedType: MapType) => + resolveMapType( + matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) case _ => checkField(expectedCol, matchedCol, byName = true, conf, addError) } } } + + if (reordered.length == expectedCols.length) { + if (matchedCols.size < inputCols.length) { + val extraCols = inputCols.filterNot(col => matchedCols.contains(col.name)) + .map(col => s"'${col.name}'").mkString(", ") + addError(s"Cannot write extra fields to struct '${colPath.quoted}': $extraCols") + Nil + } else { + reordered + } + } else { + Nil + } + } + + private def resolveStructType( + input: NamedExpression, + inputType: StructType, + expectedType: StructType, + expectedName: String, + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String]): Option[NamedExpression] = { + val fields = inputType.zipWithIndex.map { case (f, i) => + Alias(GetStructField(input, i, Some(f.name)), f.name)() + } + val reordered = reorderColumnsByName(fields, expectedType.toAttributes, conf, addError, colPath) + if (reordered.length == expectedType.length) { + val struct = CreateStruct(reordered) + val res = if (input.nullable) { + If(IsNull(input), Literal(null, struct.dataType), struct) + } else { + struct + } + Some(Alias(res, expectedName)()) + } else { + None + } + } + + private def resolveArrayType( + input: NamedExpression, + inputType: ArrayType, + expectedType: ArrayType, + expectedName: String, + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String]): Option[NamedExpression] = { + if (inputType.containsNull && !expectedType.containsNull) { + addError(s"Cannot write nullable elements to array of non-nulls: '${colPath.quoted}'") + None + } else { + val param = NamedLambdaVariable("x", inputType.elementType, inputType.containsNull) + val fakeAttr = AttributeReference("x", expectedType.elementType, expectedType.containsNull)() + val res = reorderColumnsByName(Seq(param), Seq(fakeAttr), conf, addError, colPath) + if (res.length == 1) { + val func = LambdaFunction(res.head, Seq(param)) + Some(Alias(ArrayTransform(input, func), expectedName)()) + } else { + None + } + } + } + + private def resolveMapType( + input: NamedExpression, + inputType: MapType, + expectedType: MapType, + expectedName: String, + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String]): Option[NamedExpression] = { + if (inputType.valueContainsNull && !expectedType.valueContainsNull) { + addError(s"Cannot write nullable values to map of non-nulls: '${colPath.quoted}'") + None + } else { + val keyParam = NamedLambdaVariable("k", inputType.keyType, nullable = false) + val fakeKeyAttr = AttributeReference("k", expectedType.keyType, nullable = false)() + val resKey = reorderColumnsByName( + Seq(keyParam), Seq(fakeKeyAttr), conf, addError, colPath :+ "key") + + val valueParam = NamedLambdaVariable("v", inputType.valueType, inputType.valueContainsNull) + val fakeValueAttr = + AttributeReference("v", expectedType.valueType, expectedType.valueContainsNull)() + val resValue = reorderColumnsByName( + Seq(valueParam), Seq(fakeValueAttr), conf, addError, colPath :+ "value") + + if (resKey.length == 1 && resValue.length == 1) { + val keyFunc = LambdaFunction(resKey.head, Seq(keyParam)) + val valueFunc = LambdaFunction(resValue.head, Seq(valueParam)) + val newKeys = ArrayTransform(MapKeys(input), keyFunc) + val newValues = ArrayTransform(MapValues(input), valueFunc) + Some(Alias(MapFromArrays(newKeys, newValues), expectedName)()) + } else { + None + } + } } private def checkField( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index 19724818666fd..49fb0c18920b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -446,6 +446,16 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { "Data columns: 'x', 'y', 'z'")) } + test("byName: fail extra data fields in struct") { + val table = TestRelation(Seq('a.int, 'b.struct('x.int, 'y.int))) + val query = TestRelation(Seq('b.struct('y.int, 'x.int, 'z.int), 'a.int)) + + val writePlan = byName(table, query) + assertAnalysisError(writePlan, Seq( + "Cannot write incompatible data to table", "'table-name'", + "Cannot write extra fields to struct 'b': 'z'")) + } + test("byPosition: basic behavior") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), @@ -688,18 +698,39 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertAnalysisError(parsedPlan2, Seq("cannot resolve", "a", "given input columns", "x, y")) } - test("SPARK-36498: reorder inner fields in byName mode") { + test("SPARK-36498: reorder inner fields with byName mode") { val table = TestRelation(Seq('a.int, 'b.struct('x.int, 'y.int))) val query = TestRelation(Seq('b.struct('y.int, 'x.byte), 'a.int)) - val writePlan = byName(table, query) - val expectedPlan = byName(table, query.select( - "a".attr, - namedStruct( - "x".expr, "b".attr.getField("x").as("x").cast(IntegerType), - "y".expr, "b".attr.getField("y").as("y") - ).as("b") - )).analyze - checkAnalysis(writePlan, expectedPlan) + val writePlan = byName(table, query).analyze + assert(writePlan.children.head.schema == table.schema) + } + + test("SPARK-36498: reorder inner fields in array of struct with byName mode") { + val table = TestRelation(Seq( + 'a.int, + 'arr.array(new StructType().add("x", "int").add("y", "int")))) + val query = TestRelation(Seq( + 'arr.array(new StructType().add("y", "int").add("x", "byte")), + 'a.int)) + + val writePlan = byName(table, query).analyze + assert(writePlan.children.head.schema == table.schema) + } + + test("SPARK-36498: reorder inner fields in map of struct with byName mode") { + val table = TestRelation(Seq( + 'a.int, + 'm.map( + new StructType().add("x", "int").add("y", "int"), + new StructType().add("x", "int").add("y", "int")))) + val query = TestRelation(Seq( + 'm.map( + new StructType().add("y", "int").add("x", "byte"), + new StructType().add("y", "int").add("x", "byte")), + 'a.int)) + + val writePlan = byName(table, query).analyze + assert(writePlan.children.head.schema == table.schema) } } From 452b53571f26ad8b4994910b2cb99325f43a4215 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Aug 2021 23:48:23 +0800 Subject: [PATCH 4/4] fix test --- .../analysis/TableOutputResolver.scala | 18 +++++++++++++++--- .../analysis/V2WriteAnalysisSuite.scala | 1 - 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index d67b0d28a87ce..d471d754e7f8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -82,9 +82,6 @@ object TableOutputResolver { } else if (matched.length > 1) { addError(s"Ambiguous column name in the input data: '${newColPath.quoted}'") None - } else if (matched.head.nullable && !expectedCol.nullable) { - addError(s"Cannot write nullable values to non-null column '${newColPath.quoted}'") - None } else { matchedCols += matched.head.name val expectedName = expectedCol.name @@ -96,12 +93,15 @@ object TableOutputResolver { } (matchedCol.dataType, expectedCol.dataType) match { case (matchedType: StructType, expectedType: StructType) => + checkNullability(matchedCol, expectedCol, conf, addError, newColPath) resolveStructType( matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) case (matchedType: ArrayType, expectedType: ArrayType) => + checkNullability(matchedCol, expectedCol, conf, addError, newColPath) resolveArrayType( matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) case (matchedType: MapType, expectedType: MapType) => + checkNullability(matchedCol, expectedCol, conf, addError, newColPath) resolveMapType( matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) case _ => @@ -124,6 +124,18 @@ object TableOutputResolver { } } + private def checkNullability( + input: Expression, + expected: Attribute, + conf: SQLConf, + addError: String => Unit, + colPath: Seq[String]): Unit = { + if (input.nullable && !expected.nullable && + conf.storeAssignmentPolicy != StoreAssignmentPolicy.LEGACY) { + addError(s"Cannot write nullable values to non-null column '${colPath.quoted}'") + } + } + private def resolveStructType( input: NamedExpression, inputType: StructType, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index 49fb0c18920b8..81043cd2dd9c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -188,7 +188,6 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { "Cannot find data for output column", "'y'")) } - test("byPosition: fail canWrite check") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType),