Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.iceberg.spark.extensions

import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.analysis.{AlignMergeIntoTable, DeleteFromTablePredicateCheck, ProcedureArgumentCoercion, ResolveProcedures}
import org.apache.spark.sql.catalyst.optimizer.{OptimizeConditionsInRowLevelOperations, PullupCorrelatedPredicatesInRowLevelOperations, RewriteDelete}
import org.apache.spark.sql.catalyst.optimizer.{OptimizeConditionsInRowLevelOperations, PullupCorrelatedPredicatesInRowLevelOperations, RewriteDelete, RewriteMergeInto}
import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser
import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy

Expand All @@ -43,6 +43,7 @@ class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
// TODO: PullupCorrelatedPredicates should handle row-level operations
extensions.injectOptimizerRule { _ => PullupCorrelatedPredicatesInRowLevelOperations }
extensions.injectOptimizerRule { spark => RewriteDelete(spark.sessionState.conf) }
extensions.injectOptimizerRule { spark => RewriteMergeInto(spark.sessionState.conf) }

// planner extensions
extensions.injectPlannerStrategy { spark => ExtendedDataSourceV2Strategy(spark) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ case class RewriteDelete(conf: SQLConf) extends Rule[LogicalPlan] with RewriteRo
remainingRowsPlan: LogicalPlan,
output: Seq[AttributeReference]): LogicalPlan = {

val fileNameCol = findOutputAttr(remainingRowsPlan, FILE_NAME_COL)
val rowPosCol = findOutputAttr(remainingRowsPlan, ROW_POS_COL)
val fileNameCol = findOutputAttr(remainingRowsPlan.output, FILE_NAME_COL)
val rowPosCol = findOutputAttr(remainingRowsPlan.output, ROW_POS_COL)
val order = Seq(SortOrder(fileNameCol, Ascending), SortOrder(rowPosCol, Ascending))
val numShufflePartitions = SQLConf.get.numShufflePartitions
val repartition = RepartitionByExpression(Seq(fileNameCol), remainingRowsPlan, numShufflePartitions)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.InputFileName
import org.apache.spark.sql.catalyst.expressions.IsNull
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.DeleteAction
import org.apache.spark.sql.catalyst.plans.logical.InsertAction
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.plans.logical.JoinHint
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.MergeAction
import org.apache.spark.sql.catalyst.plans.logical.MergeInto
import org.apache.spark.sql.catalyst.plans.logical.MergeIntoParams
import org.apache.spark.sql.catalyst.plans.logical.MergeIntoTable
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.catalyst.plans.logical.ReplaceData
import org.apache.spark.sql.catalyst.plans.logical.UpdateAction
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.utils.RewriteRowLevelOperationHelper
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.BooleanType

case class RewriteMergeInto(conf: SQLConf) extends Rule[LogicalPlan] with RewriteRowLevelOperationHelper {
private val ROW_FROM_SOURCE = "_row_from_source_"
private val ROW_FROM_TARGET = "_row_from_target_"
private val TRUE_LITERAL = Literal(true, BooleanType)
private val FALSE_LITERAL = Literal(false, BooleanType)

import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Implicits._

override def resolver: Resolver = conf.resolver

override def apply(plan: LogicalPlan): LogicalPlan = {
plan resolveOperators {
case MergeIntoTable(target: DataSourceV2Relation, source: LogicalPlan, cond, matchedActions, notMatchedActions) =>
// Construct the plan to prune target based on join condition between source and target.
val writeInfo = newWriteInfo(target.schema)
val mergeBuilder = target.table.asMergeable.newMergeBuilder("merge", writeInfo)
val matchingRowsPlanBuilder = (rel: DataSourceV2ScanRelation) =>
Join(source, rel, Inner, Some(cond), JoinHint.NONE)
val targetTableScan = buildScanPlan(target.table, target.output, mergeBuilder, cond, matchingRowsPlanBuilder)

// Construct an outer join to help track changes in source and target.
// TODO : Optimize this to use LEFT ANTI or RIGHT OUTER when applicable.
val sourceTableProj = source.output ++ Seq(Alias(TRUE_LITERAL, ROW_FROM_SOURCE)())
val targetTableProj = target.output ++ Seq(Alias(TRUE_LITERAL, ROW_FROM_TARGET)())
val newTargetTableScan = Project(targetTableProj, targetTableScan)
val newSourceTableScan = Project(sourceTableProj, source)
val joinPlan = Join(newSourceTableScan, newTargetTableScan, FullOuter, Some(cond), JoinHint.NONE)

// Construct the plan to replace the data based on the output of `MergeInto`
val mergeParams = MergeIntoParams(
isSourceRowNotPresent = IsNull(findOutputAttr(joinPlan.output, ROW_FROM_SOURCE)),
isTargetRowNotPresent = IsNull(findOutputAttr(joinPlan.output, ROW_FROM_TARGET)),
matchedConditions = matchedActions.map(getClauseCondition),
matchedOutputs = matchedActions.map(actionOutput(_, target.output)),
notMatchedConditions = notMatchedActions.map(getClauseCondition),
notMatchedOutputs = notMatchedActions.map(actionOutput(_, target.output)),
targetOutput = target.output :+ FALSE_LITERAL,
deleteOutput = target.output :+ TRUE_LITERAL,
joinedAttributes = joinPlan.output
)
val mergePlan = MergeInto(mergeParams, target, joinPlan)
val batchWrite = mergeBuilder.asWriteBuilder.buildForBatch()
ReplaceData(target, batchWrite, mergePlan)
}
}

private def actionOutput(clause: MergeAction, targetOutputCols: Seq[Expression]): Seq[Expression] = {
clause match {
case u: UpdateAction =>
u.assignments.map(_.value) :+ FALSE_LITERAL
case _: DeleteAction =>
targetOutputCols :+ TRUE_LITERAL
case i: InsertAction =>
i.assignments.map(_.value) :+ FALSE_LITERAL
}
}

private def getClauseCondition(clause: MergeAction): Expression = {
clause.condition.getOrElse(TRUE_LITERAL)
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation

case class MergeInto(
mergeIntoProcessor: MergeIntoParams,
targetRelation: DataSourceV2Relation,
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = targetRelation.output
}

case class MergeIntoParams(
isSourceRowNotPresent: Expression,
isTargetRowNotPresent: Expression,
matchedConditions: Seq[Expression],
matchedOutputs: Seq[Seq[Expression]],
notMatchedConditions: Seq[Expression],
notMatchedOutputs: Seq[Seq[Expression]],
targetOutput: Seq[Expression],
deleteOutput: Seq[Expression],
joinedAttributes: Seq[Attribute]) extends Serializable
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging {

scan match {
case filterable: SupportsFileFilter =>
val matchingFilePlan = buildFileFilterPlan(matchingRowsPlanBuilder(scanRelation))
val matchingFilePlan = buildFileFilterPlan(scanRelation.output, matchingRowsPlanBuilder(scanRelation))
DynamicFileFilter(scanRelation, matchingFilePlan, filterable)
case _ =>
scanRelation
Expand Down Expand Up @@ -102,15 +102,15 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging {
LogicalWriteInfoImpl(queryId = uuid.toString, schema, CaseInsensitiveStringMap.empty)
}

private def buildFileFilterPlan(matchingRowsPlan: LogicalPlan): LogicalPlan = {
val fileAttr = findOutputAttr(matchingRowsPlan, FILE_NAME_COL)
private def buildFileFilterPlan(tableAttrs: Seq[AttributeReference], matchingRowsPlan: LogicalPlan): LogicalPlan = {
val fileAttr = findOutputAttr(tableAttrs, FILE_NAME_COL)
val agg = Aggregate(Seq(fileAttr), Seq(fileAttr), matchingRowsPlan)
Project(Seq(findOutputAttr(agg, FILE_NAME_COL)), agg)
Project(Seq(findOutputAttr(agg.output, FILE_NAME_COL)), agg)
}

protected def findOutputAttr(plan: LogicalPlan, attrName: String): Attribute = {
plan.output.find(attr => resolver(attr.name, attrName)).getOrElse {
throw new AnalysisException(s"Cannot find $attrName in ${plan.output}")
protected def findOutputAttr(attrs: Seq[Attribute], attrName: String): Attribute = {
attrs.find(attr => resolver(attr.name, attrName)).getOrElse {
throw new AnalysisException(s"Cannot find $attrName in $attrs")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Call
import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField
import org.apache.spark.sql.catalyst.plans.logical.DynamicFileFilter
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.MergeInto
import org.apache.spark.sql.catalyst.plans.logical.ReplaceData
import org.apache.spark.sql.catalyst.plans.logical.SetWriteOrder
import org.apache.spark.sql.connector.catalog.Identifier
Expand Down Expand Up @@ -75,6 +76,9 @@ case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy {
case ReplaceData(_, batchWrite, query) =>
ReplaceDataExec(batchWrite, planLater(query)) :: Nil

case MergeInto(mergeIntoProcessor, targetRelation, child) =>
MergeIntoExec(mergeIntoProcessor, targetRelation, planLater(child)) :: Nil
Comment thread
rdblue marked this conversation as resolved.

case _ => Nil
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.BasePredicate
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
import org.apache.spark.sql.catalyst.plans.logical.MergeIntoParams
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.UnaryExecNode

case class MergeIntoExec(
mergeIntoParams: MergeIntoParams,
@transient targetRelation: DataSourceV2Relation,
override val child: SparkPlan) extends UnaryExecNode {

override def output: Seq[Attribute] = targetRelation.output

protected override def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions {
processPartition(mergeIntoParams, _)
}
}

private def generateProjection(exprs: Seq[Expression], attrs: Seq[Attribute]): UnsafeProjection = {
UnsafeProjection.create(exprs, attrs)
}

private def generatePredicate(expr: Expression, attrs: Seq[Attribute]): BasePredicate = {
GeneratePredicate.generate(expr, attrs)
}

def applyProjection(
actions: Seq[(BasePredicate, UnsafeProjection)],
projectTargetCols: UnsafeProjection,
projectDeleteRow: UnsafeProjection,
inputRow: InternalRow,
targetRowNotPresent: Boolean): InternalRow = {


// Find the first combination where the predicate evaluates to true.
// In case when there are overlapping condition in the MATCHED
// clauses, for the first one that satisfies the predicate, the
// corresponding action is applied. For example:
// WHEN MATCHED AND id > 1 AND id < 10 UPDATE *
// WHEN MATCHED AND id = 5 OR id = 21 DELETE
// In above case, when id = 5, it applies both that matched predicates. In this
// case the first one we see is applied.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: no need for an empty comment and an empty line.

val pair = actions.find {
case (predicate, _) => predicate.eval(inputRow)
}
Comment thread
rdblue marked this conversation as resolved.

// Now apply the appropriate projection to either :
// - Insert a row into target
// - Update a row of target
// - Delete a row in target. The projected row will have the deleted bit set.
pair match {
case Some((_, projection)) =>
projection.apply(inputRow)
case None =>
if (targetRowNotPresent) {
projectDeleteRow.apply(inputRow)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a bit odd to apply this projection because the target row will be deleted. It seems like we could use the same lazily-initialized row for every delete.

Copy link
Copy Markdown
Contributor Author

@dilipbiswal dilipbiswal Jan 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdblue I had thought about it. But couldn't think of a way to do it. How about, we create a materialized delete row once per partition like this :

val deleteExpr = params.targetOutput.dropRight(1).map(e => Literal.default(e.dataType)) ++ Seq(Literal.create(true, BooleanType))
    val deletedRow1 = UnsafeProjection.create(deleteExpr)
    val deletedRow = deletedRow1.apply(null)

deteExpr will come from rewriteMergeInto just like its passed now. Here we will just create the InternalRow once and use it ? Will that work Ryan ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we just returned null instead?

I think the problem is that this is trying to create one output for each input row, then filtering happens afterward. An extra column is added to signal that the row should be kept or not. But we don't need to copy the row if it is going to be removed. We also don't need to copy incoming target rows just to add a true at the end if they are going to be kept.

So what if we changed all of the delete cases to produce null instead?

Let's not worry about this for now, but I'll open a PR after this is merged to simplify and avoid some of the copies.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdblue OK.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdblue One thing to note is that, the output of the outer join is target cols + source cols. So we have to project out the necessary target columns, i think.

} else {
projectTargetCols.apply(inputRow)
}
}
}

def processPartition(
params: MergeIntoParams,
rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = {

val joinedAttrs = params.joinedAttributes
val isSourceRowNotPresentPred = generatePredicate(params.isSourceRowNotPresent, joinedAttrs)
val isTargetRowNotPresentPred = generatePredicate(params.isTargetRowNotPresent, joinedAttrs)
val matchedPreds = params.matchedConditions.map(generatePredicate(_, joinedAttrs))
val matchedProjs = params.matchedOutputs.map(generateProjection(_, joinedAttrs))
val notMatchedPreds = params.notMatchedConditions.map(generatePredicate(_, joinedAttrs))
val notMatchedProjs = params.notMatchedOutputs.map(generateProjection(_, joinedAttrs))
val projectTargetCols = generateProjection(params.targetOutput, joinedAttrs)
val projectDeletedRow = generateProjection(params.deleteOutput, joinedAttrs)
Copy link
Copy Markdown
Contributor

@rdblue rdblue Jan 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These last two projections are only needed when notMatchedPreds or matchedPreds does not have a default case, i.e. lit(true).

In the rewrite, there is also a function, getClauseCondition, that fills in lit(true) if there is no clause condition. But I don't think that any predicates after the true condition are dropped.

I think we could simplify the logic here and avoid extra clauses by ensuring that both matchedPreds and notMatchedPreds end with lit(true). Then this class would not need to account for the case where no predicate matches and we wouldn't have extra predicates passed through. Last, we wouldn't need the last two projections here or in MergeIntoParams because they would be added to notMatchedProjs or matchedProjs.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess projectTargetCols would still be needed for the case where the source row isn't present, but it would still make this a little simpler. Especially the applyProjection method.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdblue

  1. projectTargetCols represents the expression that needs to be applied on the output of outer join which has columns from both the tables to only project the target output columns plus the deleted flag set to false.
  2. projectDeletedRow does the same but with the "deleted flag". I think in the earlier comment we discussed possible ideas to optimize this (will address in follow-up)
  3. ``matchedPredsandnotMatchedPred``` go hand in hand with their corresponding projections that is specified by the user in the WHEN MATCHED and `WHEN NOT MATCHED` clauses.

Given this background, can you please explain your idea a little bit ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's clear this up in a follow-up.

val nonMatchedPairs = notMatchedPreds zip notMatchedProjs
val matchedPairs = matchedPreds zip matchedProjs

def shouldDeleteRow(row: InternalRow): Boolean =
row.getBoolean(params.targetOutput.size - 1)

Comment thread
rdblue marked this conversation as resolved.
/**
* This method is responsible for processing a input row to emit the resultant row with an
* additional column that indicates whether the row is going to be included in the final
* output of merge or not.
* 1. If there is a target row for which there is no corresponding source row (join condition not met)
* - Only project the target columns with deleted flag set to false.
* 2. If there is a source row for which there is no corresponding target row (join condition not met)
* - Apply the not matched actions (i.e INSERT actions) if non match conditions are met.
* 3. If there is a source row for which there is a corresponding target row (join condition met)
* - Apply the matched actions (i.e DELETE or UPDATE actions) if match conditions are met.
*/
def processRow(inputRow: InternalRow): InternalRow = {
Comment thread
rdblue marked this conversation as resolved.
if (isSourceRowNotPresentPred.eval(inputRow)) {
projectTargetCols.apply(inputRow)
} else if (isTargetRowNotPresentPred.eval(inputRow)) {
applyProjection(nonMatchedPairs, projectTargetCols, projectDeletedRow, inputRow, true)
} else {
applyProjection(matchedPairs, projectTargetCols, projectDeletedRow, inputRow, false)
}
}

rowIterator
.map(processRow)
.filterNot(shouldDeleteRow)
}
}
Loading