Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1818,7 +1818,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
"sqlExpr" -> a.sql,
"cols" -> cols))
}
resolved
resolved match {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is needed as references to nested columns in MERGE are resolved like a.nested_col.i AS i (with added aliases), which breaks our assumption about what kind of keys we expect. Let me know if anyone can spot any edge cases when it is not safe.

case Alias(child: ExtractValue, _) => child
case other => other
}
}

// Expand the star expression using the input plan first. If failed, try resolve
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
object AssignmentUtils extends SQLConfHelper with CastSupport {

/**
* Aligns assignments to match table columns.
* Aligns update assignments to match table columns.
* <p>
* This method processes and reorders given assignments so that each target column gets
* an expression it should be set to. If a column does not have a matching assignment,
Expand All @@ -46,9 +46,9 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
*
* @param attrs table attributes
* @param assignments assignments to align
* @return aligned assignments that match table attributes
* @return aligned update assignments that match table attributes
*/
def alignAssignments(
def alignUpdateAssignments(
attrs: Seq[Attribute],
assignments: Seq[Assignment]): Seq[Assignment] = {

Expand All @@ -70,6 +70,61 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
attrs.zip(output).map { case (attr, expr) => Assignment(attr, expr) }
}

/**
* Aligns insert assignments to match table columns.
* <p>
* This method processes and reorders given assignments so that each target column gets
* an expression it should be set to. There must be exactly one assignment for each top-level
* attribute and its value must be compatible.
* <p>
* Insert assignments cannot refer to nested columns.
*
* @param attrs table attributes
* @param assignments insert assignments to align
* @return aligned insert assignments that match table attributes
*/
def alignInsertAssignments(
attrs: Seq[Attribute],
assignments: Seq[Assignment]): Seq[Assignment] = {

val errors = new mutable.ArrayBuffer[String]()

val (topLevelAssignments, nestedAssignments) = assignments.partition { assignment =>
assignment.key.isInstanceOf[Attribute]
}

if (nestedAssignments.nonEmpty) {
val nestedAssignmentsStr = nestedAssignments.map(_.sql).mkString(", ")
errors += s"INSERT assignment keys cannot be nested fields: $nestedAssignmentsStr"
}

val alignedAssignments = attrs.map { attr =>
val matchingAssignments = topLevelAssignments.collect {
case assignment if assignment.key.semanticEquals(attr) => assignment
}
val resolvedValue = if (matchingAssignments.isEmpty) {
errors += s"No assignment for '${attr.name}'"
attr
} else if (matchingAssignments.length > 1) {
val conflictingValuesStr = matchingAssignments.map(_.value.sql).mkString(", ")
errors += s"Multiple assignments for '${attr.name}': $conflictingValuesStr"
attr
} else {
val colPath = Seq(attr.name)
val actualAttr = restoreActualType(attr)
val value = matchingAssignments.head.value
TableOutputResolver.resolveUpdate(value, actualAttr, conf, err => errors += err, colPath)
}
Assignment(attr, resolvedValue)
}

if (errors.nonEmpty) {
throw QueryCompilationErrors.invalidRowLevelOperationAssignments(assignments, errors.toSeq)
}

alignedAssignments
}

private def restoreActualType(attr: Attribute): Attribute = {
attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast}
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan, MergeIntoTable, UpdateTable}
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
Expand All @@ -42,17 +43,23 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
case u: UpdateTable if !u.skipSchemaResolution && u.resolved &&
supportsRowLevelOperations(u.table) && !u.aligned =>
validateStoreAssignmentPolicy()
val newTable = u.table.transform {
case r: DataSourceV2Relation =>
r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))
}
val newAssignments = AssignmentUtils.alignAssignments(u.table.output, u.assignments)
val newTable = cleanAttrMetadata(u.table)
val newAssignments = AssignmentUtils.alignUpdateAssignments(u.table.output, u.assignments)
u.copy(table = newTable, assignments = newAssignments)

case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned =>
resolveAssignments(u)

case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved =>
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved &&
supportsRowLevelOperations(m.targetTable) && !m.aligned =>
validateStoreAssignmentPolicy()
m.copy(
targetTable = cleanAttrMetadata(m.targetTable),
matchedActions = alignActions(m.targetTable.output, m.matchedActions),
notMatchedActions = alignActions(m.targetTable.output, m.notMatchedActions),
notMatchedBySourceActions = alignActions(m.targetTable.output, m.notMatchedBySourceActions))

case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && !m.aligned =>
resolveAssignments(m)
}

Expand All @@ -63,6 +70,13 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
}
}

private def cleanAttrMetadata(table: LogicalPlan): LogicalPlan = {
table.transform {
case r: DataSourceV2Relation =>
r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))
}
}

private def supportsRowLevelOperations(table: LogicalPlan): Boolean = {
EliminateSubqueryAliases(table) match {
case DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _) => true
Expand Down Expand Up @@ -100,4 +114,19 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
Assignment(cleanedKey, finalValue)
}
}

private def alignActions(
attrs: Seq[Attribute],
actions: Seq[MergeAction]): Seq[MergeAction] = {
actions.map {
case u @ UpdateAction(_, assignments) =>
u.copy(assignments = AssignmentUtils.alignUpdateAssignments(attrs, assignments))
case d: DeleteAction =>
d
case i @ InsertAction(_, assignments) =>
i.copy(assignments = AssignmentUtils.alignInsertAssignments(attrs, assignments))
case other =>
throw new AnalysisException(s"Unexpected resolved action: $other")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,21 @@ case class MergeIntoTable(
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery {

lazy val aligned: Boolean = {
val actions = matchedActions ++ notMatchedActions ++ notMatchedBySourceActions
actions.forall {
case UpdateAction(_, assignments) =>
AssignmentUtils.aligned(targetTable.output, assignments)
case _: DeleteAction =>
true
case InsertAction(_, assignments) =>
AssignmentUtils.aligned(targetTable.output, assignments)
case _ =>
false
}
}

def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty

def skipSchemaResolution: Boolean = targetTable match {
Expand Down
Loading