Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, NamedExpression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

object TableOutputResolver {
def resolveOutputColumns(
Expand All @@ -41,16 +42,7 @@ object TableOutputResolver {

val errors = new mutable.ArrayBuffer[String]()
val resolved: Seq[NamedExpression] = if (byName) {
expected.flatMap { tableAttr =>
query.resolve(Seq(tableAttr.name), conf.resolver) match {
case Some(queryExpr) =>
checkField(tableAttr, queryExpr, byName, conf, err => errors += err)
case None =>
errors += s"Cannot find data for output column '${tableAttr.name}'"
None
}
}

reorderColumnsByName(query.output, expected, conf, errors += _)
} else {
if (expected.size > query.output.size) {
throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError(
Expand All @@ -74,6 +66,160 @@ object TableOutputResolver {
}
}

private def reorderColumnsByName(
inputCols: Seq[NamedExpression],
expectedCols: Seq[Attribute],
conf: SQLConf,
addError: String => Unit,
colPath: Seq[String] = Nil): Seq[NamedExpression] = {
val matchedCols = mutable.HashSet.empty[String]
val reordered = expectedCols.flatMap { expectedCol =>
val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name))
val newColPath = colPath :+ expectedCol.name
if (matched.isEmpty) {
addError(s"Cannot find data for output column '${newColPath.quoted}'")
None
} else if (matched.length > 1) {
addError(s"Ambiguous column name in the input data: '${newColPath.quoted}'")
None
} else {
matchedCols += matched.head.name
val expectedName = expectedCol.name
val matchedCol = matched.head match {
// Save an Alias if we can change the name directly.
case a: Attribute => a.withName(expectedName)
case a: Alias => a.withName(expectedName)
case other => other
}
(matchedCol.dataType, expectedCol.dataType) match {
case (matchedType: StructType, expectedType: StructType) =>
checkNullability(matchedCol, expectedCol, conf, addError, newColPath)
resolveStructType(
matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath)
case (matchedType: ArrayType, expectedType: ArrayType) =>
checkNullability(matchedCol, expectedCol, conf, addError, newColPath)
resolveArrayType(
matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath)
case (matchedType: MapType, expectedType: MapType) =>
checkNullability(matchedCol, expectedCol, conf, addError, newColPath)
resolveMapType(
matchedCol, matchedType, expectedType, expectedName, conf, addError, newColPath)
case _ =>
Comment thread
cloud-fan marked this conversation as resolved.
checkField(expectedCol, matchedCol, byName = true, conf, addError)
}
}
}

if (reordered.length == expectedCols.length) {
if (matchedCols.size < inputCols.length) {
val extraCols = inputCols.filterNot(col => matchedCols.contains(col.name))
.map(col => s"'${col.name}'").mkString(", ")
addError(s"Cannot write extra fields to struct '${colPath.quoted}': $extraCols")
Nil
} else {
reordered
}
} else {
Nil
}
}

private def checkNullability(
input: Expression,
expected: Attribute,
conf: SQLConf,
addError: String => Unit,
colPath: Seq[String]): Unit = {
if (input.nullable && !expected.nullable &&
conf.storeAssignmentPolicy != StoreAssignmentPolicy.LEGACY) {
addError(s"Cannot write nullable values to non-null column '${colPath.quoted}'")
}
}

private def resolveStructType(
input: NamedExpression,
inputType: StructType,
expectedType: StructType,
expectedName: String,
conf: SQLConf,
addError: String => Unit,
colPath: Seq[String]): Option[NamedExpression] = {
val fields = inputType.zipWithIndex.map { case (f, i) =>
Alias(GetStructField(input, i, Some(f.name)), f.name)()
}
val reordered = reorderColumnsByName(fields, expectedType.toAttributes, conf, addError, colPath)
if (reordered.length == expectedType.length) {
val struct = CreateStruct(reordered)
val res = if (input.nullable) {
If(IsNull(input), Literal(null, struct.dataType), struct)
} else {
struct
}
Some(Alias(res, expectedName)())
} else {
None
}
}

private def resolveArrayType(
input: NamedExpression,
inputType: ArrayType,
expectedType: ArrayType,
expectedName: String,
conf: SQLConf,
addError: String => Unit,
colPath: Seq[String]): Option[NamedExpression] = {
if (inputType.containsNull && !expectedType.containsNull) {
addError(s"Cannot write nullable elements to array of non-nulls: '${colPath.quoted}'")
None
} else {
val param = NamedLambdaVariable("x", inputType.elementType, inputType.containsNull)
val fakeAttr = AttributeReference("x", expectedType.elementType, expectedType.containsNull)()
val res = reorderColumnsByName(Seq(param), Seq(fakeAttr), conf, addError, colPath)
if (res.length == 1) {
val func = LambdaFunction(res.head, Seq(param))
Some(Alias(ArrayTransform(input, func), expectedName)())
} else {
None
}
}
}

private def resolveMapType(
input: NamedExpression,
inputType: MapType,
expectedType: MapType,
expectedName: String,
conf: SQLConf,
addError: String => Unit,
colPath: Seq[String]): Option[NamedExpression] = {
if (inputType.valueContainsNull && !expectedType.valueContainsNull) {
addError(s"Cannot write nullable values to map of non-nulls: '${colPath.quoted}'")
None
} else {
val keyParam = NamedLambdaVariable("k", inputType.keyType, nullable = false)
val fakeKeyAttr = AttributeReference("k", expectedType.keyType, nullable = false)()
val resKey = reorderColumnsByName(
Seq(keyParam), Seq(fakeKeyAttr), conf, addError, colPath :+ "key")

val valueParam = NamedLambdaVariable("v", inputType.valueType, inputType.valueContainsNull)
val fakeValueAttr =
AttributeReference("v", expectedType.valueType, expectedType.valueContainsNull)()
val resValue = reorderColumnsByName(
Seq(valueParam), Seq(fakeValueAttr), conf, addError, colPath :+ "value")

if (resKey.length == 1 && resValue.length == 1) {
val keyFunc = LambdaFunction(resKey.head, Seq(keyParam))
val valueFunc = LambdaFunction(resValue.head, Seq(valueParam))
val newKeys = ArrayTransform(MapKeys(input), keyFunc)
val newValues = ArrayTransform(MapValues(input), valueFunc)
Some(Alias(MapFromArrays(newKeys, newValues), expectedName)())
Comment on lines +214 to +216

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransformValues(TransformKeys(input, keyFunc), valueFunc)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates map twice, can be slower.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

} else {
None
}
}
}

private def checkField(
tableAttr: Attribute,
queryExpr: NamedExpression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@

package org.apache.spark.sql.catalyst.analysis

import java.net.URI
import java.util.Locale

import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, AttributeReference, Cast, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
import org.apache.spark.sql.types._

class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite {
class V2AppendDataANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase {
override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
AppendData.byName(table, query)
}
Expand All @@ -37,7 +38,7 @@ class V2AppendDataANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite {
}
}

class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite {
class V2AppendDataStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase {
override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
AppendData.byName(table, query)
}
Expand All @@ -47,7 +48,7 @@ class V2AppendDataStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite {
}
}

class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite {
class V2OverwritePartitionsDynamicANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase {
override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
OverwritePartitionsDynamic.byName(table, query)
}
Expand All @@ -57,7 +58,7 @@ class V2OverwritePartitionsDynamicANSIAnalysisSuite extends DataSourceV2ANSIAnal
}
}

class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite {
class V2OverwritePartitionsDynamicStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase {
override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
OverwritePartitionsDynamic.byName(table, query)
}
Expand All @@ -67,7 +68,7 @@ class V2OverwritePartitionsDynamicStrictAnalysisSuite extends DataSourceV2Strict
}
}

class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisSuite {
class V2OverwriteByExpressionANSIAnalysisSuite extends V2ANSIWriteAnalysisSuiteBase {
override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
OverwriteByExpression.byName(table, query, Literal(true))
}
Expand All @@ -85,7 +86,7 @@ class V2OverwriteByExpressionANSIAnalysisSuite extends DataSourceV2ANSIAnalysisS
}
}

class V2OverwriteByExpressionStrictAnalysisSuite extends DataSourceV2StrictAnalysisSuite {
class V2OverwriteByExpressionStrictAnalysisSuite extends V2StrictWriteAnalysisSuiteBase {
override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
OverwriteByExpression.byName(table, query, Literal(true))
}
Expand Down Expand Up @@ -113,7 +114,7 @@ case class TestRelationAcceptAnySchema(output: Seq[AttributeReference])
override def skipSchemaResolution: Boolean = true
}

abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSuite {
abstract class V2ANSIWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase {

// For Ansi store assignment policy, expression `AnsiCast` is used instead of `Cast`.
override def checkAnalysis(
Expand All @@ -140,7 +141,7 @@ abstract class DataSourceV2ANSIAnalysisSuite extends DataSourceV2AnalysisBaseSui
}
}

abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseSuite {
abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase {
override def checkAnalysis(
inputPlan: LogicalPlan,
expectedPlan: LogicalPlan,
Expand Down Expand Up @@ -187,7 +188,6 @@ abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseS
"Cannot find data for output column", "'y'"))
}


test("byPosition: fail canWrite check") {
val widerTable = TestRelation(StructType(Seq(
StructField("a", DoubleType),
Expand Down Expand Up @@ -220,29 +220,15 @@ abstract class DataSourceV2StrictAnalysisSuite extends DataSourceV2AnalysisBaseS
}
}

abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest {
abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a small update to make the test suite names consistent: V2XXXSuite


override def getAnalyzer: Analyzer = {
val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin)
catalog.createDatabase(
CatalogDatabase("default", "", new URI("loc"), Map.empty),
ignoreIfExists = false)
new Analyzer(catalog) {
override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
}
}
override def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Seq(EliminateSubqueryAliases)

val table = TestRelation(StructType(Seq(
StructField("x", FloatType),
StructField("y", FloatType))).toAttributes)
val table = TestRelation(Seq('x.float, 'y.float))

val requiredTable = TestRelation(StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("y", FloatType, nullable = false))).toAttributes)
val requiredTable = TestRelation(Seq('x.float.notNull, 'y.float.notNull))

val widerTable = TestRelation(StructType(Seq(
StructField("x", DoubleType),
StructField("y", DoubleType))).toAttributes)
val widerTable = TestRelation(Seq('x.double, 'y.double))

def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan

Expand Down Expand Up @@ -459,6 +445,16 @@ abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest {
"Data columns: 'x', 'y', 'z'"))
}

test("byName: fail extra data fields in struct") {
val table = TestRelation(Seq('a.int, 'b.struct('x.int, 'y.int)))
val query = TestRelation(Seq('b.struct('y.int, 'x.int, 'z.int), 'a.int))

val writePlan = byName(table, query)
assertAnalysisError(writePlan, Seq(
"Cannot write incompatible data to table", "'table-name'",
"Cannot write extra fields to struct 'b': 'z'"))
}

test("byPosition: basic behavior") {
val query = TestRelation(StructType(Seq(
StructField("a", FloatType),
Expand Down Expand Up @@ -616,8 +612,8 @@ abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest {
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
"Cannot write incompatible data to table", "'table-name'",
"Struct 'col' 0-th field name does not match", "expected 'a', found 'x'",
"Struct 'col' 1-th field name does not match", "expected 'b', found 'y'"))
"Cannot find data for output column 'col.a'",
"Cannot find data for output column 'col.b'"))
}

withClue("byPosition") {
Expand Down Expand Up @@ -700,4 +696,40 @@ abstract class DataSourceV2AnalysisBaseSuite extends AnalysisTest {
assertNotResolved(parsedPlan2)
assertAnalysisError(parsedPlan2, Seq("cannot resolve", "a", "given input columns", "x, y"))
}

test("SPARK-36498: reorder inner fields with byName mode") {
val table = TestRelation(Seq('a.int, 'b.struct('x.int, 'y.int)))
val query = TestRelation(Seq('b.struct('y.int, 'x.byte), 'a.int))

val writePlan = byName(table, query).analyze
assert(writePlan.children.head.schema == table.schema)
}

test("SPARK-36498: reorder inner fields in array of struct with byName mode") {
val table = TestRelation(Seq(
'a.int,
'arr.array(new StructType().add("x", "int").add("y", "int"))))
val query = TestRelation(Seq(
'arr.array(new StructType().add("y", "int").add("x", "byte")),
'a.int))

val writePlan = byName(table, query).analyze
assert(writePlan.children.head.schema == table.schema)
}

test("SPARK-36498: reorder inner fields in map of struct with byName mode") {
val table = TestRelation(Seq(
'a.int,
'm.map(
new StructType().add("x", "int").add("y", "int"),
new StructType().add("x", "int").add("y", "int"))))
val query = TestRelation(Seq(
'm.map(
new StructType().add("y", "int").add("x", "byte"),
new StructType().add("y", "int").add("x", "byte")),
'a.int))

val writePlan = byName(table, query).analyze
assert(writePlan.children.head.schema == table.schema)
}
}