diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 62555c9a99cc3..5a50c13839300 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1818,7 +1818,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor "sqlExpr" -> a.sql, "cols" -> cols)) } - resolved + resolved match { + case Alias(child: ExtractValue, _) => child + case other => other + } } // Expand the star expression using the input plan first. If failed, try resolve diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala index 265909d3a7e7d..6d8118548fb41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DataType, StructType} object AssignmentUtils extends SQLConfHelper with CastSupport { /** - * Aligns assignments to match table columns. + * Aligns update assignments to match table columns. *

* This method processes and reorders given assignments so that each target column gets * an expression it should be set to. If a column does not have a matching assignment, @@ -46,9 +46,9 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { * * @param attrs table attributes * @param assignments assignments to align - * @return aligned assignments that match table attributes + * @return aligned update assignments that match table attributes */ - def alignAssignments( + def alignUpdateAssignments( attrs: Seq[Attribute], assignments: Seq[Assignment]): Seq[Assignment] = { @@ -70,6 +70,61 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { attrs.zip(output).map { case (attr, expr) => Assignment(attr, expr) } } + /** + * Aligns insert assignments to match table columns. + *

+ * This method processes and reorders given assignments so that each target column gets + * an expression it should be set to. There must be exactly one assignment for each top-level + * attribute and its value must be compatible. + *

+ * Insert assignments cannot refer to nested columns. + * + * @param attrs table attributes + * @param assignments insert assignments to align + * @return aligned insert assignments that match table attributes + */ + def alignInsertAssignments( + attrs: Seq[Attribute], + assignments: Seq[Assignment]): Seq[Assignment] = { + + val errors = new mutable.ArrayBuffer[String]() + + val (topLevelAssignments, nestedAssignments) = assignments.partition { assignment => + assignment.key.isInstanceOf[Attribute] + } + + if (nestedAssignments.nonEmpty) { + val nestedAssignmentsStr = nestedAssignments.map(_.sql).mkString(", ") + errors += s"INSERT assignment keys cannot be nested fields: $nestedAssignmentsStr" + } + + val alignedAssignments = attrs.map { attr => + val matchingAssignments = topLevelAssignments.collect { + case assignment if assignment.key.semanticEquals(attr) => assignment + } + val resolvedValue = if (matchingAssignments.isEmpty) { + errors += s"No assignment for '${attr.name}'" + attr + } else if (matchingAssignments.length > 1) { + val conflictingValuesStr = matchingAssignments.map(_.value.sql).mkString(", ") + errors += s"Multiple assignments for '${attr.name}': $conflictingValuesStr" + attr + } else { + val colPath = Seq(attr.name) + val actualAttr = restoreActualType(attr) + val value = matchingAssignments.head.value + TableOutputResolver.resolveUpdate(value, actualAttr, conf, err => errors += err, colPath) + } + Assignment(attr, resolvedValue) + } + + if (errors.nonEmpty) { + throw QueryCompilationErrors.invalidRowLevelOperationAssignments(assignments, errors.toSeq) + } + + alignedAssignments + } + private def restoreActualType(attr: Attribute): Attribute = { attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala index 596dc00b9176b..b22c91973f794 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast} import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull -import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan, MergeIntoTable, UpdateTable} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND import org.apache.spark.sql.catalyst.util.CharVarcharUtils @@ -42,17 +43,23 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { case u: UpdateTable if !u.skipSchemaResolution && u.resolved && supportsRowLevelOperations(u.table) && !u.aligned => validateStoreAssignmentPolicy() - val newTable = u.table.transform { - case r: DataSourceV2Relation => - r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata)) - } - val newAssignments = AssignmentUtils.alignAssignments(u.table.output, u.assignments) + val newTable = cleanAttrMetadata(u.table) + val newAssignments = AssignmentUtils.alignUpdateAssignments(u.table.output, u.assignments) u.copy(table = newTable, assignments = newAssignments) case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned => resolveAssignments(u) - case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved => + case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && + supportsRowLevelOperations(m.targetTable) && !m.aligned => + validateStoreAssignmentPolicy() + m.copy( + targetTable = cleanAttrMetadata(m.targetTable), + matchedActions = alignActions(m.targetTable.output, m.matchedActions), + notMatchedActions = alignActions(m.targetTable.output, m.notMatchedActions), + notMatchedBySourceActions = alignActions(m.targetTable.output, m.notMatchedBySourceActions)) + + case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && !m.aligned => resolveAssignments(m) } @@ -63,6 +70,13 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { } } + private def cleanAttrMetadata(table: LogicalPlan): LogicalPlan = { + table.transform { + case r: DataSourceV2Relation => + r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata)) + } + } + private def supportsRowLevelOperations(table: LogicalPlan): Boolean = { EliminateSubqueryAliases(table) match { case DataSourceV2Relation(_: SupportsRowLevelOperations, _, _, _, _) => true @@ -100,4 +114,19 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { Assignment(cleanedKey, finalValue) } } + + private def alignActions( + attrs: Seq[Attribute], + actions: Seq[MergeAction]): Seq[MergeAction] = { + actions.map { + case u @ UpdateAction(_, assignments) => + u.copy(assignments = AssignmentUtils.alignUpdateAssignments(attrs, assignments)) + case d: DeleteAction => + d + case i @ InsertAction(_, assignments) => + i.copy(assignments = AssignmentUtils.alignInsertAssignments(attrs, assignments)) + case other => + throw new AnalysisException(s"Unexpected resolved action: $other") + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 68943c918b12f..a5ef21c4db8c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -708,6 +708,21 @@ case class MergeIntoTable( matchedActions: Seq[MergeAction], notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery { + + lazy val aligned: Boolean = { + val actions = matchedActions ++ notMatchedActions ++ notMatchedBySourceActions + actions.forall { + case UpdateAction(_, assignments) => + AssignmentUtils.aligned(targetTable.output, assignments) + case _: DeleteAction => + true + case InsertAction(_, assignments) => + AssignmentUtils.aligned(targetTable.output, assignments) + case _ => + false + } + } + def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty def skipSchemaResolution: Boolean = targetTable match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuite.scala new file mode 100644 index 0000000000000..e6e124f1d5fd9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuite.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import java.util.Collections + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, FunctionRegistry, NoSuchTableException, ResolveSessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, CatalogV2Util, Column, ColumnDefaultValue, Identifier, SupportsRowLevelOperations, TableCapability, TableCatalog} +import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform} +import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BooleanType, IntegerType, StructType} + +class AlignAssignmentsSuite extends AnalysisTest { + + private val primitiveTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("i", "INT", nullable = false) + .add("l", "LONG") + .add("txt", "STRING") + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val primitiveTableSource = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("l", "LONG") + .add("txt", "STRING") + .add("i", "INT") + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val nestedStructTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("i", "INT") + .add( + "s", + "STRUCT>", + nullable = false) + .add("txt", "STRING") + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val nestedStructTableSource = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("i", "INT") + .add("s", "STRUCT>") + .add("txt", "STRING") + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val mapArrayTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("i", "INT") + .add("a", "ARRAY>") + .add("m", "MAP") + .add("txt", "STRING") + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val charVarcharTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("c", "CHAR(5)") + .add( + "s", + "STRUCT", + nullable = false) + .add( + "a", + "ARRAY>", + nullable = false) + .add( + "mk", + "MAP, STRING>", + nullable = false) + .add( + "mv", + "MAP>", + nullable = false) + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val acceptsAnySchemaTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val schema = new StructType() + .add("i", "INT", nullable = false) + .add("l", "LONG") + .add("txt", "STRING") + when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + when(t.capabilities()).thenReturn(Collections.singleton(TableCapability.ACCEPT_ANY_SCHEMA)) + t + } + + private val defaultValuesTable = { + val t = mock(classOf[SupportsRowLevelOperations]) + val iDefault = new ColumnDefaultValue("42", LiteralValue(42, IntegerType)) + when(t.columns()).thenReturn(Array( + Column.create("b", BooleanType, true, null, null), + Column.create("i", IntegerType, true, null, iDefault, null))) + when(t.partitioning()).thenReturn(Array.empty[Transform]) + t + } + + private val v2Catalog = { + val newCatalog = mock(classOf[TableCatalog]) + when(newCatalog.loadTable(any())).thenAnswer((invocation: InvocationOnMock) => { + val ident = invocation.getArgument[Identifier](0) + ident.name match { + case "primitive_table" => primitiveTable + case "primitive_table_src" => primitiveTableSource + case "nested_struct_table" => nestedStructTable + case "nested_struct_table_src" => nestedStructTableSource + case "map_array_table" => mapArrayTable + case "char_varchar_table" => charVarcharTable + case "accepts_any_schema_table" => acceptsAnySchemaTable + case "default_values_table" => defaultValuesTable + case name => throw new NoSuchTableException(Seq(name)) + } + }) + when(newCatalog.name()).thenReturn("cat") + newCatalog + } + + private val v1SessionCatalog = + new SessionCatalog(new InMemoryCatalog(), FunctionRegistry.builtin, new SQLConf()) + + private val v2SessionCatalog = new V2SessionCatalog(v1SessionCatalog) + + private val catalogManager = { + val manager = mock(classOf[CatalogManager]) + when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => { + invocation.getArgument[String](0) match { + case "testcat" => v2Catalog + case CatalogManager.SESSION_CATALOG_NAME => v2SessionCatalog + case name => throw new CatalogNotFoundException(s"No such catalog: $name") + } + }) + when(manager.currentCatalog).thenReturn(v2Catalog) + when(manager.currentNamespace).thenReturn(Array.empty[String]) + when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog) + when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog) + manager + } + + protected def parseAndResolve(query: String): LogicalPlan = { + val analyzer = new Analyzer(catalogManager) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq( + new ResolveSessionCatalog(catalogManager)) + } + val analyzed = analyzer.execute(CatalystSqlParser.parsePlan(query)) + analyzer.checkAnalysis(analyzed) + analyzed + } + + protected def assertNullCheckExists(plan: LogicalPlan, colPath: Seq[String]): Unit = { + val asserts = plan.expressions.flatMap(e => e.collect { + case assert: AssertNotNull if assert.walkedTypePath == colPath => assert + }) + assert(asserts.nonEmpty, s"Must have NOT NULL checks for col $colPath") + } + + protected def assertAnalysisException(query: String, messages: String*): Unit = { + val exception = intercept[AnalysisException] { + parseAndResolve(query) + } + messages.foreach(message => assert(exception.message.contains(message))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala new file mode 100644 index 0000000000000..e569166633908 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala @@ -0,0 +1,937 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, AttributeReference, BooleanLiteral, Cast, CheckOverflowInTableInsert, CreateNamedStruct, EvalMode, GetStructField, IntegerLiteral, LambdaFunction, LongLiteral, MapFromArrays, StringLiteral} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, StaticInvoke} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, InsertAction, MergeAction, MergeIntoTable, UpdateAction} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy +import org.apache.spark.sql.types.IntegerType + +class AlignMergeAssignmentsSuite extends AlignAssignmentsSuite { + + test("align assignments (primitive types)") { + val (matchedActions, notMatchedActions, notMatchedBySourceActions) = + parseAndAlignAssignments( + """MERGE INTO primitive_table t USING primitive_table_src s + |ON t.l = s.l + |WHEN MATCHED THEN + | UPDATE SET t.txt = 'a', t.i = s.i + |WHEN NOT MATCHED THEN + | INSERT * + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET t.txt = "error", t.i = CAST(null AS INT)""".stripMargin) + + matchedActions match { + case Seq(UpdateAction(None, assignments)) => + assignments match { + case Seq( + Assignment(i: AttributeReference, AssertNotNull(iValue: AttributeReference, _)), + Assignment(l: AttributeReference, lValue: AttributeReference), + Assignment(txt: AttributeReference, StringLiteral("a"))) => + + assert(i.name == "i" && iValue.name == "i" && i != iValue) + assert(l.name == "l" && l == lValue) + assert(txt.name == "txt") + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + + notMatchedActions match { + case Seq(InsertAction(None, assignments)) => + assignments match { + case Seq( + Assignment(i: AttributeReference, AssertNotNull(iValue: AttributeReference, _)), + Assignment(l: AttributeReference, lValue: AttributeReference), + Assignment(txt: AttributeReference, txtValue: AttributeReference)) => + + assert(i.name == "i" && iValue.name == "i" && i != iValue) + assert(l.name == "l" && lValue.name == "l" && l != lValue) + assert(txt.name == "txt" && txtValue.name == "txt" && txt != txtValue) + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + + notMatchedBySourceActions match { + case Seq(UpdateAction(None, assignments)) => + assignments match { + case Seq( + Assignment(i: AttributeReference, AssertNotNull(_: Cast, _)), + Assignment(l: AttributeReference, lValue: AttributeReference), + Assignment(txt: AttributeReference, StringLiteral("error"))) => + + assert(i.name == "i") + assert(l.name == "l" && l == lValue) + assert(txt.name == "txt") + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + } + + test("align assignments (top-level structs)") { + val (matchedActions, notMatchedActions, notMatchedBySourceActions) = + parseAndAlignAssignments( + """MERGE INTO nested_struct_table t USING nested_struct_table_src s + |ON t.i = s.i + |WHEN MATCHED THEN + | UPDATE SET t.s = named_struct('n_s', named_struct('dn_i', 1, 'dn_l', 100L), 'n_i', 1) + |WHEN NOT MATCHED THEN + | INSERT (txt, s, i) VALUES ( + | 'new', + | named_struct('n_s', named_struct('dn_i', 1, 'dn_l', 100L), 'n_i', 1), + | 1) + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET t.s = named_struct('n_s', named_struct('dn_i', 1, 'dn_l', 100L), 'n_i', 1) + |""".stripMargin) + + def checkStruct(value: CreateNamedStruct): Unit = { + value.children match { + case Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_s"), nsValue: CreateNamedStruct) => + + nsValue.children match { + case Seq( + StringLiteral("dn_i"), GetStructField(_, _, Some("dn_i")), + StringLiteral("dn_l"), GetStructField(_, _, Some("dn_l"))) => + // OK + + case nsValueChildren => + fail(s"Unexpected children for 's.n_s': $nsValueChildren") + } + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + } + + matchedActions match { + case Seq(UpdateAction(None, assignments)) => + assignments match { + case Seq( + Assignment(i: AttributeReference, iValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(txt: AttributeReference, txtValue: AttributeReference)) => + + assert(i.name == "i" && i == iValue) + + assert(s.name == "s") + checkStruct(sValue) + + assert(txt.name == "txt" && txt == txtValue) + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + + notMatchedActions match { + case Seq(InsertAction(None, assignments)) => + assignments match { + case Seq( + Assignment(i: AttributeReference, IntegerLiteral(1)), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(txt: AttributeReference, StringLiteral("new"))) => + + assert(i.name == "i") + + assert(s.name == "s") + checkStruct(sValue) + + assert(txt.name == "txt") + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + + notMatchedBySourceActions match { + case Seq(UpdateAction(None, assignments)) => + assignments match { + case Seq( + Assignment(i: AttributeReference, iValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(txt: AttributeReference, txtValue: AttributeReference)) => + + assert(i.name == "i" && i == iValue) + + assert(s.name == "s") + checkStruct(sValue) + + assert(txt.name == "txt" && txt == txtValue) + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + } + + test("align UPDATE assignments with references to nested attributes on both sides") { + val (matchedActions, _, _) = + parseAndAlignAssignments( + """MERGE INTO nested_struct_table_src t USING nested_struct_table_src src + |ON t.i = src.i + |WHEN MATCHED THEN + | UPDATE SET t.s.n_s = src.s.n_s + |""".stripMargin) + + matchedActions match { + case Seq(UpdateAction(None, assignments)) => + assignments match { + case Seq( + Assignment(i: AttributeReference, iValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(txt: AttributeReference, txtValue: AttributeReference)) => + + assert(i.name == "i" && i == iValue) + + assert(s.name == "s") + sValue.children match { + case Seq( + StringLiteral("n_i"), + GetStructField(_, _, Some("n_i")), + StringLiteral("n_s"), + GetStructField(source: AttributeReference, _, Some("n_s"))) => + + assert(source.name == "s" && s != source) + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + + assert(txt.name == "txt" && txt == txtValue) + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + } + + test("align assignments (nested structs)") { + val (matchedActions, notMatchedActions, notMatchedBySourceActions) = + parseAndAlignAssignments( + """MERGE INTO nested_struct_table t USING nested_struct_table_src s + |ON t.i = s.i + |WHEN MATCHED THEN + | UPDATE SET t.s.n_s = named_struct('dn_l', 1L, 'dn_i', 1) + |WHEN NOT MATCHED THEN + | INSERT (txt, s, i) VALUES ( + | 'new', + | named_struct('n_i', 1, 'n_s', named_struct('dn_l', 1L, 'dn_i', 1)), + | 1) + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET t.s.n_s = named_struct('dn_l', 1L, 'dn_i', 1) + |""".stripMargin) + + def checkNestedStruct(value: CreateNamedStruct): Unit = { + value.children match { + case Seq( + StringLiteral("dn_i"), GetStructField(_, _, Some("dn_i")), + StringLiteral("dn_l"), GetStructField(_, _, Some("dn_l"))) => + // OK + + case nsValueChildren => + fail(s"Unexpected children for 's.n_s': $nsValueChildren") + } + } + + matchedActions match { + case Seq(UpdateAction(None, assignments)) => + assignments match { + case Seq( + Assignment(i: AttributeReference, iValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(txt: AttributeReference, txtValue: AttributeReference)) => + + assert(i.name == "i" && i == iValue) + + assert(s.name == "s") + sValue.children match { + case Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_s"), nsValue: CreateNamedStruct) => + checkNestedStruct(nsValue) + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + + assert(txt.name == "txt" && txt == txtValue) + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + + notMatchedActions match { + case Seq(InsertAction(None, assignments)) => + assignments match { + case Seq( + Assignment(i: AttributeReference, IntegerLiteral(1)), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(txt: AttributeReference, StringLiteral("new"))) => + + assert(i.name == "i") + + assert(s.name == "s") + sValue.children match { + case Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_s"), nsValue: CreateNamedStruct) => + checkNestedStruct(nsValue) + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + + assert(txt.name == "txt") + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + + notMatchedBySourceActions match { + case Seq(UpdateAction(None, assignments)) => + assignments match { + case Seq( + Assignment(i: AttributeReference, iValue: AttributeReference), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(txt: AttributeReference, txtValue: AttributeReference)) => + + assert(i.name == "i" && i == iValue) + + assert(s.name == "s") + sValue.children match { + case Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_s"), nsValue: CreateNamedStruct) => + checkNestedStruct(nsValue) + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + + assert(txt.name == "txt" && txt == txtValue) + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + } + + test("align assignments (char and varchar types)") { + val (matchedActions, notMatchedActions, notMatchedBySourceActions) = + parseAndAlignAssignments( + """MERGE INTO char_varchar_table t USING char_varchar_table src + |ON t.c = src.c + |WHEN MATCHED THEN + | UPDATE SET + | c = 'a', + | a = array(named_struct('n_i', 1, 'n_vc', 3)), + | s.n_vc = 'a', + | mv = map('v', named_struct('n_vc', 'a', 'n_i', 1)), + | mk = map(named_struct('n_vc', 'a', 'n_i', 1), 'v') + |WHEN NOT MATCHED THEN + | INSERT (c, a, s, mv, mk) VALUES ( + | 'a', + | array(named_struct('n_i', 1, 'n_vc', 3)), + | named_struct('n_vc', 3, 'n_i', 1), + | map('v', named_struct('n_vc', 'a', 'n_i', 1)), + | map(named_struct('n_vc', 'a', 'n_i', 1), 'v')) + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET + | c = 'a', + | a = array(named_struct('n_i', 1, 'n_vc', 3)), + | s = named_struct('n_vc', 3, 'n_i', 1), + | mv = map('v', named_struct('n_vc', 'a', 'n_i', 1)), + | mk = map(named_struct('n_vc', 'a', 'n_i', 1), 'v') + |""".stripMargin) + + def checkStruct(value: CreateNamedStruct): Unit = { + value.children match { + case Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_vc"), invoke: StaticInvoke) => + + assert(invoke.arguments.length == 2) + assert(invoke.functionName == "varcharTypeWriteSideCheck") + + case sValueChildren => + fail(s"Unexpected children for 's': $sValueChildren") + } + } + + def checkArray(value: ArrayTransform): Unit = { + val lambda = value.function.asInstanceOf[LambdaFunction] + lambda.function match { + case CreateNamedStruct(Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_vc"), invoke: StaticInvoke)) => + + assert(invoke.arguments.length == 2) + assert(invoke.functionName == "varcharTypeWriteSideCheck") + + case func => + fail(s"Unexpected lambda function: $func") + } + } + + def checkMapWithStructKey(value: MapFromArrays): Unit = { + val keyTransform = value.left.asInstanceOf[ArrayTransform] + val keyLambda = keyTransform.function.asInstanceOf[LambdaFunction] + keyLambda.function match { + case CreateNamedStruct(Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_vc"), invoke: StaticInvoke)) => + + assert(invoke.arguments.length == 2) + assert(invoke.functionName == "varcharTypeWriteSideCheck") + + case func => + fail(s"Unexpected key lambda function: $func") + } + } + + def checkMapWithStructValue(value: MapFromArrays): Unit = { + val valueTransform = value.right.asInstanceOf[ArrayTransform] + val valueLambda = valueTransform.function.asInstanceOf[LambdaFunction] + valueLambda.function match { + case CreateNamedStruct(Seq( + StringLiteral("n_i"), GetStructField(_, _, Some("n_i")), + StringLiteral("n_vc"), invoke: StaticInvoke)) => + + assert(invoke.arguments.length == 2) + assert(invoke.functionName == "varcharTypeWriteSideCheck") + + case func => + fail(s"Unexpected key lambda function: $func") + } + } + + matchedActions match { + case Seq(UpdateAction(None, assignments)) => + assignments match { + case Seq( + Assignment(c: AttributeReference, cValue: StaticInvoke), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(a: AttributeReference, aValue: ArrayTransform), + Assignment(mk: AttributeReference, mkValue: MapFromArrays), + Assignment(mv: AttributeReference, mvValue: MapFromArrays)) => + + assert(c.name == "c") + assert(cValue.arguments.length == 2) + assert(cValue.functionName == "charTypeWriteSideCheck") + + assert(s.name == "s") + checkStruct(sValue) + + assert(a.name == "a") + checkArray(aValue) + + assert(mk.name == "mk") + checkMapWithStructKey(mkValue) + + assert(mv.name == "mv") + checkMapWithStructValue(mvValue) + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + + notMatchedActions match { + case Seq(InsertAction(None, assignments)) => + assignments match { + case Seq( + Assignment(c: AttributeReference, cValue: StaticInvoke), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(a: AttributeReference, aValue: ArrayTransform), + Assignment(mk: AttributeReference, mkValue: MapFromArrays), + Assignment(mv: AttributeReference, mvValue: MapFromArrays)) => + + assert(c.name == "c") + assert(cValue.arguments.length == 2) + assert(cValue.functionName == "charTypeWriteSideCheck") + + assert(s.name == "s") + checkStruct(sValue) + + assert(a.name == "a") + checkArray(aValue) + + assert(mk.name == "mk") + checkMapWithStructKey(mkValue) + + assert(mv.name == "mv") + checkMapWithStructValue(mvValue) + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + + notMatchedBySourceActions match { + case Seq(UpdateAction(None, assignments)) => + assignments match { + case Seq( + Assignment(c: AttributeReference, cValue: StaticInvoke), + Assignment(s: AttributeReference, sValue: CreateNamedStruct), + Assignment(a: AttributeReference, aValue: ArrayTransform), + Assignment(mk: AttributeReference, mkValue: MapFromArrays), + Assignment(mv: AttributeReference, mvValue: MapFromArrays)) => + + assert(c.name == "c") + assert(cValue.arguments.length == 2) + assert(cValue.functionName == "charTypeWriteSideCheck") + + assert(s.name == "s") + checkStruct(sValue) + + assert(a.name == "a") + checkArray(aValue) + + assert(mk.name == "mk") + checkMapWithStructKey(mkValue) + + assert(mv.name == "mv") + checkMapWithStructValue(mvValue) + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + } + + test("conflicting UPDATE assignments") { + Seq(StoreAssignmentPolicy.ANSI, StoreAssignmentPolicy.STRICT).foreach { policy => + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) { + Seq("WHEN MATCHED", "WHEN NOT MATCHED BY SOURCE").foreach { clause => + // two updates to a top-level column + assertAnalysisException( + s"""MERGE INTO primitive_table t USING primitive_table_src s + |ON t.l = s.l + |$clause THEN + | UPDATE SET t.txt = 'a', t.txt = 'b' + |""".stripMargin, + "Multiple assignments for 'txt': 'a', 'b'") + + // two updates to a nested column + assertAnalysisException( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_i = 1, s.n_s = null, s.n_i = -1 + |""".stripMargin, + "Multiple assignments for 's.n_i': 1, -1") + + // conflicting updates to a nested struct and its fields + assertAnalysisException( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_s.dn_i = 1, s.n_s = named_struct('dn_i', 1, 'dn_l', 1L) + |""".stripMargin, + "Conflicting assignments for 's.n_s'", + "t.s.`n_s` = named_struct('dn_i', 1, 'dn_l', 1L)", + "t.s.`n_s`.`dn_i` = 1") + } + } + } + } + + test("invalid INSERT assignments") { + assertAnalysisException( + """MERGE INTO primitive_table t USING primitive_table src + |ON t.i = src.i + |WHEN NOT MATCHED THEN + | INSERT (i, txt) VALUES (src.i, src.txt) + |""".stripMargin, + "No assignment for 'l'") + + assertAnalysisException( + """MERGE INTO primitive_table t USING primitive_table src + |ON t.i = src.i + |WHEN NOT MATCHED THEN + | INSERT (i, l, txt, txt) VALUES (src.i, src.l, src.txt, src.txt) + |""".stripMargin, + "Multiple assignments for 'txt'") + + assertAnalysisException( + """MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |WHEN NOT MATCHED THEN + | INSERT (s.n_i) VALUES (1) + |""".stripMargin, + "INSERT assignment keys cannot be nested fields: t.s.`n_i` = 1", + "No assignment for 'i'", + "No assignment for 's'", + "No assignment for 'txt'") + } + + test("updates to nested structs in arrays") { + Seq(StoreAssignmentPolicy.ANSI, StoreAssignmentPolicy.STRICT).foreach { policy => + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) { + assertAnalysisException( + """MERGE INTO map_array_table t USING map_array_table s + |ON t.i = s.i + |WHEN MATCHED THEN + | UPDATE SET t.a.i1 = 1 + |""".stripMargin, + "Updating nested fields is only supported for StructType but 'a' is of type ArrayType") + } + } + } + + test("ANSI mode in UPDATE assignments") { + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.ANSI.toString) { + Seq("WHEN MATCHED", "WHEN NOT MATCHED BY SOURCE").foreach { clause => + val plan1 = parseAndResolve( + s"""MERGE INTO primitive_table t USING primitive_table_src src + |ON t.l = src.l + |$clause THEN + | UPDATE SET i = NULL + |""".stripMargin) + assertNullCheckExists(plan1, Seq("i")) + + val plan2 = parseAndResolve( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_i = NULL + |""".stripMargin) + assertNullCheckExists(plan2, Seq("s", "n_i")) + + val plan3 = parseAndResolve( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_s.dn_i = NULL + |""".stripMargin) + assertNullCheckExists(plan3, Seq("s", "n_s", "dn_i")) + + val plan4 = parseAndResolve( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_s = named_struct('dn_i', NULL, 'dn_l', 1L) + |""".stripMargin) + assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i")) + + assertAnalysisException( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_s = named_struct('dn_i', 1) + |""".stripMargin, + "Cannot find data for output column 's.n_s.dn_l'") + + // ANSI mode does NOT allow string to int casts + assertAnalysisException( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_s = named_struct('dn_i', 'string-value', 'dn_l', 1L) + |""".stripMargin, + "Cannot safely cast") + + val (matchedActions, _, notMatchedBySourceActions) = + parseAndAlignAssignments( + s"""MERGE INTO primitive_table t USING primitive_table_src src + |ON t.i = src.i + |$clause THEN + | UPDATE SET i = 1L, txt = 'new', l = 10L + |""".stripMargin) + + val actions = if (matchedActions.nonEmpty) matchedActions else notMatchedBySourceActions + actions match { + case Seq(UpdateAction(_, assignments)) => + assignments match { + case Seq( + Assignment( + i: AttributeReference, + CheckOverflowInTableInsert( + Cast(LongLiteral(1L), IntegerType, _, EvalMode.ANSI), _)), + Assignment(l: AttributeReference, LongLiteral(10L)), + Assignment(txt: AttributeReference, StringLiteral("new"))) => + + assert(i.name == "i") + assert(l.name == "l") + assert(txt.name == "txt") + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + case other => + fail(s"Unexpected actions: $other") + } + } + } + } + + test("ANSI mode in INSERT assignments") { + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.ANSI.toString) { + val plan1 = parseAndResolve( + """MERGE INTO primitive_table t USING primitive_table_src src + |ON t.l = src.l + |WHEN NOT MATCHED THEN + | INSERT (i, l, txt) VALUES (NULL, 1, 'value') + |""".stripMargin) + assertNullCheckExists(plan1, Seq("i")) + + // ANSI mode does NOT allow string to int casts + assertAnalysisException( + """MERGE INTO primitive_table t USING primitive_table_src src + |ON t.l = src.l + |WHEN NOT MATCHED THEN + | INSERT (i, l, txt) VALUES ('1', 1, 'value') + |""".stripMargin, + "Cannot safely cast") + + val (_, notMatchedActions, _) = + parseAndAlignAssignments( + """MERGE INTO primitive_table t USING primitive_table_src src + |ON t.i = src.i + |WHEN NOT MATCHED THEN + | INSERT (i, l, txt) VALUES (1L, 10L, 'new') + |""".stripMargin) + + notMatchedActions match { + case Seq(InsertAction(_, assignments)) => + assignments match { + case Seq( + Assignment( + i: AttributeReference, + CheckOverflowInTableInsert( + Cast(LongLiteral(1L), IntegerType, _, EvalMode.ANSI), _)), + Assignment(l: AttributeReference, LongLiteral(10L)), + Assignment(txt: AttributeReference, StringLiteral("new"))) => + + assert(i.name == "i") + assert(l.name == "l") + assert(txt.name == "txt") + + case assignments => + fail(s"Unexpected assignments: $assignments") + } + + case other => + fail(s"Unexpected actions: $other") + } + } + } + + test("strict mode in UPDATE assignments") { + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.STRICT.toString) { + Seq("WHEN MATCHED", "WHEN NOT MATCHED BY SOURCE").foreach { clause => + val plan1 = parseAndResolve( + s"""MERGE INTO primitive_table t USING primitive_table_src src + |ON t.l = src.l + |$clause THEN + | UPDATE SET i = CAST(NULL AS INT) + |""".stripMargin) + assertNullCheckExists(plan1, Seq("i")) + + val plan2 = parseAndResolve( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_i = CAST(NULL AS INT) + |""".stripMargin) + assertNullCheckExists(plan2, Seq("s", "n_i")) + + val plan3 = parseAndResolve( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_s.dn_i = CAST(NULL AS INT) + |""".stripMargin) + assertNullCheckExists(plan3, Seq("s", "n_s", "dn_i")) + + val plan4 = parseAndResolve( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_s = named_struct('dn_i', CAST (NULL AS INT), 'dn_l', 1L) + |""".stripMargin) + assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i")) + + assertAnalysisException( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_s = named_struct('dn_i', 1) + |""".stripMargin, + "Cannot find data for output column 's.n_s.dn_l'") + + // strict mode does NOT allow string to int casts + assertAnalysisException( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_s = named_struct('dn_i', 'string-value', 'dn_l', 1L) + |""".stripMargin, + "Cannot safely cast") + + // strict mode does not allow long to int casts + assertAnalysisException( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET i = 1L + |""".stripMargin, + "Cannot safely cast") + } + } + } + + test("legacy mode assignments") { + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.LEGACY.toString) { + assertAnalysisException( + s"""MERGE INTO nested_struct_table t USING nested_struct_table src + |ON t.i = src.i + |WHEN MATCHED THEN + | UPDATE SET i = 1L + |""".stripMargin, + "LEGACY store assignment policy is disallowed in Spark data source V2") + } + } + + test("align assignments with default values") { + val (matchedActions, notMatchedActions, notMatchedBySourceActions) = + parseAndAlignAssignments( + """MERGE INTO default_values_table t USING default_values_table s + |ON t.b = s.b + |WHEN MATCHED THEN + | UPDATE SET t.i = DEFAULT + |WHEN NOT MATCHED THEN + | INSERT (i, b) VALUES (DEFAULT, false) + |WHEN NOT MATCHED BY SOURCE THEN + | UPDATE SET t.i = DEFAULT""".stripMargin) + + matchedActions match { + case Seq(UpdateAction(None, assignments)) => + assignments match { + case Seq( + Assignment(b: AttributeReference, bValue: AttributeReference), + Assignment(i: AttributeReference, IntegerLiteral(42))) => + + assert(b.name == "b" && b == bValue) + assert(i.name == "i") + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + + notMatchedActions match { + case Seq(InsertAction(None, assignments)) => + assignments match { + case Seq( + Assignment(b: AttributeReference, BooleanLiteral(false)), + Assignment(i: AttributeReference, IntegerLiteral(42))) => + + assert(b.name == "b") + assert(i.name == "i") + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + + notMatchedBySourceActions match { + case Seq(UpdateAction(None, assignments)) => + assignments match { + case Seq( + Assignment(b: AttributeReference, bValue: AttributeReference), + Assignment(i: AttributeReference, IntegerLiteral(42))) => + + assert(b.name == "b" && b == bValue) + assert(i.name == "i") + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + } + + private def parseAndAlignAssignments( + query: String): (Seq[MergeAction], Seq[MergeAction], Seq[MergeAction]) = { + + parseAndResolve(query) match { + case m: MergeIntoTable => (m.matchedActions, m.notMatchedActions, m.notMatchedBySourceActions) + case plan => fail("Expected MergeIntoTable, but got:\n" + plan.treeString) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala index a173106db99e9..96c4580745fd2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignUpdateAssignmentsSuite.scala @@ -17,151 +17,14 @@ package org.apache.spark.sql.execution.command -import java.util.Collections - -import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{mock, when} -import org.mockito.invocation.InvocationOnMock - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, FunctionRegistry, NoSuchTableException, ResolveSessionCatalog} -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, AttributeReference, BooleanLiteral, Cast, CheckOverflowInTableInsert, CreateNamedStruct, EvalMode, GetStructField, IntegerLiteral, LambdaFunction, LongLiteral, MapFromArrays, StringLiteral} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, StaticInvoke} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan, UpdateTable} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, CatalogV2Util, Column, ColumnDefaultValue, Identifier, SupportsRowLevelOperations, TableCapability, TableCatalog} -import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform} -import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy -import org.apache.spark.sql.types.{BooleanType, IntegerType, StructType} - -class AlignUpdateAssignmentsSuite extends AnalysisTest { - - private val primitiveTable = { - val t = mock(classOf[SupportsRowLevelOperations]) - val schema = new StructType() - .add("i", "INT", nullable = false) - .add("l", "LONG") - .add("txt", "STRING") - when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) - when(t.partitioning()).thenReturn(Array.empty[Transform]) - t - } - - private val nestedStructTable = { - val t = mock(classOf[SupportsRowLevelOperations]) - val schema = new StructType() - .add("i", "INT") - .add( - "s", - "STRUCT>", - nullable = false) - .add("txt", "STRING") - when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) - when(t.partitioning()).thenReturn(Array.empty[Transform]) - t - } - - private val mapArrayTable = { - val t = mock(classOf[SupportsRowLevelOperations]) - val schema = new StructType() - .add("i", "INT") - .add("a", "ARRAY>") - .add("m", "MAP") - .add("txt", "STRING") - when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) - when(t.partitioning()).thenReturn(Array.empty[Transform]) - t - } - - private val charVarcharTable = { - val t = mock(classOf[SupportsRowLevelOperations]) - val schema = new StructType() - .add("c", "CHAR(5)") - .add( - "s", - "STRUCT", - nullable = false) - .add( - "a", - "ARRAY>", - nullable = false) - .add( - "mk", - "MAP, STRING>", - nullable = false) - .add( - "mv", - "MAP>", - nullable = false) - when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) - when(t.partitioning()).thenReturn(Array.empty[Transform]) - t - } - - private val acceptsAnySchemaTable = { - val t = mock(classOf[SupportsRowLevelOperations]) - val schema = new StructType() - .add("i", "INT", nullable = false) - .add("l", "LONG") - .add("txt", "STRING") - when(t.columns()).thenReturn(CatalogV2Util.structTypeToV2Columns(schema)) - when(t.partitioning()).thenReturn(Array.empty[Transform]) - when(t.capabilities()).thenReturn(Collections.singleton(TableCapability.ACCEPT_ANY_SCHEMA)) - t - } - - private val defaultValuesTable = { - val t = mock(classOf[SupportsRowLevelOperations]) - val iDefault = new ColumnDefaultValue("42", LiteralValue(42, IntegerType)) - when(t.columns()).thenReturn(Array( - Column.create("b", BooleanType, true, null, null), - Column.create("i", IntegerType, true, null, iDefault, null))) - when(t.partitioning()).thenReturn(Array.empty[Transform]) - t - } - - private val v2Catalog = { - val newCatalog = mock(classOf[TableCatalog]) - when(newCatalog.loadTable(any())).thenAnswer((invocation: InvocationOnMock) => { - val ident = invocation.getArgument[Identifier](0) - ident.name match { - case "primitive_table" => primitiveTable - case "nested_struct_table" => nestedStructTable - case "map_array_table" => mapArrayTable - case "char_varchar_table" => charVarcharTable - case "accepts_any_schema_table" => acceptsAnySchemaTable - case "default_values_table" => defaultValuesTable - case name => throw new NoSuchTableException(Seq(name)) - } - }) - when(newCatalog.name()).thenReturn("cat") - newCatalog - } +import org.apache.spark.sql.types.IntegerType - private val v1SessionCatalog = - new SessionCatalog(new InMemoryCatalog(), FunctionRegistry.builtin, new SQLConf()) - - private val v2SessionCatalog = new V2SessionCatalog(v1SessionCatalog) - - private val catalogManager = { - val manager = mock(classOf[CatalogManager]) - when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => { - invocation.getArgument[String](0) match { - case "testcat" => v2Catalog - case CatalogManager.SESSION_CATALOG_NAME => v2SessionCatalog - case name => throw new CatalogNotFoundException(s"No such catalog: $name") - } - }) - when(manager.currentCatalog).thenReturn(v2Catalog) - when(manager.currentNamespace).thenReturn(Array.empty[String]) - when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog) - when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog) - manager - } +class AlignUpdateAssignmentsSuite extends AlignAssignmentsSuite { test("align assignments (primitive types)") { val sql1 = "UPDATE primitive_table AS t SET t.txt = 'new', t.i = 1" @@ -752,28 +615,4 @@ class AlignUpdateAssignmentsSuite extends AnalysisTest { case plan => fail("Expected UpdateTable, but got:\n" + plan.treeString) } } - - private def parseAndResolve(query: String): LogicalPlan = { - val analyzer = new Analyzer(catalogManager) { - override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq( - new ResolveSessionCatalog(catalogManager)) - } - val analyzed = analyzer.execute(CatalystSqlParser.parsePlan(query)) - analyzer.checkAnalysis(analyzed) - analyzed - } - - private def assertAnalysisException(query: String, messages: String*): Unit = { - val exception = intercept[AnalysisException] { - parseAndResolve(query) - } - messages.foreach(message => assert(exception.message.contains(message))) - } - - private def assertNullCheckExists(plan: LogicalPlan, colPath: Seq[String]): Unit = { - val asserts = plan.expressions.flatMap(e => e.collect { - case assert: AssertNotNull if assert.walkedTypePath == colPath => assert - }) - assert(asserts.nonEmpty, s"Must have NOT NULL checks for col $colPath") - } }