From 9db80c8f340b2e9c3e11f19a0a687d411c8dbc29 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 16 Dec 2020 14:31:31 -0800 Subject: [PATCH 1/9] Spark MERGE INTO Support (copy-on-write implementation --- .../IcebergSparkSessionExtensions.scala | 3 +- .../catalyst/optimizer/RewriteMergeInto.scala | 164 ++++++++++++++++++ .../catalyst/plans/logical/MergeInto.scala | 101 +++++++++++ .../v2/ExtendedDataSourceV2Strategy.scala | 3 + .../datasources/v2/MergeIntoExec.scala | 38 ++++ 5 files changed, 308 insertions(+), 1 deletion(-) create mode 100644 spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala create mode 100644 spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala create mode 100644 spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala diff --git a/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala index 3369a49c57ee..1884c4c62046 100644 --- a/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala +++ b/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala @@ -21,7 +21,7 @@ package org.apache.iceberg.spark.extensions import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.catalyst.analysis.{AlignMergeIntoTable, DeleteFromTablePredicateCheck, ProcedureArgumentCoercion, ResolveProcedures} -import org.apache.spark.sql.catalyst.optimizer.{OptimizeConditionsInRowLevelOperations, PullupCorrelatedPredicatesInRowLevelOperations, RewriteDelete} +import org.apache.spark.sql.catalyst.optimizer.{OptimizeConditionsInRowLevelOperations, PullupCorrelatedPredicatesInRowLevelOperations, RewriteDelete, RewriteMergeInto} import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy @@ -43,6 +43,7 @@ class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // TODO: PullupCorrelatedPredicates should handle row-level operations extensions.injectOptimizerRule { _ => PullupCorrelatedPredicatesInRowLevelOperations } extensions.injectOptimizerRule { spark => RewriteDelete(spark.sessionState.conf) } + extensions.injectOptimizerRule { _ => RewriteMergeInto } // planner extensions extensions.injectPlannerStrategy { spark => ExtendedDataSourceV2Strategy(spark) } diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala new file mode 100644 index 000000000000..c91c8c0d5949 --- /dev/null +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import java.util.UUID +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.iceberg.read.SupportsFileFilter +import org.apache.spark.sql.connector.iceberg.write.MergeBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +object RewriteMergeInto extends Rule[LogicalPlan] + with PredicateHelper + with Logging { + val ROW_ID_COL = "_row_id_" + val FILE_NAME_COL = "_file_name_" + val SOURCE_ROW_PRESENT_COL = "_source_row_present_" + val TARGET_ROW_PRESENT_COL = "_target_row_present_" + + import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan resolveOperators { + // rewrite all operations that require reading the table to delete records + case MergeIntoTable(target: DataSourceV2Relation, + source: LogicalPlan, cond, actions, notActions) => + // Find the files in target that matches the JOIN condition from source. + val targetOutputCols = target.output + val newProjectCols = target.output ++ Seq(Alias(InputFileName(), FILE_NAME_COL)()) + val newTargetTable = Project(newProjectCols, target) + val prunedTargetPlan = Join(source, newTargetTable, Inner, Some(cond), JoinHint.NONE) + + val writeInfo = newWriteInfo(target.schema) + val mergeBuilder = target.table.asMergeable.newMergeBuilder("delete", writeInfo) + val targetTableScan = buildScanPlan(target.table, target.output, mergeBuilder, prunedTargetPlan) + val sourceTableProj = source.output ++ Seq(Alias(lit(true).expr, SOURCE_ROW_PRESENT_COL)()) + val targetTableProj = target.output ++ Seq(Alias(lit(true).expr, TARGET_ROW_PRESENT_COL)()) + val newTargetTableScan = Project(targetTableProj, targetTableScan) + val newSourceTableScan = Project(sourceTableProj, source) + val joinPlan = Join(newSourceTableScan, newTargetTableScan, FullOuter, Some(cond), JoinHint.NONE) + + val mergeIntoProcessor = new MergeIntoProcessor( + isSourceRowNotPresent = resolveExprs(Seq(col(SOURCE_ROW_PRESENT_COL).isNull.expr), joinPlan).head, + isTargetRowNotPresent = resolveExprs(Seq(col(TARGET_ROW_PRESENT_COL).isNull.expr), joinPlan).head, + matchedConditions = actions.map(resolveClauseCondition(_, joinPlan)), + matchedOutputs = actions.map(actionOutput(_, targetOutputCols, joinPlan)), + notMatchedConditions = notActions.map(resolveClauseCondition(_, joinPlan)), + notMatchedOutputs = notActions.map(actionOutput(_, targetOutputCols, joinPlan)), + targetOutput = resolveExprs(targetOutputCols :+ Literal(false), joinPlan), + joinedAttributes = joinPlan.output + ) + + val mergePlan = MergeInto(mergeIntoProcessor, target, joinPlan) + val batchWrite = mergeBuilder.asWriteBuilder.buildForBatch() + ReplaceData(target, batchWrite, mergePlan) + } + } + + private def buildScanPlan( + table: Table, + output: Seq[AttributeReference], + mergeBuilder: MergeBuilder, + prunedTargetPlan: LogicalPlan): LogicalPlan = { + + val scanBuilder = mergeBuilder.asScanBuilder + val scan = scanBuilder.build() + val scanRelation = DataSourceV2ScanRelation(table, scan, output) + + scan match { + case _: SupportsFileFilter => + val matchingFilePlan = buildFileFilterPlan(prunedTargetPlan) + val dynamicFileFilter = DynamicFileFilter(scanRelation, matchingFilePlan) + dynamicFileFilter + case _ => + scanRelation + } + } + + private def newWriteInfo(schema: StructType): LogicalWriteInfo = { + val uuid = UUID.randomUUID() + LogicalWriteInfoImpl(queryId = uuid.toString, schema, CaseInsensitiveStringMap.empty) + } + + private def buildFileFilterPlan(prunedTargetPlan: LogicalPlan): LogicalPlan = { + val fileAttr = findOutputAttr(prunedTargetPlan, FILE_NAME_COL) + Aggregate(Seq(fileAttr), Seq(fileAttr), prunedTargetPlan) + } + + private def findOutputAttr(plan: LogicalPlan, attrName: String): Attribute = { + val resolver = SQLConf.get.resolver + plan.output.find(attr => resolver(attr.name, attrName)).getOrElse { + throw new AnalysisException(s"Cannot find $attrName in ${plan.output}") + } + } + + private def resolveExprs(exprs: Seq[Expression], plan: LogicalPlan): Seq[Expression] = { + val spark = SparkSession.active + exprs.map { expr => resolveExpressionInternal(spark, expr, plan) } + } + + def getTargetOutputCols(target: DataSourceV2Relation): Seq[NamedExpression] = { + target.schema.map { col => + target.output.find(attr => SQLConf.get.resolver(attr.name, col.name)).getOrElse { + Alias(Literal(null, col.dataType), col.name)() + } + } + } + + def actionOutput(clause: MergeAction, + targetOutputCols: Seq[Expression], + plan: LogicalPlan): Seq[Expression] = { + val exprs = clause match { + case u: UpdateAction => + u.assignments.map(_.value) :+ Literal(false) + case _: DeleteAction => + targetOutputCols :+ Literal(true) + case i: InsertAction => + i.assignments.map(_.value) :+ Literal(false) + } + resolveExprs(exprs, plan) + } + + def resolveClauseCondition(clause: MergeAction, plan: LogicalPlan): Expression = { + val condExpr = clause.condition.getOrElse(Literal(true)) + resolveExprs(Seq(condExpr), plan).head + } + + def resolveExpressionInternal(spark: SparkSession, expr: Expression, plan: LogicalPlan): Expression = { + val dummyPlan = Filter(expr, plan) + spark.sessionState.analyzer.execute(dummyPlan) match { + case Filter(resolvedExpr, _) => resolvedExpr + case _ => throw new AnalysisException(s"Could not resolve expression $expr", plan = Option(plan)) + } + } +} + diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala new file mode 100644 index 000000000000..810a72e3a6b7 --- /dev/null +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.functions.col + +case class MergeInto(mergeIntoProcessor: MergeIntoProcessor, + targetRelation: DataSourceV2Relation, + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = targetRelation.output +} + +class MergeIntoProcessor(isSourceRowNotPresent: Expression, + isTargetRowNotPresent: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Expression]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Expression]], + targetOutput: Seq[Expression], + joinedAttributes: Seq[Attribute]) extends Serializable { + + private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = { + UnsafeProjection.create(exprs, joinedAttributes) + } + + private def generatePredicate(expr: Expression): BasePredicate = { + GeneratePredicate.generate(expr, joinedAttributes) + } + + def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = { + val isSourceRowNotPresentPred = generatePredicate(isSourceRowNotPresent) + val isTargetRowNotPresentPred = generatePredicate(isTargetRowNotPresent) + val matchedPreds = matchedConditions.map(generatePredicate) + val matchedProjs = matchedOutputs.map(generateProjection) + val notMatchedPreds = notMatchedConditions.map(generatePredicate) + val notMatchedProjs = notMatchedOutputs.map(generateProjection) + val projectTargetCols = generateProjection(targetOutput) + + def shouldDeleteRow(row: InternalRow): Boolean = + row.getBoolean(targetOutput.size - 1) + + def applyProjection(predicates: Seq[BasePredicate], + projections: Seq[UnsafeProjection], + inputRow: InternalRow): InternalRow = { + // Find the first combination where the predicate evaluates to true + val pair = (predicates zip projections).find { + case (predicate, _) => predicate.eval(inputRow) + } + + // Now apply the appropriate projection to either : + // - Insert a row into target + // - Update a row of target + // - Delete a row in target. The projected row will have the deleted bit set. + pair match { + case Some((_, projection)) => + projection.apply(inputRow) + case None => + projectTargetCols.apply(inputRow) + } + } + + def processRow(inputRow: InternalRow): InternalRow = { + isSourceRowNotPresentPred.eval(inputRow) match { + case true => projectTargetCols.apply(inputRow) + case false => + if (isTargetRowNotPresentPred.eval(inputRow)) { + applyProjection(notMatchedPreds, notMatchedProjs, inputRow) + } else { + applyProjection(matchedPreds, matchedProjs, inputRow) + } + } + } + + rowIterator + .map(processRow) + .filter(!shouldDeleteRow(_)) + } +} diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala index 9a93a591962a..6dc9b56aae94 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField import org.apache.spark.sql.catalyst.plans.logical.DynamicFileFilter import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.ReplaceData +import org.apache.spark.sql.catalyst.plans.logical.MergeInto import org.apache.spark.sql.catalyst.plans.logical.SetWriteOrder import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableCatalog @@ -75,6 +76,8 @@ case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy { case ReplaceData(_, batchWrite, query) => ReplaceDataExec(batchWrite, planLater(query)) :: Nil + case MergeInto(mergeIntoProcessor, targetRelation, child) => + MergeIntoExec(mergeIntoProcessor, targetRelation, planLater(child)) :: Nil case _ => Nil } diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala new file mode 100644 index 000000000000..c9fd1bf22a12 --- /dev/null +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoProcessor +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} + +case class MergeIntoExec(mergeIntoProcessor: MergeIntoProcessor, + @transient targetRelation: DataSourceV2Relation, + override val child: SparkPlan) extends UnaryExecNode { + protected override def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { + mergeIntoProcessor.processPartition + } + } + + override def output: Seq[Attribute] = targetRelation.output +} From d82afba8da1da306791e82969c12fb6da0b2f0de Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 24 Dec 2020 10:16:13 -0800 Subject: [PATCH 2/9] Rebase + Scalastyle + cleancompile --- .../spark/sql/catalyst/optimizer/RewriteMergeInto.scala | 4 ++-- .../datasources/v2/ExtendedDataSourceV2Strategy.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala index c91c8c0d5949..2f863eb00360 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala @@ -95,9 +95,9 @@ object RewriteMergeInto extends Rule[LogicalPlan] val scanRelation = DataSourceV2ScanRelation(table, scan, output) scan match { - case _: SupportsFileFilter => + case filterable: SupportsFileFilter => val matchingFilePlan = buildFileFilterPlan(prunedTargetPlan) - val dynamicFileFilter = DynamicFileFilter(scanRelation, matchingFilePlan) + val dynamicFileFilter = DynamicFileFilter(scanRelation, matchingFilePlan, filterable) dynamicFileFilter case _ => scanRelation diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala index 6dc9b56aae94..1e8e90b99f61 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala @@ -35,8 +35,8 @@ import org.apache.spark.sql.catalyst.plans.logical.Call import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField import org.apache.spark.sql.catalyst.plans.logical.DynamicFileFilter import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.ReplaceData import org.apache.spark.sql.catalyst.plans.logical.MergeInto +import org.apache.spark.sql.catalyst.plans.logical.ReplaceData import org.apache.spark.sql.catalyst.plans.logical.SetWriteOrder import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableCatalog From 35f68137c102c9609754c83b8a11a462fd5a8e1f Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 4 Jan 2021 22:48:22 -0800 Subject: [PATCH 3/9] Code review + base infrastructure --- .../catalyst/optimizer/RewriteMergeInto.scala | 119 ++------ .../catalyst/plans/logical/MergeInto.scala | 82 +----- .../spark/sql/catalyst/utils/PlanHelper.scala | 87 ++++++ .../datasources/v2/MergeIntoExec.scala | 83 +++++- .../spark/extensions/TestMergeIntoTable.java | 267 ++++++++++++++++++ 5 files changed, 471 insertions(+), 167 deletions(-) create mode 100644 spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanHelper.scala create mode 100644 spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala index 2f863eb00360..4a89557ef41f 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala @@ -19,31 +19,19 @@ package org.apache.spark.sql.catalyst.optimizer -import java.util.UUID import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.iceberg.read.SupportsFileFilter -import org.apache.spark.sql.connector.iceberg.write.MergeBuilder -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.sql.catalyst.utils.PlanHelper +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.CaseInsensitiveStringMap -object RewriteMergeInto extends Rule[LogicalPlan] - with PredicateHelper - with Logging { - val ROW_ID_COL = "_row_id_" - val FILE_NAME_COL = "_file_name_" - val SOURCE_ROW_PRESENT_COL = "_source_row_present_" - val TARGET_ROW_PRESENT_COL = "_target_row_present_" +object RewriteMergeInto extends Rule[LogicalPlan] with PlanHelper with Logging { + val ROW_FROM_SOURCE = "_row_from_source_" + val ROW_FROM_TARGET = "_row_from_target_" import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Implicits._ @@ -52,80 +40,43 @@ object RewriteMergeInto extends Rule[LogicalPlan] // rewrite all operations that require reading the table to delete records case MergeIntoTable(target: DataSourceV2Relation, source: LogicalPlan, cond, actions, notActions) => - // Find the files in target that matches the JOIN condition from source. val targetOutputCols = target.output val newProjectCols = target.output ++ Seq(Alias(InputFileName(), FILE_NAME_COL)()) val newTargetTable = Project(newProjectCols, target) - val prunedTargetPlan = Join(source, newTargetTable, Inner, Some(cond), JoinHint.NONE) + // Construct the plan to prune target based on join condition between source and + // target. + val prunedTargetPlan = Join(source, newTargetTable, Inner, Some(cond), JoinHint.NONE) val writeInfo = newWriteInfo(target.schema) val mergeBuilder = target.table.asMergeable.newMergeBuilder("delete", writeInfo) val targetTableScan = buildScanPlan(target.table, target.output, mergeBuilder, prunedTargetPlan) - val sourceTableProj = source.output ++ Seq(Alias(lit(true).expr, SOURCE_ROW_PRESENT_COL)()) - val targetTableProj = target.output ++ Seq(Alias(lit(true).expr, TARGET_ROW_PRESENT_COL)()) + + // Construct an outer join to help track changes in source and target. + // TODO : Optimize this to use LEFT ANTI or RIGHT OUTER when applicable. + val sourceTableProj = source.output ++ Seq(Alias(lit(true).expr, ROW_FROM_SOURCE)()) + val targetTableProj = target.output ++ Seq(Alias(lit(true).expr, ROW_FROM_TARGET)()) val newTargetTableScan = Project(targetTableProj, targetTableScan) val newSourceTableScan = Project(sourceTableProj, source) val joinPlan = Join(newSourceTableScan, newTargetTableScan, FullOuter, Some(cond), JoinHint.NONE) - val mergeIntoProcessor = new MergeIntoProcessor( - isSourceRowNotPresent = resolveExprs(Seq(col(SOURCE_ROW_PRESENT_COL).isNull.expr), joinPlan).head, - isTargetRowNotPresent = resolveExprs(Seq(col(TARGET_ROW_PRESENT_COL).isNull.expr), joinPlan).head, - matchedConditions = actions.map(resolveClauseCondition(_, joinPlan)), - matchedOutputs = actions.map(actionOutput(_, targetOutputCols, joinPlan)), - notMatchedConditions = notActions.map(resolveClauseCondition(_, joinPlan)), - notMatchedOutputs = notActions.map(actionOutput(_, targetOutputCols, joinPlan)), - targetOutput = resolveExprs(targetOutputCols :+ Literal(false), joinPlan), + // Construct the plan to replace the data based on the output of `MergeInto` + val mergeParams = MergeIntoParams( + isSourceRowNotPresent = IsNull(findOutputAttr(joinPlan, ROW_FROM_SOURCE)), + isTargetRowNotPresent = IsNull(findOutputAttr(joinPlan, ROW_FROM_TARGET)), + matchedConditions = actions.map(getClauseCondition), + matchedOutputs = actions.map(actionOutput(_, targetOutputCols)), + notMatchedConditions = notActions.map(getClauseCondition), + notMatchedOutputs = notActions.map(actionOutput(_, targetOutputCols)), + targetOutput = targetOutputCols :+ Literal(false), + deleteOutput = targetOutputCols :+ Literal(true), joinedAttributes = joinPlan.output ) - - val mergePlan = MergeInto(mergeIntoProcessor, target, joinPlan) + val mergePlan = MergeInto(mergeParams, target, joinPlan) val batchWrite = mergeBuilder.asWriteBuilder.buildForBatch() ReplaceData(target, batchWrite, mergePlan) } } - private def buildScanPlan( - table: Table, - output: Seq[AttributeReference], - mergeBuilder: MergeBuilder, - prunedTargetPlan: LogicalPlan): LogicalPlan = { - - val scanBuilder = mergeBuilder.asScanBuilder - val scan = scanBuilder.build() - val scanRelation = DataSourceV2ScanRelation(table, scan, output) - - scan match { - case filterable: SupportsFileFilter => - val matchingFilePlan = buildFileFilterPlan(prunedTargetPlan) - val dynamicFileFilter = DynamicFileFilter(scanRelation, matchingFilePlan, filterable) - dynamicFileFilter - case _ => - scanRelation - } - } - - private def newWriteInfo(schema: StructType): LogicalWriteInfo = { - val uuid = UUID.randomUUID() - LogicalWriteInfoImpl(queryId = uuid.toString, schema, CaseInsensitiveStringMap.empty) - } - - private def buildFileFilterPlan(prunedTargetPlan: LogicalPlan): LogicalPlan = { - val fileAttr = findOutputAttr(prunedTargetPlan, FILE_NAME_COL) - Aggregate(Seq(fileAttr), Seq(fileAttr), prunedTargetPlan) - } - - private def findOutputAttr(plan: LogicalPlan, attrName: String): Attribute = { - val resolver = SQLConf.get.resolver - plan.output.find(attr => resolver(attr.name, attrName)).getOrElse { - throw new AnalysisException(s"Cannot find $attrName in ${plan.output}") - } - } - - private def resolveExprs(exprs: Seq[Expression], plan: LogicalPlan): Seq[Expression] = { - val spark = SparkSession.active - exprs.map { expr => resolveExpressionInternal(spark, expr, plan) } - } - def getTargetOutputCols(target: DataSourceV2Relation): Seq[NamedExpression] = { target.schema.map { col => target.output.find(attr => SQLConf.get.resolver(attr.name, col.name)).getOrElse { @@ -134,10 +85,8 @@ object RewriteMergeInto extends Rule[LogicalPlan] } } - def actionOutput(clause: MergeAction, - targetOutputCols: Seq[Expression], - plan: LogicalPlan): Seq[Expression] = { - val exprs = clause match { + def actionOutput(clause: MergeAction, targetOutputCols: Seq[Expression]): Seq[Expression] = { + clause match { case u: UpdateAction => u.assignments.map(_.value) :+ Literal(false) case _: DeleteAction => @@ -145,20 +94,10 @@ object RewriteMergeInto extends Rule[LogicalPlan] case i: InsertAction => i.assignments.map(_.value) :+ Literal(false) } - resolveExprs(exprs, plan) } - def resolveClauseCondition(clause: MergeAction, plan: LogicalPlan): Expression = { - val condExpr = clause.condition.getOrElse(Literal(true)) - resolveExprs(Seq(condExpr), plan).head - } - - def resolveExpressionInternal(spark: SparkSession, expr: Expression, plan: LogicalPlan): Expression = { - val dummyPlan = Filter(expr, plan) - spark.sessionState.analyzer.execute(dummyPlan) match { - case Filter(resolvedExpr, _) => resolvedExpr - case _ => throw new AnalysisException(s"Could not resolve expression $expr", plan = Option(plan)) - } + def getClauseCondition(clause: MergeAction): Expression = { + clause.condition.getOrElse(Literal(true)) } } diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala index 810a72e3a6b7..a3a0ac68c0df 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala @@ -19,83 +19,21 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.functions.col -case class MergeInto(mergeIntoProcessor: MergeIntoProcessor, +case class MergeInto(mergeIntoProcessor: MergeIntoParams, targetRelation: DataSourceV2Relation, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = targetRelation.output } -class MergeIntoProcessor(isSourceRowNotPresent: Expression, - isTargetRowNotPresent: Expression, - matchedConditions: Seq[Expression], - matchedOutputs: Seq[Seq[Expression]], - notMatchedConditions: Seq[Expression], - notMatchedOutputs: Seq[Seq[Expression]], - targetOutput: Seq[Expression], - joinedAttributes: Seq[Attribute]) extends Serializable { - - private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = { - UnsafeProjection.create(exprs, joinedAttributes) - } - - private def generatePredicate(expr: Expression): BasePredicate = { - GeneratePredicate.generate(expr, joinedAttributes) - } - - def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = { - val isSourceRowNotPresentPred = generatePredicate(isSourceRowNotPresent) - val isTargetRowNotPresentPred = generatePredicate(isTargetRowNotPresent) - val matchedPreds = matchedConditions.map(generatePredicate) - val matchedProjs = matchedOutputs.map(generateProjection) - val notMatchedPreds = notMatchedConditions.map(generatePredicate) - val notMatchedProjs = notMatchedOutputs.map(generateProjection) - val projectTargetCols = generateProjection(targetOutput) - - def shouldDeleteRow(row: InternalRow): Boolean = - row.getBoolean(targetOutput.size - 1) - - def applyProjection(predicates: Seq[BasePredicate], - projections: Seq[UnsafeProjection], - inputRow: InternalRow): InternalRow = { - // Find the first combination where the predicate evaluates to true - val pair = (predicates zip projections).find { - case (predicate, _) => predicate.eval(inputRow) - } - - // Now apply the appropriate projection to either : - // - Insert a row into target - // - Update a row of target - // - Delete a row in target. The projected row will have the deleted bit set. - pair match { - case Some((_, projection)) => - projection.apply(inputRow) - case None => - projectTargetCols.apply(inputRow) - } - } - - def processRow(inputRow: InternalRow): InternalRow = { - isSourceRowNotPresentPred.eval(inputRow) match { - case true => projectTargetCols.apply(inputRow) - case false => - if (isTargetRowNotPresentPred.eval(inputRow)) { - applyProjection(notMatchedPreds, notMatchedProjs, inputRow) - } else { - applyProjection(matchedPreds, matchedProjs, inputRow) - } - } - } - - rowIterator - .map(processRow) - .filter(!shouldDeleteRow(_)) - } -} +case class MergeIntoParams(isSourceRowNotPresent: Expression, + isTargetRowNotPresent: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Expression]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Expression]], + targetOutput: Seq[Expression], + deleteOutput: Seq[Expression], + joinedAttributes: Seq[Attribute]) extends Serializable diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanHelper.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanHelper.scala new file mode 100644 index 000000000000..5724b2700b44 --- /dev/null +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanHelper.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.catalyst.utils + +import java.util.UUID +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, PredicateHelper} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DynamicFileFilter, LogicalPlan} +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.iceberg.read.SupportsFileFilter +import org.apache.spark.sql.connector.iceberg.write.MergeBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +trait PlanHelper extends PredicateHelper { + val FILE_NAME_COL = "_file" + val ROW_POS_COL = "_pos" + + def buildScanPlan(table: Table, + output: Seq[AttributeReference], + mergeBuilder: MergeBuilder, + prunedTargetPlan: LogicalPlan): LogicalPlan = { + + val scanBuilder = mergeBuilder.asScanBuilder + val scan = scanBuilder.build() + val scanRelation = DataSourceV2ScanRelation(table, scan, toOutputAttrs(scan.readSchema(), output)) + + scan match { + case filterable: SupportsFileFilter => + val matchingFilePlan = buildFileFilterPlan(prunedTargetPlan) + val dynamicFileFilter = DynamicFileFilter(scanRelation, matchingFilePlan, filterable) + dynamicFileFilter + case _ => + scanRelation + } + } + + private def buildFileFilterPlan(prunedTargetPlan: LogicalPlan): LogicalPlan = { + val fileAttr = findOutputAttr(prunedTargetPlan, FILE_NAME_COL) + Aggregate(Seq(fileAttr), Seq(fileAttr), prunedTargetPlan) + } + + def findOutputAttr(plan: LogicalPlan, attrName: String): Attribute = { + val resolver = SQLConf.get.resolver + plan.output.find(attr => resolver(attr.name, attrName)).getOrElse { + throw new AnalysisException(s"Cannot find $attrName in ${plan.output}") + } + } + + def newWriteInfo(schema: StructType): LogicalWriteInfo = { + val uuid = UUID.randomUUID() + LogicalWriteInfoImpl(queryId = uuid.toString, schema, CaseInsensitiveStringMap.empty) + } + + private def toOutputAttrs(schema: StructType, output: Seq[AttributeReference]): Seq[AttributeReference] = { + val nameToAttr = output.map(_.name).zip(output).toMap + schema.toAttributes.map { + a => nameToAttr.get(a.name) match { + case Some(ref) => + // keep the attribute id if it was present in the relation + a.withExprId(ref.exprId) + case _ => + // if the field is new, create a new attribute + AttributeReference(a.name, a.dataType, a.nullable, a.metadata)() + } + } + } +} diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala index c9fd1bf22a12..fad0f754bd1a 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala @@ -21,18 +21,91 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.MergeIntoProcessor +import org.apache.spark.sql.catalyst.expressions.{Attribute, BasePredicate, Expression, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoParams import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -case class MergeIntoExec(mergeIntoProcessor: MergeIntoProcessor, +case class MergeIntoExec(mergeIntoProcessor: MergeIntoParams, @transient targetRelation: DataSourceV2Relation, override val child: SparkPlan) extends UnaryExecNode { + + override def output: Seq[Attribute] = targetRelation.output + protected override def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { - mergeIntoProcessor.processPartition + processPartition(mergeIntoProcessor, _) } } - override def output: Seq[Attribute] = targetRelation.output + private def generateProjection(exprs: Seq[Expression], attrs: Seq[Attribute]): UnsafeProjection = { + UnsafeProjection.create(exprs, attrs) + } + + private def generatePredicate(expr: Expression, attrs: Seq[Attribute]): BasePredicate = { + GeneratePredicate.generate(expr, attrs) + } + + def applyProjection(predicates: Seq[BasePredicate], + projections: Seq[UnsafeProjection], + projectTargetCols: UnsafeProjection, + projectDeleteRow: UnsafeProjection, + inputRow: InternalRow, + targetRowNotPresent: Boolean): InternalRow = { + // Find the first combination where the predicate evaluates to true + val pair = (predicates zip projections).find { + case (predicate, _) => predicate.eval(inputRow) + } + + // Now apply the appropriate projection to either : + // - Insert a row into target + // - Update a row of target + // - Delete a row in target. The projected row will have the deleted bit set. + pair match { + case Some((_, projection)) => + projection.apply(inputRow) + case None => + if (targetRowNotPresent) { + projectDeleteRow.apply(inputRow) + } else { + projectTargetCols.apply(inputRow) + } + } + } + + def processPartition(params: MergeIntoParams, + rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = { + val joinedAttrs = params.joinedAttributes + val isSourceRowNotPresentPred = generatePredicate(params.isSourceRowNotPresent, joinedAttrs) + val isTargetRowNotPresentPred = generatePredicate(params.isTargetRowNotPresent, joinedAttrs) + val matchedPreds = params.matchedConditions.map(generatePredicate(_, joinedAttrs)) + val matchedProjs = params.matchedOutputs.map(generateProjection(_, joinedAttrs)) + val notMatchedPreds = params.notMatchedConditions.map(generatePredicate(_, joinedAttrs)) + val notMatchedProjs = params.notMatchedOutputs.map(generateProjection(_, joinedAttrs)) + val projectTargetCols = generateProjection(params.targetOutput, joinedAttrs) + val projectDeletedRow = generateProjection(params.deleteOutput, joinedAttrs) + + def shouldDeleteRow(row: InternalRow): Boolean = + row.getBoolean(params.targetOutput.size - 1) + + + def processRow(inputRow: InternalRow): InternalRow = { + isSourceRowNotPresentPred.eval(inputRow) match { + case true => + projectTargetCols.apply(inputRow) + case false => + if (isTargetRowNotPresentPred.eval(inputRow)) { + applyProjection(notMatchedPreds, notMatchedProjs, projectTargetCols, + projectDeletedRow, inputRow, true) + } else { + applyProjection(matchedPreds, matchedProjs, projectTargetCols, + projectDeletedRow,inputRow, false) + } + } + } + + rowIterator + .map(processRow) + .filter(!shouldDeleteRow(_)) + } } diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java new file mode 100644 index 000000000000..4780fb4aef3e --- /dev/null +++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.extensions; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; + +public class TestMergeIntoTable extends SparkRowLevelOperationsTestBase { + private final String sourceName; + private final String targetName; + + public TestMergeIntoTable(String catalogName, String implementation, Map config, + String fileFormat, Boolean vectorized) { + super(catalogName, implementation, config, fileFormat, vectorized); + this.sourceName = tableName("source"); + this.targetName = tableName("target"); + } + + @BeforeClass + public static void setupSparkConf() { + spark.conf().set("spark.sql.shuffle.partitions", "4"); + } + + protected Map extraTableProperties() { + return ImmutableMap.of(TableProperties.DELETE_MODE, "copy-on-write"); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", targetName); + sql("DROP TABLE IF EXISTS %s", sourceName); + } + + @Test + public void testEmptyTargetInsertAllNonMatchingRows() throws NoSuchTableException { + createAndInitUnPartitionedTargetTable(targetName); + createAndInitSourceTable(sourceName); + append(sourceName, new Employee(1, "emp-id-1"), new Employee(2, "emp-id-2"), new Employee(3, "emp-id-3")); + String sqlText = "MERGE INTO " + targetName + " AS target \n" + + "USING " + sourceName + " AS source \n" + + "ON target.id = source.id \n" + + "WHEN NOT MATCHED THEN INSERT * "; + + sql(sqlText, ""); + sql("SELECT * FROM %s ORDER BY id, dep", targetName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2"), row(3, "emp-id-3")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testEmptyTargetInsertOnlyMatchingRows() throws NoSuchTableException { + createAndInitUnPartitionedTargetTable(targetName); + createAndInitSourceTable(sourceName); + append(sourceName, new Employee(1, "emp-id-1"), new Employee(2, "emp-id-2"), new Employee(3, "emp-id-3")); + String sqlText = "MERGE INTO " + targetName + " AS target \n" + + "USING " + sourceName + " AS source \n" + + "ON target.id = source.id \n" + + "WHEN NOT MATCHED AND (source.id >= 2) THEN INSERT * "; + + sql(sqlText, ""); + List res = sql("SELECT * FROM %s ORDER BY id, dep", targetName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "emp-id-2"), row(3, "emp-id-3")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testOnlyUpdate() throws NoSuchTableException { + createAndInitUnPartitionedTargetTable(targetName); + createAndInitSourceTable(sourceName); + append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); + append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); + String sqlText = "MERGE INTO " + targetName + " AS target \n" + + "USING " + sourceName + " AS source \n" + + "ON target.id = source.id \n" + + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * "; + + sql(sqlText, ""); + List res = sql("SELECT * FROM %s ORDER BY id, dep", targetName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-1"), row(6, "emp-id-6")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testOnlyDelete() throws NoSuchTableException { + createAndInitUnPartitionedTargetTable(targetName); + createAndInitSourceTable(sourceName); + append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); + append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); + String sqlText = "MERGE INTO " + targetName + " AS target \n" + + "USING " + sourceName + " AS source \n" + + "ON target.id = source.id \n" + + "WHEN MATCHED AND target.id = 6 THEN DELETE"; + + sql(sqlText, ""); + List res = sql("SELECT * FROM %s ORDER BY id, dep", targetName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-one")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testAllCauses() throws NoSuchTableException { + createAndInitUnPartitionedTargetTable(targetName); + createAndInitSourceTable(sourceName); + append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); + append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); + String sqlText = "MERGE INTO " + targetName + " AS target \n" + + "USING " + sourceName + " AS source \n" + + "ON target.id = source.id \n" + + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" + + "WHEN MATCHED AND target.id = 6 THEN DELETE \n" + + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * "; + + sql(sqlText, ""); + sql("SELECT * FROM %s ORDER BY id, dep", targetName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testAllCausesWithExplicitColumnSpecification() throws NoSuchTableException { + createAndInitUnPartitionedTargetTable(targetName); + createAndInitSourceTable(sourceName); + append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); + append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); + String sqlText = "MERGE INTO " + targetName + " AS target \n" + + "USING " + sourceName + " AS source \n" + + "ON target.id = source.id \n" + + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET target.id = source.id, target.dep = source.dep \n" + + "WHEN MATCHED AND target.id = 6 THEN DELETE \n" + + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT (target.id, target.dep) VALUES (source.id, source.dep) "; + + sql(sqlText, ""); + sql("SELECT * FROM %s ORDER BY id, dep", targetName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testSourceCTE() throws NoSuchTableException { + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhive")); + + createAndInitUnPartitionedTargetTable(targetName); + createAndInitSourceTable(sourceName); + append(targetName, new Employee(2, "emp-id-two"), new Employee(6, "emp-id-6")); + append(sourceName, new Employee(2, "emp-id-3"), new Employee(1, "emp-id-2"), new Employee(5, "emp-id-6")); + String sourceCTE = "WITH cte1 AS (SELECT id + 1 AS id, dep FROM source)"; + String sqlText = sourceCTE + " " + "MERGE INTO " + targetName + " AS target \n" + + "USING cte1" + " AS source \n" + + "ON target.id = source.id \n" + + "WHEN MATCHED AND target.id = 2 THEN UPDATE SET * \n" + + "WHEN MATCHED AND target.id = 6 THEN DELETE \n" + + "WHEN NOT MATCHED AND source.id = 3 THEN INSERT * "; + + sql(sqlText, ""); + sql("SELECT * FROM %s ORDER BY id, dep", targetName); + assertEquals("Should have expected rows", + ImmutableList.of(row(2, "emp-id-2"), row(3, "emp-id-3")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + @Test + public void testSourceFromSetOps() throws NoSuchTableException { + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); + Assume.assumeFalse(catalogName.equalsIgnoreCase("testhive")); + + createAndInitUnPartitionedTargetTable(targetName); + createAndInitSourceTable(sourceName); + append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); + append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); + String derivedSource = " ( SELECT * FROM source WHERE id = 2 \n" + + " UNION ALL \n" + + " SELECT * FROM source WHERE id = 1 OR id = 6)"; + String sqlText = "MERGE INTO " + targetName + " AS target \n" + + "USING " + derivedSource + " AS source \n" + + "ON target.id = source.id \n" + + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" + + "WHEN MATCHED AND target.id = 6 THEN DELETE \n" + + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * "; + + sql(sqlText, ""); + sql("SELECT * FROM %s ORDER BY id, dep", targetName); + assertEquals("Should have expected rows", + ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); + } + + protected void createAndInitPartitionedTargetTable(String tabName) { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY (dep)", tabName); + initTable(tabName); + } + + protected void createAndInitUnPartitionedTargetTable(String tabName) { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg", tabName); + initTable(tabName); + } + + protected void createAndInitSourceTable(String tabName) { + sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY (dep)", tabName); + initTable(tabName); + } + + private void initTable(String tabName) { + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tabName, DEFAULT_FILE_FORMAT, fileFormat); + + switch (fileFormat) { + case "parquet": + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%b')", tabName, PARQUET_VECTORIZATION_ENABLED, vectorized); + break; + case "orc": + Assert.assertTrue(vectorized); + break; + case "avro": + Assert.assertFalse(vectorized); + break; + } + + Map props = extraTableProperties(); + props.forEach((prop, value) -> { + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tabName, prop, value); + }); + } + + protected void append(String tabName, Employee... employees) throws NoSuchTableException { + List input = Arrays.asList(employees); + Dataset inputDF = spark.createDataFrame(input, Employee.class); + inputDF.coalesce(1).writeTo(tabName).append(); + } +} From 9cb2e86f1962fc02b65065253f5ab3c3c18ade09 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sun, 17 Jan 2021 00:13:23 -0800 Subject: [PATCH 4/9] Code review comments (Round-2) --- .../IcebergSparkSessionExtensions.scala | 2 +- .../catalyst/optimizer/RewriteDelete.scala | 4 +- .../catalyst/optimizer/RewriteMergeInto.scala | 82 ++++++++++-------- .../RewriteRowLevelOperationHelper.scala | 5 +- .../v2/ExtendedDataSourceV2Strategy.scala | 1 + .../datasources/v2/MergeIntoExec.scala | 82 ++++++++++++------ .../spark/extensions/TestMergeIntoTable.java | 84 +++++++++---------- 7 files changed, 152 insertions(+), 108 deletions(-) diff --git a/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala index 1884c4c62046..ffa3bafacbf4 100644 --- a/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala +++ b/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala @@ -43,7 +43,7 @@ class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // TODO: PullupCorrelatedPredicates should handle row-level operations extensions.injectOptimizerRule { _ => PullupCorrelatedPredicatesInRowLevelOperations } extensions.injectOptimizerRule { spark => RewriteDelete(spark.sessionState.conf) } - extensions.injectOptimizerRule { _ => RewriteMergeInto } + extensions.injectOptimizerRule { spark => RewriteMergeInto(spark.sessionState.conf) } // planner extensions extensions.injectPlannerStrategy { spark => ExtendedDataSourceV2Strategy(spark) } diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala index e86f21f553bb..de7cd9ad6792 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala @@ -56,13 +56,13 @@ case class RewriteDelete(conf: SQLConf) extends Rule[LogicalPlan] with RewriteRo d // rewrite all operations that require reading the table to delete records - case DeleteFromTable(r: DataSourceV2Relation, Some(cond)) => + case DeleteFromTable(r: DataSourceV2Relation, optionalCond @ Some(cond)) => // TODO: do a switch based on whether we get BatchWrite or DeltaBatchWrite val writeInfo = newWriteInfo(r.schema) val mergeBuilder = r.table.asMergeable.newMergeBuilder("delete", writeInfo) val matchingRowsPlanBuilder = scanRelation => Filter(cond, scanRelation) - val scanPlan = buildScanPlan(r.table, r.output, mergeBuilder, cond, matchingRowsPlanBuilder) + val scanPlan = buildScanPlan(r.table, r.output, mergeBuilder, optionalCond, matchingRowsPlanBuilder) val remainingRowFilter = Not(EqualNullSafe(cond, Literal(true, BooleanType))) val remainingRowsPlan = Filter(remainingRowFilter, scanPlan) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala index 4a89557ef41f..0015e316ab3b 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala @@ -19,42 +19,64 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner} -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.InputFileName +import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.FullOuter +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.DeleteAction +import org.apache.spark.sql.catalyst.plans.logical.InsertAction +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.JoinHint +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeAction +import org.apache.spark.sql.catalyst.plans.logical.MergeInto +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoParams +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoTable +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceData +import org.apache.spark.sql.catalyst.plans.logical.UpdateAction import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.utils.PlanHelper +import org.apache.spark.sql.catalyst.utils.RewriteRowLevelOperationHelper import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BooleanType -object RewriteMergeInto extends Rule[LogicalPlan] with PlanHelper with Logging { +case class RewriteMergeInto(conf: SQLConf) extends Rule[LogicalPlan] with RewriteRowLevelOperationHelper { val ROW_FROM_SOURCE = "_row_from_source_" val ROW_FROM_TARGET = "_row_from_target_" + private val TRUE_LITERAL = Literal(true, BooleanType) + private val FALSE_LITERAL = Literal(false, BooleanType) import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Implicits._ + override def resolver: Resolver = conf.resolver + override def apply(plan: LogicalPlan): LogicalPlan = { plan resolveOperators { - // rewrite all operations that require reading the table to delete records - case MergeIntoTable(target: DataSourceV2Relation, - source: LogicalPlan, cond, actions, notActions) => + case MergeIntoTable(target: DataSourceV2Relation, source: LogicalPlan, cond, matchedActions, notMatchedActions) => val targetOutputCols = target.output val newProjectCols = target.output ++ Seq(Alias(InputFileName(), FILE_NAME_COL)()) val newTargetTable = Project(newProjectCols, target) // Construct the plan to prune target based on join condition between source and // target. - val prunedTargetPlan = Join(source, newTargetTable, Inner, Some(cond), JoinHint.NONE) val writeInfo = newWriteInfo(target.schema) - val mergeBuilder = target.table.asMergeable.newMergeBuilder("delete", writeInfo) - val targetTableScan = buildScanPlan(target.table, target.output, mergeBuilder, prunedTargetPlan) + val mergeBuilder = target.table.asMergeable.newMergeBuilder("merge", writeInfo) + val matchingRowsPlanBuilder = (_: DataSourceV2ScanRelation) => + Join(source, newTargetTable, Inner, Some(cond), JoinHint.NONE) + // TODO - extract the local predicates that references the target from the join condition and + // pass to buildScanPlan to ensure push-down. + val targetTableScan = buildScanPlan(target.table, target.output, mergeBuilder, None, matchingRowsPlanBuilder) // Construct an outer join to help track changes in source and target. // TODO : Optimize this to use LEFT ANTI or RIGHT OUTER when applicable. - val sourceTableProj = source.output ++ Seq(Alias(lit(true).expr, ROW_FROM_SOURCE)()) - val targetTableProj = target.output ++ Seq(Alias(lit(true).expr, ROW_FROM_TARGET)()) + val sourceTableProj = source.output ++ Seq(Alias(TRUE_LITERAL, ROW_FROM_SOURCE)()) + val targetTableProj = target.output ++ Seq(Alias(TRUE_LITERAL, ROW_FROM_TARGET)()) val newTargetTableScan = Project(targetTableProj, targetTableScan) val newSourceTableScan = Project(sourceTableProj, source) val joinPlan = Join(newSourceTableScan, newTargetTableScan, FullOuter, Some(cond), JoinHint.NONE) @@ -63,12 +85,12 @@ object RewriteMergeInto extends Rule[LogicalPlan] with PlanHelper with Logging val mergeParams = MergeIntoParams( isSourceRowNotPresent = IsNull(findOutputAttr(joinPlan, ROW_FROM_SOURCE)), isTargetRowNotPresent = IsNull(findOutputAttr(joinPlan, ROW_FROM_TARGET)), - matchedConditions = actions.map(getClauseCondition), - matchedOutputs = actions.map(actionOutput(_, targetOutputCols)), - notMatchedConditions = notActions.map(getClauseCondition), - notMatchedOutputs = notActions.map(actionOutput(_, targetOutputCols)), - targetOutput = targetOutputCols :+ Literal(false), - deleteOutput = targetOutputCols :+ Literal(true), + matchedConditions = matchedActions.map(getClauseCondition), + matchedOutputs = matchedActions.map(actionOutput(_, targetOutputCols)), + notMatchedConditions = notMatchedActions.map(getClauseCondition), + notMatchedOutputs = notMatchedActions.map(actionOutput(_, targetOutputCols)), + targetOutput = targetOutputCols :+ FALSE_LITERAL, + deleteOutput = targetOutputCols :+ TRUE_LITERAL, joinedAttributes = joinPlan.output ) val mergePlan = MergeInto(mergeParams, target, joinPlan) @@ -77,26 +99,18 @@ object RewriteMergeInto extends Rule[LogicalPlan] with PlanHelper with Logging } } - def getTargetOutputCols(target: DataSourceV2Relation): Seq[NamedExpression] = { - target.schema.map { col => - target.output.find(attr => SQLConf.get.resolver(attr.name, col.name)).getOrElse { - Alias(Literal(null, col.dataType), col.name)() - } - } - } - - def actionOutput(clause: MergeAction, targetOutputCols: Seq[Expression]): Seq[Expression] = { + private def actionOutput(clause: MergeAction, targetOutputCols: Seq[Expression]): Seq[Expression] = { clause match { case u: UpdateAction => - u.assignments.map(_.value) :+ Literal(false) + u.assignments.map(_.value) :+ FALSE_LITERAL case _: DeleteAction => - targetOutputCols :+ Literal(true) + targetOutputCols :+ TRUE_LITERAL case i: InsertAction => - i.assignments.map(_.value) :+ Literal(false) + i.assignments.map(_.value) :+ FALSE_LITERAL } } - def getClauseCondition(clause: MergeAction): Expression = { + private def getClauseCondition(clause: MergeAction): Expression = { clause.condition.getOrElse(Literal(true)) } } diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala index f7ad083be9fe..e026fb12bd97 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala @@ -54,12 +54,12 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging { table: Table, tableAttrs: Seq[AttributeReference], mergeBuilder: MergeBuilder, - cond: Expression, + cond: Option[Expression] = None, matchingRowsPlanBuilder: DataSourceV2ScanRelation => LogicalPlan): LogicalPlan = { val scanBuilder = mergeBuilder.asScanBuilder - pushFilters(scanBuilder, cond, tableAttrs) + cond.map(pushFilters(scanBuilder, _, tableAttrs)) val scan = scanBuilder.build() val outputAttrs = toOutputAttrs(scan.readSchema(), tableAttrs) @@ -103,6 +103,7 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging { } private def buildFileFilterPlan(matchingRowsPlan: LogicalPlan): LogicalPlan = { + // TODO: For merge-into make sure _file is resolved only from target table. val fileAttr = findOutputAttr(matchingRowsPlan, FILE_NAME_COL) val agg = Aggregate(Seq(fileAttr), Seq(fileAttr), matchingRowsPlan) Project(Seq(findOutputAttr(agg, FILE_NAME_COL)), agg) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala index 1e8e90b99f61..3ba876da8846 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala @@ -78,6 +78,7 @@ case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy { case MergeInto(mergeIntoProcessor, targetRelation, child) => MergeIntoExec(mergeIntoProcessor, targetRelation, planLater(child)) :: Nil + case _ => Nil } diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala index fad0f754bd1a..e7da5011759a 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala @@ -21,20 +21,25 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BasePredicate, Expression, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.BasePredicate +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.plans.logical.MergeIntoParams -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode -case class MergeIntoExec(mergeIntoProcessor: MergeIntoParams, - @transient targetRelation: DataSourceV2Relation, - override val child: SparkPlan) extends UnaryExecNode { +case class MergeIntoExec( + mergeIntoParams: MergeIntoParams, + @transient targetRelation: DataSourceV2Relation, + override val child: SparkPlan) extends UnaryExecNode { override def output: Seq[Attribute] = targetRelation.output protected override def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { - processPartition(mergeIntoProcessor, _) + processPartition(mergeIntoParams, _) } } @@ -46,14 +51,27 @@ case class MergeIntoExec(mergeIntoProcessor: MergeIntoParams, GeneratePredicate.generate(expr, attrs) } - def applyProjection(predicates: Seq[BasePredicate], - projections: Seq[UnsafeProjection], - projectTargetCols: UnsafeProjection, - projectDeleteRow: UnsafeProjection, - inputRow: InternalRow, - targetRowNotPresent: Boolean): InternalRow = { - // Find the first combination where the predicate evaluates to true - val pair = (predicates zip projections).find { + def applyProjection( + actions: Seq[(BasePredicate, UnsafeProjection)], + projectTargetCols: UnsafeProjection, + projectDeleteRow: UnsafeProjection, + inputRow: InternalRow, + targetRowNotPresent: Boolean): InternalRow = { + + /** + * Find the first combination where the predicate evaluates to true. + * In case when there are overlapping condition in the MATCHED + * clauses, for the first one that satisfies the predicate, the + * corresponding action is applied. For example: + * + * WHEN MATCHED AND id > 1 AND id < 10 UPDATE * + * WHEN MATCHED AND id = 5 OR id = 21 DELETE + * + * In above case, when id = 5, it applies both that matched predicates. In this + * case the first one we see is applied. + */ + + val pair = actions.find { case (predicate, _) => predicate.eval(inputRow) } @@ -73,8 +91,10 @@ case class MergeIntoExec(mergeIntoProcessor: MergeIntoParams, } } - def processPartition(params: MergeIntoParams, - rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = { + def processPartition( + params: MergeIntoParams, + rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = { + val joinedAttrs = params.joinedAttributes val isSourceRowNotPresentPred = generatePredicate(params.isSourceRowNotPresent, joinedAttrs) val isTargetRowNotPresentPred = generatePredicate(params.isTargetRowNotPresent, joinedAttrs) @@ -84,23 +104,31 @@ case class MergeIntoExec(mergeIntoProcessor: MergeIntoParams, val notMatchedProjs = params.notMatchedOutputs.map(generateProjection(_, joinedAttrs)) val projectTargetCols = generateProjection(params.targetOutput, joinedAttrs) val projectDeletedRow = generateProjection(params.deleteOutput, joinedAttrs) + val nonMatchedPairs = notMatchedPreds zip notMatchedProjs + val matchedPairs = matchedPreds zip matchedProjs def shouldDeleteRow(row: InternalRow): Boolean = row.getBoolean(params.targetOutput.size - 1) + /** + * This method is responsible for processing a input row to emit the resultant row with an + * additional column that indicates whether the row is going to be included in the final + * output of merge or not. + * 1. If there is a target row for which there is no corresponding source row (join condition not met) + * - Only project the target columns with deleted flag set to false. + * 2. If there is a source row for which there is no corresponding target row (join condition not met) + * - Apply the not matched actions (i.e INSERT actions) if non match conditions are met. + * 3. If there is a source row for which there is a corresponding target row (join condition met) + * - Apply the matched actions (i.e DELETE or UPDATE actions) if match conditions are met. + */ def processRow(inputRow: InternalRow): InternalRow = { - isSourceRowNotPresentPred.eval(inputRow) match { - case true => - projectTargetCols.apply(inputRow) - case false => - if (isTargetRowNotPresentPred.eval(inputRow)) { - applyProjection(notMatchedPreds, notMatchedProjs, projectTargetCols, - projectDeletedRow, inputRow, true) - } else { - applyProjection(matchedPreds, matchedProjs, projectTargetCols, - projectDeletedRow,inputRow, false) - } + if (isSourceRowNotPresentPred.eval(inputRow)) { + projectTargetCols.apply(inputRow) + } else if (isTargetRowNotPresentPred.eval(inputRow)) { + applyProjection(nonMatchedPairs, projectTargetCols, projectDeletedRow, inputRow, true) + } else { + applyProjection(matchedPairs, projectTargetCols, projectDeletedRow, inputRow, false) } } diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java index 4780fb4aef3e..538db7209aca 100644 --- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java +++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java @@ -68,12 +68,12 @@ public void testEmptyTargetInsertAllNonMatchingRows() throws NoSuchTableExceptio createAndInitUnPartitionedTargetTable(targetName); createAndInitSourceTable(sourceName); append(sourceName, new Employee(1, "emp-id-1"), new Employee(2, "emp-id-2"), new Employee(3, "emp-id-3")); - String sqlText = "MERGE INTO " + targetName + " AS target \n" + - "USING " + sourceName + " AS source \n" + - "ON target.id = source.id \n" + + String sqlText = "MERGE INTO " + targetName + " AS target " + + "USING " + sourceName + " AS source " + + "ON target.id = source.id " + "WHEN NOT MATCHED THEN INSERT * "; - sql(sqlText, ""); + sql(sqlText); sql("SELECT * FROM %s ORDER BY id, dep", targetName); assertEquals("Should have expected rows", ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2"), row(3, "emp-id-3")), @@ -85,12 +85,12 @@ public void testEmptyTargetInsertOnlyMatchingRows() throws NoSuchTableException createAndInitUnPartitionedTargetTable(targetName); createAndInitSourceTable(sourceName); append(sourceName, new Employee(1, "emp-id-1"), new Employee(2, "emp-id-2"), new Employee(3, "emp-id-3")); - String sqlText = "MERGE INTO " + targetName + " AS target \n" + - "USING " + sourceName + " AS source \n" + - "ON target.id = source.id \n" + + String sqlText = "MERGE INTO " + targetName + " AS target " + + "USING " + sourceName + " AS source " + + "ON target.id = source.id " + "WHEN NOT MATCHED AND (source.id >= 2) THEN INSERT * "; - sql(sqlText, ""); + sql(sqlText); List res = sql("SELECT * FROM %s ORDER BY id, dep", targetName); assertEquals("Should have expected rows", ImmutableList.of(row(2, "emp-id-2"), row(3, "emp-id-3")), @@ -103,12 +103,12 @@ public void testOnlyUpdate() throws NoSuchTableException { createAndInitSourceTable(sourceName); append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); - String sqlText = "MERGE INTO " + targetName + " AS target \n" + - "USING " + sourceName + " AS source \n" + - "ON target.id = source.id \n" + + String sqlText = "MERGE INTO " + targetName + " AS target " + + "USING " + sourceName + " AS source " + + "ON target.id = source.id " + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * "; - sql(sqlText, ""); + sql(sqlText); List res = sql("SELECT * FROM %s ORDER BY id, dep", targetName); assertEquals("Should have expected rows", ImmutableList.of(row(1, "emp-id-1"), row(6, "emp-id-6")), @@ -121,12 +121,12 @@ public void testOnlyDelete() throws NoSuchTableException { createAndInitSourceTable(sourceName); append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); - String sqlText = "MERGE INTO " + targetName + " AS target \n" + - "USING " + sourceName + " AS source \n" + - "ON target.id = source.id \n" + + String sqlText = "MERGE INTO " + targetName + " AS target " + + "USING " + sourceName + " AS source " + + "ON target.id = source.id " + "WHEN MATCHED AND target.id = 6 THEN DELETE"; - sql(sqlText, ""); + sql(sqlText); List res = sql("SELECT * FROM %s ORDER BY id, dep", targetName); assertEquals("Should have expected rows", ImmutableList.of(row(1, "emp-id-one")), @@ -139,14 +139,14 @@ public void testAllCauses() throws NoSuchTableException { createAndInitSourceTable(sourceName); append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); - String sqlText = "MERGE INTO " + targetName + " AS target \n" + - "USING " + sourceName + " AS source \n" + - "ON target.id = source.id \n" + - "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" + - "WHEN MATCHED AND target.id = 6 THEN DELETE \n" + + String sqlText = "MERGE INTO " + targetName + " AS target " + + "USING " + sourceName + " AS source " + + "ON target.id = source.id " + + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * " + + "WHEN MATCHED AND target.id = 6 THEN DELETE " + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * "; - sql(sqlText, ""); + sql(sqlText); sql("SELECT * FROM %s ORDER BY id, dep", targetName); assertEquals("Should have expected rows", ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), @@ -159,14 +159,14 @@ public void testAllCausesWithExplicitColumnSpecification() throws NoSuchTableExc createAndInitSourceTable(sourceName); append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); - String sqlText = "MERGE INTO " + targetName + " AS target \n" + - "USING " + sourceName + " AS source \n" + - "ON target.id = source.id \n" + - "WHEN MATCHED AND target.id = 1 THEN UPDATE SET target.id = source.id, target.dep = source.dep \n" + - "WHEN MATCHED AND target.id = 6 THEN DELETE \n" + + String sqlText = "MERGE INTO " + targetName + " AS target " + + "USING " + sourceName + " AS source " + + "ON target.id = source.id " + + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET target.id = source.id, target.dep = source.dep " + + "WHEN MATCHED AND target.id = 6 THEN DELETE " + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT (target.id, target.dep) VALUES (source.id, source.dep) "; - sql(sqlText, ""); + sql(sqlText); sql("SELECT * FROM %s ORDER BY id, dep", targetName); assertEquals("Should have expected rows", ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), @@ -183,14 +183,14 @@ public void testSourceCTE() throws NoSuchTableException { append(targetName, new Employee(2, "emp-id-two"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-3"), new Employee(1, "emp-id-2"), new Employee(5, "emp-id-6")); String sourceCTE = "WITH cte1 AS (SELECT id + 1 AS id, dep FROM source)"; - String sqlText = sourceCTE + " " + "MERGE INTO " + targetName + " AS target \n" + - "USING cte1" + " AS source \n" + - "ON target.id = source.id \n" + - "WHEN MATCHED AND target.id = 2 THEN UPDATE SET * \n" + - "WHEN MATCHED AND target.id = 6 THEN DELETE \n" + + String sqlText = sourceCTE + " " + "MERGE INTO " + targetName + " AS target " + + "USING cte1" + " AS source " + + "ON target.id = source.id " + + "WHEN MATCHED AND target.id = 2 THEN UPDATE SET * " + + "WHEN MATCHED AND target.id = 6 THEN DELETE " + "WHEN NOT MATCHED AND source.id = 3 THEN INSERT * "; - sql(sqlText, ""); + sql(sqlText); sql("SELECT * FROM %s ORDER BY id, dep", targetName); assertEquals("Should have expected rows", ImmutableList.of(row(2, "emp-id-2"), row(3, "emp-id-3")), @@ -206,17 +206,17 @@ public void testSourceFromSetOps() throws NoSuchTableException { createAndInitSourceTable(sourceName); append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); - String derivedSource = " ( SELECT * FROM source WHERE id = 2 \n" + - " UNION ALL \n" + + String derivedSource = " ( SELECT * FROM source WHERE id = 2 " + + " UNION ALL " + " SELECT * FROM source WHERE id = 1 OR id = 6)"; - String sqlText = "MERGE INTO " + targetName + " AS target \n" + - "USING " + derivedSource + " AS source \n" + - "ON target.id = source.id \n" + - "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * \n" + - "WHEN MATCHED AND target.id = 6 THEN DELETE \n" + + String sqlText = "MERGE INTO " + targetName + " AS target " + + "USING " + derivedSource + " AS source " + + "ON target.id = source.id " + + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * " + + "WHEN MATCHED AND target.id = 6 THEN DELETE " + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * "; - sql(sqlText, ""); + sql(sqlText); sql("SELECT * FROM %s ORDER BY id, dep", targetName); assertEquals("Should have expected rows", ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), From 227a1081dfc460840c0c611a01d6eb5fed9de15f Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 18 Jan 2021 14:58:47 -0800 Subject: [PATCH 5/9] Missed code review comments --- .../catalyst/optimizer/RewriteMergeInto.scala | 4 +- .../catalyst/plans/logical/MergeInto.scala | 29 ++++--- .../spark/sql/catalyst/utils/PlanHelper.scala | 87 ------------------- 3 files changed, 18 insertions(+), 102 deletions(-) delete mode 100644 spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanHelper.scala diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala index 0015e316ab3b..8ffe0a773fd1 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala @@ -47,8 +47,8 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.BooleanType case class RewriteMergeInto(conf: SQLConf) extends Rule[LogicalPlan] with RewriteRowLevelOperationHelper { - val ROW_FROM_SOURCE = "_row_from_source_" - val ROW_FROM_TARGET = "_row_from_target_" + private val ROW_FROM_SOURCE = "_row_from_source_" + private val ROW_FROM_TARGET = "_row_from_target_" private val TRUE_LITERAL = Literal(true, BooleanType) private val FALSE_LITERAL = Literal(false, BooleanType) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala index a3a0ac68c0df..78a0cb57a3ed 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala @@ -19,21 +19,24 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -case class MergeInto(mergeIntoProcessor: MergeIntoParams, - targetRelation: DataSourceV2Relation, - child: LogicalPlan) extends UnaryNode { +case class MergeInto( + mergeIntoProcessor: MergeIntoParams, + targetRelation: DataSourceV2Relation, + child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = targetRelation.output } -case class MergeIntoParams(isSourceRowNotPresent: Expression, - isTargetRowNotPresent: Expression, - matchedConditions: Seq[Expression], - matchedOutputs: Seq[Seq[Expression]], - notMatchedConditions: Seq[Expression], - notMatchedOutputs: Seq[Seq[Expression]], - targetOutput: Seq[Expression], - deleteOutput: Seq[Expression], - joinedAttributes: Seq[Attribute]) extends Serializable +case class MergeIntoParams( + isSourceRowNotPresent: Expression, + isTargetRowNotPresent: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Expression]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Expression]], + targetOutput: Seq[Expression], + deleteOutput: Seq[Expression], + joinedAttributes: Seq[Attribute]) extends Serializable diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanHelper.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanHelper.scala deleted file mode 100644 index 5724b2700b44..000000000000 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/PlanHelper.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.catalyst.utils - -import java.util.UUID -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, PredicateHelper} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DynamicFileFilter, LogicalPlan} -import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.iceberg.read.SupportsFileFilter -import org.apache.spark.sql.connector.iceberg.write.MergeBuilder -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -trait PlanHelper extends PredicateHelper { - val FILE_NAME_COL = "_file" - val ROW_POS_COL = "_pos" - - def buildScanPlan(table: Table, - output: Seq[AttributeReference], - mergeBuilder: MergeBuilder, - prunedTargetPlan: LogicalPlan): LogicalPlan = { - - val scanBuilder = mergeBuilder.asScanBuilder - val scan = scanBuilder.build() - val scanRelation = DataSourceV2ScanRelation(table, scan, toOutputAttrs(scan.readSchema(), output)) - - scan match { - case filterable: SupportsFileFilter => - val matchingFilePlan = buildFileFilterPlan(prunedTargetPlan) - val dynamicFileFilter = DynamicFileFilter(scanRelation, matchingFilePlan, filterable) - dynamicFileFilter - case _ => - scanRelation - } - } - - private def buildFileFilterPlan(prunedTargetPlan: LogicalPlan): LogicalPlan = { - val fileAttr = findOutputAttr(prunedTargetPlan, FILE_NAME_COL) - Aggregate(Seq(fileAttr), Seq(fileAttr), prunedTargetPlan) - } - - def findOutputAttr(plan: LogicalPlan, attrName: String): Attribute = { - val resolver = SQLConf.get.resolver - plan.output.find(attr => resolver(attr.name, attrName)).getOrElse { - throw new AnalysisException(s"Cannot find $attrName in ${plan.output}") - } - } - - def newWriteInfo(schema: StructType): LogicalWriteInfo = { - val uuid = UUID.randomUUID() - LogicalWriteInfoImpl(queryId = uuid.toString, schema, CaseInsensitiveStringMap.empty) - } - - private def toOutputAttrs(schema: StructType, output: Seq[AttributeReference]): Seq[AttributeReference] = { - val nameToAttr = output.map(_.name).zip(output).toMap - schema.toAttributes.map { - a => nameToAttr.get(a.name) match { - case Some(ref) => - // keep the attribute id if it was present in the relation - a.withExprId(ref.exprId) - case _ => - // if the field is new, create a new attribute - AttributeReference(a.name, a.dataType, a.nullable, a.metadata)() - } - } - } -} From 9f264b752cb6a1c81133985d139f3cdb4b42b008 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 18 Jan 2021 15:03:17 -0800 Subject: [PATCH 6/9] More review --- .../apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala index 8ffe0a773fd1..8041bc5cf5b0 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala @@ -111,7 +111,7 @@ case class RewriteMergeInto(conf: SQLConf) extends Rule[LogicalPlan] with Rewrit } private def getClauseCondition(clause: MergeAction): Expression = { - clause.condition.getOrElse(Literal(true)) + clause.condition.getOrElse(TRUE_LITERAL) } } From 6f293356f830d1fe05056bfd820cc470b1de6110 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 18 Jan 2021 15:38:24 -0800 Subject: [PATCH 7/9] Review - contd. --- .../catalyst/optimizer/RewriteDelete.scala | 2 +- .../catalyst/optimizer/RewriteMergeInto.scala | 23 ++++++---------- .../RewriteRowLevelOperationHelper.scala | 4 +-- .../datasources/v2/MergeIntoExec.scala | 26 +++++++++---------- .../spark/extensions/TestMergeIntoTable.java | 2 +- 5 files changed, 25 insertions(+), 32 deletions(-) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala index de7cd9ad6792..7df282ee1d7e 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala @@ -62,7 +62,7 @@ case class RewriteDelete(conf: SQLConf) extends Rule[LogicalPlan] with RewriteRo val mergeBuilder = r.table.asMergeable.newMergeBuilder("delete", writeInfo) val matchingRowsPlanBuilder = scanRelation => Filter(cond, scanRelation) - val scanPlan = buildScanPlan(r.table, r.output, mergeBuilder, optionalCond, matchingRowsPlanBuilder) + val scanPlan = buildScanPlan(r.table, r.output, mergeBuilder, cond, matchingRowsPlanBuilder) val remainingRowFilter = Not(EqualNullSafe(cond, Literal(true, BooleanType))) val remainingRowsPlan = Filter(remainingRowFilter, scanPlan) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala index 8041bc5cf5b0..0ed0db2e2136 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala @@ -59,19 +59,12 @@ case class RewriteMergeInto(conf: SQLConf) extends Rule[LogicalPlan] with Rewrit override def apply(plan: LogicalPlan): LogicalPlan = { plan resolveOperators { case MergeIntoTable(target: DataSourceV2Relation, source: LogicalPlan, cond, matchedActions, notMatchedActions) => - val targetOutputCols = target.output - val newProjectCols = target.output ++ Seq(Alias(InputFileName(), FILE_NAME_COL)()) - val newTargetTable = Project(newProjectCols, target) - - // Construct the plan to prune target based on join condition between source and - // target. + // Construct the plan to prune target based on join condition between source and target. val writeInfo = newWriteInfo(target.schema) val mergeBuilder = target.table.asMergeable.newMergeBuilder("merge", writeInfo) - val matchingRowsPlanBuilder = (_: DataSourceV2ScanRelation) => - Join(source, newTargetTable, Inner, Some(cond), JoinHint.NONE) - // TODO - extract the local predicates that references the target from the join condition and - // pass to buildScanPlan to ensure push-down. - val targetTableScan = buildScanPlan(target.table, target.output, mergeBuilder, None, matchingRowsPlanBuilder) + val matchingRowsPlanBuilder = (rel: DataSourceV2ScanRelation) => + Join(source, rel, Inner, Some(cond), JoinHint.NONE) + val targetTableScan = buildScanPlan(target.table, target.output, mergeBuilder, cond, matchingRowsPlanBuilder) // Construct an outer join to help track changes in source and target. // TODO : Optimize this to use LEFT ANTI or RIGHT OUTER when applicable. @@ -86,11 +79,11 @@ case class RewriteMergeInto(conf: SQLConf) extends Rule[LogicalPlan] with Rewrit isSourceRowNotPresent = IsNull(findOutputAttr(joinPlan, ROW_FROM_SOURCE)), isTargetRowNotPresent = IsNull(findOutputAttr(joinPlan, ROW_FROM_TARGET)), matchedConditions = matchedActions.map(getClauseCondition), - matchedOutputs = matchedActions.map(actionOutput(_, targetOutputCols)), + matchedOutputs = matchedActions.map(actionOutput(_, target.output)), notMatchedConditions = notMatchedActions.map(getClauseCondition), - notMatchedOutputs = notMatchedActions.map(actionOutput(_, targetOutputCols)), - targetOutput = targetOutputCols :+ FALSE_LITERAL, - deleteOutput = targetOutputCols :+ TRUE_LITERAL, + notMatchedOutputs = notMatchedActions.map(actionOutput(_, target.output)), + targetOutput = target.output :+ FALSE_LITERAL, + deleteOutput = target.output :+ TRUE_LITERAL, joinedAttributes = joinPlan.output ) val mergePlan = MergeInto(mergeParams, target, joinPlan) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala index e026fb12bd97..bd9cbc64265c 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala @@ -54,12 +54,12 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging { table: Table, tableAttrs: Seq[AttributeReference], mergeBuilder: MergeBuilder, - cond: Option[Expression] = None, + cond: Expression, matchingRowsPlanBuilder: DataSourceV2ScanRelation => LogicalPlan): LogicalPlan = { val scanBuilder = mergeBuilder.asScanBuilder - cond.map(pushFilters(scanBuilder, _, tableAttrs)) + pushFilters(scanBuilder, cond, tableAttrs) val scan = scanBuilder.build() val outputAttrs = toOutputAttrs(scan.readSchema(), tableAttrs) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala index e7da5011759a..9077e9ada457 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala @@ -58,18 +58,18 @@ case class MergeIntoExec( inputRow: InternalRow, targetRowNotPresent: Boolean): InternalRow = { - /** - * Find the first combination where the predicate evaluates to true. - * In case when there are overlapping condition in the MATCHED - * clauses, for the first one that satisfies the predicate, the - * corresponding action is applied. For example: - * - * WHEN MATCHED AND id > 1 AND id < 10 UPDATE * - * WHEN MATCHED AND id = 5 OR id = 21 DELETE - * - * In above case, when id = 5, it applies both that matched predicates. In this - * case the first one we see is applied. - */ + + // Find the first combination where the predicate evaluates to true. + // In case when there are overlapping condition in the MATCHED + // clauses, for the first one that satisfies the predicate, the + // corresponding action is applied. For example: + // + // WHEN MATCHED AND id > 1 AND id < 10 UPDATE * + // WHEN MATCHED AND id = 5 OR id = 21 DELETE + // + // In above case, when id = 5, it applies both that matched predicates. In this + // case the first one we see is applied. + // val pair = actions.find { case (predicate, _) => predicate.eval(inputRow) @@ -134,6 +134,6 @@ case class MergeIntoExec( rowIterator .map(processRow) - .filter(!shouldDeleteRow(_)) + .filterNot(shouldDeleteRow) } } diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java index 538db7209aca..deeb41468dd4 100644 --- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java +++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java @@ -54,7 +54,7 @@ public static void setupSparkConf() { } protected Map extraTableProperties() { - return ImmutableMap.of(TableProperties.DELETE_MODE, "copy-on-write"); + return ImmutableMap.of(TableProperties.MERGE_MODE, TableProperties.MERGE_MODE_DEFAULT); } @After From d7caece5de862437092fbbd58702f41fc66d8fdd Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 18 Jan 2021 17:00:46 -0800 Subject: [PATCH 8/9] Code review --- .../catalyst/optimizer/RewriteDelete.scala | 4 +- .../catalyst/optimizer/RewriteMergeInto.scala | 4 +- .../RewriteRowLevelOperationHelper.scala | 15 ++- .../datasources/v2/MergeIntoExec.scala | 1 - .../spark/extensions/TestMergeIntoTable.java | 112 ++++++++++-------- 5 files changed, 71 insertions(+), 65 deletions(-) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala index 7df282ee1d7e..e8ecf65b3750 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala @@ -76,8 +76,8 @@ case class RewriteDelete(conf: SQLConf) extends Rule[LogicalPlan] with RewriteRo remainingRowsPlan: LogicalPlan, output: Seq[AttributeReference]): LogicalPlan = { - val fileNameCol = findOutputAttr(remainingRowsPlan, FILE_NAME_COL) - val rowPosCol = findOutputAttr(remainingRowsPlan, ROW_POS_COL) + val fileNameCol = findOutputAttr(remainingRowsPlan.output, FILE_NAME_COL) + val rowPosCol = findOutputAttr(remainingRowsPlan.output, ROW_POS_COL) val order = Seq(SortOrder(fileNameCol, Ascending), SortOrder(rowPosCol, Ascending)) val numShufflePartitions = SQLConf.get.numShufflePartitions val repartition = RepartitionByExpression(Seq(fileNameCol), remainingRowsPlan, numShufflePartitions) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala index 0ed0db2e2136..0c53dbcd7639 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteMergeInto.scala @@ -76,8 +76,8 @@ case class RewriteMergeInto(conf: SQLConf) extends Rule[LogicalPlan] with Rewrit // Construct the plan to replace the data based on the output of `MergeInto` val mergeParams = MergeIntoParams( - isSourceRowNotPresent = IsNull(findOutputAttr(joinPlan, ROW_FROM_SOURCE)), - isTargetRowNotPresent = IsNull(findOutputAttr(joinPlan, ROW_FROM_TARGET)), + isSourceRowNotPresent = IsNull(findOutputAttr(joinPlan.output, ROW_FROM_SOURCE)), + isTargetRowNotPresent = IsNull(findOutputAttr(joinPlan.output, ROW_FROM_TARGET)), matchedConditions = matchedActions.map(getClauseCondition), matchedOutputs = matchedActions.map(actionOutput(_, target.output)), notMatchedConditions = notMatchedActions.map(getClauseCondition), diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala index bd9cbc64265c..0d79e8b0e498 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala @@ -67,7 +67,7 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging { scan match { case filterable: SupportsFileFilter => - val matchingFilePlan = buildFileFilterPlan(matchingRowsPlanBuilder(scanRelation)) + val matchingFilePlan = buildFileFilterPlan(scanRelation.output, matchingRowsPlanBuilder(scanRelation)) DynamicFileFilter(scanRelation, matchingFilePlan, filterable) case _ => scanRelation @@ -102,16 +102,15 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging { LogicalWriteInfoImpl(queryId = uuid.toString, schema, CaseInsensitiveStringMap.empty) } - private def buildFileFilterPlan(matchingRowsPlan: LogicalPlan): LogicalPlan = { - // TODO: For merge-into make sure _file is resolved only from target table. - val fileAttr = findOutputAttr(matchingRowsPlan, FILE_NAME_COL) + private def buildFileFilterPlan(tableAttrs: Seq[AttributeReference], matchingRowsPlan: LogicalPlan): LogicalPlan = { + val fileAttr = findOutputAttr(tableAttrs, FILE_NAME_COL) val agg = Aggregate(Seq(fileAttr), Seq(fileAttr), matchingRowsPlan) - Project(Seq(findOutputAttr(agg, FILE_NAME_COL)), agg) + Project(Seq(findOutputAttr(agg.output, FILE_NAME_COL)), agg) } - protected def findOutputAttr(plan: LogicalPlan, attrName: String): Attribute = { - plan.output.find(attr => resolver(attr.name, attrName)).getOrElse { - throw new AnalysisException(s"Cannot find $attrName in ${plan.output}") + protected def findOutputAttr(attrs: Seq[Attribute], attrName: String): Attribute = { + attrs.find(attr => resolver(attr.name, attrName)).getOrElse { + throw new AnalysisException(s"Cannot find $attrName in $attrs") } } diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala index 9077e9ada457..dcd1c78776af 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala @@ -110,7 +110,6 @@ case class MergeIntoExec( def shouldDeleteRow(row: InternalRow): Boolean = row.getBoolean(params.targetOutput.size - 1) - /** * This method is responsible for processing a input row to emit the resultant row with an * additional column that indicates whether the row is going to be included in the final diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java index deeb41468dd4..4036d371986c 100644 --- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java +++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java @@ -25,14 +25,18 @@ import org.apache.iceberg.TableProperties; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkSessionCatalog; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.junit.After; import org.junit.Assert; import org.junit.Assume; +import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; +import org.junit.runners.Parameterized; import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; @@ -41,6 +45,32 @@ public class TestMergeIntoTable extends SparkRowLevelOperationsTestBase { private final String sourceName; private final String targetName; + @Parameterized.Parameters( + name = "catalogName = {0}, implementation = {1}, config = {2}, format = {3}, vectorized = {4}") + public static Object[][] parameters() { + return new Object[][] { + { "testhive", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default" + ), + "parquet", + true + }, + { "spark_catalog", SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + "parquet", + false + } + }; + } + public TestMergeIntoTable(String catalogName, String implementation, Map config, String fileFormat, Boolean vectorized) { super(catalogName, implementation, config, fileFormat, vectorized); @@ -57,6 +87,12 @@ protected Map extraTableProperties() { return ImmutableMap.of(TableProperties.MERGE_MODE, TableProperties.MERGE_MODE_DEFAULT); } + @Before + public void createTables() { + createAndInitUnPartitionedTargetTable(targetName); + createAndInitSourceTable(sourceName); + } + @After public void removeTables() { sql("DROP TABLE IF EXISTS %s", targetName); @@ -65,16 +101,13 @@ public void removeTables() { @Test public void testEmptyTargetInsertAllNonMatchingRows() throws NoSuchTableException { - createAndInitUnPartitionedTargetTable(targetName); - createAndInitSourceTable(sourceName); append(sourceName, new Employee(1, "emp-id-1"), new Employee(2, "emp-id-2"), new Employee(3, "emp-id-3")); - String sqlText = "MERGE INTO " + targetName + " AS target " + - "USING " + sourceName + " AS source " + + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + "ON target.id = source.id " + "WHEN NOT MATCHED THEN INSERT * "; - sql(sqlText); - sql("SELECT * FROM %s ORDER BY id, dep", targetName); + sql(sqlText, targetName, sourceName); assertEquals("Should have expected rows", ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2"), row(3, "emp-id-3")), sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); @@ -82,16 +115,13 @@ public void testEmptyTargetInsertAllNonMatchingRows() throws NoSuchTableExceptio @Test public void testEmptyTargetInsertOnlyMatchingRows() throws NoSuchTableException { - createAndInitUnPartitionedTargetTable(targetName); - createAndInitSourceTable(sourceName); append(sourceName, new Employee(1, "emp-id-1"), new Employee(2, "emp-id-2"), new Employee(3, "emp-id-3")); - String sqlText = "MERGE INTO " + targetName + " AS target " + - "USING " + sourceName + " AS source " + + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + "ON target.id = source.id " + "WHEN NOT MATCHED AND (source.id >= 2) THEN INSERT * "; - sql(sqlText); - List res = sql("SELECT * FROM %s ORDER BY id, dep", targetName); + sql(sqlText, targetName, sourceName); assertEquals("Should have expected rows", ImmutableList.of(row(2, "emp-id-2"), row(3, "emp-id-3")), sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); @@ -99,35 +129,29 @@ public void testEmptyTargetInsertOnlyMatchingRows() throws NoSuchTableException @Test public void testOnlyUpdate() throws NoSuchTableException { - createAndInitUnPartitionedTargetTable(targetName); - createAndInitSourceTable(sourceName); - append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); + append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-six")); append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); - String sqlText = "MERGE INTO " + targetName + " AS target " + - "USING " + sourceName + " AS source " + + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + "ON target.id = source.id " + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * "; - sql(sqlText); - List res = sql("SELECT * FROM %s ORDER BY id, dep", targetName); + sql(sqlText, targetName, sourceName); assertEquals("Should have expected rows", - ImmutableList.of(row(1, "emp-id-1"), row(6, "emp-id-6")), + ImmutableList.of(row(1, "emp-id-1"), row(6, "emp-id-six")), sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); } @Test public void testOnlyDelete() throws NoSuchTableException { - createAndInitUnPartitionedTargetTable(targetName); - createAndInitSourceTable(sourceName); append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); - String sqlText = "MERGE INTO " + targetName + " AS target " + - "USING " + sourceName + " AS source " + + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + "ON target.id = source.id " + "WHEN MATCHED AND target.id = 6 THEN DELETE"; - sql(sqlText); - List res = sql("SELECT * FROM %s ORDER BY id, dep", targetName); + sql(sqlText, targetName, sourceName); assertEquals("Should have expected rows", ImmutableList.of(row(1, "emp-id-one")), sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); @@ -135,19 +159,16 @@ public void testOnlyDelete() throws NoSuchTableException { @Test public void testAllCauses() throws NoSuchTableException { - createAndInitUnPartitionedTargetTable(targetName); - createAndInitSourceTable(sourceName); append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); - String sqlText = "MERGE INTO " + targetName + " AS target " + - "USING " + sourceName + " AS source " + + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + "ON target.id = source.id " + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * " + "WHEN MATCHED AND target.id = 6 THEN DELETE " + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * "; - sql(sqlText); - sql("SELECT * FROM %s ORDER BY id, dep", targetName); + sql(sqlText, targetName, sourceName); assertEquals("Should have expected rows", ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); @@ -155,19 +176,16 @@ public void testAllCauses() throws NoSuchTableException { @Test public void testAllCausesWithExplicitColumnSpecification() throws NoSuchTableException { - createAndInitUnPartitionedTargetTable(targetName); - createAndInitSourceTable(sourceName); append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); - String sqlText = "MERGE INTO " + targetName + " AS target " + - "USING " + sourceName + " AS source " + + String sqlText = "MERGE INTO %s AS target " + + "USING %s AS source " + "ON target.id = source.id " + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET target.id = source.id, target.dep = source.dep " + "WHEN MATCHED AND target.id = 6 THEN DELETE " + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT (target.id, target.dep) VALUES (source.id, source.dep) "; - sql(sqlText); - sql("SELECT * FROM %s ORDER BY id, dep", targetName); + sql(sqlText, targetName, sourceName); assertEquals("Should have expected rows", ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); @@ -178,20 +196,17 @@ public void testSourceCTE() throws NoSuchTableException { Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); Assume.assumeFalse(catalogName.equalsIgnoreCase("testhive")); - createAndInitUnPartitionedTargetTable(targetName); - createAndInitSourceTable(sourceName); append(targetName, new Employee(2, "emp-id-two"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-3"), new Employee(1, "emp-id-2"), new Employee(5, "emp-id-6")); String sourceCTE = "WITH cte1 AS (SELECT id + 1 AS id, dep FROM source)"; - String sqlText = sourceCTE + " " + "MERGE INTO " + targetName + " AS target " + + String sqlText = sourceCTE + " " + "MERGE INTO %s AS target " + "USING cte1" + " AS source " + "ON target.id = source.id " + "WHEN MATCHED AND target.id = 2 THEN UPDATE SET * " + "WHEN MATCHED AND target.id = 6 THEN DELETE " + "WHEN NOT MATCHED AND source.id = 3 THEN INSERT * "; - sql(sqlText); - sql("SELECT * FROM %s ORDER BY id, dep", targetName); + sql(sqlText, targetName); assertEquals("Should have expected rows", ImmutableList.of(row(2, "emp-id-2"), row(3, "emp-id-3")), sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); @@ -202,32 +217,25 @@ public void testSourceFromSetOps() throws NoSuchTableException { Assume.assumeFalse(catalogName.equalsIgnoreCase("testhadoop")); Assume.assumeFalse(catalogName.equalsIgnoreCase("testhive")); - createAndInitUnPartitionedTargetTable(targetName); - createAndInitSourceTable(sourceName); append(targetName, new Employee(1, "emp-id-one"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-2"), new Employee(1, "emp-id-1"), new Employee(6, "emp-id-6")); String derivedSource = " ( SELECT * FROM source WHERE id = 2 " + " UNION ALL " + " SELECT * FROM source WHERE id = 1 OR id = 6)"; - String sqlText = "MERGE INTO " + targetName + " AS target " + + String sqlText = "MERGE INTO %s AS target " + "USING " + derivedSource + " AS source " + "ON target.id = source.id " + "WHEN MATCHED AND target.id = 1 THEN UPDATE SET * " + "WHEN MATCHED AND target.id = 6 THEN DELETE " + "WHEN NOT MATCHED AND source.id = 2 THEN INSERT * "; - sql(sqlText); + sql(sqlText, targetName); sql("SELECT * FROM %s ORDER BY id, dep", targetName); assertEquals("Should have expected rows", ImmutableList.of(row(1, "emp-id-1"), row(2, "emp-id-2")), sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", targetName)); } - protected void createAndInitPartitionedTargetTable(String tabName) { - sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg PARTITIONED BY (dep)", tabName); - initTable(tabName); - } - protected void createAndInitUnPartitionedTargetTable(String tabName) { sql("CREATE TABLE %s (id INT, dep STRING) USING iceberg", tabName); initTable(tabName); From 9fadc1d85e60d70f91b180bd5352903916cc13fb Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 18 Jan 2021 18:23:08 -0800 Subject: [PATCH 9/9] More review --- .../spark/sql/catalyst/optimizer/RewriteDelete.scala | 2 +- .../spark/sql/execution/datasources/v2/MergeIntoExec.scala | 7 ++----- .../iceberg/spark/extensions/TestMergeIntoTable.java | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala index e8ecf65b3750..8e9b4a1f541e 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala @@ -56,7 +56,7 @@ case class RewriteDelete(conf: SQLConf) extends Rule[LogicalPlan] with RewriteRo d // rewrite all operations that require reading the table to delete records - case DeleteFromTable(r: DataSourceV2Relation, optionalCond @ Some(cond)) => + case DeleteFromTable(r: DataSourceV2Relation, Some(cond)) => // TODO: do a switch based on whether we get BatchWrite or DeltaBatchWrite val writeInfo = newWriteInfo(r.schema) val mergeBuilder = r.table.asMergeable.newMergeBuilder("delete", writeInfo) diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala index dcd1c78776af..80370ad21c5a 100644 --- a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala +++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeIntoExec.scala @@ -63,13 +63,10 @@ case class MergeIntoExec( // In case when there are overlapping condition in the MATCHED // clauses, for the first one that satisfies the predicate, the // corresponding action is applied. For example: - // - // WHEN MATCHED AND id > 1 AND id < 10 UPDATE * - // WHEN MATCHED AND id = 5 OR id = 21 DELETE - // + // WHEN MATCHED AND id > 1 AND id < 10 UPDATE * + // WHEN MATCHED AND id = 5 OR id = 21 DELETE // In above case, when id = 5, it applies both that matched predicates. In this // case the first one we see is applied. - // val pair = actions.find { case (predicate, _) => predicate.eval(inputRow) diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java index 4036d371986c..bf304e1890cf 100644 --- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java +++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeIntoTable.java @@ -199,7 +199,7 @@ public void testSourceCTE() throws NoSuchTableException { append(targetName, new Employee(2, "emp-id-two"), new Employee(6, "emp-id-6")); append(sourceName, new Employee(2, "emp-id-3"), new Employee(1, "emp-id-2"), new Employee(5, "emp-id-6")); String sourceCTE = "WITH cte1 AS (SELECT id + 1 AS id, dep FROM source)"; - String sqlText = sourceCTE + " " + "MERGE INTO %s AS target " + + String sqlText = sourceCTE + " MERGE INTO %s AS target " + "USING cte1" + " AS source " + "ON target.id = source.id " + "WHEN MATCHED AND target.id = 2 THEN UPDATE SET * " +