diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 62555c9a99cc3..5a50c13839300 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1818,7 +1818,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
"sqlExpr" -> a.sql,
"cols" -> cols))
}
- resolved
+ resolved match {
+ case Alias(child: ExtractValue, _) => child
+ case other => other
+ }
}
// Expand the star expression using the input plan first. If failed, try resolve
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
index 265909d3a7e7d..6d8118548fb41 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
object AssignmentUtils extends SQLConfHelper with CastSupport {
/**
- * Aligns assignments to match table columns.
+ * Aligns update assignments to match table columns.
*
* This method processes and reorders given assignments so that each target column gets
* an expression it should be set to. If a column does not have a matching assignment,
@@ -46,9 +46,9 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
*
* @param attrs table attributes
* @param assignments assignments to align
- * @return aligned assignments that match table attributes
+ * @return aligned update assignments that match table attributes
*/
- def alignAssignments(
+ def alignUpdateAssignments(
attrs: Seq[Attribute],
assignments: Seq[Assignment]): Seq[Assignment] = {
@@ -70,6 +70,61 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
attrs.zip(output).map { case (attr, expr) => Assignment(attr, expr) }
}
+ /**
+ * Aligns insert assignments to match table columns.
+ *
+ * This method processes and reorders given assignments so that each target column gets
+ * an expression it should be set to. There must be exactly one assignment for each top-level
+ * attribute and its value must be compatible.
+ *
+ * Insert assignments cannot refer to nested columns.
+ *
+ * @param attrs table attributes
+ * @param assignments insert assignments to align
+ * @return aligned insert assignments that match table attributes
+ */
+ def alignInsertAssignments(
+ attrs: Seq[Attribute],
+ assignments: Seq[Assignment]): Seq[Assignment] = {
+
+ val errors = new mutable.ArrayBuffer[String]()
+
+ val (topLevelAssignments, nestedAssignments) = assignments.partition { assignment =>
+ assignment.key.isInstanceOf[Attribute]
+ }
+
+ if (nestedAssignments.nonEmpty) {
+ val nestedAssignmentsStr = nestedAssignments.map(_.sql).mkString(", ")
+ errors += s"INSERT assignment keys cannot be nested fields: $nestedAssignmentsStr"
+ }
+
+ val alignedAssignments = attrs.map { attr =>
+ val matchingAssignments = topLevelAssignments.collect {
+ case assignment if assignment.key.semanticEquals(attr) => assignment
+ }
+ val resolvedValue = if (matchingAssignments.isEmpty) {
+ errors += s"No assignment for '${attr.name}'"
+ attr
+ } else if (matchingAssignments.length > 1) {
+ val conflictingValuesStr = matchingAssignments.map(_.value.sql).mkString(", ")
+ errors += s"Multiple assignments for '${attr.name}': $conflictingValuesStr"
+ attr
+ } else {
+ val colPath = Seq(attr.name)
+ val actualAttr = restoreActualType(attr)
+ val value = matchingAssignments.head.value
+ TableOutputResolver.resolveUpdate(value, actualAttr, conf, err => errors += err, colPath)
+ }
+ Assignment(attr, resolvedValue)
+ }
+
+ if (errors.nonEmpty) {
+ throw QueryCompilationErrors.invalidRowLevelOperationAssignments(assignments, errors.toSeq)
+ }
+
+ alignedAssignments
+ }
+
private def restoreActualType(attr: Attribute): Attribute = {
attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
index 596dc00b9176b..b22c91973f794 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast}
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast}
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
-import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan, MergeIntoTable, UpdateTable}
+import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
@@ -42,17 +43,23 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
case u: UpdateTable if !u.skipSchemaResolution && u.resolved &&
supportsRowLevelOperations(u.table) && !u.aligned =>
validateStoreAssignmentPolicy()
- val newTable = u.table.transform {
- case r: DataSourceV2Relation =>
- r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))
- }
- val newAssignments = AssignmentUtils.alignAssignments(u.table.output, u.assignments)
+ val newTable = cleanAttrMetadata(u.table)
+ val newAssignments = AssignmentUtils.alignUpdateAssignments(u.table.output, u.assignments)
u.copy(table = newTable, assignments = newAssignments)
case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned =>
resolveAssignments(u)
- case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved =>
+ case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved &&
+ supportsRowLevelOperations(m.targetTable) && !m.aligned =>
+ validateStoreAssignmentPolicy()
+ m.copy(
+ targetTable = cleanAttrMetadata(m.targetTable),
+ matchedActions = alignActions(m.targetTable.output, m.matchedActions),
+ notMatchedActions = alignActions(m.targetTable.output, m.notMatchedActions),
+ notMatchedBySourceActions = alignActions(m.targetTable.output, m.notMatchedBySourceActions))
+
+ case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && !m.aligned =>
resolveAssignments(m)
}
@@ -63,6 +70,13 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
}
}
+ private def cleanAttrMetadata(table: LogicalPlan): LogicalPlan = {
+ table.transform {
+ case r: DataSourceV2Relation =>
+ r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))
+ }
+ }
+
private def supportsRowLevelOperations(table: LogicalPlan): Boolean = {
EliminateSubqueryAliases(table) match {
case DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _) => true
@@ -100,4 +114,19 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
Assignment(cleanedKey, finalValue)
}
}
+
+ private def alignActions(
+ attrs: Seq[Attribute],
+ actions: Seq[MergeAction]): Seq[MergeAction] = {
+ actions.map {
+ case u @ UpdateAction(_, assignments) =>
+ u.copy(assignments = AssignmentUtils.alignUpdateAssignments(attrs, assignments))
+ case d: DeleteAction =>
+ d
+ case i @ InsertAction(_, assignments) =>
+ i.copy(assignments = AssignmentUtils.alignInsertAssignments(attrs, assignments))
+ case other =>
+ throw new AnalysisException(s"Unexpected resolved action: $other")
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 68943c918b12f..a5ef21c4db8c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -708,6 +708,21 @@ case class MergeIntoTable(
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery {
+
+ lazy val aligned: Boolean = {
+ val actions = matchedActions ++ notMatchedActions ++ notMatchedBySourceActions
+ actions.forall {
+ case UpdateAction(_, assignments) =>
+ AssignmentUtils.aligned(targetTable.output, assignments)
+ case _: DeleteAction =>
+ true
+ case InsertAction(_, assignments) =>
+ AssignmentUtils.aligned(targetTable.output, assignments)
+ case _ =>
+ false
+ }
+ }
+
def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty
def skipSchemaResolution: Boolean = targetTable match {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuite.scala
new file mode 100644
index 0000000000000..e6e124f1d5fd9
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuite.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import java.util.Collections
+
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito.{mock, when}
+import org.mockito.invocation.InvocationOnMock
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, FunctionRegistry, NoSuchTableException, ResolveSessionCatalog}
+import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, CatalogV2Util, Column, ColumnDefaultValue, Identifier, SupportsRowLevelOperations, TableCapability, TableCatalog}
+import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform}
+import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{BooleanType, IntegerType, StructType}
+
+class AlignAssignmentsSuite extends AnalysisTest {
+
+ private val primitiveTable = {
+ val t = mock(classOf[SupportsRowLevelOperations])
+ val schema = new StructType()
+ .add("i", "INT", nullable = false)
+ .add("l", "LONG")
+ .add("txt", "STRING")
+ when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
+ when(t.partitioning()).thenReturn(Array.empty[Transform])
+ t
+ }
+
+ private val primitiveTableSource = {
+ val t = mock(classOf[SupportsRowLevelOperations])
+ val schema = new StructType()
+ .add("l", "LONG")
+ .add("txt", "STRING")
+ .add("i", "INT")
+ when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
+ when(t.partitioning()).thenReturn(Array.empty[Transform])
+ t
+ }
+
+ private val nestedStructTable = {
+ val t = mock(classOf[SupportsRowLevelOperations])
+ val schema = new StructType()
+ .add("i", "INT")
+ .add(
+ "s",
+ "STRUCT>",
+ nullable = false)
+ .add("txt", "STRING")
+ when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
+ when(t.partitioning()).thenReturn(Array.empty[Transform])
+ t
+ }
+
+ private val nestedStructTableSource = {
+ val t = mock(classOf[SupportsRowLevelOperations])
+ val schema = new StructType()
+ .add("i", "INT")
+ .add("s", "STRUCT>")
+ .add("txt", "STRING")
+ when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
+ when(t.partitioning()).thenReturn(Array.empty[Transform])
+ t
+ }
+
+ private val mapArrayTable = {
+ val t = mock(classOf[SupportsRowLevelOperations])
+ val schema = new StructType()
+ .add("i", "INT")
+ .add("a", "ARRAY>")
+ .add("m", "MAP")
+ .add("txt", "STRING")
+ when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
+ when(t.partitioning()).thenReturn(Array.empty[Transform])
+ t
+ }
+
+ private val charVarcharTable = {
+ val t = mock(classOf[SupportsRowLevelOperations])
+ val schema = new StructType()
+ .add("c", "CHAR(5)")
+ .add(
+ "s",
+ "STRUCT",
+ nullable = false)
+ .add(
+ "a",
+ "ARRAY>",
+ nullable = false)
+ .add(
+ "mk",
+ "MAP, STRING>",
+ nullable = false)
+ .add(
+ "mv",
+ "MAP>",
+ nullable = false)
+ when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
+ when(t.partitioning()).thenReturn(Array.empty[Transform])
+ t
+ }
+
+ private val acceptsAnySchemaTable = {
+ val t = mock(classOf[SupportsRowLevelOperations])
+ val schema = new StructType()
+ .add("i", "INT", nullable = false)
+ .add("l", "LONG")
+ .add("txt", "STRING")
+ when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
+ when(t.partitioning()).thenReturn(Array.empty[Transform])
+ when(t.capabilities()).thenReturn(Collections.singleton(TableCapability.ACCEPT_ANY_SCHEMA))
+ t
+ }
+
+ private val defaultValuesTable = {
+ val t = mock(classOf[SupportsRowLevelOperations])
+ val iDefault = new ColumnDefaultValue("42", LiteralValue(42, IntegerType))
+ when(t.columns()).thenReturn(Array(
+ Column.create("b", BooleanType, true, null, null),
+ Column.create("i", IntegerType, true, null, iDefault, null)))
+ when(t.partitioning()).thenReturn(Array.empty[Transform])
+ t
+ }
+
+ private val v2Catalog = {
+ val newCatalog = mock(classOf[TableCatalog])
+ when(newCatalog.loadTable(any())).thenAnswer((invocation: InvocationOnMock) => {
+ val ident = invocation.getArgument[Identifier](0)
+ ident.name match {
+ case "primitive_table" => primitiveTable
+ case "primitive_table_src" => primitiveTableSource
+ case "nested_struct_table" => nestedStructTable
+ case "nested_struct_table_src" => nestedStructTableSource
+ case "map_array_table" => mapArrayTable
+ case "char_varchar_table" => charVarcharTable
+ case "accepts_any_schema_table" => acceptsAnySchemaTable
+ case "default_values_table" => defaultValuesTable
+ case name => throw new NoSuchTableException(Seq(name))
+ }
+ })
+ when(newCatalog.name()).thenReturn("cat")
+ newCatalog
+ }
+
+ private val v1SessionCatalog =
+ new SessionCatalog(new InMemoryCatalog(), FunctionRegistry.builtin, new SQLConf())
+
+ private val v2SessionCatalog = new V2SessionCatalog(v1SessionCatalog)
+
+ private val catalogManager = {
+ val manager = mock(classOf[CatalogManager])
+ when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => {
+ invocation.getArgument[String](0) match {
+ case "testcat" => v2Catalog
+ case CatalogManager.SESSION_CATALOG_NAME => v2SessionCatalog
+ case name => throw new CatalogNotFoundException(s"No such catalog: $name")
+ }
+ })
+ when(manager.currentCatalog).thenReturn(v2Catalog)
+ when(manager.currentNamespace).thenReturn(Array.empty[String])
+ when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog)
+ when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog)
+ manager
+ }
+
+ protected def parseAndResolve(query: String): LogicalPlan = {
+ val analyzer = new Analyzer(catalogManager) {
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq(
+ new ResolveSessionCatalog(catalogManager))
+ }
+ val analyzed = analyzer.execute(CatalystSqlParser.parsePlan(query))
+ analyzer.checkAnalysis(analyzed)
+ analyzed
+ }
+
+ protected def assertNullCheckExists(plan: LogicalPlan, colPath: Seq[String]): Unit = {
+ val asserts = plan.expressions.flatMap(e => e.collect {
+ case assert: AssertNotNull if assert.walkedTypePath == colPath => assert
+ })
+ assert(asserts.nonEmpty, s"Must have NOT NULL checks for col $colPath")
+ }
+
+ protected def assertAnalysisException(query: String, messages: String*): Unit = {
+ val exception = intercept[AnalysisException] {
+ parseAndResolve(query)
+ }
+ messages.foreach(message => assert(exception.message.contains(message)))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
new file mode 100644
index 0000000000000..e569166633908
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
@@ -0,0 +1,937 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, AttributeReference, BooleanLiteral, Cast, CheckOverflowInTableInsert, CreateNamedStruct, EvalMode, GetStructField, IntegerLiteral, LambdaFunction, LongLiteral, MapFromArrays, StringLiteral}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, StaticInvoke}
+import org.apache.spark.sql.catalyst.plans.logical.{Assignment, InsertAction, MergeAction, MergeIntoTable, UpdateAction}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
+import org.apache.spark.sql.types.IntegerType
+
+class AlignMergeAssignmentsSuite extends AlignAssignmentsSuite {
+
+ test("align assignments (primitive types)") {
+ val (matchedActions, notMatchedActions, notMatchedBySourceActions) =
+ parseAndAlignAssignments(
+ """MERGE INTO primitive_table t USING primitive_table_src s
+ |ON t.l = s.l
+ |WHEN MATCHED THEN
+ | UPDATE SET t.txt = 'a', t.i = s.i
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |WHEN NOT MATCHED BY SOURCE THEN
+ | UPDATE SET t.txt = "error", t.i = CAST(null AS INT)""".stripMargin)
+
+ matchedActions match {
+ case Seq(UpdateAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(i: AttributeReference, AssertNotNull(iValue: AttributeReference, _)),
+ Assignment(l: AttributeReference, lValue: AttributeReference),
+ Assignment(txt: AttributeReference, StringLiteral("a"))) =>
+
+ assert(i.name == "i" && iValue.name == "i" && i != iValue)
+ assert(l.name == "l" && l == lValue)
+ assert(txt.name == "txt")
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+
+ notMatchedActions match {
+ case Seq(InsertAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(i: AttributeReference, AssertNotNull(iValue: AttributeReference, _)),
+ Assignment(l: AttributeReference, lValue: AttributeReference),
+ Assignment(txt: AttributeReference, txtValue: AttributeReference)) =>
+
+ assert(i.name == "i" && iValue.name == "i" && i != iValue)
+ assert(l.name == "l" && lValue.name == "l" && l != lValue)
+ assert(txt.name == "txt" && txtValue.name == "txt" && txt != txtValue)
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+
+ notMatchedBySourceActions match {
+ case Seq(UpdateAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(i: AttributeReference, AssertNotNull(_: Cast, _)),
+ Assignment(l: AttributeReference, lValue: AttributeReference),
+ Assignment(txt: AttributeReference, StringLiteral("error"))) =>
+
+ assert(i.name == "i")
+ assert(l.name == "l" && l == lValue)
+ assert(txt.name == "txt")
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+ }
+
+ test("align assignments (top-level structs)") {
+ val (matchedActions, notMatchedActions, notMatchedBySourceActions) =
+ parseAndAlignAssignments(
+ """MERGE INTO nested_struct_table t USING nested_struct_table_src s
+ |ON t.i = s.i
+ |WHEN MATCHED THEN
+ | UPDATE SET t.s = named_struct('n_s', named_struct('dn_i', 1, 'dn_l', 100L), 'n_i', 1)
+ |WHEN NOT MATCHED THEN
+ | INSERT (txt, s, i) VALUES (
+ | 'new',
+ | named_struct('n_s', named_struct('dn_i', 1, 'dn_l', 100L), 'n_i', 1),
+ | 1)
+ |WHEN NOT MATCHED BY SOURCE THEN
+ | UPDATE SET t.s = named_struct('n_s', named_struct('dn_i', 1, 'dn_l', 100L), 'n_i', 1)
+ |""".stripMargin)
+
+ def checkStruct(value: CreateNamedStruct): Unit = {
+ value.children match {
+ case Seq(
+ StringLiteral("n_i"), GetStructField(_, _, Some("n_i")),
+ StringLiteral("n_s"), nsValue: CreateNamedStruct) =>
+
+ nsValue.children match {
+ case Seq(
+ StringLiteral("dn_i"), GetStructField(_, _, Some("dn_i")),
+ StringLiteral("dn_l"), GetStructField(_, _, Some("dn_l"))) =>
+ // OK
+
+ case nsValueChildren =>
+ fail(s"Unexpected children for 's.n_s': $nsValueChildren")
+ }
+
+ case sValueChildren =>
+ fail(s"Unexpected children for 's': $sValueChildren")
+ }
+ }
+
+ matchedActions match {
+ case Seq(UpdateAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(i: AttributeReference, iValue: AttributeReference),
+ Assignment(s: AttributeReference, sValue: CreateNamedStruct),
+ Assignment(txt: AttributeReference, txtValue: AttributeReference)) =>
+
+ assert(i.name == "i" && i == iValue)
+
+ assert(s.name == "s")
+ checkStruct(sValue)
+
+ assert(txt.name == "txt" && txt == txtValue)
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+
+ notMatchedActions match {
+ case Seq(InsertAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(i: AttributeReference, IntegerLiteral(1)),
+ Assignment(s: AttributeReference, sValue: CreateNamedStruct),
+ Assignment(txt: AttributeReference, StringLiteral("new"))) =>
+
+ assert(i.name == "i")
+
+ assert(s.name == "s")
+ checkStruct(sValue)
+
+ assert(txt.name == "txt")
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+
+ notMatchedBySourceActions match {
+ case Seq(UpdateAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(i: AttributeReference, iValue: AttributeReference),
+ Assignment(s: AttributeReference, sValue: CreateNamedStruct),
+ Assignment(txt: AttributeReference, txtValue: AttributeReference)) =>
+
+ assert(i.name == "i" && i == iValue)
+
+ assert(s.name == "s")
+ checkStruct(sValue)
+
+ assert(txt.name == "txt" && txt == txtValue)
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+ }
+
+ test("align UPDATE assignments with references to nested attributes on both sides") {
+ val (matchedActions, _, _) =
+ parseAndAlignAssignments(
+ """MERGE INTO nested_struct_table_src t USING nested_struct_table_src src
+ |ON t.i = src.i
+ |WHEN MATCHED THEN
+ | UPDATE SET t.s.n_s = src.s.n_s
+ |""".stripMargin)
+
+ matchedActions match {
+ case Seq(UpdateAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(i: AttributeReference, iValue: AttributeReference),
+ Assignment(s: AttributeReference, sValue: CreateNamedStruct),
+ Assignment(txt: AttributeReference, txtValue: AttributeReference)) =>
+
+ assert(i.name == "i" && i == iValue)
+
+ assert(s.name == "s")
+ sValue.children match {
+ case Seq(
+ StringLiteral("n_i"),
+ GetStructField(_, _, Some("n_i")),
+ StringLiteral("n_s"),
+ GetStructField(source: AttributeReference, _, Some("n_s"))) =>
+
+ assert(source.name == "s" && s != source)
+
+ case sValueChildren =>
+ fail(s"Unexpected children for 's': $sValueChildren")
+ }
+
+ assert(txt.name == "txt" && txt == txtValue)
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+ }
+
+ test("align assignments (nested structs)") {
+ val (matchedActions, notMatchedActions, notMatchedBySourceActions) =
+ parseAndAlignAssignments(
+ """MERGE INTO nested_struct_table t USING nested_struct_table_src s
+ |ON t.i = s.i
+ |WHEN MATCHED THEN
+ | UPDATE SET t.s.n_s = named_struct('dn_l', 1L, 'dn_i', 1)
+ |WHEN NOT MATCHED THEN
+ | INSERT (txt, s, i) VALUES (
+ | 'new',
+ | named_struct('n_i', 1, 'n_s', named_struct('dn_l', 1L, 'dn_i', 1)),
+ | 1)
+ |WHEN NOT MATCHED BY SOURCE THEN
+ | UPDATE SET t.s.n_s = named_struct('dn_l', 1L, 'dn_i', 1)
+ |""".stripMargin)
+
+ def checkNestedStruct(value: CreateNamedStruct): Unit = {
+ value.children match {
+ case Seq(
+ StringLiteral("dn_i"), GetStructField(_, _, Some("dn_i")),
+ StringLiteral("dn_l"), GetStructField(_, _, Some("dn_l"))) =>
+ // OK
+
+ case nsValueChildren =>
+ fail(s"Unexpected children for 's.n_s': $nsValueChildren")
+ }
+ }
+
+ matchedActions match {
+ case Seq(UpdateAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(i: AttributeReference, iValue: AttributeReference),
+ Assignment(s: AttributeReference, sValue: CreateNamedStruct),
+ Assignment(txt: AttributeReference, txtValue: AttributeReference)) =>
+
+ assert(i.name == "i" && i == iValue)
+
+ assert(s.name == "s")
+ sValue.children match {
+ case Seq(
+ StringLiteral("n_i"), GetStructField(_, _, Some("n_i")),
+ StringLiteral("n_s"), nsValue: CreateNamedStruct) =>
+ checkNestedStruct(nsValue)
+
+ case sValueChildren =>
+ fail(s"Unexpected children for 's': $sValueChildren")
+ }
+
+ assert(txt.name == "txt" && txt == txtValue)
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+
+ notMatchedActions match {
+ case Seq(InsertAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(i: AttributeReference, IntegerLiteral(1)),
+ Assignment(s: AttributeReference, sValue: CreateNamedStruct),
+ Assignment(txt: AttributeReference, StringLiteral("new"))) =>
+
+ assert(i.name == "i")
+
+ assert(s.name == "s")
+ sValue.children match {
+ case Seq(
+ StringLiteral("n_i"), GetStructField(_, _, Some("n_i")),
+ StringLiteral("n_s"), nsValue: CreateNamedStruct) =>
+ checkNestedStruct(nsValue)
+
+ case sValueChildren =>
+ fail(s"Unexpected children for 's': $sValueChildren")
+ }
+
+ assert(txt.name == "txt")
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+
+ notMatchedBySourceActions match {
+ case Seq(UpdateAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(i: AttributeReference, iValue: AttributeReference),
+ Assignment(s: AttributeReference, sValue: CreateNamedStruct),
+ Assignment(txt: AttributeReference, txtValue: AttributeReference)) =>
+
+ assert(i.name == "i" && i == iValue)
+
+ assert(s.name == "s")
+ sValue.children match {
+ case Seq(
+ StringLiteral("n_i"), GetStructField(_, _, Some("n_i")),
+ StringLiteral("n_s"), nsValue: CreateNamedStruct) =>
+ checkNestedStruct(nsValue)
+
+ case sValueChildren =>
+ fail(s"Unexpected children for 's': $sValueChildren")
+ }
+
+ assert(txt.name == "txt" && txt == txtValue)
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+ }
+
+ test("align assignments (char and varchar types)") {
+ val (matchedActions, notMatchedActions, notMatchedBySourceActions) =
+ parseAndAlignAssignments(
+ """MERGE INTO char_varchar_table t USING char_varchar_table src
+ |ON t.c = src.c
+ |WHEN MATCHED THEN
+ | UPDATE SET
+ | c = 'a',
+ | a = array(named_struct('n_i', 1, 'n_vc', 3)),
+ | s.n_vc = 'a',
+ | mv = map('v', named_struct('n_vc', 'a', 'n_i', 1)),
+ | mk = map(named_struct('n_vc', 'a', 'n_i', 1), 'v')
+ |WHEN NOT MATCHED THEN
+ | INSERT (c, a, s, mv, mk) VALUES (
+ | 'a',
+ | array(named_struct('n_i', 1, 'n_vc', 3)),
+ | named_struct('n_vc', 3, 'n_i', 1),
+ | map('v', named_struct('n_vc', 'a', 'n_i', 1)),
+ | map(named_struct('n_vc', 'a', 'n_i', 1), 'v'))
+ |WHEN NOT MATCHED BY SOURCE THEN
+ | UPDATE SET
+ | c = 'a',
+ | a = array(named_struct('n_i', 1, 'n_vc', 3)),
+ | s = named_struct('n_vc', 3, 'n_i', 1),
+ | mv = map('v', named_struct('n_vc', 'a', 'n_i', 1)),
+ | mk = map(named_struct('n_vc', 'a', 'n_i', 1), 'v')
+ |""".stripMargin)
+
+ def checkStruct(value: CreateNamedStruct): Unit = {
+ value.children match {
+ case Seq(
+ StringLiteral("n_i"), GetStructField(_, _, Some("n_i")),
+ StringLiteral("n_vc"), invoke: StaticInvoke) =>
+
+ assert(invoke.arguments.length == 2)
+ assert(invoke.functionName == "varcharTypeWriteSideCheck")
+
+ case sValueChildren =>
+ fail(s"Unexpected children for 's': $sValueChildren")
+ }
+ }
+
+ def checkArray(value: ArrayTransform): Unit = {
+ val lambda = value.function.asInstanceOf[LambdaFunction]
+ lambda.function match {
+ case CreateNamedStruct(Seq(
+ StringLiteral("n_i"), GetStructField(_, _, Some("n_i")),
+ StringLiteral("n_vc"), invoke: StaticInvoke)) =>
+
+ assert(invoke.arguments.length == 2)
+ assert(invoke.functionName == "varcharTypeWriteSideCheck")
+
+ case func =>
+ fail(s"Unexpected lambda function: $func")
+ }
+ }
+
+ def checkMapWithStructKey(value: MapFromArrays): Unit = {
+ val keyTransform = value.left.asInstanceOf[ArrayTransform]
+ val keyLambda = keyTransform.function.asInstanceOf[LambdaFunction]
+ keyLambda.function match {
+ case CreateNamedStruct(Seq(
+ StringLiteral("n_i"), GetStructField(_, _, Some("n_i")),
+ StringLiteral("n_vc"), invoke: StaticInvoke)) =>
+
+ assert(invoke.arguments.length == 2)
+ assert(invoke.functionName == "varcharTypeWriteSideCheck")
+
+ case func =>
+ fail(s"Unexpected key lambda function: $func")
+ }
+ }
+
+ def checkMapWithStructValue(value: MapFromArrays): Unit = {
+ val valueTransform = value.right.asInstanceOf[ArrayTransform]
+ val valueLambda = valueTransform.function.asInstanceOf[LambdaFunction]
+ valueLambda.function match {
+ case CreateNamedStruct(Seq(
+ StringLiteral("n_i"), GetStructField(_, _, Some("n_i")),
+ StringLiteral("n_vc"), invoke: StaticInvoke)) =>
+
+ assert(invoke.arguments.length == 2)
+ assert(invoke.functionName == "varcharTypeWriteSideCheck")
+
+ case func =>
+ fail(s"Unexpected key lambda function: $func")
+ }
+ }
+
+ matchedActions match {
+ case Seq(UpdateAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(c: AttributeReference, cValue: StaticInvoke),
+ Assignment(s: AttributeReference, sValue: CreateNamedStruct),
+ Assignment(a: AttributeReference, aValue: ArrayTransform),
+ Assignment(mk: AttributeReference, mkValue: MapFromArrays),
+ Assignment(mv: AttributeReference, mvValue: MapFromArrays)) =>
+
+ assert(c.name == "c")
+ assert(cValue.arguments.length == 2)
+ assert(cValue.functionName == "charTypeWriteSideCheck")
+
+ assert(s.name == "s")
+ checkStruct(sValue)
+
+ assert(a.name == "a")
+ checkArray(aValue)
+
+ assert(mk.name == "mk")
+ checkMapWithStructKey(mkValue)
+
+ assert(mv.name == "mv")
+ checkMapWithStructValue(mvValue)
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+
+ notMatchedActions match {
+ case Seq(InsertAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(c: AttributeReference, cValue: StaticInvoke),
+ Assignment(s: AttributeReference, sValue: CreateNamedStruct),
+ Assignment(a: AttributeReference, aValue: ArrayTransform),
+ Assignment(mk: AttributeReference, mkValue: MapFromArrays),
+ Assignment(mv: AttributeReference, mvValue: MapFromArrays)) =>
+
+ assert(c.name == "c")
+ assert(cValue.arguments.length == 2)
+ assert(cValue.functionName == "charTypeWriteSideCheck")
+
+ assert(s.name == "s")
+ checkStruct(sValue)
+
+ assert(a.name == "a")
+ checkArray(aValue)
+
+ assert(mk.name == "mk")
+ checkMapWithStructKey(mkValue)
+
+ assert(mv.name == "mv")
+ checkMapWithStructValue(mvValue)
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+
+ notMatchedBySourceActions match {
+ case Seq(UpdateAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(c: AttributeReference, cValue: StaticInvoke),
+ Assignment(s: AttributeReference, sValue: CreateNamedStruct),
+ Assignment(a: AttributeReference, aValue: ArrayTransform),
+ Assignment(mk: AttributeReference, mkValue: MapFromArrays),
+ Assignment(mv: AttributeReference, mvValue: MapFromArrays)) =>
+
+ assert(c.name == "c")
+ assert(cValue.arguments.length == 2)
+ assert(cValue.functionName == "charTypeWriteSideCheck")
+
+ assert(s.name == "s")
+ checkStruct(sValue)
+
+ assert(a.name == "a")
+ checkArray(aValue)
+
+ assert(mk.name == "mk")
+ checkMapWithStructKey(mkValue)
+
+ assert(mv.name == "mv")
+ checkMapWithStructValue(mvValue)
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+ }
+
+ test("conflicting UPDATE assignments") {
+ Seq(StoreAssignmentPolicy.ANSI, StoreAssignmentPolicy.STRICT).foreach { policy =>
+ withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) {
+ Seq("WHEN MATCHED", "WHEN NOT MATCHED BY SOURCE").foreach { clause =>
+ // two updates to a top-level column
+ assertAnalysisException(
+ s"""MERGE INTO primitive_table t USING primitive_table_src s
+ |ON t.l = s.l
+ |$clause THEN
+ | UPDATE SET t.txt = 'a', t.txt = 'b'
+ |""".stripMargin,
+ "Multiple assignments for 'txt': 'a', 'b'")
+
+ // two updates to a nested column
+ assertAnalysisException(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_i = 1, s.n_s = null, s.n_i = -1
+ |""".stripMargin,
+ "Multiple assignments for 's.n_i': 1, -1")
+
+ // conflicting updates to a nested struct and its fields
+ assertAnalysisException(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s.dn_i = 1, s.n_s = named_struct('dn_i', 1, 'dn_l', 1L)
+ |""".stripMargin,
+ "Conflicting assignments for 's.n_s'",
+ "t.s.`n_s` = named_struct('dn_i', 1, 'dn_l', 1L)",
+ "t.s.`n_s`.`dn_i` = 1")
+ }
+ }
+ }
+ }
+
+ test("invalid INSERT assignments") {
+ assertAnalysisException(
+ """MERGE INTO primitive_table t USING primitive_table src
+ |ON t.i = src.i
+ |WHEN NOT MATCHED THEN
+ | INSERT (i, txt) VALUES (src.i, src.txt)
+ |""".stripMargin,
+ "No assignment for 'l'")
+
+ assertAnalysisException(
+ """MERGE INTO primitive_table t USING primitive_table src
+ |ON t.i = src.i
+ |WHEN NOT MATCHED THEN
+ | INSERT (i, l, txt, txt) VALUES (src.i, src.l, src.txt, src.txt)
+ |""".stripMargin,
+ "Multiple assignments for 'txt'")
+
+ assertAnalysisException(
+ """MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |WHEN NOT MATCHED THEN
+ | INSERT (s.n_i) VALUES (1)
+ |""".stripMargin,
+ "INSERT assignment keys cannot be nested fields: t.s.`n_i` = 1",
+ "No assignment for 'i'",
+ "No assignment for 's'",
+ "No assignment for 'txt'")
+ }
+
+ test("updates to nested structs in arrays") {
+ Seq(StoreAssignmentPolicy.ANSI, StoreAssignmentPolicy.STRICT).foreach { policy =>
+ withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) {
+ assertAnalysisException(
+ """MERGE INTO map_array_table t USING map_array_table s
+ |ON t.i = s.i
+ |WHEN MATCHED THEN
+ | UPDATE SET t.a.i1 = 1
+ |""".stripMargin,
+ "Updating nested fields is only supported for StructType but 'a' is of type ArrayType")
+ }
+ }
+ }
+
+ test("ANSI mode in UPDATE assignments") {
+ withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.ANSI.toString) {
+ Seq("WHEN MATCHED", "WHEN NOT MATCHED BY SOURCE").foreach { clause =>
+ val plan1 = parseAndResolve(
+ s"""MERGE INTO primitive_table t USING primitive_table_src src
+ |ON t.l = src.l
+ |$clause THEN
+ | UPDATE SET i = NULL
+ |""".stripMargin)
+ assertNullCheckExists(plan1, Seq("i"))
+
+ val plan2 = parseAndResolve(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_i = NULL
+ |""".stripMargin)
+ assertNullCheckExists(plan2, Seq("s", "n_i"))
+
+ val plan3 = parseAndResolve(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s.dn_i = NULL
+ |""".stripMargin)
+ assertNullCheckExists(plan3, Seq("s", "n_s", "dn_i"))
+
+ val plan4 = parseAndResolve(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s = named_struct('dn_i', NULL, 'dn_l', 1L)
+ |""".stripMargin)
+ assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i"))
+
+ assertAnalysisException(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s = named_struct('dn_i', 1)
+ |""".stripMargin,
+ "Cannot find data for output column 's.n_s.dn_l'")
+
+ // ANSI mode does NOT allow string to int casts
+ assertAnalysisException(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s = named_struct('dn_i', 'string-value', 'dn_l', 1L)
+ |""".stripMargin,
+ "Cannot safely cast")
+
+ val (matchedActions, _, notMatchedBySourceActions) =
+ parseAndAlignAssignments(
+ s"""MERGE INTO primitive_table t USING primitive_table_src src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET i = 1L, txt = 'new', l = 10L
+ |""".stripMargin)
+
+ val actions = if (matchedActions.nonEmpty) matchedActions else notMatchedBySourceActions
+ actions match {
+ case Seq(UpdateAction(_, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(
+ i: AttributeReference,
+ CheckOverflowInTableInsert(
+ Cast(LongLiteral(1L), IntegerType, _, EvalMode.ANSI), _)),
+ Assignment(l: AttributeReference, LongLiteral(10L)),
+ Assignment(txt: AttributeReference, StringLiteral("new"))) =>
+
+ assert(i.name == "i")
+ assert(l.name == "l")
+ assert(txt.name == "txt")
+
+ case assignments =>
+ fail(s"Unexpected assignments: $assignments")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+ }
+ }
+ }
+
+ test("ANSI mode in INSERT assignments") {
+ withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.ANSI.toString) {
+ val plan1 = parseAndResolve(
+ """MERGE INTO primitive_table t USING primitive_table_src src
+ |ON t.l = src.l
+ |WHEN NOT MATCHED THEN
+ | INSERT (i, l, txt) VALUES (NULL, 1, 'value')
+ |""".stripMargin)
+ assertNullCheckExists(plan1, Seq("i"))
+
+ // ANSI mode does NOT allow string to int casts
+ assertAnalysisException(
+ """MERGE INTO primitive_table t USING primitive_table_src src
+ |ON t.l = src.l
+ |WHEN NOT MATCHED THEN
+ | INSERT (i, l, txt) VALUES ('1', 1, 'value')
+ |""".stripMargin,
+ "Cannot safely cast")
+
+ val (_, notMatchedActions, _) =
+ parseAndAlignAssignments(
+ """MERGE INTO primitive_table t USING primitive_table_src src
+ |ON t.i = src.i
+ |WHEN NOT MATCHED THEN
+ | INSERT (i, l, txt) VALUES (1L, 10L, 'new')
+ |""".stripMargin)
+
+ notMatchedActions match {
+ case Seq(InsertAction(_, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(
+ i: AttributeReference,
+ CheckOverflowInTableInsert(
+ Cast(LongLiteral(1L), IntegerType, _, EvalMode.ANSI), _)),
+ Assignment(l: AttributeReference, LongLiteral(10L)),
+ Assignment(txt: AttributeReference, StringLiteral("new"))) =>
+
+ assert(i.name == "i")
+ assert(l.name == "l")
+ assert(txt.name == "txt")
+
+ case assignments =>
+ fail(s"Unexpected assignments: $assignments")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+ }
+ }
+
+ test("strict mode in UPDATE assignments") {
+ withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.STRICT.toString) {
+ Seq("WHEN MATCHED", "WHEN NOT MATCHED BY SOURCE").foreach { clause =>
+ val plan1 = parseAndResolve(
+ s"""MERGE INTO primitive_table t USING primitive_table_src src
+ |ON t.l = src.l
+ |$clause THEN
+ | UPDATE SET i = CAST(NULL AS INT)
+ |""".stripMargin)
+ assertNullCheckExists(plan1, Seq("i"))
+
+ val plan2 = parseAndResolve(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_i = CAST(NULL AS INT)
+ |""".stripMargin)
+ assertNullCheckExists(plan2, Seq("s", "n_i"))
+
+ val plan3 = parseAndResolve(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s.dn_i = CAST(NULL AS INT)
+ |""".stripMargin)
+ assertNullCheckExists(plan3, Seq("s", "n_s", "dn_i"))
+
+ val plan4 = parseAndResolve(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s = named_struct('dn_i', CAST (NULL AS INT), 'dn_l', 1L)
+ |""".stripMargin)
+ assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i"))
+
+ assertAnalysisException(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s = named_struct('dn_i', 1)
+ |""".stripMargin,
+ "Cannot find data for output column 's.n_s.dn_l'")
+
+ // strict mode does NOT allow string to int casts
+ assertAnalysisException(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s = named_struct('dn_i', 'string-value', 'dn_l', 1L)
+ |""".stripMargin,
+ "Cannot safely cast")
+
+ // strict mode does not allow long to int casts
+ assertAnalysisException(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET i = 1L
+ |""".stripMargin,
+ "Cannot safely cast")
+ }
+ }
+ }
+
+ test("legacy mode assignments") {
+ withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.LEGACY.toString) {
+ assertAnalysisException(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |WHEN MATCHED THEN
+ | UPDATE SET i = 1L
+ |""".stripMargin,
+ "LEGACY store assignment policy is disallowed in Spark data source V2")
+ }
+ }
+
+ test("align assignments with default values") {
+ val (matchedActions, notMatchedActions, notMatchedBySourceActions) =
+ parseAndAlignAssignments(
+ """MERGE INTO default_values_table t USING default_values_table s
+ |ON t.b = s.b
+ |WHEN MATCHED THEN
+ | UPDATE SET t.i = DEFAULT
+ |WHEN NOT MATCHED THEN
+ | INSERT (i, b) VALUES (DEFAULT, false)
+ |WHEN NOT MATCHED BY SOURCE THEN
+ | UPDATE SET t.i = DEFAULT""".stripMargin)
+
+ matchedActions match {
+ case Seq(UpdateAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(b: AttributeReference, bValue: AttributeReference),
+ Assignment(i: AttributeReference, IntegerLiteral(42))) =>
+
+ assert(b.name == "b" && b == bValue)
+ assert(i.name == "i")
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+
+ notMatchedActions match {
+ case Seq(InsertAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(b: AttributeReference, BooleanLiteral(false)),
+ Assignment(i: AttributeReference, IntegerLiteral(42))) =>
+
+ assert(b.name == "b")
+ assert(i.name == "i")
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+
+ notMatchedBySourceActions match {
+ case Seq(UpdateAction(None, assignments)) =>
+ assignments match {
+ case Seq(
+ Assignment(b: AttributeReference, bValue: AttributeReference),
+ Assignment(i: AttributeReference, IntegerLiteral(42))) =>
+
+ assert(b.name == "b" && b == bValue)
+ assert(i.name == "i")
+
+ case other =>
+ fail(s"Unexpected assignments: $other")
+ }
+
+ case other =>
+ fail(s"Unexpected actions: $other")
+ }
+ }
+
+ private def parseAndAlignAssignments(
+ query: String): (Seq[MergeAction], Seq[MergeAction], Seq[MergeAction]) = {
+
+ parseAndResolve(query) match {
+ case m: MergeIntoTable => (m.matchedActions, m.notMatchedActions, m.notMatchedBySourceActions)
+ case plan => fail("Expected MergeIntoTable, but got:\n" + plan.treeString)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala
index a173106db99e9..96c4580745fd2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala
@@ -17,151 +17,14 @@
package org.apache.spark.sql.execution.command
-import java.util.Collections
-
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito.{mock, when}
-import org.mockito.invocation.InvocationOnMock
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, FunctionRegistry, NoSuchTableException, ResolveSessionCatalog}
-import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, AttributeReference, BooleanLiteral, Cast, CheckOverflowInTableInsert, CreateNamedStruct, EvalMode, GetStructField, IntegerLiteral, LambdaFunction, LongLiteral, MapFromArrays, StringLiteral}
-import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, StaticInvoke}
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
-import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan, UpdateTable}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, CatalogV2Util, Column, ColumnDefaultValue, Identifier, SupportsRowLevelOperations, TableCapability, TableCatalog}
-import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform}
-import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
+import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
-import org.apache.spark.sql.types.{BooleanType, IntegerType, StructType}
-
-class AlignUpdateAssignmentsSuite extends AnalysisTest {
-
- private val primitiveTable = {
- val t = mock(classOf[SupportsRowLevelOperations])
- val schema = new StructType()
- .add("i", "INT", nullable = false)
- .add("l", "LONG")
- .add("txt", "STRING")
- when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
- when(t.partitioning()).thenReturn(Array.empty[Transform])
- t
- }
-
- private val nestedStructTable = {
- val t = mock(classOf[SupportsRowLevelOperations])
- val schema = new StructType()
- .add("i", "INT")
- .add(
- "s",
- "STRUCT>",
- nullable = false)
- .add("txt", "STRING")
- when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
- when(t.partitioning()).thenReturn(Array.empty[Transform])
- t
- }
-
- private val mapArrayTable = {
- val t = mock(classOf[SupportsRowLevelOperations])
- val schema = new StructType()
- .add("i", "INT")
- .add("a", "ARRAY>")
- .add("m", "MAP")
- .add("txt", "STRING")
- when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
- when(t.partitioning()).thenReturn(Array.empty[Transform])
- t
- }
-
- private val charVarcharTable = {
- val t = mock(classOf[SupportsRowLevelOperations])
- val schema = new StructType()
- .add("c", "CHAR(5)")
- .add(
- "s",
- "STRUCT",
- nullable = false)
- .add(
- "a",
- "ARRAY>",
- nullable = false)
- .add(
- "mk",
- "MAP, STRING>",
- nullable = false)
- .add(
- "mv",
- "MAP>",
- nullable = false)
- when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
- when(t.partitioning()).thenReturn(Array.empty[Transform])
- t
- }
-
- private val acceptsAnySchemaTable = {
- val t = mock(classOf[SupportsRowLevelOperations])
- val schema = new StructType()
- .add("i", "INT", nullable = false)
- .add("l", "LONG")
- .add("txt", "STRING")
- when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema))
- when(t.partitioning()).thenReturn(Array.empty[Transform])
- when(t.capabilities()).thenReturn(Collections.singleton(TableCapability.ACCEPT_ANY_SCHEMA))
- t
- }
-
- private val defaultValuesTable = {
- val t = mock(classOf[SupportsRowLevelOperations])
- val iDefault = new ColumnDefaultValue("42", LiteralValue(42, IntegerType))
- when(t.columns()).thenReturn(Array(
- Column.create("b", BooleanType, true, null, null),
- Column.create("i", IntegerType, true, null, iDefault, null)))
- when(t.partitioning()).thenReturn(Array.empty[Transform])
- t
- }
-
- private val v2Catalog = {
- val newCatalog = mock(classOf[TableCatalog])
- when(newCatalog.loadTable(any())).thenAnswer((invocation: InvocationOnMock) => {
- val ident = invocation.getArgument[Identifier](0)
- ident.name match {
- case "primitive_table" => primitiveTable
- case "nested_struct_table" => nestedStructTable
- case "map_array_table" => mapArrayTable
- case "char_varchar_table" => charVarcharTable
- case "accepts_any_schema_table" => acceptsAnySchemaTable
- case "default_values_table" => defaultValuesTable
- case name => throw new NoSuchTableException(Seq(name))
- }
- })
- when(newCatalog.name()).thenReturn("cat")
- newCatalog
- }
+import org.apache.spark.sql.types.IntegerType
- private val v1SessionCatalog =
- new SessionCatalog(new InMemoryCatalog(), FunctionRegistry.builtin, new SQLConf())
-
- private val v2SessionCatalog = new V2SessionCatalog(v1SessionCatalog)
-
- private val catalogManager = {
- val manager = mock(classOf[CatalogManager])
- when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => {
- invocation.getArgument[String](0) match {
- case "testcat" => v2Catalog
- case CatalogManager.SESSION_CATALOG_NAME => v2SessionCatalog
- case name => throw new CatalogNotFoundException(s"No such catalog: $name")
- }
- })
- when(manager.currentCatalog).thenReturn(v2Catalog)
- when(manager.currentNamespace).thenReturn(Array.empty[String])
- when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog)
- when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog)
- manager
- }
+class AlignUpdateAssignmentsSuite extends AlignAssignmentsSuite {
test("align assignments (primitive types)") {
val sql1 = "UPDATE primitive_table AS t SET t.txt = 'new', t.i = 1"
@@ -752,28 +615,4 @@ class AlignUpdateAssignmentsSuite extends AnalysisTest {
case plan => fail("Expected UpdateTable, but got:\n" + plan.treeString)
}
}
-
- private def parseAndResolve(query: String): LogicalPlan = {
- val analyzer = new Analyzer(catalogManager) {
- override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq(
- new ResolveSessionCatalog(catalogManager))
- }
- val analyzed = analyzer.execute(CatalystSqlParser.parsePlan(query))
- analyzer.checkAnalysis(analyzed)
- analyzed
- }
-
- private def assertAnalysisException(query: String, messages: String*): Unit = {
- val exception = intercept[AnalysisException] {
- parseAndResolve(query)
- }
- messages.foreach(message => assert(exception.message.contains(message)))
- }
-
- private def assertNullCheckExists(plan: LogicalPlan, colPath: Seq[String]): Unit = {
- val asserts = plan.expressions.flatMap(e => e.collect {
- case assert: AssertNotNull if assert.walkedTypePath == colPath => assert
- })
- assert(asserts.nonEmpty, s"Must have NOT NULL checks for col $colPath")
- }
}