diff --git a/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala index 3369a49c57ee..ffa3bafacbf4 100644 --- a/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala +++ b/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala @@ -21,7 +21,7 @@ package org.apache.iceberg.spark.extensions import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.catalyst.analysis.{AlignMergeIntoTable, DeleteFromTablePredicateCheck, ProcedureArgumentCoercion, ResolveProcedures} -import org.apache.spark.sql.catalyst.optimizer.{OptimizeConditionsInRowLevelOperations, PullupCorrelatedPredicatesInRowLevelOperations, RewriteDelete} +import org.apache.spark.sql.catalyst.optimizer.{OptimizeConditionsInRowLevelOperations, PullupCorrelatedPredicatesInRowLevelOperations, RewriteDelete, RewriteMergeInto} import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy @@ -43,6 +43,7 @@ class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // TODO: PullupCorrelatedPredicates should handle row-level operations extensions.injectOptimizerRule { _ => PullupCorrelatedPredicatesInRowLevelOperations } extensions.injectOptimizerRule { spark => RewriteDelete(spark.sessionState.conf) } + extensions.injectOptimizerRule { spark => RewriteMergeInto(spark.sessionState.conf) } // planner extensions extensions.injectPlannerStrategy { spark => ExtendedDataSourceV2Strategy(spark) } diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala index e86f21f553bb..8e9b4a1f541e 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala @@ -76,8 +76,8 @@ case class RewriteDelete(conf: SQLConf) extends Rule[LogicalPlan] with RewriteRo remainingRowsPlan: LogicalPlan, output: Seq[AttributeReference]): LogicalPlan = { - val fileNameCol = findOutputAttr(remainingRowsPlan, FILE_NAME_COL) - val rowPosCol = findOutputAttr(remainingRowsPlan, ROW_POS_COL) + val fileNameCol = findOutputAttr(remainingRowsPlan.output, FILE_NAME_COL) + val rowPosCol = findOutputAttr(remainingRowsPlan.output, ROW_POS_COL) val order = Seq(SortOrder(fileNameCol, Ascending), SortOrder(rowPosCol, Ascending)) val numShufflePartitions = SQLConf.get.numShufflePartitions val repartition = RepartitionByExpression(Seq(fileNameCol), remainingRowsPlan, numShufflePartitions) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala new file mode 100644 index 000000000000..0c53dbcd7639 --- /dev/null +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.InputFileName +import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.FullOuter +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.JoinHint +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeAction +import org.apache.spark.sql.catalyst.plans.logical.MergeInto +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoParams +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoTable +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceData +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.utils.RewriteRowLevelOperationHelper +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BooleanType + +case class RewriteMergeInto(conf: SQLConf) extends Rule[LogicalPlan] with RewriteRowLevelOperationHelper { + private val ROW_FROM_SOURCE = "_row_from_source_" + private val ROW_FROM_TARGET = "_row_from_target_" + private val TRUE_LITERAL = Literal(true, BooleanType) + private val FALSE_LITERAL = Literal(false, BooleanType) + + import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Implicits._ + + override def resolver: Resolver = conf.resolver + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan resolveOperators { + case MergeIntoTable(target: DataSourceV2Relation, source: LogicalPlan, cond, matchedActions, notMatchedActions) => + // Construct the plan to prune target based on join condition between source and target. + val writeInfo = newWriteInfo(target.schema) + val mergeBuilder = target.table.asMergeable.newMergeBuilder("merge", writeInfo) + val matchingRowsPlanBuilder = (rel: DataSourceV2ScanRelation) => + Join(source, rel, Inner, Some(cond), JoinHint.NONE) + val targetTableScan = buildScanPlan(target.table, target.output, mergeBuilder, cond, matchingRowsPlanBuilder) + + // Construct an outer join to help track changes in source and target. + // TODO : Optimize this to use LEFT ANTI or RIGHT OUTER when applicable. + val sourceTableProj = source.output ++ Seq(Alias(TRUE_LITERAL, ROW_FROM_SOURCE)()) + val targetTableProj = target.output ++ Seq(Alias(TRUE_LITERAL, ROW_FROM_TARGET)()) + val newTargetTableScan = Project(targetTableProj, targetTableScan) + val newSourceTableScan = Project(sourceTableProj, source) + val joinPlan = Join(newSourceTableScan, newTargetTableScan, FullOuter, Some(cond), JoinHint.NONE) + + // Construct the plan to replace the data based on the output of `MergeInto` + val mergeParams = MergeIntoParams( + isSourceRowNotPresent = IsNull(findOutputAttr(joinPlan.output, ROW_FROM_SOURCE)), + isTargetRowNotPresent = IsNull(findOutputAttr(joinPlan.output, ROW_FROM_TARGET)), + matchedConditions = matchedActions.map(getClauseCondition), + matchedOutputs = matchedActions.map(actionOutput(_, target.output)), + notMatchedConditions = notMatchedActions.map(getClauseCondition), + notMatchedOutputs = notMatchedActions.map(actionOutput(_, target.output)), + targetOutput = target.output :+ FALSE_LITERAL, + deleteOutput = target.output :+ TRUE_LITERAL, + joinedAttributes = joinPlan.output + ) + val mergePlan = MergeInto(mergeParams, target, joinPlan) + val batchWrite = mergeBuilder.asWriteBuilder.buildForBatch() + ReplaceData(target, batchWrite, mergePlan) + } + } + + private def actionOutput(clause: MergeAction, targetOutputCols: Seq[Expression]): Seq[Expression] = { + clause match { + case u: UpdateAction => + u.assignments.map(_.value) :+ FALSE_LITERAL + case _: DeleteAction => + targetOutputCols :+ TRUE_LITERAL + case i: InsertAction => + i.assignments.map(_.value) :+ FALSE_LITERAL + } + } + + private def getClauseCondition(clause: MergeAction): Expression = { + clause.condition.getOrElse(TRUE_LITERAL) + } +} + diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala new file mode 100644 index 000000000000..78a0cb57a3ed --- /dev/null +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +case class MergeInto( + mergeIntoProcessor: MergeIntoParams, + targetRelation: DataSourceV2Relation, + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = targetRelation.output +} + +case class MergeIntoParams( + isSourceRowNotPresent: Expression, + isTargetRowNotPresent: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Expression]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Expression]], + targetOutput: Seq[Expression], + deleteOutput: Seq[Expression], + joinedAttributes: Seq[Attribute]) extends Serializable diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala index f7ad083be9fe..0d79e8b0e498 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala @@ -67,7 +67,7 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging { scan match { case filterable: SupportsFileFilter => - val matchingFilePlan = buildFileFilterPlan(matchingRowsPlanBuilder(scanRelation)) + val matchingFilePlan = buildFileFilterPlan(scanRelation.output, matchingRowsPlanBuilder(scanRelation)) DynamicFileFilter(scanRelation, matchingFilePlan, filterable) case _ => scanRelation @@ -102,15 +102,15 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging { LogicalWriteInfoImpl(queryId = uuid.toString, schema, CaseInsensitiveStringMap.empty) } - private def buildFileFilterPlan(matchingRowsPlan: LogicalPlan): LogicalPlan = { - val fileAttr = findOutputAttr(matchingRowsPlan, FILE_NAME_COL) + private def buildFileFilterPlan(tableAttrs: Seq[AttributeReference], matchingRowsPlan: LogicalPlan): LogicalPlan = { + val fileAttr = findOutputAttr(tableAttrs, FILE_NAME_COL) val agg = Aggregate(Seq(fileAttr), Seq(fileAttr), matchingRowsPlan) - Project(Seq(findOutputAttr(agg, FILE_NAME_COL)), agg) + Project(Seq(findOutputAttr(agg.output, FILE_NAME_COL)), agg) } - protected def findOutputAttr(plan: LogicalPlan, attrName: String): Attribute = { - plan.output.find(attr => resolver(attr.name, attrName)).getOrElse { - throw new AnalysisException(s"Cannot find $attrName in ${plan.output}") + protected def findOutputAttr(attrs: Seq[Attribute], attrName: String): Attribute = { + attrs.find(attr => resolver(attr.name, attrName)).getOrElse { + throw new AnalysisException(s"Cannot find $attrName in $attrs") } } diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala index 9a93a591962a..3ba876da8846 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Call import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField import org.apache.spark.sql.catalyst.plans.logical.DynamicFileFilter import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeInto import org.apache.spark.sql.catalyst.plans.logical.ReplaceData import org.apache.spark.sql.catalyst.plans.logical.SetWriteOrder import org.apache.spark.sql.connector.catalog.Identifier @@ -75,6 +76,9 @@ case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy { case ReplaceData(_, batchWrite, query) => ReplaceDataExec(batchWrite, planLater(query)) :: Nil + case MergeInto(mergeIntoProcessor, targetRelation, child) => + MergeIntoExec(mergeIntoProcessor, targetRelation, planLater(child)) :: Nil + case _ => Nil } diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala new file mode 100644 index 000000000000..80370ad21c5a --- /dev/null +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.BasePredicate +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoParams +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode + +case class MergeIntoExec( + mergeIntoParams: MergeIntoParams, + @transient targetRelation: DataSourceV2Relation, + override val child: SparkPlan) extends UnaryExecNode { + + override def output: Seq[Attribute] = targetRelation.output + + protected override def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { + processPartition(mergeIntoParams, _) + } + } + + private def generateProjection(exprs: Seq[Expression], attrs: Seq[Attribute]): UnsafeProjection = { + UnsafeProjection.create(exprs, attrs) + } + + private def generatePredicate(expr: Expression, attrs: Seq[Attribute]): BasePredicate = { + GeneratePredicate.generate(expr, attrs) + } + + def applyProjection( + actions: Seq[(BasePredicate, UnsafeProjection)], + projectTargetCols: UnsafeProjection, + projectDeleteRow: UnsafeProjection, + inputRow: InternalRow, + targetRowNotPresent: Boolean): InternalRow = { + + + // Find the first combination where the predicate evaluates to true. + // In case when there are overlapping condition in the MATCHED + // clauses, for the first one that satisfies the predicate, the + // corresponding action is applied. For example: + // WHEN MATCHED AND id > 1 AND id < 10 UPDATE * + // WHEN MATCHED AND id = 5 OR id = 21 DELETE + // In above case, when id = 5, it applies both that matched predicates. In this + // case the first one we see is applied. + + val pair = actions.find { + case (predicate, _) => predicate.eval(inputRow) + } + + // Now apply the appropriate projection to either : + // - Insert a row into target + // - Update a row of target + // - Delete a row in target. The projected row will have the deleted bit set. + pair match { + case Some((_, projection)) => + projection.apply(inputRow) + case None => + if (targetRowNotPresent) { + projectDeleteRow.apply(inputRow) + } else { + projectTargetCols.apply(inputRow) + } + } + } + + def processPartition( + params: MergeIntoParams, + rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = { + + val joinedAttrs = params.joinedAttributes + val isSourceRowNotPresentPred = generatePredicate(params.isSourceRowNotPresent, joinedAttrs) + val isTargetRowNotPresentPred = generatePredicate(params.isTargetRowNotPresent, joinedAttrs) + val matchedPreds = params.matchedConditions.map(generatePredicate(_, joinedAttrs)) + val matchedProjs = params.matchedOutputs.map(generateProjection(_, joinedAttrs)) + val notMatchedPreds = params.notMatchedConditions.map(generatePredicate(_, joinedAttrs)) + val notMatchedProjs = params.notMatchedOutputs.map(generateProjection(_, joinedAttrs)) + val projectTargetCols = generateProjection(params.targetOutput, joinedAttrs) + val projectDeletedRow = generateProjection(params.deleteOutput, joinedAttrs) + val nonMatchedPairs = notMatchedPreds zip notMatchedProjs + val matchedPairs = matchedPreds zip matchedProjs + + def shouldDeleteRow(row: InternalRow): Boolean = + row.getBoolean(params.targetOutput.size - 1) + + /** + * This method is responsible for processing a input row to emit the resultant row with an + * additional column that indicates whether the row is going to be included in the final + * output of merge or not. + * 1. If there is a target row for which there is no corresponding source row (join condition not met) + * - Only project the target columns with deleted flag set to false. + * 2. If there is a source row for which there is no corresponding target row (join condition not met) + * - Apply the not matched actions (i.e INSERT actions) if non match conditions are met. + * 3. If there is a source row for which there is a corresponding target row (join condition met) + * - Apply the matched actions (i.e DELETE or UPDATE actions) if match conditions are met. + */ + def processRow(inputRow: InternalRow): InternalRow = { + if (isSourceRowNotPresentPred.eval(inputRow)) { + projectTargetCols.apply(inputRow) + } else if (isTargetRowNotPresentPred.eval(inputRow)) { + applyProjection(nonMatchedPairs, projectTargetCols, projectDeletedRow, inputRow, true) + } else { + applyProjection(matchedPairs, projectTargetCols, projectDeletedRow, inputRow, false) + } + } + + rowIterator + .map(processRow) + .filterNot(shouldDeleteRow) + } +} diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java new file mode 100644 index 000000000000..bf304e1890cf --- /dev/null +++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; + +public class TestMergeIntoTable extends SparkRowLevelOperationsTestBase { + private final String sourceName; + private final String targetName; + + @Parameterized.Parameters( + name = "catalogName = {0}, implementation = {1}, config = {2}, format = {3}, vectorized = {4}") + public static Object[][] parameters() { + return new Object[][] { + { "testhive", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default" + ), + "parquet", + true + }, + { "spark_catalog", SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + "parquet", + false + } + }; + } + + public TestMergeIntoTable(String catalogName, String implementation, Map config, + String fileFormat, Boolean vectorized) { + super(catalogName, implementation, config, fileFormat, vectorized); + this.sourceName = tableName("source"); + this.targetName = tableName("target"); + } + + @BeforeClass + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + protected Map extraTableProperties() { + return ImmutableMap.of(TableProperties.MERGE_MODE, TableProperties.MERGE_MODE_DEFAULT); + } + + @Before + public void createTables() { + createAndInitUnPartitionedTargetTable(targetName); + createAndInitSourceTable(sourceName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", targetName); + sql("DROP TABLE IF EXISTS %s", sourceName); + } + + @Test + public void testEmptyTargetInsertAllNonMatchingRows() throws NoSuchTableException { + append(sourceName, new Employee(1, "emp-id-1"), new Employee(2, "emp-id-2"), new Employee(3, "emp-id-3")); + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + + "ON target.id = source.id " + + "WHEN NOT MATCHED THEN INSERT * "; + + sql(sqlText, targetName, sourceName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2"), row(3, "emp-id-3")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testEmptyTargetInsertOnlyMatchingRows() throws NoSuchTableException { + append(sourceName, new Employee(1, "emp-id-1"), new Employee(2, "emp-id-2"), new Employee(3, "emp-id-3")); + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + + "ON target.id = source.id " + + "WHEN NOT MATCHED AND (source.id >= 2) THEN INSERT * "; + + sql(sqlText, targetName, sourceName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "emp-id-2"), row(3, "emp-id-3")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testOnlyUpdate() throws NoSuchTableException { + append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-six")); + append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + + "ON target.id = source.id " + + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * "; + + sql(sqlText, targetName, sourceName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-1"), row(6, "emp-id-six")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testOnlyDelete() throws NoSuchTableException { + append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); + append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + + "ON target.id = source.id " + + "WHEN MATCHED AND target.id = 6 THEN DELETE"; + + sql(sqlText, targetName, sourceName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-one")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testAllCauses() throws NoSuchTableException { + append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); + append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + + "ON target.id = source.id " + + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * " + + "WHEN MATCHED AND target.id = 6 THEN DELETE " + + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * "; + + sql(sqlText, targetName, sourceName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testAllCausesWithExplicitColumnSpecification() throws NoSuchTableException { + append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); + append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + + "ON target.id = source.id " + + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET target.id = source.id, target.dep = source.dep " + + "WHEN MATCHED AND target.id = 6 THEN DELETE " + + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT (target.id, target.dep) VALUES (source.id, source.dep) "; + + sql(sqlText, targetName, sourceName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testSourceCTE() throws NoSuchTableException { + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhive")); + + append(targetName, new Employee(2, "emp-id-two"), new Employee(6, "emp-id-6")); + append(sourceName, new Employee(2, "emp-id-3"), new Employee(1, "emp-id-2"), new Employee(5, "emp-id-6")); + String sourceCTE = "WITH cte1 AS (SELECT id + 1 AS id, dep FROM source)"; + String sqlText = sourceCTE + " MERGE INTO %s AS target " + + "USING cte1" + " AS source " + + "ON target.id = source.id " + + "WHEN MATCHED AND target.id = 2 THEN UPDATE SET * " + + "WHEN MATCHED AND target.id = 6 THEN DELETE " + + "WHEN NOT MATCHED AND source.id = 3 THEN INSERT * "; + + sql(sqlText, targetName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "emp-id-2"), row(3, "emp-id-3")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testSourceFromSetOps() throws NoSuchTableException { + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhive")); + + append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); + append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); + String derivedSource = " ( SELECT * FROM source WHERE id = 2 " + + " UNION ALL " + + " SELECT * FROM source WHERE id = 1 OR id = 6)"; + String sqlText = "MERGE INTO %s AS target " + + "USING " + derivedSource + " AS source " + + "ON target.id = source.id " + + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * " + + "WHEN MATCHED AND target.id = 6 THEN DELETE " + + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * "; + + sql(sqlText, targetName); + sql("SELECT * FROM %s ORDER BY id, dep", targetName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + protected void createAndInitUnPartitionedTargetTable(String tabName) { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg", tabName); + initTable(tabName); + } + + protected void createAndInitSourceTable(String tabName) { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY (dep)", tabName); + initTable(tabName); + } + + private void initTable(String tabName) { + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tabName, DEFAULT_FILE_FORMAT, fileFormat); + + switch (fileFormat) { + case "parquet": + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%b')", tabName, PARQUET_VECTORIZATION_ENABLED, vectorized); + break; + case "orc": + Assert.assertTrue(vectorized); + break; + case "avro": + Assert.assertFalse(vectorized); + break; + } + + Map props = extraTableProperties(); + props.forEach((prop, value) -> { + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tabName, prop, value); + }); + } + + protected void append(String tabName, Employee... employees) throws NoSuchTableException { + List input = Arrays.asList(employees); + Dataset inputDF = spark.createDataFrame(input, Employee.class); + inputDF.coalesce(1).writeTo(tabName).append(); + } +}