-
Notifications
You must be signed in to change notification settings - Fork 29.3k
[SPARK-36498][SQL] Reorder inner fields of the input query in byName V2 write #33728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.analysis | |
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, NamedExpression} | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} | ||
| import org.apache.spark.sql.catalyst.util.CharVarcharUtils | ||
| import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ | ||
| import org.apache.spark.sql.errors.QueryCompilationErrors | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy | ||
| import org.apache.spark.sql.types.DataType | ||
| import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} | ||
|
|
||
| object TableOutputResolver { | ||
| def resolveOutputColumns( | ||
|
|
@@ -41,16 +42,7 @@ object TableOutputResolver { | |
|
|
||
| val errors = new mutable.ArrayBuffer[String]() | ||
| val resolved: Seq[NamedExpression] = if (byName) { | ||
| expected.flatMap { tableAttr => | ||
| query.resolve(Seq(tableAttr.name), conf.resolver) match { | ||
| case Some(queryExpr) => | ||
| checkField(tableAttr, queryExpr, byName, conf, err => errors += err) | ||
| case None => | ||
| errors += s"Cannot find data for output column '${tableAttr.name}'" | ||
| None | ||
| } | ||
| } | ||
|
|
||
| reorderColumnsByName(query.output, expected, conf, errors += _) | ||
| } else { | ||
| if (expected.size > query.output.size) { | ||
| throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( | ||
|
|
@@ -74,6 +66,160 @@ object TableOutputResolver { | |
| } | ||
| } | ||
|
|
||
| private def reorderColumnsByName( | ||
| inputCols: Seq[NamedExpression], | ||
| expectedCols: Seq[Attribute], | ||
| conf: SQLConf, | ||
| addError: String => Unit, | ||
| colPath: Seq[String] = Nil): Seq[NamedExpression] = { | ||
| val matchedCols = mutable.HashSet.empty[String] | ||
| val reordered = expectedCols.flatMap { expectedCol => | ||
| val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name)) | ||
| val newColPath = colPath :+ expectedCol.name | ||
| if (matched.isEmpty) { | ||
| addError(s"Cannot find data for output column '${newColPath.quoted}'") | ||
| None | ||
| } else if (matched.length > 1) { | ||
| addError(s"Ambiguous column name in the input data: '${newColPath.quoted}'") | ||
| None | ||
| } else { | ||
| matchedCols += matched.head.name | ||
| val expectedName = expectedCol.name | ||
| val matchedCol = matched.head match { | ||
| // Save an Alias if we can change the name directly. | ||
| case a: Attribute => a.withName(expectedName) | ||
| case a: Alias => a.withName(expectedName) | ||
| case other => other | ||
| } | ||
| (matchedCol.dataType, expectedCol.dataType) match { | ||
| case (matchedType: StructType, expectedType: StructType) => | ||
| checkNullability(matchedCol, expectedCol, conf, addError, newColPath) | ||
| resolveStructType( | ||
| matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) | ||
| case (matchedType: ArrayType, expectedType: ArrayType) => | ||
| checkNullability(matchedCol, expectedCol, conf, addError, newColPath) | ||
| resolveArrayType( | ||
| matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) | ||
| case (matchedType: MapType, expectedType: MapType) => | ||
| checkNullability(matchedCol, expectedCol, conf, addError, newColPath) | ||
| resolveMapType( | ||
| matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath) | ||
| case _ => | ||
| checkField(expectedCol, matchedCol, byName = true, conf, addError) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (reordered.length == expectedCols.length) { | ||
| if (matchedCols.size < inputCols.length) { | ||
| val extraCols = inputCols.filterNot(col => matchedCols.contains(col.name)) | ||
| .map(col => s"'${col.name}'").mkString(", ") | ||
| addError(s"Cannot write extra fields to struct '${colPath.quoted}': $extraCols") | ||
| Nil | ||
| } else { | ||
| reordered | ||
| } | ||
| } else { | ||
| Nil | ||
| } | ||
| } | ||
|
|
||
| private def checkNullability( | ||
| input: Expression, | ||
| expected: Attribute, | ||
| conf: SQLConf, | ||
| addError: String => Unit, | ||
| colPath: Seq[String]): Unit = { | ||
| if (input.nullable && !expected.nullable && | ||
| conf.storeAssignmentPolicy != StoreAssignmentPolicy.LEGACY) { | ||
| addError(s"Cannot write nullable values to non-null column '${colPath.quoted}'") | ||
| } | ||
| } | ||
|
|
||
| private def resolveStructType( | ||
| input: NamedExpression, | ||
| inputType: StructType, | ||
| expectedType: StructType, | ||
| expectedName: String, | ||
| conf: SQLConf, | ||
| addError: String => Unit, | ||
| colPath: Seq[String]): Option[NamedExpression] = { | ||
| val fields = inputType.zipWithIndex.map { case (f, i) => | ||
| Alias(GetStructField(input, i, Some(f.name)), f.name)() | ||
| } | ||
| val reordered = reorderColumnsByName(fields, expectedType.toAttributes, conf, addError, colPath) | ||
| if (reordered.length == expectedType.length) { | ||
| val struct = CreateStruct(reordered) | ||
| val res = if (input.nullable) { | ||
| If(IsNull(input), Literal(null, struct.dataType), struct) | ||
| } else { | ||
| struct | ||
| } | ||
| Some(Alias(res, expectedName)()) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
|
|
||
| private def resolveArrayType( | ||
| input: NamedExpression, | ||
| inputType: ArrayType, | ||
| expectedType: ArrayType, | ||
| expectedName: String, | ||
| conf: SQLConf, | ||
| addError: String => Unit, | ||
| colPath: Seq[String]): Option[NamedExpression] = { | ||
| if (inputType.containsNull && !expectedType.containsNull) { | ||
| addError(s"Cannot write nullable elements to array of non-nulls: '${colPath.quoted}'") | ||
| None | ||
| } else { | ||
| val param = NamedLambdaVariable("x", inputType.elementType, inputType.containsNull) | ||
| val fakeAttr = AttributeReference("x", expectedType.elementType, expectedType.containsNull)() | ||
| val res = reorderColumnsByName(Seq(param), Seq(fakeAttr), conf, addError, colPath) | ||
| if (res.length == 1) { | ||
| val func = LambdaFunction(res.head, Seq(param)) | ||
| Some(Alias(ArrayTransform(input, func), expectedName)()) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def resolveMapType( | ||
| input: NamedExpression, | ||
| inputType: MapType, | ||
| expectedType: MapType, | ||
| expectedName: String, | ||
| conf: SQLConf, | ||
| addError: String => Unit, | ||
| colPath: Seq[String]): Option[NamedExpression] = { | ||
| if (inputType.valueContainsNull && !expectedType.valueContainsNull) { | ||
| addError(s"Cannot write nullable values to map of non-nulls: '${colPath.quoted}'") | ||
| None | ||
| } else { | ||
| val keyParam = NamedLambdaVariable("k", inputType.keyType, nullable = false) | ||
| val fakeKeyAttr = AttributeReference("k", expectedType.keyType, nullable = false)() | ||
| val resKey = reorderColumnsByName( | ||
| Seq(keyParam), Seq(fakeKeyAttr), conf, addError, colPath :+ "key") | ||
|
|
||
| val valueParam = NamedLambdaVariable("v", inputType.valueType, inputType.valueContainsNull) | ||
| val fakeValueAttr = | ||
| AttributeReference("v", expectedType.valueType, expectedType.valueContainsNull)() | ||
| val resValue = reorderColumnsByName( | ||
| Seq(valueParam), Seq(fakeValueAttr), conf, addError, colPath :+ "value") | ||
|
|
||
| if (resKey.length == 1 && resValue.length == 1) { | ||
| val keyFunc = LambdaFunction(resKey.head, Seq(keyParam)) | ||
| val valueFunc = LambdaFunction(resValue.head, Seq(valueParam)) | ||
| val newKeys = ArrayTransform(MapKeys(input), keyFunc) | ||
| val newValues = ArrayTransform(MapValues(input), valueFunc) | ||
| Some(Alias(MapFromArrays(newKeys, newValues), expectedName)()) | ||
|
Comment on lines
+214
to
+216
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This creates map twice, can be slower.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
| } else { | ||
| None | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def checkField( | ||
| tableAttr: Attribute, | ||
| queryExpr: NamedExpression, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,17 +17,18 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.analysis | ||
|
|
||
| import java.net.URI | ||
| import java.util.Locale | ||
|
|
||
| import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, AttributeReference, Cast, LessThanOrEqual, Literal} | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { | ||
| class V2AppendDataANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase { | ||
| override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { | ||
| AppendData.byName(table, query) | ||
| } | ||
|
|
@@ -37,7 +38,7 @@ class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { | |
| } | ||
| } | ||
|
|
||
| class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { | ||
| class V2AppendDataStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase { | ||
| override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { | ||
| AppendData.byName(table, query) | ||
| } | ||
|
|
@@ -47,7 +48,7 @@ class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { | |
| } | ||
| } | ||
|
|
||
| class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { | ||
| class V2OverwritePartitionsDynamicANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase { | ||
| override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { | ||
| OverwritePartitionsDynamic.byName(table, query) | ||
| } | ||
|
|
@@ -57,7 +58,7 @@ class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnal | |
| } | ||
| } | ||
|
|
||
| class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { | ||
| class V2OverwritePartitionsDynamicStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase { | ||
| override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { | ||
| OverwritePartitionsDynamic.byName(table, query) | ||
| } | ||
|
|
@@ -67,7 +68,7 @@ class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2Strict | |
| } | ||
| } | ||
|
|
||
| class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite { | ||
| class V2OverwriteByExpressionANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase { | ||
| override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { | ||
| OverwriteByExpression.byName(table, query, Literal(true)) | ||
| } | ||
|
|
@@ -85,7 +86,7 @@ class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisS | |
| } | ||
| } | ||
|
|
||
| class V2OverwriteByExpressionStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite { | ||
| class V2OverwriteByExpressionStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase { | ||
| override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { | ||
| OverwriteByExpression.byName(table, query, Literal(true)) | ||
| } | ||
|
|
@@ -113,7 +114,7 @@ case class TestRelationAcceptAnySchema(output: Seq[AttributeReference]) | |
| override def skipSchemaResolution: Boolean = true | ||
| } | ||
|
|
||
| abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSuite { | ||
| abstract class V2ANSIWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { | ||
|
|
||
| // For Ansi store assignment policy, expression `AnsiCast` is used instead of `Cast`. | ||
| override def checkAnalysis( | ||
|
|
@@ -140,7 +141,7 @@ abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSui | |
| } | ||
| } | ||
|
|
||
| abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseSuite { | ||
| abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { | ||
| override def checkAnalysis( | ||
| inputPlan: LogicalPlan, | ||
| expectedPlan: LogicalPlan, | ||
|
|
@@ -187,7 +188,6 @@ abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseS | |
| "Cannot find data for output column", "'y'")) | ||
| } | ||
|
|
||
|
|
||
| test("byPosition: fail canWrite check") { | ||
| val widerTable = TestRelation(StructType(Seq( | ||
| StructField("a", DoubleType), | ||
|
|
@@ -220,29 +220,15 @@ abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseS | |
| } | ||
| } | ||
|
|
||
| abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { | ||
| abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a small update to make the test suite names consistent: V2XXXSuite |
||
|
|
||
| override def getAnalyzer: Analyzer = { | ||
| val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin) | ||
| catalog.createDatabase( | ||
| CatalogDatabase("default", "", new URI("loc"), Map.empty), | ||
| ignoreIfExists = false) | ||
| new Analyzer(catalog) { | ||
| override val extendedResolutionRules = EliminateSubqueryAliases :: Nil | ||
| } | ||
| } | ||
| override def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Seq(EliminateSubqueryAliases) | ||
|
|
||
| val table = TestRelation(StructType(Seq( | ||
| StructField("x", FloatType), | ||
| StructField("y", FloatType))).toAttributes) | ||
| val table = TestRelation(Seq('x.float, 'y.float)) | ||
|
|
||
| val requiredTable = TestRelation(StructType(Seq( | ||
| StructField("x", FloatType, nullable = false), | ||
| StructField("y", FloatType, nullable = false))).toAttributes) | ||
| val requiredTable = TestRelation(Seq('x.float.notNull, 'y.float.notNull)) | ||
|
|
||
| val widerTable = TestRelation(StructType(Seq( | ||
| StructField("x", DoubleType), | ||
| StructField("y", DoubleType))).toAttributes) | ||
| val widerTable = TestRelation(Seq('x.double, 'y.double)) | ||
|
|
||
| def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan | ||
|
|
||
|
|
@@ -459,6 +445,16 @@ abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { | |
| "Data columns: 'x', 'y', 'z'")) | ||
| } | ||
|
|
||
| test("byName: fail extra data fields in struct") { | ||
| val table = TestRelation(Seq('a.int, 'b.struct('x.int, 'y.int))) | ||
| val query = TestRelation(Seq('b.struct('y.int, 'x.int, 'z.int), 'a.int)) | ||
|
|
||
| val writePlan = byName(table, query) | ||
| assertAnalysisError(writePlan, Seq( | ||
| "Cannot write incompatible data to table", "'table-name'", | ||
| "Cannot write extra fields to struct 'b': 'z'")) | ||
| } | ||
|
|
||
| test("byPosition: basic behavior") { | ||
| val query = TestRelation(StructType(Seq( | ||
| StructField("a", FloatType), | ||
|
|
@@ -616,8 +612,8 @@ abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { | |
| assertNotResolved(parsedPlan) | ||
| assertAnalysisError(parsedPlan, Seq( | ||
| "Cannot write incompatible data to table", "'table-name'", | ||
| "Struct 'col' 0-th field name does not match", "expected 'a', found 'x'", | ||
| "Struct 'col' 1-th field name does not match", "expected 'b', found 'y'")) | ||
| "Cannot find data for output column 'col.a'", | ||
| "Cannot find data for output column 'col.b'")) | ||
| } | ||
|
|
||
| withClue("byPosition") { | ||
|
|
@@ -700,4 +696,40 @@ abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest { | |
| assertNotResolved(parsedPlan2) | ||
| assertAnalysisError(parsedPlan2, Seq("cannot resolve", "a", "given input columns", "x, y")) | ||
| } | ||
|
|
||
| test("SPARK-36498: reorder inner fields with byName mode") { | ||
| val table = TestRelation(Seq('a.int, 'b.struct('x.int, 'y.int))) | ||
| val query = TestRelation(Seq('b.struct('y.int, 'x.byte), 'a.int)) | ||
|
|
||
| val writePlan = byName(table, query).analyze | ||
| assert(writePlan.children.head.schema == table.schema) | ||
| } | ||
|
|
||
| test("SPARK-36498: reorder inner fields in array of struct with byName mode") { | ||
| val table = TestRelation(Seq( | ||
| 'a.int, | ||
| 'arr.array(new StructType().add("x", "int").add("y", "int")))) | ||
| val query = TestRelation(Seq( | ||
| 'arr.array(new StructType().add("y", "int").add("x", "byte")), | ||
| 'a.int)) | ||
|
|
||
| val writePlan = byName(table, query).analyze | ||
| assert(writePlan.children.head.schema == table.schema) | ||
| } | ||
|
|
||
| test("SPARK-36498: reorder inner fields in map of struct with byName mode") { | ||
| val table = TestRelation(Seq( | ||
| 'a.int, | ||
| 'm.map( | ||
| new StructType().add("x", "int").add("y", "int"), | ||
| new StructType().add("x", "int").add("y", "int")))) | ||
| val query = TestRelation(Seq( | ||
| 'm.map( | ||
| new StructType().add("y", "int").add("x", "byte"), | ||
| new StructType().add("y", "int").add("x", "byte")), | ||
| 'a.int)) | ||
|
|
||
| val writePlan = byName(table, query).analyze | ||
| assert(writePlan.children.head.schema == table.schema) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.