From 0755e670f64c0b34f49f8d55366a22f5a0acedac Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Wed, 1 Apr 2026 17:18:35 +0000 Subject: [PATCH 01/20] Add operation metrics for UPDATE queries in DSv2 --- .../sql/connector/write/UpdateSummary.java | 39 ++++++ .../analysis/RewriteUpdateTable.scala | 25 +++- .../UnresolvedIncrementMetric.scala | 112 +++++++++++++++ .../connector/write/UpdateSummaryImpl.scala | 27 ++++ .../spark/sql/execution/QueryExecution.scala | 1 + .../adaptive/InsertAdaptiveSparkPlan.scala | 16 ++- .../datasources/v2/DataSourceV2Strategy.scala | 12 +- .../v2/WriteToDataSourceV2Exec.scala | 50 ++++--- .../execution/metric/IncrementMetric.scala | 130 ++++++++++++++++++ .../metric/ResolveIncrementMetric.scala | 86 ++++++++++++ .../DeltaBasedUpdateTableSuiteBase.scala | 2 + .../sql/connector/UpdateTableSuiteBase.scala | 75 +++++++++- 12 files changed, 545 insertions(+), 30 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/UpdateSummary.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/UpdateSummaryImpl.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/ResolveIncrementMetric.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/UpdateSummary.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/UpdateSummary.java new file mode 100644 index 0000000000000..ef7fc4534811f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/UpdateSummary.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write; + +import org.apache.spark.annotation.Evolving; + +/** + * Provides an informational summary of the UPDATE operation producing write. + * + * @since 4.2.0 + */ +@Evolving +public interface UpdateSummary extends WriteSummary { + + /** + * Returns the number of rows updated. + */ + long numUpdatedRows(); + + /** + * Returns the number of rows copied unmodified. + */ + long numCopiedRows(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index caf7579da889a..c0ca55c524998 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression, UnresolvedIncrementMetricIf, UnresolvedIncrementMetricIfThenReturn} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, Filter, LogicalPlan, Project, ReplaceData, Union, UpdateTable, WriteDelta} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ @@ -71,9 +71,20 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // construct a read relation and include all required metadata columns val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + // add Filter with chained IncrementMetric expressions to the plan + val updatedRowsMetric = UnresolvedIncrementMetricIfThenReturn( + condition = cond, + returnExpr = Literal.TrueLiteral, + metricName = "numUpdatedRows") + val copiedRowsMetric = UnresolvedIncrementMetricIfThenReturn( + condition = Not(EqualNullSafe(cond, Literal.TrueLiteral)), + returnExpr = updatedRowsMetric, + metricName = "numCopiedRows") + val readRelationWithMetrics = Filter(copiedRowsMetric, readRelation) + // build a plan with updated and copied over records val updatedAndRemainingRowsPlan = buildReplaceDataUpdateProjection( - readRelation, assignments, cond) + readRelationWithMetrics, assignments, cond) // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) @@ -100,12 +111,15 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) // build a plan for updated records that match the condition - val matchedRowsPlan = Filter(cond, readRelation) + val condWithMetric = UnresolvedIncrementMetricIf(cond, "numUpdatedRows") + val matchedRowsPlan = Filter(condWithMetric, readRelation) val updatedRowsPlan = buildReplaceDataUpdateProjection(matchedRowsPlan, assignments) // build a plan that contains unmatched rows in matched groups that must be copied over val remainingRowFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral)) - val remainingRowsPlan = Filter(remainingRowFilter, readRelation) + val remainingRowFilterWithMetric = + UnresolvedIncrementMetricIf(remainingRowFilter, "numCopiedRows") + val remainingRowsPlan = Filter(remainingRowFilterWithMetric, readRelation) // the new state is a union of updated and copied over records val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan) @@ -164,7 +178,8 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs, rowIdAttrs) // build a plan for updated records that match the condition - val matchedRowsPlan = Filter(cond, readRelation) + val condWithMetric = UnresolvedIncrementMetricIf(cond, "numUpdatedRows") + val matchedRowsPlan = Filter(condWithMetric, readRelation) val rowDeltaPlan = if (operation.representUpdateAsDeleteAndInsert) { buildDeletesAndInserts(matchedRowsPlan, assignments, rowIdAttrs) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala new file mode 100644 index 0000000000000..b5129e59d9045 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types.DataType + +/** + * Evaluates the boolean [[condition]] and marks a point in the plan where a metric should be + * incremented when the condition is true. Returns the condition's value unchanged. + * + * This is the unresolved form - resolved into IncrementMetricIf by a preparation rule. + * + * Marked as Nondeterministic to prevent the optimizer from pruning or reordering it. + * Cannot mix in [[Unevaluable]] because both [[Unevaluable]] and [[Nondeterministic]] declare + * [[foldable]] as final. + * + * @param condition the boolean expression to evaluate. + * @param metricName the name of the metric to increment. + */ +case class UnresolvedIncrementMetricIf( + condition: Expression, + metricName: String) + extends UnaryExpression with Nondeterministic { + + override def child: Expression = condition + + override def nullable: Boolean = condition.nullable + + override def dataType: DataType = condition.dataType + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override def prettyName: String = "unresolved_increment_metric_if" + + override def toString: String = s"unresolved_increment_metric_if($condition, $metricName)" + + override protected def evalInternal(input: InternalRow): Any = + throw QueryExecutionErrors.cannotEvaluateExpressionError(this) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) + + override protected def withNewChildInternal( + newChild: Expression): UnresolvedIncrementMetricIf = + copy(condition = newChild) +} + +/** + * Evaluates the boolean [[condition]], marks a point in the plan where a metric should + * be incremented when the condition is true, then evaluates and returns [[returnExpr]]. + * + * This is the unresolved form - resolved into IncrementMetricIfThenReturn by a preparation rule. + * + * Marked as Nondeterministic to prevent the optimizer from pruning or reordering it. + * Cannot mix in [[Unevaluable]] because both [[Unevaluable]] and [[Nondeterministic]] declare + * [[foldable]] as final. + * + * @param condition the boolean expression to evaluate. + * @param returnExpr the expression whose value is returned. + * @param metricName the name of the metric to increment. + */ +case class UnresolvedIncrementMetricIfThenReturn( + condition: Expression, + returnExpr: Expression, + metricName: String) + extends Expression with BinaryLike[Expression] with Nondeterministic { + + override def left: Expression = condition + + override def right: Expression = returnExpr + + override def nullable: Boolean = returnExpr.nullable + + override def dataType: DataType = returnExpr.dataType + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override def prettyName: String = "unresolved_increment_metric_if_then_return" + + override def toString: String = + s"unresolved_increment_metric_if_then_return($condition, $returnExpr, $metricName)" + + override protected def evalInternal(input: InternalRow): Any = + throw QueryExecutionErrors.cannotEvaluateExpressionError(this) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): UnresolvedIncrementMetricIfThenReturn = + copy(condition = newLeft, returnExpr = newRight) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/UpdateSummaryImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/UpdateSummaryImpl.scala new file mode 100644 index 0000000000000..fc5cef30f000d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/UpdateSummaryImpl.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write + +/** + * Implementation of [[UpdateSummary]] that provides UPDATE operation summary. + */ +private[sql] case class UpdateSummaryImpl( + numUpdatedRows: Long, + numCopiedRows: Long) + extends UpdateSummary { +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f08b561d6ef9a..036f46ac6409f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -624,6 +624,7 @@ object QueryExecution { subquery: Boolean): Seq[Rule[SparkPlan]] = { // `AdaptiveSparkPlanExec` is a leaf node. If inserted, all the following rules will be no-op // as the original plan is hidden behind `AdaptiveSparkPlanExec`. + Seq(metric.ResolveIncrementMetrics) ++ adaptiveExecutionRule.toSeq ++ Seq( CoalesceBucketsInJoin, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index d8f850ab2189e..dcf85af1a7ca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -30,8 +30,9 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.V1WriteCommand -import org.apache.spark.sql.execution.datasources.v2.V2CommandExec +import org.apache.spark.sql.execution.datasources.v2.{V2CommandExec, V2ExistingTableWriteExec} import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.metric.ResolveIncrementMetric import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperator import org.apache.spark.sql.internal.SQLConf @@ -46,12 +47,22 @@ case class InsertAdaptiveSparkPlan( override def conf: SQLConf = adaptiveExecutionContext.session.sessionState.conf + // Extra preprocessing rules to pass to the next AQE instance (e.g., metric resolution + // rules from a parent V2 write exec). Set before processing children, cleared after. + @transient private var extraPreprocessingRules: Seq[Rule[SparkPlan]] = Nil + override def apply(plan: SparkPlan): SparkPlan = applyInternal(plan, false) private def applyInternal(plan: SparkPlan, isSubquery: Boolean): SparkPlan = plan match { case _ if !conf.adaptiveExecutionEnabled => plan case _: ExecutedCommandExec => plan case _: CommandResultExec => plan + case w: V2ExistingTableWriteExec if w.operationMetrics.nonEmpty => + val saved = extraPreprocessingRules + extraPreprocessingRules = Seq(ResolveIncrementMetric(w.operationMetrics)) + val result = w.withNewChildren(w.children.map(apply)) + extraPreprocessingRules = saved + result case c: V2CommandExec => c.withNewChildren(c.children.map(apply)) case c: DataWritingCommandExec if !c.cmd.isInstanceOf[V1WriteCommand] || !conf.plannedWriteEnabled => @@ -75,8 +86,7 @@ case class InsertAdaptiveSparkPlan( // Fall back to non-AQE mode if AQE is not supported in any of the sub-queries. val subqueryMap = buildSubqueryMap(plan) val planSubqueriesRule = PlanAdaptiveSubqueries(subqueryMap) - val preprocessingRules = Seq( - planSubqueriesRule) + val preprocessingRules = Seq(planSubqueriesRule) ++ extraPreprocessingRules // Run pre-processing rules. val newPlan = AdaptiveSparkPlanExec.applyPhysicalRules(plan, preprocessingRules) logDebug(s"Adaptive execution enabled for plan: $plan") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 91e753096a238..1c29117818788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -358,15 +358,17 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat throw SparkException.internalError("Unexpected table relation: " + other) } - case ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, _, - Some(write)) => + case rd @ ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, + _, Some(write)) => // use the original relation to refresh the cache - ReplaceDataExec(planLater(query), refreshCache(r), projections, write) :: Nil + ReplaceDataExec(planLater(query), refreshCache(r), projections, write, + Some(rd.operation.command())) :: Nil - case WriteDelta(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, + case wd @ WriteDelta(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, Some(write)) => // use the original relation to refresh the cache - WriteDeltaExec(planLater(query), refreshCache(r), projections, write) :: Nil + WriteDeltaExec(planLater(query), refreshCache(r), projections, write, + Some(wd.operation.command())) :: Nil case MergeRows(isSourceRowPresent, isTargetRowPresent, matchedInstructions, notMatchedInstructions, notMatchedBySourceInstructions, checkCardinality, output, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 5a2da729c1b52..d438a287f0822 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSER import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric -import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, Write, WriterCommitMessage, WriteSummary} +import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperation, UpdateSummaryImpl, Write, WriterCommitMessage, WriteSummary} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -311,7 +311,10 @@ case class ReplaceDataExec( query: SparkPlan, refreshCache: () => Unit, projections: ReplaceDataProjections, - write: Write) extends V2ExistingTableWriteExec { + write: Write, + override val rowLevelCommand: Option[RowLevelOperation.Command], + override val operationMetrics: Map[String, SQLMetric] = Map.empty) + extends V2ExistingTableWriteExec { override def writingTask: WritingSparkTask[_] = { projections match { @@ -334,7 +337,10 @@ case class WriteDeltaExec( query: SparkPlan, refreshCache: () => Unit, projections: WriteDeltaProjections, - write: DeltaWrite) extends V2ExistingTableWriteExec { + write: DeltaWrite, + override val rowLevelCommand: Option[RowLevelOperation.Command] = None, + override val operationMetrics: Map[String, SQLMetric] = Map.empty) + extends V2ExistingTableWriteExec { override lazy val writingTask: WritingSparkTask[_] = { if (projections.metadataProjection.isDefined) { @@ -411,6 +417,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSparkPlanHelper { def query: SparkPlan def writingTask: WritingSparkTask[_] = DataWritingSparkTask + def rowLevelCommand: Option[RowLevelOperation.Command] = None var commitProgress: Option[StreamWriterCommitProgress] = None @@ -418,8 +425,9 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSpa override def output: Seq[Attribute] = Nil protected val customMetrics: Map[String, SQLMetric] = Map.empty + val operationMetrics: Map[String, SQLMetric] = Map.empty - override lazy val metrics = customMetrics + override lazy val metrics = customMetrics ++ operationMetrics protected def writeWithV2(batchWrite: BatchWrite): Seq[InternalRow] = { val rdd: RDD[InternalRow] = { @@ -490,18 +498,28 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSpa } private def getWriteSummary(query: SparkPlan): Option[WriteSummary] = { - collectFirst(query) { case m: MergeRowsExec => m }.map { n => - val metrics = n.metrics - MergeSummaryImpl( - metrics.get("numTargetRowsCopied").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsDeleted").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsUpdated").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsInserted").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsMatchedUpdated").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsMatchedDeleted").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsNotMatchedBySourceUpdated").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsNotMatchedBySourceDeleted").map(_.value).getOrElse(-1L) - ) + rowLevelCommand.flatMap { + case RowLevelOperation.Command.MERGE => + collectFirst(query) { case m: MergeRowsExec => m }.map { n => + val metrics = n.metrics + MergeSummaryImpl( + metrics.get("numTargetRowsCopied").map(_.value).getOrElse(-1L), + metrics.get("numTargetRowsDeleted").map(_.value).getOrElse(-1L), + metrics.get("numTargetRowsUpdated").map(_.value).getOrElse(-1L), + metrics.get("numTargetRowsInserted").map(_.value).getOrElse(-1L), + metrics.get("numTargetRowsMatchedUpdated").map(_.value).getOrElse(-1L), + metrics.get("numTargetRowsMatchedDeleted").map(_.value).getOrElse(-1L), + metrics.get("numTargetRowsNotMatchedBySourceUpdated").map(_.value).getOrElse(-1L), + metrics.get("numTargetRowsNotMatchedBySourceDeleted").map(_.value).getOrElse(-1L) + ) + } + case RowLevelOperation.Command.UPDATE => + Some(UpdateSummaryImpl( + operationMetrics.get("numUpdatedRows").map(_.value).get, + operationMetrics.get("numCopiedRows").map(_.value).get + )) + case RowLevelOperation.Command.DELETE => + None } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala new file mode 100644 index 0000000000000..9ceb8c8a6e15d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, Nondeterministic, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.types.DataType + +/** + * Evaluates the boolean [[condition]] and increments an SQLMetric when it is true. + * Returns the condition's value unchanged. + * + * This is the resolved form of + * [[org.apache.spark.sql.catalyst.expressions.UnresolvedIncrementMetricIf]]. + * + * @param condition the boolean expression to evaluate. + * @param metric the SQLMetric accumulator to conditionally increment. + */ +case class IncrementMetricIf(condition: Expression, metric: SQLMetric) + extends UnaryExpression with Nondeterministic { + + override def child: Expression = condition + + override def nullable: Boolean = condition.nullable + + override def dataType: DataType = condition.dataType + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override def prettyName: String = "increment_metric_if" + + override def toString: String = + s"increment_metric_if($condition, ${metric.name.getOrElse("metric")})" + + override protected def evalInternal(input: InternalRow): Any = { + val result = condition.eval(input) + if (result != null && result.asInstanceOf[Boolean]) { + metric.add(1L) + } + result + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val condEval = condition.genCode(ctx) + val metricRef = ctx.addReferenceObj(metric.name.getOrElse("metric"), metric) + condEval.copy(code = condEval.code + code""" + if (!${condEval.isNull} && ${condEval.value}) { + $metricRef.add(1L); + } + """) + } + + override protected def withNewChildInternal(newChild: Expression): IncrementMetricIf = + copy(condition = newChild) +} + +/** + * Evaluates the boolean [[condition]], increments an SQLMetric when it is true, + * then evaluates and returns [[returnExpr]]. + * + * This is the resolved form of + * [[org.apache.spark.sql.catalyst.expressions.UnresolvedIncrementMetricIfThenReturn]]. + * + * @param condition the boolean expression to evaluate. + * @param returnExpr the expression whose value is returned. + * @param metric the SQLMetric accumulator to conditionally increment. + */ +case class IncrementMetricIfThenReturn( + condition: Expression, + returnExpr: Expression, + metric: SQLMetric) + extends Expression with BinaryLike[Expression] with Nondeterministic { + + override def left: Expression = condition + + override def right: Expression = returnExpr + + override def nullable: Boolean = returnExpr.nullable + + override def dataType: DataType = returnExpr.dataType + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override def prettyName: String = "increment_metric_if_then_return" + + override def toString: String = + s"increment_metric_if_then_return($condition, $returnExpr, ${metric.name.getOrElse("metric")})" + + override protected def evalInternal(input: InternalRow): Any = { + val condResult = condition.eval(input) + if (condResult != null && condResult.asInstanceOf[Boolean]) { + metric.add(1L) + } + returnExpr.eval(input) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val condEval = condition.genCode(ctx) + val returnEval = returnExpr.genCode(ctx) + val metricRef = ctx.addReferenceObj(metric.name.getOrElse("metric"), metric) + returnEval.copy(code = condEval.code + code""" + if (!${condEval.isNull} && ${condEval.value}) { + $metricRef.add(1L); + } + """ + returnEval.code) + } + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): IncrementMetricIfThenReturn = + copy(condition = newLeft, returnExpr = newRight) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/ResolveIncrementMetric.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/ResolveIncrementMetric.scala new file mode 100644 index 0000000000000..5385c5ea1d632 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/ResolveIncrementMetric.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.expressions.{UnresolvedIncrementMetricIf, UnresolvedIncrementMetricIfThenReturn} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.write.RowLevelOperation +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.{ReplaceDataExec, V2ExistingTableWriteExec, WriteDeltaExec} + +/** + * Resolves [[UnresolvedIncrementMetricIf]] and [[UnresolvedIncrementMetricIfThenReturn]] + * expressions in a subtree into their resolved counterparts using a provided metrics map. + * + * Used as an AQE preprocessing rule so that resolution survives AQE replanning (which + * re-creates physical plans from logical plans). + * + * @param metricsMap mapping from metric name to SQLMetric accumulator. + */ +case class ResolveIncrementMetric(metricsMap: Map[String, SQLMetric]) + extends Rule[SparkPlan] { + + override def apply(plan: SparkPlan): SparkPlan = { + if (metricsMap.isEmpty) return plan + plan.transformAllExpressions { + case UnresolvedIncrementMetricIf(cond, name) => + IncrementMetricIf(cond, metricsMap(name)) + case UnresolvedIncrementMetricIfThenReturn(cond, ret, name) => + IncrementMetricIfThenReturn(cond, ret, metricsMap(name)) + } + } +} + +/** + * Top-level preparation rule that finds V2 write exec nodes with a `rowLevelCommand`, + * creates operation SQLMetrics, resolves [[UnresolvedIncrementMetricIf]] and + * [[UnresolvedIncrementMetricIfThenReturn]] expressions in the child plan, and stores the + * metrics on the exec node. + */ +object ResolveIncrementMetrics extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case w: V2ExistingTableWriteExec if w.rowLevelCommand.isDefined && w.operationMetrics.isEmpty => + val metricsMap = createOperationMetrics(w.rowLevelCommand.get) + val resolved = ResolveIncrementMetric(metricsMap).apply(w.child) + val withChild = w.withNewChildren(Seq(resolved)) + setOperationMetrics(withChild, metricsMap) + } + + private def createOperationMetrics(cmd: RowLevelOperation.Command): Map[String, SQLMetric] = { + val sc = SparkContext.getOrCreate() + cmd match { + case RowLevelOperation.Command.UPDATE => + Seq( + "numUpdatedRows", + "numCopiedRows" + ).map { name => + name -> SQLMetrics.createMetric(sc, name) + }.toMap + case _ => Map.empty + } + } + + private def setOperationMetrics( + plan: SparkPlan, + metricsMap: Map[String, SQLMetric]): SparkPlan = plan match { + case r: ReplaceDataExec => r.copy(operationMetrics = metricsMap) + case d: WriteDeltaExec => d.copy(operationMetrics = metricsMap) + case other => other + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala index 89b42b5e6db7b..e821fc3f660da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.{AnalysisException, Row} abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { + override protected def deltaUpdate: Boolean = true + test("nullable row ID attrs") { createAndInitTable("pk INT, salary INT, dep STRING", """{ "pk": 1, "salary": 300, "dep": 'hr' } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index ac0bf3bdba9ce..d3a6b61a61b9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.connector import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, TableChange, TableInfo} +import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, InMemoryTable, TableChange, TableInfo} import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue} +import org.apache.spark.sql.connector.write.UpdateSummary import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType} @@ -28,6 +29,24 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { import testImplicits._ + protected def deltaUpdate: Boolean = false + + protected def getUpdateSummary(): UpdateSummary = { + val t = catalog.loadTable(ident).asInstanceOf[InMemoryTable] + t.commits.last.writeSummary.get.asInstanceOf[UpdateSummary] + } + + protected def checkUpdateMetrics( + numUpdatedRows: Long, + numCopiedRows: Long): Unit = { + val summary = getUpdateSummary() + assert(summary.numUpdatedRows() === numUpdatedRows, + s"Expected numUpdatedRows=$numUpdatedRows, got ${summary.numUpdatedRows()}") + val expectedCopied = if (deltaUpdate) 0L else numCopiedRows + assert(summary.numCopiedRows() === expectedCopied, + s"Expected numCopiedRows=$expectedCopied, got ${summary.numCopiedRows()}") + } + test("update table containing added column with default value") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } @@ -63,6 +82,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { Row(3, 300, "hr", "explicit-text"), Row(4, 400, "software", "explicit-text"), Row(5, 500, "hr", null))) + + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 1) } test("update table with expression-based default values") { @@ -134,6 +155,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { sql(s"UPDATE $tableNameAsString SET dep = 'invalid' WHERE salary <= 1") checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) + + checkUpdateMetrics(numUpdatedRows = 0, numCopiedRows = 0) } test("update with basic filters") { @@ -148,6 +171,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 100, "invalid") :: Row(2, 200, "software") :: Row(3, 300, "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 1) } test("update with aliases") { @@ -162,6 +187,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, -1, "hr") :: Row(2, 200, "software") :: Row(3, -1, "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 0) } test("update aligns assignments") { @@ -175,6 +202,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 10, 109, "hr") :: Row(2, 22, 222, "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 1) } test("update non-existing records") { @@ -189,6 +218,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 100, "hr") :: Row(2, 200, "hardware") :: Row(3, null, "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 0, numCopiedRows = 0) } test("update without condition") { @@ -203,6 +234,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, -1, "hr") :: Row(2, -1, "hardware") :: Row(3, -1, "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 3, numCopiedRows = 0) } test("update with NULL conditions on partition columns") { @@ -217,12 +250,14 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 100, null) :: Row(2, 200, "hr") :: Row(3, 300, "hardware") :: Nil) + checkUpdateMetrics(numUpdatedRows = 0, numCopiedRows = 0) // should update one matching row with a null-safe condition sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep <=> NULL") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, -1, null) :: Row(2, 200, "hr") :: Row(3, 300, "hardware") :: Nil) + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 0) } test("update with NULL conditions on data columns") { @@ -237,12 +272,14 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, null, "hr") :: Row(2, 200, "hr") :: Row(3, 300, "hardware") :: Nil) + checkUpdateMetrics(numUpdatedRows = 0, numCopiedRows = 0) // should update one matching row with a null-safe condition sql(s"UPDATE $tableNameAsString SET dep = 'invalid' WHERE salary <=> NULL") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, null, "invalid") :: Row(2, 200, "hr") :: Row(3, 300, "hardware") :: Nil) + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 1) } test("update with IN and NOT IN predicates") { @@ -256,16 +293,19 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, -1, "hr") :: Row(2, 200, "hardware") :: Row(3, null, "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 1) sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE salary NOT IN (null, 1)") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, -1, "hr") :: Row(2, 200, "hardware") :: Row(3, null, "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 0, numCopiedRows = 0) sql(s"UPDATE $tableNameAsString SET salary = 100 WHERE salary NOT IN (1, 10)") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 100, "hr") :: Row(2, 100, "hardware") :: Row(3, null, "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 1) } test("update nested struct fields") { @@ -280,12 +320,14 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, Row(-1, Row(Seq(-1), Map("k" -> "v"))), "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 0) // set primitive, array, map columns to NULL (proper casts should be in inserted) sql(s"UPDATE $tableNameAsString SET s.c1 = NULL, s.c2 = NULL WHERE pk = 1") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, Row(null, null), "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 0) // assign an entire struct sql( @@ -295,6 +337,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, Row(1, Row(Seq(1), null)), "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 0) } test("update fields inside NULL structs") { @@ -306,6 +349,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, Row(-1, null), "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 0) } test("update refreshes relation cache") { @@ -341,6 +386,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { Row(3, 200, "hardware") :: Row(4, 300, "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 1) + // verify the view reflects the changes in the table checkAnswer(sql("SELECT * FROM temp"), Nil) } @@ -358,6 +405,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, -1, Row(300, "v1"), "hr") :: Row(2, 200, Row(200, "v2"), "software") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 0) } test("update with IN subqueries") { @@ -385,6 +434,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "invalid") :: Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 1) sql( s"""UPDATE $tableNameAsString @@ -397,6 +447,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "invalid") :: Row(2, 2, "invalid") :: Row(3, null, "invalid") :: Nil) + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 0) } } @@ -421,6 +472,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "invalid") :: Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 1) } } @@ -447,6 +500,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "hr") :: Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 0, numCopiedRows = 0) sql( s"""UPDATE $tableNameAsString @@ -457,6 +511,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "invalid") :: Row(2, 2, "invalid") :: Row(3, null, "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 1) sql( s"""UPDATE $tableNameAsString @@ -469,6 +524,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "hr") :: Row(2, 2, "hr") :: Row(3, null, "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 0) } } @@ -495,6 +551,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "hr") :: Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 0, numCopiedRows = 0) sql( s"""UPDATE $tableNameAsString t @@ -505,6 +562,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "invalid") :: Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 1) sql( s"""UPDATE $tableNameAsString t @@ -515,6 +573,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "invalid") :: Row(2, 2, "hardware") :: Row(3, null, "invalid") :: Nil) + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 0) sql( s"""UPDATE $tableNameAsString t @@ -527,6 +586,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "invalid") :: Row(2, 2, "hardware") :: Row(3, null, "invalid") :: Nil) + checkUpdateMetrics(numUpdatedRows = 0, numCopiedRows = 0) } } @@ -555,6 +615,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "hr") :: Row(2, 2, "invalid") :: Row(3, null, "hr") :: Nil) + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 0) sql( s"""UPDATE $tableNameAsString t @@ -565,6 +626,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "hr") :: Row(2, 2, "invalid") :: Row(3, null, "invalid") :: Nil) + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 1) sql( s"""UPDATE $tableNameAsString t @@ -577,6 +639,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "invalid") :: Row(2, 2, "invalid") :: Row(3, null, "invalid") :: Nil) + checkUpdateMetrics(numUpdatedRows = 3, numCopiedRows = 0) } } @@ -601,6 +664,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "invalid") :: Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 1) } } @@ -617,6 +682,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT count(*) FROM $tableNameAsString WHERE value < 2.0"), Row(2) :: Nil) + + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 1) } test("SPARK-53538: update with nondeterministic assignments and no wholestage codegen") { @@ -636,6 +703,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT count(*) FROM $tableNameAsString WHERE value < 2.0"), Row(2) :: Nil) + + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 1) } test("update with default values") { @@ -658,6 +727,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 42, "hr") :: Row(2, 2, "software") :: Row(3, 42, "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 0) } test("update with current_timestamp default value using DEFAULT keyword") { @@ -703,6 +774,8 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { Row(1, Row("x ", "y"), "hr") :: Row(2, Row("bbb", "bbb"), "software") :: Row(3, Row("x ", "y"), "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 0) } test("update with NOT NULL checks") { From 7336546082e04c61713fe3290f4cbf0fd071d00c Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Thu, 2 Apr 2026 08:55:44 +0000 Subject: [PATCH 02/20] Add comment --- .../spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index dcf85af1a7ca4..0cd541b663967 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -58,6 +58,9 @@ case class InsertAdaptiveSparkPlan( case _: ExecutedCommandExec => plan case _: CommandResultExec => plan case w: V2ExistingTableWriteExec if w.operationMetrics.nonEmpty => + // We need to create a preprocessing rule with the operationMetrics reference to prevent + // losing and re-creating metric instances. ResolveIncrementMetric resolves + // UnresolvedIncrementMetric expressions with the list of already existing metrics. val saved = extraPreprocessingRules extraPreprocessingRules = Seq(ResolveIncrementMetric(w.operationMetrics)) val result = w.withNewChildren(w.children.map(apply)) From e58f003f2a6020eb30de223686552ae93ae70b7c Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Thu, 2 Apr 2026 18:12:39 +0000 Subject: [PATCH 03/20] Address comments --- .../sql/catalyst/analysis/RewriteUpdateTable.scala | 6 +++--- .../expressions/UnresolvedIncrementMetric.scala | 11 ++++++----- .../spark/sql/execution/metric/IncrementMetric.scala | 5 ++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index c0ca55c524998..0b7567c6ba87b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -71,16 +71,16 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // construct a read relation and include all required metadata columns val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) - // add Filter with chained IncrementMetric expressions to the plan + // add Filter with IncrementMetric expressions to the plan val updatedRowsMetric = UnresolvedIncrementMetricIfThenReturn( condition = cond, returnExpr = Literal.TrueLiteral, metricName = "numUpdatedRows") val copiedRowsMetric = UnresolvedIncrementMetricIfThenReturn( condition = Not(EqualNullSafe(cond, Literal.TrueLiteral)), - returnExpr = updatedRowsMetric, + returnExpr = Literal.TrueLiteral, metricName = "numCopiedRows") - val readRelationWithMetrics = Filter(copiedRowsMetric, readRelation) + val readRelationWithMetrics = Filter(copiedRowsMetric, Filter(updatedRowsMetric, readRelation)) // build a plan with updated and copied over records val updatedAndRemainingRowsPlan = buildReplaceDataUpdateProjection( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala index b5129e59d9045..2acc778db2440 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.types.DataType * * This is the unresolved form - resolved into IncrementMetricIf by a preparation rule. * - * Marked as Nondeterministic to prevent the optimizer from pruning or reordering it. + * Marked as Nondeterministic to prevent the optimizer from pruning or reordering it, as moving this + * expression can result in miscounted metric values. * Cannot mix in [[Unevaluable]] because both [[Unevaluable]] and [[Nondeterministic]] declare * [[foldable]] as final. * @@ -51,7 +52,7 @@ case class UnresolvedIncrementMetricIf( override def prettyName: String = "unresolved_increment_metric_if" - override def toString: String = s"unresolved_increment_metric_if($condition, $metricName)" + override def toString: String = s"$prettyName($condition, $metricName)" override protected def evalInternal(input: InternalRow): Any = throw QueryExecutionErrors.cannotEvaluateExpressionError(this) @@ -70,7 +71,8 @@ case class UnresolvedIncrementMetricIf( * * This is the unresolved form - resolved into IncrementMetricIfThenReturn by a preparation rule. * - * Marked as Nondeterministic to prevent the optimizer from pruning or reordering it. + * Marked as Nondeterministic to prevent the optimizer from pruning or reordering it, as moving this + * expression can result in miscounted metric values. * Cannot mix in [[Unevaluable]] because both [[Unevaluable]] and [[Nondeterministic]] declare * [[foldable]] as final. * @@ -96,8 +98,7 @@ case class UnresolvedIncrementMetricIfThenReturn( override def prettyName: String = "unresolved_increment_metric_if_then_return" - override def toString: String = - s"unresolved_increment_metric_if_then_return($condition, $returnExpr, $metricName)" + override def toString: String = s"$prettyName($condition, $returnExpr, $metricName)" override protected def evalInternal(input: InternalRow): Any = throw QueryExecutionErrors.cannotEvaluateExpressionError(this) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala index 9ceb8c8a6e15d..71740c4e0296e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala @@ -47,8 +47,7 @@ case class IncrementMetricIf(condition: Expression, metric: SQLMetric) override def prettyName: String = "increment_metric_if" - override def toString: String = - s"increment_metric_if($condition, ${metric.name.getOrElse("metric")})" + override def toString: String = s"$prettyName($condition, ${metric.name.getOrElse("metric")})" override protected def evalInternal(input: InternalRow): Any = { val result = condition.eval(input) @@ -102,7 +101,7 @@ case class IncrementMetricIfThenReturn( override def prettyName: String = "increment_metric_if_then_return" override def toString: String = - s"increment_metric_if_then_return($condition, $returnExpr, ${metric.name.getOrElse("metric")})" + s"$prettyName($condition, $returnExpr, ${metric.name.getOrElse("metric")})" override protected def evalInternal(input: InternalRow): Any = { val condResult = condition.eval(input) From b8fcfb4fe284e52599790e57e29e9a3b8351c56a Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Tue, 7 Apr 2026 08:08:50 +0000 Subject: [PATCH 04/20] Remove IncrementMetric, compute metrics via additional attribute --- .../analysis/RewriteUpdateTable.scala | 40 +++--- .../UnresolvedIncrementMetric.scala | 113 --------------- .../spark/sql/execution/QueryExecution.scala | 1 - .../adaptive/InsertAdaptiveSparkPlan.scala | 19 +-- .../v2/WriteToDataSourceV2Exec.scala | 100 ++++++++++++-- .../execution/metric/IncrementMetric.scala | 129 ------------------ .../metric/ResolveIncrementMetric.scala | 86 ------------ 7 files changed, 104 insertions(+), 384 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/ResolveIncrementMetric.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index 0b7567c6ba87b..aac399dbefcad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression, UnresolvedIncrementMetricIf, UnresolvedIncrementMetricIfThenReturn} -import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, Filter, LogicalPlan, Project, ReplaceData, Union, UpdateTable, WriteDelta} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations @@ -35,6 +35,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ object RewriteUpdateTable extends RewriteRowLevelCommand { + private[sql] final val IS_UPDATED_COLUMN: String = "__is_updated" + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u @ UpdateTable(aliasedTable, assignments, cond) if u.resolved && u.rewritable && u.aligned => @@ -71,20 +73,9 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // construct a read relation and include all required metadata columns val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) - // add Filter with IncrementMetric expressions to the plan - val updatedRowsMetric = UnresolvedIncrementMetricIfThenReturn( - condition = cond, - returnExpr = Literal.TrueLiteral, - metricName = "numUpdatedRows") - val copiedRowsMetric = UnresolvedIncrementMetricIfThenReturn( - condition = Not(EqualNullSafe(cond, Literal.TrueLiteral)), - returnExpr = Literal.TrueLiteral, - metricName = "numCopiedRows") - val readRelationWithMetrics = Filter(copiedRowsMetric, Filter(updatedRowsMetric, readRelation)) - // build a plan with updated and copied over records val updatedAndRemainingRowsPlan = buildReplaceDataUpdateProjection( - readRelationWithMetrics, assignments, cond) + readRelation, assignments, cond) // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) @@ -111,18 +102,18 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) // build a plan for updated records that match the condition - val condWithMetric = UnresolvedIncrementMetricIf(cond, "numUpdatedRows") - val matchedRowsPlan = Filter(condWithMetric, readRelation) + val matchedRowsPlan = Filter(cond, readRelation) val updatedRowsPlan = buildReplaceDataUpdateProjection(matchedRowsPlan, assignments) // build a plan that contains unmatched rows in matched groups that must be copied over - val remainingRowFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral)) - val remainingRowFilterWithMetric = - UnresolvedIncrementMetricIf(remainingRowFilter, "numCopiedRows") - val remainingRowsPlan = Filter(remainingRowFilterWithMetric, readRelation) + val remainingRowFilter = Not(EqualNullSafe(cond, TrueLiteral)) + val remainingRowsPlan = Filter(remainingRowFilter, readRelation) + val remainingRowsPlanWithFlag = Project( + remainingRowsPlan.output :+ Alias(FalseLiteral, IS_UPDATED_COLUMN)(), + remainingRowsPlan) // the new state is a union of updated and copied over records - val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan) + val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlanWithFlag) // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) @@ -157,7 +148,9 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { } } - Project(updatedValues, plan) + // add a boolean column to indicate whether each row was updated or copied over + val isUpdatedCol = Alias(EqualNullSafe(cond, TrueLiteral), IS_UPDATED_COLUMN)() + Project(updatedValues :+ isUpdatedCol, plan) } // build a rewrite plan for sources that support row deltas @@ -178,8 +171,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs, rowIdAttrs) // build a plan for updated records that match the condition - val condWithMetric = UnresolvedIncrementMetricIf(cond, "numUpdatedRows") - val matchedRowsPlan = Filter(condWithMetric, readRelation) + val matchedRowsPlan = Filter(cond, readRelation) val rowDeltaPlan = if (operation.representUpdateAsDeleteAndInsert) { buildDeletesAndInserts(matchedRowsPlan, assignments, rowIdAttrs) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala deleted file mode 100644 index 2acc778db2440..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnresolvedIncrementMetric.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types.DataType - -/** - * Evaluates the boolean [[condition]] and marks a point in the plan where a metric should be - * incremented when the condition is true. Returns the condition's value unchanged. - * - * This is the unresolved form - resolved into IncrementMetricIf by a preparation rule. - * - * Marked as Nondeterministic to prevent the optimizer from pruning or reordering it, as moving this - * expression can result in miscounted metric values. - * Cannot mix in [[Unevaluable]] because both [[Unevaluable]] and [[Nondeterministic]] declare - * [[foldable]] as final. - * - * @param condition the boolean expression to evaluate. - * @param metricName the name of the metric to increment. - */ -case class UnresolvedIncrementMetricIf( - condition: Expression, - metricName: String) - extends UnaryExpression with Nondeterministic { - - override def child: Expression = condition - - override def nullable: Boolean = condition.nullable - - override def dataType: DataType = condition.dataType - - override protected def initializeInternal(partitionIndex: Int): Unit = {} - - override def prettyName: String = "unresolved_increment_metric_if" - - override def toString: String = s"$prettyName($condition, $metricName)" - - override protected def evalInternal(input: InternalRow): Any = - throw QueryExecutionErrors.cannotEvaluateExpressionError(this) - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = - throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) - - override protected def withNewChildInternal( - newChild: Expression): UnresolvedIncrementMetricIf = - copy(condition = newChild) -} - -/** - * Evaluates the boolean [[condition]], marks a point in the plan where a metric should - * be incremented when the condition is true, then evaluates and returns [[returnExpr]]. - * - * This is the unresolved form - resolved into IncrementMetricIfThenReturn by a preparation rule. - * - * Marked as Nondeterministic to prevent the optimizer from pruning or reordering it, as moving this - * expression can result in miscounted metric values. - * Cannot mix in [[Unevaluable]] because both [[Unevaluable]] and [[Nondeterministic]] declare - * [[foldable]] as final. - * - * @param condition the boolean expression to evaluate. - * @param returnExpr the expression whose value is returned. - * @param metricName the name of the metric to increment. - */ -case class UnresolvedIncrementMetricIfThenReturn( - condition: Expression, - returnExpr: Expression, - metricName: String) - extends Expression with BinaryLike[Expression] with Nondeterministic { - - override def left: Expression = condition - - override def right: Expression = returnExpr - - override def nullable: Boolean = returnExpr.nullable - - override def dataType: DataType = returnExpr.dataType - - override protected def initializeInternal(partitionIndex: Int): Unit = {} - - override def prettyName: String = "unresolved_increment_metric_if_then_return" - - override def toString: String = s"$prettyName($condition, $returnExpr, $metricName)" - - override protected def evalInternal(input: InternalRow): Any = - throw QueryExecutionErrors.cannotEvaluateExpressionError(this) - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = - throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) - - override protected def withNewChildrenInternal( - newLeft: Expression, - newRight: Expression): UnresolvedIncrementMetricIfThenReturn = - copy(condition = newLeft, returnExpr = newRight) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 036f46ac6409f..f08b561d6ef9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -624,7 +624,6 @@ object QueryExecution { subquery: Boolean): Seq[Rule[SparkPlan]] = { // `AdaptiveSparkPlanExec` is a leaf node. If inserted, all the following rules will be no-op // as the original plan is hidden behind `AdaptiveSparkPlanExec`. - Seq(metric.ResolveIncrementMetrics) ++ adaptiveExecutionRule.toSeq ++ Seq( CoalesceBucketsInJoin, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index 0cd541b663967..d8f850ab2189e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -30,9 +30,8 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.V1WriteCommand -import org.apache.spark.sql.execution.datasources.v2.{V2CommandExec, V2ExistingTableWriteExec} +import org.apache.spark.sql.execution.datasources.v2.V2CommandExec import org.apache.spark.sql.execution.exchange.Exchange -import org.apache.spark.sql.execution.metric.ResolveIncrementMetric import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperator import org.apache.spark.sql.internal.SQLConf @@ -47,25 +46,12 @@ case class InsertAdaptiveSparkPlan( override def conf: SQLConf = adaptiveExecutionContext.session.sessionState.conf - // Extra preprocessing rules to pass to the next AQE instance (e.g., metric resolution - // rules from a parent V2 write exec). Set before processing children, cleared after. - @transient private var extraPreprocessingRules: Seq[Rule[SparkPlan]] = Nil - override def apply(plan: SparkPlan): SparkPlan = applyInternal(plan, false) private def applyInternal(plan: SparkPlan, isSubquery: Boolean): SparkPlan = plan match { case _ if !conf.adaptiveExecutionEnabled => plan case _: ExecutedCommandExec => plan case _: CommandResultExec => plan - case w: V2ExistingTableWriteExec if w.operationMetrics.nonEmpty => - // We need to create a preprocessing rule with the operationMetrics reference to prevent - // losing and re-creating metric instances. ResolveIncrementMetric resolves - // UnresolvedIncrementMetric expressions with the list of already existing metrics. - val saved = extraPreprocessingRules - extraPreprocessingRules = Seq(ResolveIncrementMetric(w.operationMetrics)) - val result = w.withNewChildren(w.children.map(apply)) - extraPreprocessingRules = saved - result case c: V2CommandExec => c.withNewChildren(c.children.map(apply)) case c: DataWritingCommandExec if !c.cmd.isInstanceOf[V1WriteCommand] || !conf.plannedWriteEnabled => @@ -89,7 +75,8 @@ case class InsertAdaptiveSparkPlan( // Fall back to non-AQE mode if AQE is not supported in any of the sub-queries. val subqueryMap = buildSubqueryMap(plan) val planSubqueriesRule = PlanAdaptiveSubqueries(subqueryMap) - val preprocessingRules = Seq(planSubqueriesRule) ++ extraPreprocessingRules + val preprocessingRules = Seq( + planSubqueriesRule) // Run pre-processing rules. val newPlan = AdaptiveSparkPlanExec.applyPhysicalRules(plan, preprocessingRules) logDebug(s"Adaptive execution enabled for plan: $plan") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index d438a287f0822..367cc66bc7453 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -24,6 +24,7 @@ import org.apache.spark.internal.{Logging, LogKeys} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{InternalRow, ProjectingInternalRow} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.analysis.RewriteUpdateTable.IS_UPDATED_COLUMN import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, TableSpec, UnaryNode} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections} @@ -312,16 +313,25 @@ case class ReplaceDataExec( refreshCache: () => Unit, projections: ReplaceDataProjections, write: Write, - override val rowLevelCommand: Option[RowLevelOperation.Command], - override val operationMetrics: Map[String, SQLMetric] = Map.empty) + override val rowLevelCommand: Option[RowLevelOperation.Command]) extends V2ExistingTableWriteExec { override def writingTask: WritingSparkTask[_] = { + val metricCounter = rowLevelCommand match { + case Some(RowLevelOperation.Command.UPDATE) => + val isUpdatedIndex = query.output.indexWhere(_.name == IS_UPDATED_COLUMN) + Some(BooleanMetricCounter( + isUpdatedIndex, operationMetrics("numUpdatedRows"), operationMetrics("numCopiedRows"))) + case _ => None + } projections match { case ReplaceDataProjections(dataProj, Some(metadataProj)) => - DataAndMetadataWritingSparkTask(dataProj, metadataProj) + DataAndMetadataWritingSparkTask(dataProj, metadataProj, metricCounter) + // TODO: add test coverage for ReplaceData without metadata attributes + case ReplaceDataProjections(dataProj, None) if metricCounter.isDefined => + DataWritingSparkTask(Some(dataProj), metricCounter) case _ => - DataWritingSparkTask + DataWritingSparkTask() } } @@ -338,15 +348,14 @@ case class WriteDeltaExec( refreshCache: () => Unit, projections: WriteDeltaProjections, write: DeltaWrite, - override val rowLevelCommand: Option[RowLevelOperation.Command] = None, - override val operationMetrics: Map[String, SQLMetric] = Map.empty) + override val rowLevelCommand: Option[RowLevelOperation.Command] = None) extends V2ExistingTableWriteExec { override lazy val writingTask: WritingSparkTask[_] = { if (projections.metadataProjection.isDefined) { - DeltaWithMetadataWritingSparkTask(projections) + DeltaWithMetadataWritingSparkTask(projections, operationMetrics) } else { - DeltaWritingSparkTask(projections) + DeltaWritingSparkTask(projections, operationMetrics) } } @@ -416,7 +425,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { */ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSparkPlanHelper { def query: SparkPlan - def writingTask: WritingSparkTask[_] = DataWritingSparkTask + def writingTask: WritingSparkTask[_] = DataWritingSparkTask() def rowLevelCommand: Option[RowLevelOperation.Command] = None var commitProgress: Option[StreamWriterCommitProgress] = None @@ -425,7 +434,15 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSpa override def output: Seq[Attribute] = Nil protected val customMetrics: Map[String, SQLMetric] = Map.empty - val operationMetrics: Map[String, SQLMetric] = Map.empty + + protected lazy val operationMetrics: Map[String, SQLMetric] = rowLevelCommand match { + case Some(RowLevelOperation.Command.UPDATE) => + Map( + "numUpdatedRows" -> SQLMetrics.createMetric(sparkContext, "number of updated rows"), + "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows") + ) + case _ => Map.empty + } override lazy val metrics = customMetrics ++ operationMetrics @@ -620,13 +637,34 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial } } +/** + * Reads a boolean column at the given ordinal and increments one of two metrics per row. + */ +case class BooleanMetricCounter( + ordinal: Int, + trueMetric: SQLMetric, + falseMetric: SQLMetric) extends Serializable { + def count(row: InternalRow): Unit = { + if (row.getBoolean(ordinal)) { + trueMetric.add(1L) + } else { + falseMetric.add(1L) + } + } +} + case class DataAndMetadataWritingSparkTask( dataProj: ProjectingInternalRow, - metadataProj: ProjectingInternalRow) extends WritingSparkTask[DataWriter[InternalRow]] { + metadataProj: ProjectingInternalRow, + metricCounter: Option[BooleanMetricCounter] = None) + extends WritingSparkTask[DataWriter[InternalRow]] { + override protected def write( writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { while (iter.hasNext) { val row = iter.next() + metricCounter.foreach(_.count(row)) + val operation = row.getInt(0) operation match { @@ -646,18 +684,39 @@ case class DataAndMetadataWritingSparkTask( } } -object DataWritingSparkTask extends WritingSparkTask[DataWriter[InternalRow]] { +case class DataWritingSparkTask( + dataProj: Option[ProjectingInternalRow] = None, + metricCounter: Option[BooleanMetricCounter] = None) + extends WritingSparkTask[DataWriter[InternalRow]] { + override protected def write( writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { - writer.writeAll(iter) + if (dataProj.isEmpty && metricCounter.isEmpty) { + writer.writeAll(iter) + } else { + while (iter.hasNext) { + val row = iter.next() + metricCounter.foreach(_.count(row)) + dataProj match { + case Some(proj) => + proj.project(row) + writer.write(proj) + case None => + writer.write(row) + } + } + } } } case class DeltaWritingSparkTask( - projections: WriteDeltaProjections) extends WritingSparkTask[DeltaWriter[InternalRow]] { + projections: WriteDeltaProjections, + operationMetrics: Map[String, SQLMetric] = Map.empty) + extends WritingSparkTask[DeltaWriter[InternalRow]] { private lazy val rowProjection = projections.rowProjection.orNull private lazy val rowIdProjection = projections.rowIdProjection + private lazy val numUpdatedRows = operationMetrics.get("numUpdatedRows") override protected def write( writer: DeltaWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { @@ -666,11 +725,15 @@ case class DeltaWritingSparkTask( val operation = row.getInt(0) operation match { + // When representUpdateAsDeleteAndInsert is true, each logical update is split + // into a DELETE and a REINSERT. Count the DELETE as one updated row. case DELETE_OPERATION => + numUpdatedRows.foreach(_.add(1L)) rowIdProjection.project(row) writer.delete(null, rowIdProjection) case UPDATE_OPERATION => + numUpdatedRows.foreach(_.add(1L)) rowProjection.project(row) rowIdProjection.project(row) writer.update(null, rowIdProjection, rowProjection) @@ -691,11 +754,14 @@ case class DeltaWritingSparkTask( } case class DeltaWithMetadataWritingSparkTask( - projections: WriteDeltaProjections) extends WritingSparkTask[DeltaWriter[InternalRow]] { + projections: WriteDeltaProjections, + operationMetrics: Map[String, SQLMetric] = Map.empty) + extends WritingSparkTask[DeltaWriter[InternalRow]] { private lazy val rowProjection = projections.rowProjection.orNull private lazy val rowIdProjection = projections.rowIdProjection private lazy val metadataProjection = projections.metadataProjection.orNull + private lazy val numUpdatedRows = operationMetrics.get("numUpdatedRows") override protected def write( writer: DeltaWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { @@ -704,12 +770,16 @@ case class DeltaWithMetadataWritingSparkTask( val operation = row.getInt(0) operation match { + // When representUpdateAsDeleteAndInsert is true, each logical update is split + // into a DELETE and a REINSERT. Count the DELETE as one updated row. case DELETE_OPERATION => + numUpdatedRows.foreach(_.add(1L)) rowIdProjection.project(row) metadataProjection.project(row) writer.delete(metadataProjection, rowIdProjection) case UPDATE_OPERATION => + numUpdatedRows.foreach(_.add(1L)) rowProjection.project(row) rowIdProjection.project(row) metadataProjection.project(row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala deleted file mode 100644 index 71740c4e0296e..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/IncrementMetric.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.metric - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, Nondeterministic, UnaryExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.types.DataType - -/** - * Evaluates the boolean [[condition]] and increments an SQLMetric when it is true. - * Returns the condition's value unchanged. - * - * This is the resolved form of - * [[org.apache.spark.sql.catalyst.expressions.UnresolvedIncrementMetricIf]]. - * - * @param condition the boolean expression to evaluate. - * @param metric the SQLMetric accumulator to conditionally increment. - */ -case class IncrementMetricIf(condition: Expression, metric: SQLMetric) - extends UnaryExpression with Nondeterministic { - - override def child: Expression = condition - - override def nullable: Boolean = condition.nullable - - override def dataType: DataType = condition.dataType - - override protected def initializeInternal(partitionIndex: Int): Unit = {} - - override def prettyName: String = "increment_metric_if" - - override def toString: String = s"$prettyName($condition, ${metric.name.getOrElse("metric")})" - - override protected def evalInternal(input: InternalRow): Any = { - val result = condition.eval(input) - if (result != null && result.asInstanceOf[Boolean]) { - metric.add(1L) - } - result - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val condEval = condition.genCode(ctx) - val metricRef = ctx.addReferenceObj(metric.name.getOrElse("metric"), metric) - condEval.copy(code = condEval.code + code""" - if (!${condEval.isNull} && ${condEval.value}) { - $metricRef.add(1L); - } - """) - } - - override protected def withNewChildInternal(newChild: Expression): IncrementMetricIf = - copy(condition = newChild) -} - -/** - * Evaluates the boolean [[condition]], increments an SQLMetric when it is true, - * then evaluates and returns [[returnExpr]]. - * - * This is the resolved form of - * [[org.apache.spark.sql.catalyst.expressions.UnresolvedIncrementMetricIfThenReturn]]. - * - * @param condition the boolean expression to evaluate. - * @param returnExpr the expression whose value is returned. - * @param metric the SQLMetric accumulator to conditionally increment. - */ -case class IncrementMetricIfThenReturn( - condition: Expression, - returnExpr: Expression, - metric: SQLMetric) - extends Expression with BinaryLike[Expression] with Nondeterministic { - - override def left: Expression = condition - - override def right: Expression = returnExpr - - override def nullable: Boolean = returnExpr.nullable - - override def dataType: DataType = returnExpr.dataType - - override protected def initializeInternal(partitionIndex: Int): Unit = {} - - override def prettyName: String = "increment_metric_if_then_return" - - override def toString: String = - s"$prettyName($condition, $returnExpr, ${metric.name.getOrElse("metric")})" - - override protected def evalInternal(input: InternalRow): Any = { - val condResult = condition.eval(input) - if (condResult != null && condResult.asInstanceOf[Boolean]) { - metric.add(1L) - } - returnExpr.eval(input) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val condEval = condition.genCode(ctx) - val returnEval = returnExpr.genCode(ctx) - val metricRef = ctx.addReferenceObj(metric.name.getOrElse("metric"), metric) - returnEval.copy(code = condEval.code + code""" - if (!${condEval.isNull} && ${condEval.value}) { - $metricRef.add(1L); - } - """ + returnEval.code) - } - - override protected def withNewChildrenInternal( - newLeft: Expression, - newRight: Expression): IncrementMetricIfThenReturn = - copy(condition = newLeft, returnExpr = newRight) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/ResolveIncrementMetric.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/ResolveIncrementMetric.scala deleted file mode 100644 index 5385c5ea1d632..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/ResolveIncrementMetric.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.metric - -import org.apache.spark.SparkContext -import org.apache.spark.sql.catalyst.expressions.{UnresolvedIncrementMetricIf, UnresolvedIncrementMetricIfThenReturn} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.write.RowLevelOperation -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.datasources.v2.{ReplaceDataExec, V2ExistingTableWriteExec, WriteDeltaExec} - -/** - * Resolves [[UnresolvedIncrementMetricIf]] and [[UnresolvedIncrementMetricIfThenReturn]] - * expressions in a subtree into their resolved counterparts using a provided metrics map. - * - * Used as an AQE preprocessing rule so that resolution survives AQE replanning (which - * re-creates physical plans from logical plans). - * - * @param metricsMap mapping from metric name to SQLMetric accumulator. - */ -case class ResolveIncrementMetric(metricsMap: Map[String, SQLMetric]) - extends Rule[SparkPlan] { - - override def apply(plan: SparkPlan): SparkPlan = { - if (metricsMap.isEmpty) return plan - plan.transformAllExpressions { - case UnresolvedIncrementMetricIf(cond, name) => - IncrementMetricIf(cond, metricsMap(name)) - case UnresolvedIncrementMetricIfThenReturn(cond, ret, name) => - IncrementMetricIfThenReturn(cond, ret, metricsMap(name)) - } - } -} - -/** - * Top-level preparation rule that finds V2 write exec nodes with a `rowLevelCommand`, - * creates operation SQLMetrics, resolves [[UnresolvedIncrementMetricIf]] and - * [[UnresolvedIncrementMetricIfThenReturn]] expressions in the child plan, and stores the - * metrics on the exec node. - */ -object ResolveIncrementMetrics extends Rule[SparkPlan] { - override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case w: V2ExistingTableWriteExec if w.rowLevelCommand.isDefined && w.operationMetrics.isEmpty => - val metricsMap = createOperationMetrics(w.rowLevelCommand.get) - val resolved = ResolveIncrementMetric(metricsMap).apply(w.child) - val withChild = w.withNewChildren(Seq(resolved)) - setOperationMetrics(withChild, metricsMap) - } - - private def createOperationMetrics(cmd: RowLevelOperation.Command): Map[String, SQLMetric] = { - val sc = SparkContext.getOrCreate() - cmd match { - case RowLevelOperation.Command.UPDATE => - Seq( - "numUpdatedRows", - "numCopiedRows" - ).map { name => - name -> SQLMetrics.createMetric(sc, name) - }.toMap - case _ => Map.empty - } - } - - private def setOperationMetrics( - plan: SparkPlan, - metricsMap: Map[String, SQLMetric]): SparkPlan = plan match { - case r: ReplaceDataExec => r.copy(operationMetrics = metricsMap) - case d: WriteDeltaExec => d.copy(operationMetrics = metricsMap) - case other => other - } -} From 5e12e9428dbfa41d23ebd36a9b096fda6c9e2679 Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Tue, 7 Apr 2026 12:36:50 +0000 Subject: [PATCH 05/20] Revert unnecessary change --- .../apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index aac399dbefcad..4718d8c824d4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -106,7 +106,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { val updatedRowsPlan = buildReplaceDataUpdateProjection(matchedRowsPlan, assignments) // build a plan that contains unmatched rows in matched groups that must be copied over - val remainingRowFilter = Not(EqualNullSafe(cond, TrueLiteral)) + val remainingRowFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral)) val remainingRowsPlan = Filter(remainingRowFilter, readRelation) val remainingRowsPlanWithFlag = Project( remainingRowsPlan.output :+ Alias(FalseLiteral, IS_UPDATED_COLUMN)(), From 6eb4407a99c2f470c8aab8317d3bdaaa39aa9bb5 Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Wed, 8 Apr 2026 17:31:11 +0000 Subject: [PATCH 06/20] Address comments --- .../catalyst/analysis/RewriteUpdateTable.scala | 16 ++++++++++------ .../datasources/v2/WriteToDataSourceV2Exec.scala | 12 ++++++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index 4718d8c824d4b..b3657183c56b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -35,6 +35,11 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ object RewriteUpdateTable extends RewriteRowLevelCommand { + /** + * A boolean column added to ReplaceData plans to distinguish updated rows from copied rows. + * The writing task reads this column to increment operation metrics and strip it before passing + * data to the writer. + */ private[sql] final val IS_UPDATED_COLUMN: String = "__is_updated" override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { @@ -107,13 +112,12 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // build a plan that contains unmatched rows in matched groups that must be copied over val remainingRowFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral)) - val remainingRowsPlan = Filter(remainingRowFilter, readRelation) - val remainingRowsPlanWithFlag = Project( - remainingRowsPlan.output :+ Alias(FalseLiteral, IS_UPDATED_COLUMN)(), - remainingRowsPlan) + val remainingRowsPlan = Project( + Alias(FalseLiteral, IS_UPDATED_COLUMN)() +: readRelation.output, + Filter(remainingRowFilter, readRelation)) // the new state is a union of updated and copied over records - val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlanWithFlag) + val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan) // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) @@ -150,7 +154,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // add a boolean column to indicate whether each row was updated or copied over val isUpdatedCol = Alias(EqualNullSafe(cond, TrueLiteral), IS_UPDATED_COLUMN)() - Project(updatedValues :+ isUpdatedCol, plan) + Project(isUpdatedCol +: updatedValues, plan) } // build a rewrite plan for sources that support row deltas diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 367cc66bc7453..a82aef259e4fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -725,10 +725,7 @@ case class DeltaWritingSparkTask( val operation = row.getInt(0) operation match { - // When representUpdateAsDeleteAndInsert is true, each logical update is split - // into a DELETE and a REINSERT. Count the DELETE as one updated row. case DELETE_OPERATION => - numUpdatedRows.foreach(_.add(1L)) rowIdProjection.project(row) writer.delete(null, rowIdProjection) @@ -738,7 +735,10 @@ case class DeltaWritingSparkTask( rowIdProjection.project(row) writer.update(null, rowIdProjection, rowProjection) + // When representUpdateAsDeleteAndInsert is true, each logical update is split + // into a DELETE and a REINSERT. Count the REINSERT as one updated row. case REINSERT_OPERATION => + numUpdatedRows.foreach(_.add(1L)) rowProjection.project(row) writer.reinsert(null, rowProjection) @@ -770,10 +770,7 @@ case class DeltaWithMetadataWritingSparkTask( val operation = row.getInt(0) operation match { - // When representUpdateAsDeleteAndInsert is true, each logical update is split - // into a DELETE and a REINSERT. Count the DELETE as one updated row. case DELETE_OPERATION => - numUpdatedRows.foreach(_.add(1L)) rowIdProjection.project(row) metadataProjection.project(row) writer.delete(metadataProjection, rowIdProjection) @@ -785,7 +782,10 @@ case class DeltaWithMetadataWritingSparkTask( metadataProjection.project(row) writer.update(metadataProjection, rowIdProjection, rowProjection) + // When representUpdateAsDeleteAndInsert is true, each logical update is split + // into a DELETE and a REINSERT. Count the REINSERT as one updated row. case REINSERT_OPERATION => + numUpdatedRows.foreach(_.add(1L)) rowProjection.project(row) metadataProjection.project(row) writer.reinsert(metadataProjection, rowProjection) From 9757ba701bc37dc907e6cb36c2d7c55e0a230b59 Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Tue, 14 Apr 2026 09:24:33 +0000 Subject: [PATCH 07/20] Address comments --- .../datasources/v2/DataSourceV2Strategy.scala | 32 +++++++++++++------ .../v2/WriteToDataSourceV2Exec.scala | 20 ++++++------ 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1c29117818788..1879f4d03abf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -358,17 +358,31 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat throw SparkException.internalError("Unexpected table relation: " + other) } - case rd @ ReplaceData(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, + case rd @ ReplaceData( + _: DataSourceV2Relation, + _, query, + r: DataSourceV2Relation, + projections, _, Some(write)) => - // use the original relation to refresh the cache - ReplaceDataExec(planLater(query), refreshCache(r), projections, write, - Some(rd.operation.command())) :: Nil - - case wd @ WriteDelta(_: DataSourceV2Relation, _, query, r: DataSourceV2Relation, projections, + ReplaceDataExec( + planLater(query), + refreshCache(r), // use the original relation to refresh the cache + projections, + write, + Some(rd.operation.command)) :: Nil + + case wd @ WriteDelta( + _: DataSourceV2Relation, + _, query, + r: DataSourceV2Relation, + projections, Some(write)) => - // use the original relation to refresh the cache - WriteDeltaExec(planLater(query), refreshCache(r), projections, write, - Some(wd.operation.command())) :: Nil + WriteDeltaExec( + planLater(query), + refreshCache(r), // use the original relation to refresh the cache + projections, + write, + Some(wd.operation.command)) :: Nil case MergeRows(isSourceRowPresent, isTargetRowPresent, matchedInstructions, notMatchedInstructions, notMatchedBySourceInstructions, checkCardinality, output, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index a82aef259e4fc..0298c40dc8350 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperation, UpdateSummaryImpl, Write, WriterCommitMessage, WriteSummary} +import org.apache.spark.sql.connector.write.RowLevelOperation.Command._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -318,7 +319,7 @@ case class ReplaceDataExec( override def writingTask: WritingSparkTask[_] = { val metricCounter = rowLevelCommand match { - case Some(RowLevelOperation.Command.UPDATE) => + case Some(UPDATE) => val isUpdatedIndex = query.output.indexWhere(_.name == IS_UPDATED_COLUMN) Some(BooleanMetricCounter( isUpdatedIndex, operationMetrics("numUpdatedRows"), operationMetrics("numCopiedRows"))) @@ -436,11 +437,10 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSpa protected val customMetrics: Map[String, SQLMetric] = Map.empty protected lazy val operationMetrics: Map[String, SQLMetric] = rowLevelCommand match { - case Some(RowLevelOperation.Command.UPDATE) => + case Some(UPDATE) => Map( "numUpdatedRows" -> SQLMetrics.createMetric(sparkContext, "number of updated rows"), - "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows") - ) + "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows")) case _ => Map.empty } @@ -516,7 +516,7 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSpa private def getWriteSummary(query: SparkPlan): Option[WriteSummary] = { rowLevelCommand.flatMap { - case RowLevelOperation.Command.MERGE => + case MERGE => collectFirst(query) { case m: MergeRowsExec => m }.map { n => val metrics = n.metrics MergeSummaryImpl( @@ -527,15 +527,13 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSpa metrics.get("numTargetRowsMatchedUpdated").map(_.value).getOrElse(-1L), metrics.get("numTargetRowsMatchedDeleted").map(_.value).getOrElse(-1L), metrics.get("numTargetRowsNotMatchedBySourceUpdated").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsNotMatchedBySourceDeleted").map(_.value).getOrElse(-1L) - ) + metrics.get("numTargetRowsNotMatchedBySourceDeleted").map(_.value).getOrElse(-1L)) } - case RowLevelOperation.Command.UPDATE => + case UPDATE => Some(UpdateSummaryImpl( operationMetrics.get("numUpdatedRows").map(_.value).get, - operationMetrics.get("numCopiedRows").map(_.value).get - )) - case RowLevelOperation.Command.DELETE => + operationMetrics.get("numCopiedRows").map(_.value).get)) + case DELETE => None } } From 072df0fd32501145cf65ca7416e7bbaf813398ef Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Tue, 14 Apr 2026 11:02:32 +0000 Subject: [PATCH 08/20] Add 2 more tests --- .../spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala index e821fc3f660da..c2db54f8f724b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala @@ -51,6 +51,8 @@ abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(10, 1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 0) } test("update with nondeterministic conditions") { @@ -89,5 +91,7 @@ abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, -1, -1, "invalid") :: Row(2, 2, 200, "software") :: Row(3, 3, 300, "hr") :: Nil) + + checkUpdateMetrics(numUpdatedRows = 1, numCopiedRows = 0) } } From 2b100a3abee1ce4c2ff30e0fa5765e8af9004a65 Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Tue, 14 Apr 2026 16:54:32 +0000 Subject: [PATCH 09/20] -1 if missing, add RowLevelWriteExec --- .../sql/connector/write/UpdateSummary.java | 4 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/WriteToDataSourceV2Exec.scala | 90 +++++++++++-------- 3 files changed, 57 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/UpdateSummary.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/UpdateSummary.java index ef7fc4534811f..99e9fcc1003ad 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/UpdateSummary.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/UpdateSummary.java @@ -28,12 +28,12 @@ public interface UpdateSummary extends WriteSummary { /** - * Returns the number of rows updated. + * Returns the number of rows updated, or -1 if not found. */ long numUpdatedRows(); /** - * Returns the number of rows copied unmodified. + * Returns the number of rows copied unmodified, or -1 if not found. */ long numCopiedRows(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1879f4d03abf0..70c6ff8cbf9aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -369,7 +369,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat refreshCache(r), // use the original relation to refresh the cache projections, write, - Some(rd.operation.command)) :: Nil + rd.operation.command) :: Nil case wd @ WriteDelta( _: DataSourceV2Relation, @@ -382,7 +382,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat refreshCache(r), // use the original relation to refresh the cache projections, write, - Some(wd.operation.command)) :: Nil + wd.operation.command) :: Nil case MergeRows(isSourceRowPresent, isTargetRowPresent, matchedInstructions, notMatchedInstructions, notMatchedBySourceInstructions, checkCardinality, output, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 0298c40dc8350..fe0de11836f94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -314,12 +314,12 @@ case class ReplaceDataExec( refreshCache: () => Unit, projections: ReplaceDataProjections, write: Write, - override val rowLevelCommand: Option[RowLevelOperation.Command]) - extends V2ExistingTableWriteExec { + rowLevelCommand: RowLevelOperation.Command) + extends RowLevelWriteExec { override def writingTask: WritingSparkTask[_] = { val metricCounter = rowLevelCommand match { - case Some(UPDATE) => + case UPDATE => val isUpdatedIndex = query.output.indexWhere(_.name == IS_UPDATED_COLUMN) Some(BooleanMetricCounter( isUpdatedIndex, operationMetrics("numUpdatedRows"), operationMetrics("numCopiedRows"))) @@ -349,8 +349,8 @@ case class WriteDeltaExec( refreshCache: () => Unit, projections: WriteDeltaProjections, write: DeltaWrite, - override val rowLevelCommand: Option[RowLevelOperation.Command] = None) - extends V2ExistingTableWriteExec { + rowLevelCommand: RowLevelOperation.Command) + extends RowLevelWriteExec { override lazy val writingTask: WritingSparkTask[_] = { if (projections.metadataProjection.isDefined) { @@ -421,13 +421,58 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { } } +/** + * A trait for row-level write operations (UPDATE, DELETE, MERGE) that carry the command. + */ +trait RowLevelWriteExec extends V2ExistingTableWriteExec { + def rowLevelCommand: RowLevelOperation.Command + + override lazy val operationMetrics: Map[String, SQLMetric] = rowLevelCommand match { + case UPDATE => + Map( + "numUpdatedRows" -> SQLMetrics.createMetric(sparkContext, "number of updated rows"), + "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows")) + case _ => Map.empty + } + + /** + * Returns the value of the named metric, or -1 if the metric is not found. + */ + private def getMetricValue(metrics: Map[String, SQLMetric], name: String): Long = { + metrics.get(name).map(_.value).getOrElse(-1L) + } + + override protected def getWriteSummary(query: SparkPlan): Option[WriteSummary] = { + rowLevelCommand match { + case MERGE => + collectFirst(query) { case m: MergeRowsExec => m }.map { n => + val metrics = n.metrics + MergeSummaryImpl( + getMetricValue(metrics, "numTargetRowsCopied"), + getMetricValue(metrics, "numTargetRowsDeleted"), + getMetricValue(metrics, "numTargetRowsUpdated"), + getMetricValue(metrics, "numTargetRowsInserted"), + getMetricValue(metrics, "numTargetRowsMatchedUpdated"), + getMetricValue(metrics, "numTargetRowsMatchedDeleted"), + getMetricValue(metrics, "numTargetRowsNotMatchedBySourceUpdated"), + getMetricValue(metrics, "numTargetRowsNotMatchedBySourceDeleted")) + } + case UPDATE => + Some(UpdateSummaryImpl( + getMetricValue(operationMetrics, "numUpdatedRows"), + getMetricValue(operationMetrics, "numCopiedRows"))) + case DELETE => + None + } + } +} + /** * The base physical plan for writing data into data source v2. */ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSparkPlanHelper { def query: SparkPlan def writingTask: WritingSparkTask[_] = DataWritingSparkTask() - def rowLevelCommand: Option[RowLevelOperation.Command] = None var commitProgress: Option[StreamWriterCommitProgress] = None @@ -435,14 +480,7 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSpa override def output: Seq[Attribute] = Nil protected val customMetrics: Map[String, SQLMetric] = Map.empty - - protected lazy val operationMetrics: Map[String, SQLMetric] = rowLevelCommand match { - case Some(UPDATE) => - Map( - "numUpdatedRows" -> SQLMetrics.createMetric(sparkContext, "number of updated rows"), - "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows")) - case _ => Map.empty - } + def operationMetrics: Map[String, SQLMetric] = Map.empty override lazy val metrics = customMetrics ++ operationMetrics @@ -514,29 +552,7 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSpa Nil } - private def getWriteSummary(query: SparkPlan): Option[WriteSummary] = { - rowLevelCommand.flatMap { - case MERGE => - collectFirst(query) { case m: MergeRowsExec => m }.map { n => - val metrics = n.metrics - MergeSummaryImpl( - metrics.get("numTargetRowsCopied").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsDeleted").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsUpdated").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsInserted").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsMatchedUpdated").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsMatchedDeleted").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsNotMatchedBySourceUpdated").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsNotMatchedBySourceDeleted").map(_.value).getOrElse(-1L)) - } - case UPDATE => - Some(UpdateSummaryImpl( - operationMetrics.get("numUpdatedRows").map(_.value).get, - operationMetrics.get("numCopiedRows").map(_.value).get)) - case DELETE => - None - } - } + protected def getWriteSummary(query: SparkPlan): Option[WriteSummary] = None } trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serializable { From f54923bf63c51626a97fc5290de2db5ee01c8bc1 Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Tue, 14 Apr 2026 18:47:19 +0000 Subject: [PATCH 10/20] Replace __is_updated with operation column --- .../analysis/RewriteRowLevelCommand.scala | 25 ++++- .../analysis/RewriteUpdateTable.scala | 44 ++++---- .../sql/catalyst/util/RowDeltaUtils.scala | 2 + .../v2/WriteToDataSourceV2Exec.scala | 106 +++++++++--------- 4 files changed, 97 insertions(+), 80 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index c5b81dec87c96..0afdecfb46753 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -22,7 +22,8 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ProjectingInternalRow import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, MetadataAttribute, NamedExpression, V2ExpressionUtils} -import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project} +import org.apache.spark.sql.catalyst.expressions.If +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{ReplaceDataProjections, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ @@ -181,21 +182,33 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { } protected def addOperationColumn(operation: Int, plan: LogicalPlan): LogicalPlan = { - val operationType = Alias(Literal(operation, IntegerType), OPERATION_COLUMN)() + addOperationColumn(Literal(operation, IntegerType), plan) + } + + protected def addOperationColumn(operationExpr: Expression, plan: LogicalPlan): LogicalPlan = { + val operationType = Alias(operationExpr, OPERATION_COLUMN)() Project(operationType +: plan.output, plan) } + private final val REPLACE_DATA_OPERATIONS_WITH_ROW = Set( + WRITE_WITH_METADATA_OPERATION, WRITE_OPERATION, + WRITE_UPDATED_OPERATION, WRITE_COPIED_OPERATION) + + private final val REPLACE_DATA_OPERATIONS_WITH_METADATA = Set( + WRITE_WITH_METADATA_OPERATION, + WRITE_UPDATED_OPERATION, WRITE_COPIED_OPERATION) + protected def buildReplaceDataProjections( plan: LogicalPlan, rowAttrs: Seq[Attribute], metadataAttrs: Seq[Attribute]): ReplaceDataProjections = { val outputs = extractOutputs(plan) - val outputsWithRow = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION, WRITE_OPERATION)) + val outputsWithRow = filterOutputs(outputs, REPLACE_DATA_OPERATIONS_WITH_ROW) val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs) val metadataProjection = if (metadataAttrs.nonEmpty) { - val outputsWithMetadata = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION)) + val outputsWithMetadata = filterOutputs(outputs, REPLACE_DATA_OPERATIONS_WITH_METADATA) Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) } else { None @@ -234,6 +247,7 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { private def extractOutputs(plan: LogicalPlan): Seq[Seq[Expression]] = { plan match { case p: Project => Seq(p.projectList) + case u: Union => extractOutputs(u.children.head) case e: Expand => e.projections case m: MergeRows => m.outputs case _ => throw SparkException.internalError("Can't extract outputs from plan: " + plan) @@ -246,6 +260,9 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { outputs.filter { case Literal(operation: Integer, _) +: _ => operations.contains(operation) case Alias(Literal(operation: Integer, _), _) +: _ => operations.contains(operation) + // handle conditional operation column (e.g., If-based for UPDATE) + case Alias(If(_, Literal(op1: Integer, _), Literal(op2: Integer, _)), _) +: _ => + operations.contains(op1) || operations.contains(op2) case other => throw SparkException.internalError("Can't determine operation: " + other) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index b3657183c56b0..5e7a34191fb5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression} -import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, Filter, LogicalPlan, Project, ReplaceData, Union, UpdateTable, WriteDelta} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations @@ -35,13 +35,6 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ object RewriteUpdateTable extends RewriteRowLevelCommand { - /** - * A boolean column added to ReplaceData plans to distinguish updated rows from copied rows. - * The writing task reads this column to increment operation metrics and strip it before passing - * data to the writer. - */ - private[sql] final val IS_UPDATED_COLUMN: String = "__is_updated" - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u @ UpdateTable(aliasedTable, assignments, cond) if u.resolved && u.rewritable && u.aligned => @@ -78,13 +71,18 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // construct a read relation and include all required metadata columns val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + // add a conditional operation column to distinguish updated and copied rows + val operationExpr = + If(EqualNullSafe(cond, TrueLiteral), + Literal(WRITE_UPDATED_OPERATION, IntegerType), + Literal(WRITE_COPIED_OPERATION, IntegerType)) + val readRelationWithOp = addOperationColumn(operationExpr, readRelation) + // build a plan with updated and copied over records - val updatedAndRemainingRowsPlan = buildReplaceDataUpdateProjection( - readRelation, assignments, cond) + val query = buildReplaceDataUpdateProjection(readRelationWithOp, assignments, cond) // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, updatedAndRemainingRowsPlan) val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) @@ -108,20 +106,17 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // build a plan for updated records that match the condition val matchedRowsPlan = Filter(cond, readRelation) - val updatedRowsPlan = buildReplaceDataUpdateProjection(matchedRowsPlan, assignments) + val matchedWithOp = addOperationColumn(WRITE_UPDATED_OPERATION, matchedRowsPlan) + val updatedRowsPlan = buildReplaceDataUpdateProjection(matchedWithOp, assignments) // build a plan that contains unmatched rows in matched groups that must be copied over - val remainingRowFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral)) - val remainingRowsPlan = Project( - Alias(FalseLiteral, IS_UPDATED_COLUMN)() +: readRelation.output, + val remainingRowFilter = Not(EqualNullSafe(cond, TrueLiteral)) + val remainingRowsPlan = addOperationColumn(WRITE_COPIED_OPERATION, Filter(remainingRowFilter, readRelation)) // the new state is a union of updated and copied over records - val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan) - - // build a plan to replace read groups in the table + val query = Union(updatedRowsPlan, remainingRowsPlan) val writeRelation = relation.copy(table = operationTable) - val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, updatedAndRemainingRowsPlan) val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) @@ -133,10 +128,11 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { assignments: Seq[Assignment], cond: Expression = TrueLiteral): LogicalPlan = { - // the plan output may include metadata columns at the end - // that's why the number of assignments may not match the number of plan output columns + // the first column is always the operation column, followed by data and optional metadata columns + // preserve the operation column expression and apply updates to the remaining columns + val Project(operationCol +: _, _) = plan val assignedValues = assignments.map(_.value) - val updatedValues = plan.output.zipWithIndex.map { case (attr, index) => + val updatedValues = plan.output.tail.zipWithIndex.map { case (attr, index) => if (index < assignments.size) { val assignedExpr = assignedValues(index) val updatedValue = If(cond, assignedExpr, attr) @@ -152,9 +148,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { } } - // add a boolean column to indicate whether each row was updated or copied over - val isUpdatedCol = Alias(EqualNullSafe(cond, TrueLiteral), IS_UPDATED_COLUMN)() - Project(isUpdatedCol +: updatedValues, plan) + Project(operationCol +: updatedValues, plan) } // build a rewrite plan for sources that support row deltas diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala index 72baad069b180..2648bce5d340d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala @@ -28,5 +28,7 @@ object RowDeltaUtils { final val REINSERT_OPERATION: Int = 4 final val WRITE_OPERATION: Int = 5 final val WRITE_WITH_METADATA_OPERATION: Int = 6 + final val WRITE_UPDATED_OPERATION: Int = 7 + final val WRITE_COPIED_OPERATION: Int = 8 final val ORIGINAL_ROW_ID_VALUE_PREFIX: String = "__original_row_id_" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index fe0de11836f94..d2489215d8b79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -24,11 +24,10 @@ import org.apache.spark.internal.{Logging, LogKeys} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{InternalRow, ProjectingInternalRow} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException -import org.apache.spark.sql.catalyst.analysis.RewriteUpdateTable.IS_UPDATED_COLUMN import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, TableSpec, UnaryNode} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections} -import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION} +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_COPIED_OPERATION, WRITE_OPERATION, WRITE_UPDATED_OPERATION, WRITE_WITH_METADATA_OPERATION} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric @@ -318,21 +317,11 @@ case class ReplaceDataExec( extends RowLevelWriteExec { override def writingTask: WritingSparkTask[_] = { - val metricCounter = rowLevelCommand match { - case UPDATE => - val isUpdatedIndex = query.output.indexWhere(_.name == IS_UPDATED_COLUMN) - Some(BooleanMetricCounter( - isUpdatedIndex, operationMetrics("numUpdatedRows"), operationMetrics("numCopiedRows"))) - case _ => None - } projections match { case ReplaceDataProjections(dataProj, Some(metadataProj)) => - DataAndMetadataWritingSparkTask(dataProj, metadataProj, metricCounter) - // TODO: add test coverage for ReplaceData without metadata attributes - case ReplaceDataProjections(dataProj, None) if metricCounter.isDefined => - DataWritingSparkTask(Some(dataProj), metricCounter) - case _ => - DataWritingSparkTask() + DataAndMetadataWritingSparkTask(dataProj, metadataProj, operationMetrics) + case ReplaceDataProjections(dataProj, None) => + DataWithProjectionWritingSparkTask(dataProj, operationMetrics) } } @@ -472,7 +461,7 @@ trait RowLevelWriteExec extends V2ExistingTableWriteExec { */ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSparkPlanHelper { def query: SparkPlan - def writingTask: WritingSparkTask[_] = DataWritingSparkTask() + def writingTask: WritingSparkTask[_] = DataWritingSparkTask var commitProgress: Option[StreamWriterCommitProgress] = None @@ -651,37 +640,34 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial } } -/** - * Reads a boolean column at the given ordinal and increments one of two metrics per row. - */ -case class BooleanMetricCounter( - ordinal: Int, - trueMetric: SQLMetric, - falseMetric: SQLMetric) extends Serializable { - def count(row: InternalRow): Unit = { - if (row.getBoolean(ordinal)) { - trueMetric.add(1L) - } else { - falseMetric.add(1L) - } - } -} - case class DataAndMetadataWritingSparkTask( dataProj: ProjectingInternalRow, metadataProj: ProjectingInternalRow, - metricCounter: Option[BooleanMetricCounter] = None) + operationMetrics: Map[String, SQLMetric] = Map.empty) extends WritingSparkTask[DataWriter[InternalRow]] { + private lazy val numUpdatedRows = operationMetrics.get("numUpdatedRows") + private lazy val numCopiedRows = operationMetrics.get("numCopiedRows") + override protected def write( writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { while (iter.hasNext) { val row = iter.next() - metricCounter.foreach(_.count(row)) - val operation = row.getInt(0) operation match { + case WRITE_UPDATED_OPERATION => + numUpdatedRows.foreach(_.add(1L)) + dataProj.project(row) + metadataProj.project(row) + writer.write(metadataProj, dataProj) + + case WRITE_COPIED_OPERATION => + numCopiedRows.foreach(_.add(1L)) + dataProj.project(row) + metadataProj.project(row) + writer.write(metadataProj, dataProj) + case WRITE_WITH_METADATA_OPERATION => dataProj.project(row) metadataProj.project(row) @@ -698,31 +684,49 @@ case class DataAndMetadataWritingSparkTask( } } -case class DataWritingSparkTask( - dataProj: Option[ProjectingInternalRow] = None, - metricCounter: Option[BooleanMetricCounter] = None) +case class DataWithProjectionWritingSparkTask( + dataProj: ProjectingInternalRow, + operationMetrics: Map[String, SQLMetric] = Map.empty) extends WritingSparkTask[DataWriter[InternalRow]] { + private lazy val numUpdatedRows = operationMetrics.get("numUpdatedRows") + private lazy val numCopiedRows = operationMetrics.get("numCopiedRows") + override protected def write( writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { - if (dataProj.isEmpty && metricCounter.isEmpty) { - writer.writeAll(iter) - } else { - while (iter.hasNext) { - val row = iter.next() - metricCounter.foreach(_.count(row)) - dataProj match { - case Some(proj) => - proj.project(row) - writer.write(proj) - case None => - writer.write(row) - } + while (iter.hasNext) { + val row = iter.next() + val operation = row.getInt(0) + + operation match { + case WRITE_UPDATED_OPERATION => + numUpdatedRows.foreach(_.add(1L)) + dataProj.project(row) + writer.write(dataProj) + + case WRITE_COPIED_OPERATION => + numCopiedRows.foreach(_.add(1L)) + dataProj.project(row) + writer.write(dataProj) + + case WRITE_WITH_METADATA_OPERATION | WRITE_OPERATION => + dataProj.project(row) + writer.write(dataProj) + + case other => + throw new SparkException(s"Unexpected operation ID: $other") } } } } +object DataWritingSparkTask extends WritingSparkTask[DataWriter[InternalRow]] { + override protected def write( + writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { + writer.writeAll(iter) + } +} + case class DeltaWritingSparkTask( projections: WriteDeltaProjections, operationMetrics: Map[String, SQLMetric] = Map.empty) From c70e1da4a217e88fac151241f0c373db4d01596b Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Thu, 16 Apr 2026 15:25:33 +0000 Subject: [PATCH 11/20] Fix ReplaceData DML without metadata attributes not projecting out the operation column --- .../analysis/RewriteDeleteFromTable.scala | 3 +- .../analysis/RewriteMergeIntoTable.scala | 6 +- .../analysis/RewriteUpdateTable.scala | 6 +- .../v2/WriteToDataSourceV2Exec.scala | 78 +++++++++++++------ 4 files changed, 66 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala index 13cfc6b73ccbc..92a2e447c4216 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala @@ -86,7 +86,8 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand { // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, remainingRowsPlan) + val writeOp = if (metadataAttrs.nonEmpty) WRITE_WITH_METADATA_OPERATION else WRITE_OPERATION + val query = addOperationColumn(writeOp, remainingRowsPlan) val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index 8ff734c7a9a09..e9bf565e49b67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -202,7 +202,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper // that's why an extra unconditional instruction that would produce the original row is added // as the last MATCHED and NOT MATCHED BY SOURCE instruction // this logic is specific to data sources that replace groups of data - val carryoverRowsOutput = Literal(WRITE_WITH_METADATA_OPERATION) +: targetTable.output + val writeOp = if (metadataAttrs.nonEmpty) WRITE_WITH_METADATA_OPERATION else WRITE_OPERATION + val carryoverRowsOutput = Literal(writeOp) +: targetTable.output val keepCarryoverRowsInstruction = Keep(Copy, TrueLiteral, carryoverRowsOutput) val matchedInstructions = matchedActions.map { action => @@ -439,7 +440,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper case UpdateAction(cond, assignments, _) => val rowValues = assignments.map(_.value) val metadataValues = nullifyMetadataOnUpdate(metadataAttrs) - val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues ++ metadataValues + val writeOp = if (metadataAttrs.nonEmpty) WRITE_WITH_METADATA_OPERATION else WRITE_OPERATION + val output = Seq(Literal(writeOp)) ++ rowValues ++ metadataValues Keep(Update, cond.getOrElse(TrueLiteral), output) case DeleteAction(cond) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index caf7579da889a..79cc24f9cf359 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -77,7 +77,8 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, updatedAndRemainingRowsPlan) + val writeOp = if (metadataAttrs.nonEmpty) WRITE_WITH_METADATA_OPERATION else WRITE_OPERATION + val query = addOperationColumn(writeOp, updatedAndRemainingRowsPlan) val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) @@ -112,7 +113,8 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val query = addOperationColumn(WRITE_WITH_METADATA_OPERATION, updatedAndRemainingRowsPlan) + val writeOp = if (metadataAttrs.nonEmpty) WRITE_WITH_METADATA_OPERATION else WRITE_OPERATION + val query = addOperationColumn(writeOp, updatedAndRemainingRowsPlan) val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 5a2da729c1b52..6124694878b53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -311,14 +311,14 @@ case class ReplaceDataExec( query: SparkPlan, refreshCache: () => Unit, projections: ReplaceDataProjections, - write: Write) extends V2ExistingTableWriteExec { + write: Write) extends RowLevelWriteExec { override def writingTask: WritingSparkTask[_] = { - projections match { - case ReplaceDataProjections(dataProj, Some(metadataProj)) => - DataAndMetadataWritingSparkTask(dataProj, metadataProj) - case _ => - DataWritingSparkTask + projections.metadataProjection match { + case Some(metadataProj) => + DataAndMetadataWritingSparkTask(projections.rowProjection, metadataProj) + case None => + DataWithProjectionWritingSparkTask(projections.rowProjection) } } @@ -334,7 +334,7 @@ case class WriteDeltaExec( query: SparkPlan, refreshCache: () => Unit, projections: WriteDeltaProjections, - write: DeltaWrite) extends V2ExistingTableWriteExec { + write: DeltaWrite) extends RowLevelWriteExec { override lazy val writingTask: WritingSparkTask[_] = { if (projections.metadataProjection.isDefined) { @@ -405,6 +405,33 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { } } +/** + * A trait for row-level write operations (UPDATE, DELETE, MERGE) that carry the command. + */ +trait RowLevelWriteExec extends V2ExistingTableWriteExec { + /** + * Returns the value of the named metric, or -1 if the metric is not found. + */ + private def getMetricValue(metrics: Map[String, SQLMetric], name: String): Long = { + metrics.get(name).map(_.value).getOrElse(-1L) + } + + override protected def getWriteSummary(query: SparkPlan): Option[WriteSummary] = { + collectFirst(query) { case m: MergeRowsExec => m }.map { n => + val metrics = n.metrics + MergeSummaryImpl( + getMetricValue(metrics, "numTargetRowsCopied"), + getMetricValue(metrics, "numTargetRowsDeleted"), + getMetricValue(metrics, "numTargetRowsUpdated"), + getMetricValue(metrics, "numTargetRowsInserted"), + getMetricValue(metrics, "numTargetRowsMatchedUpdated"), + getMetricValue(metrics, "numTargetRowsMatchedDeleted"), + getMetricValue(metrics, "numTargetRowsNotMatchedBySourceUpdated"), + getMetricValue(metrics, "numTargetRowsNotMatchedBySourceDeleted")) + } + } +} + /** * The base physical plan for writing data into data source v2. */ @@ -489,21 +516,7 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSpa Nil } - private def getWriteSummary(query: SparkPlan): Option[WriteSummary] = { - collectFirst(query) { case m: MergeRowsExec => m }.map { n => - val metrics = n.metrics - MergeSummaryImpl( - metrics.get("numTargetRowsCopied").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsDeleted").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsUpdated").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsInserted").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsMatchedUpdated").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsMatchedDeleted").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsNotMatchedBySourceUpdated").map(_.value).getOrElse(-1L), - metrics.get("numTargetRowsNotMatchedBySourceDeleted").map(_.value).getOrElse(-1L) - ) - } - } + protected def getWriteSummary(query: SparkPlan): Option[WriteSummary] = None } trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serializable { @@ -628,6 +641,27 @@ case class DataAndMetadataWritingSparkTask( } } +case class DataWithProjectionWritingSparkTask( + dataProj: ProjectingInternalRow) extends WritingSparkTask[DataWriter[InternalRow]] { + + override protected def write( + writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { + while (iter.hasNext) { + val row = iter.next() + val operation = row.getInt(0) + + operation match { + case WRITE_OPERATION => + dataProj.project(row) + writer.write(dataProj) + + case other => + throw new SparkException(s"Unexpected operation ID: $other") + } + } + } +} + object DataWritingSparkTask extends WritingSparkTask[DataWriter[InternalRow]] { override protected def write( writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { From 8b376047c2e36ec0e295112b485ce897ec4a3a86 Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Thu, 16 Apr 2026 16:02:27 +0000 Subject: [PATCH 12/20] Rename WRITE_OPERATION and WRITE_WITH_METADATA_OPERATION --- .../sql/catalyst/analysis/RewriteDeleteFromTable.scala | 3 +-- .../sql/catalyst/analysis/RewriteMergeIntoTable.scala | 10 ++++------ .../sql/catalyst/analysis/RewriteRowLevelCommand.scala | 5 +++-- .../sql/catalyst/analysis/RewriteUpdateTable.scala | 6 ++---- .../apache/spark/sql/catalyst/util/RowDeltaUtils.scala | 4 ++-- .../datasources/v2/WriteToDataSourceV2Exec.scala | 8 ++++---- 6 files changed, 16 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala index 92a2e447c4216..5a1ea61b38e30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala @@ -86,8 +86,7 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand { // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val writeOp = if (metadataAttrs.nonEmpty) WRITE_WITH_METADATA_OPERATION else WRITE_OPERATION - val query = addOperationColumn(writeOp, remainingRowsPlan) + val query = addOperationColumn(WRITE_OPERATION, remainingRowsPlan) val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index e9bf565e49b67..bec35bdac4911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta} import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update} -import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION} +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITHOUT_METADATA_OPERATION} import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta} import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE @@ -202,8 +202,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper // that's why an extra unconditional instruction that would produce the original row is added // as the last MATCHED and NOT MATCHED BY SOURCE instruction // this logic is specific to data sources that replace groups of data - val writeOp = if (metadataAttrs.nonEmpty) WRITE_WITH_METADATA_OPERATION else WRITE_OPERATION - val carryoverRowsOutput = Literal(writeOp) +: targetTable.output + val carryoverRowsOutput = Literal(WRITE_OPERATION) +: targetTable.output val keepCarryoverRowsInstruction = Keep(Copy, TrueLiteral, carryoverRowsOutput) val matchedInstructions = matchedActions.map { action => @@ -440,8 +439,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper case UpdateAction(cond, assignments, _) => val rowValues = assignments.map(_.value) val metadataValues = nullifyMetadataOnUpdate(metadataAttrs) - val writeOp = if (metadataAttrs.nonEmpty) WRITE_WITH_METADATA_OPERATION else WRITE_OPERATION - val output = Seq(Literal(writeOp)) ++ rowValues ++ metadataValues + val output = Seq(Literal(WRITE_OPERATION)) ++ rowValues ++ metadataValues Keep(Update, cond.getOrElse(TrueLiteral), output) case DeleteAction(cond) => @@ -450,7 +448,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper case InsertAction(cond, assignments) => val rowValues = assignments.map(_.value) val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType)) - val output = Seq(Literal(WRITE_OPERATION)) ++ rowValues ++ metadataValues + val output = Seq(Literal(WRITE_WITHOUT_METADATA_OPERATION)) ++ rowValues ++ metadataValues Keep(Insert, cond.getOrElse(TrueLiteral), output) case other => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index c5b81dec87c96..eef488a25e0d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -191,11 +191,12 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { metadataAttrs: Seq[Attribute]): ReplaceDataProjections = { val outputs = extractOutputs(plan) - val outputsWithRow = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION, WRITE_OPERATION)) + val outputsWithRow = filterOutputs(outputs, + Set(WRITE_OPERATION, WRITE_WITHOUT_METADATA_OPERATION)) val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs) val metadataProjection = if (metadataAttrs.nonEmpty) { - val outputsWithMetadata = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION)) + val outputsWithMetadata = filterOutputs(outputs, Set(WRITE_OPERATION)) Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index 79cc24f9cf359..96d1e8e9cd06f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -77,8 +77,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val writeOp = if (metadataAttrs.nonEmpty) WRITE_WITH_METADATA_OPERATION else WRITE_OPERATION - val query = addOperationColumn(writeOp, updatedAndRemainingRowsPlan) + val query = addOperationColumn(WRITE_OPERATION, updatedAndRemainingRowsPlan) val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) @@ -113,8 +112,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val writeOp = if (metadataAttrs.nonEmpty) WRITE_WITH_METADATA_OPERATION else WRITE_OPERATION - val query = addOperationColumn(writeOp, updatedAndRemainingRowsPlan) + val query = addOperationColumn(WRITE_OPERATION, updatedAndRemainingRowsPlan) val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala index 72baad069b180..9bdbcf899f87f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala @@ -26,7 +26,7 @@ object RowDeltaUtils { final val UPDATE_OPERATION: Int = 2 final val INSERT_OPERATION: Int = 3 final val REINSERT_OPERATION: Int = 4 - final val WRITE_OPERATION: Int = 5 - final val WRITE_WITH_METADATA_OPERATION: Int = 6 + final val WRITE_WITHOUT_METADATA_OPERATION: Int = 5 + final val WRITE_OPERATION: Int = 6 final val ORIGINAL_ROW_ID_VALUE_PREFIX: String = "__original_row_id_" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 6124694878b53..94e6e750a2433 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, TableSpec, UnaryNode} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections} -import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION} +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITHOUT_METADATA_OPERATION} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric @@ -625,12 +625,12 @@ case class DataAndMetadataWritingSparkTask( val operation = row.getInt(0) operation match { - case WRITE_WITH_METADATA_OPERATION => + case WRITE_OPERATION => dataProj.project(row) metadataProj.project(row) writer.write(metadataProj, dataProj) - case WRITE_OPERATION => + case WRITE_WITHOUT_METADATA_OPERATION => dataProj.project(row) writer.write(dataProj) @@ -651,7 +651,7 @@ case class DataWithProjectionWritingSparkTask( val operation = row.getInt(0) operation match { - case WRITE_OPERATION => + case WRITE_OPERATION | WRITE_WITHOUT_METADATA_OPERATION => dataProj.project(row) writer.write(dataProj) From f69f4e6b825863798153c9911a8af27f401ccdf2 Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Thu, 16 Apr 2026 16:29:42 +0000 Subject: [PATCH 13/20] Tests without metadata attributes --- .../InMemoryRowLevelOperationTable.scala | 92 ++++++++++++------- ...aBasedNoMetadataDeleteFromTableSuite.scala | 30 ++++++ ...taBasedNoMetadataMergeIntoTableSuite.scala | 30 ++++++ ...DeltaBasedNoMetadataUpdateTableSuite.scala | 28 ++++++ ...pBasedNoMetadataDeleteFromTableSuite.scala | 27 ++++++ ...upBasedNoMetadataMergeIntoTableSuite.scala | 27 ++++++ ...GroupBasedNoMetadataUpdateTableSuite.scala | 27 ++++++ 7 files changed, 228 insertions(+), 33 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataDeleteFromTableSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataMergeIntoTableSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataUpdateTableSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataDeleteFromTableSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataMergeIntoTableSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataUpdateTableSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index 2f7cad5992153..91e899bc1169e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -51,6 +51,8 @@ class InMemoryRowLevelOperationTable( private final val INDEX_COLUMN_REF = FieldReference(IndexColumn.name) private final val SUPPORTS_DELTAS = "supports-deltas" private final val SPLIT_UPDATES = "split-updates" + private final val NO_METADATA = "no-metadata" + private final val noMetadata = properties.getOrDefault(NO_METADATA, "false") == "true" // used in row-level operation tests to verify replaced partitions var replacedPartitions: Seq[Seq[Any]] = Seq.empty @@ -73,7 +75,11 @@ class InMemoryRowLevelOperationTable( var configuredScan: InMemoryBatchScan = _ override def requiredMetadataAttributes(): Array[NamedReference] = { - Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF) + if (noMetadata) { + Array.empty + } else { + Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF) + } } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { @@ -89,22 +95,29 @@ class InMemoryRowLevelOperationTable( override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { lastWriteInfo = info new WriteBuilder { - override def build(): Write = new Write with RequiresDistributionAndOrdering { - override def requiredDistribution: Distribution = { - Distributions.clustered(Array(PARTITION_COLUMN_REF)) + override def build(): Write = if (noMetadata) { + new Write { + override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan) + override def description: String = "InMemoryWrite" } - - override def requiredOrdering: Array[SortOrder] = { - Array[SortOrder]( - LogicalExpressions.sort( - PARTITION_COLUMN_REF, - SortDirection.ASCENDING, - SortDirection.ASCENDING.defaultNullOrdering())) + } else { + new Write with RequiresDistributionAndOrdering { + override def requiredDistribution: Distribution = { + Distributions.clustered(Array(PARTITION_COLUMN_REF)) + } + + override def requiredOrdering: Array[SortOrder] = { + Array[SortOrder]( + LogicalExpressions.sort( + PARTITION_COLUMN_REF, + SortDirection.ASCENDING, + SortDirection.ASCENDING.defaultNullOrdering())) + } + + override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan) + + override def description: String = "InMemoryWrite" } - - override def toBatch: BatchWrite = PartitionBasedReplaceData(configuredScan) - - override def description: String = "InMemoryWrite" } } } @@ -138,7 +151,11 @@ class InMemoryRowLevelOperationTable( private final val PK_COLUMN_REF = FieldReference("pk") override def requiredMetadataAttributes(): Array[NamedReference] = { - Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF) + if (noMetadata) { + Array.empty + } else { + Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF) + } } override def rowId(): Array[NamedReference] = Array(PK_COLUMN_REF) @@ -150,22 +167,28 @@ class InMemoryRowLevelOperationTable( override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = { lastWriteInfo = info new DeltaWriteBuilder { - override def build(): DeltaWrite = new DeltaWrite with RequiresDistributionAndOrdering { - - override def requiredDistribution(): Distribution = { - Distributions.clustered(Array(PARTITION_COLUMN_REF)) + override def build(): DeltaWrite = if (noMetadata) { + new DeltaWrite { + override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite } - - override def requiredOrdering(): Array[SortOrder] = { - Array[SortOrder]( - LogicalExpressions.sort( - PARTITION_COLUMN_REF, - SortDirection.ASCENDING, - SortDirection.ASCENDING.defaultNullOrdering()) - ) + } else { + new DeltaWrite with RequiresDistributionAndOrdering { + + override def requiredDistribution(): Distribution = { + Distributions.clustered(Array(PARTITION_COLUMN_REF)) + } + + override def requiredOrdering(): Array[SortOrder] = { + Array[SortOrder]( + LogicalExpressions.sort( + PARTITION_COLUMN_REF, + SortDirection.ASCENDING, + SortDirection.ASCENDING.defaultNullOrdering()) + ) + } + + override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite } - - override def toBatch: DeltaBatchWrite = TestDeltaBatchWrite } } } @@ -208,7 +231,8 @@ private class DeltaBufferWriter(schema: StructType) extends BufferWriter(schema) override def delete(meta: InternalRow, id: InternalRow): Unit = { val pk = id.getInt(0) buffer.deletes += pk - val logEntry = new GenericInternalRow(Array[Any](DELETE, pk, meta.copy(), null)) + val metaCopy = if (meta != null) meta.copy() else null + val logEntry = new GenericInternalRow(Array[Any](DELETE, pk, metaCopy, null)) buffer.log += logEntry } @@ -216,13 +240,15 @@ private class DeltaBufferWriter(schema: StructType) extends BufferWriter(schema) val pk = id.getInt(0) buffer.deletes += pk buffer.rows.append(row.copy()) - val logEntry = new GenericInternalRow(Array[Any](UPDATE, pk, meta.copy(), row.copy())) + val metaCopy = if (meta != null) meta.copy() else null + val logEntry = new GenericInternalRow(Array[Any](UPDATE, pk, metaCopy, row.copy())) buffer.log += logEntry } override def reinsert(meta: InternalRow, row: InternalRow): Unit = { buffer.rows.append(row.copy()) - val logEntry = new GenericInternalRow(Array[Any](REINSERT, null, meta.copy(), row.copy())) + val metaCopy = if (meta != null) meta.copy() else null + val logEntry = new GenericInternalRow(Array[Any](REINSERT, null, metaCopy, row.copy())) buffer.log += logEntry } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataDeleteFromTableSuite.scala new file mode 100644 index 0000000000000..73407d640923a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataDeleteFromTableSuite.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +class DeltaBasedNoMetadataDeleteFromTableSuite extends DeleteFromTableSuiteBase { + + override protected def extraTableProps: java.util.Map[String, String] = { + val props = new java.util.HashMap[String, String]() + props.put("supports-deltas", "true") + props.put("no-metadata", "true") + props + } + + override def enforceCheckConstraintOnDelete: Boolean = false +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataMergeIntoTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataMergeIntoTableSuite.scala new file mode 100644 index 0000000000000..d6e1484253135 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataMergeIntoTableSuite.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +class DeltaBasedNoMetadataMergeIntoTableSuite extends MergeIntoTableSuiteBase { + + override protected def deltaMerge = true + + override protected def extraTableProps: java.util.Map[String, String] = { + val props = new java.util.HashMap[String, String]() + props.put("supports-deltas", "true") + props.put("no-metadata", "true") + props + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataUpdateTableSuite.scala new file mode 100644 index 0000000000000..15ff7688a26b1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataUpdateTableSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +class DeltaBasedNoMetadataUpdateTableSuite extends UpdateTableSuiteBase { + + override protected def extraTableProps: java.util.Map[String, String] = { + val props = new java.util.HashMap[String, String]() + props.put("supports-deltas", "true") + props.put("no-metadata", "true") + props + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataDeleteFromTableSuite.scala new file mode 100644 index 0000000000000..6ac4f6e32fb18 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataDeleteFromTableSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +class GroupBasedNoMetadataDeleteFromTableSuite extends DeleteFromTableSuiteBase { + + override protected def extraTableProps: java.util.Map[String, String] = { + val props = new java.util.HashMap[String, String]() + props.put("no-metadata", "true") + props + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataMergeIntoTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataMergeIntoTableSuite.scala new file mode 100644 index 0000000000000..5feadcdec23e3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataMergeIntoTableSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +class GroupBasedNoMetadataMergeIntoTableSuite extends MergeIntoTableSuiteBase { + + override protected def extraTableProps: java.util.Map[String, String] = { + val props = new java.util.HashMap[String, String]() + props.put("no-metadata", "true") + props + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataUpdateTableSuite.scala new file mode 100644 index 0000000000000..31db56b8d5594 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedNoMetadataUpdateTableSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +class GroupBasedNoMetadataUpdateTableSuite extends UpdateTableSuiteBase { + + override protected def extraTableProps: java.util.Map[String, String] = { + val props = new java.util.HashMap[String, String]() + props.put("no-metadata", "true") + props + } +} From f70cc22b334505741614fd73464afb4c9701cdf8 Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Thu, 16 Apr 2026 17:04:18 +0000 Subject: [PATCH 14/20] Fix comment --- .../sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 94e6e750a2433..082e2ccfb4819 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -406,7 +406,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { } /** - * A trait for row-level write operations (UPDATE, DELETE, MERGE) that carry the command. + * A trait for row-level write operations (UPDATE, DELETE, MERGE). */ trait RowLevelWriteExec extends V2ExistingTableWriteExec { /** From cae5e1e79e201daa4628c7996095adadbdb01cc2 Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Fri, 17 Apr 2026 12:49:08 +0000 Subject: [PATCH 15/20] Remove WRITE_OPERATION and instead use fine-grained operations --- .../analysis/RewriteDeleteFromTable.scala | 2 +- .../analysis/RewriteMergeIntoTable.scala | 8 ++-- .../analysis/RewriteRowLevelCommand.scala | 40 +++++++++++-------- .../analysis/RewriteUpdateTable.scala | 20 ++++++---- .../sql/catalyst/util/RowDeltaUtils.scala | 3 +- .../v2/WriteToDataSourceV2Exec.scala | 8 ++-- 6 files changed, 45 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala index 5a1ea61b38e30..f8881e2077103 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDeleteFromTable.scala @@ -86,7 +86,7 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand { // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val query = addOperationColumn(WRITE_OPERATION, remainingRowsPlan) + val query = addOperationColumn(COPY_OPERATION, remainingRowsPlan) val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index bec35bdac4911..f21f53a28300d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta} import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update} -import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITHOUT_METADATA_OPERATION} +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{COPY_OPERATION, INSERT_OPERATION, OPERATION_COLUMN, UPDATE_OPERATION} import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta} import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE @@ -202,7 +202,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper // that's why an extra unconditional instruction that would produce the original row is added // as the last MATCHED and NOT MATCHED BY SOURCE instruction // this logic is specific to data sources that replace groups of data - val carryoverRowsOutput = Literal(WRITE_OPERATION) +: targetTable.output + val carryoverRowsOutput = Literal(COPY_OPERATION) +: targetTable.output val keepCarryoverRowsInstruction = Keep(Copy, TrueLiteral, carryoverRowsOutput) val matchedInstructions = matchedActions.map { action => @@ -439,7 +439,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper case UpdateAction(cond, assignments, _) => val rowValues = assignments.map(_.value) val metadataValues = nullifyMetadataOnUpdate(metadataAttrs) - val output = Seq(Literal(WRITE_OPERATION)) ++ rowValues ++ metadataValues + val output = Seq(Literal(UPDATE_OPERATION)) ++ rowValues ++ metadataValues Keep(Update, cond.getOrElse(TrueLiteral), output) case DeleteAction(cond) => @@ -448,7 +448,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper case InsertAction(cond, assignments) => val rowValues = assignments.map(_.value) val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType)) - val output = Seq(Literal(WRITE_WITHOUT_METADATA_OPERATION)) ++ rowValues ++ metadataValues + val output = Seq(Literal(INSERT_OPERATION)) ++ rowValues ++ metadataValues Keep(Insert, cond.getOrElse(TrueLiteral), output) case other => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index eef488a25e0d9..f0592eac82844 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -21,8 +21,8 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ProjectingInternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, MetadataAttribute, NamedExpression, V2ExpressionUtils} -import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, If, Literal, MetadataAttribute, NamedExpression, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{ReplaceDataProjections, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ @@ -38,11 +38,11 @@ import org.apache.spark.util.ArrayImplicits._ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { - private final val DELTA_OPERATIONS_WITH_ROW = - Set(UPDATE_OPERATION, REINSERT_OPERATION, INSERT_OPERATION) - private final val DELTA_OPERATIONS_WITH_METADATA = - Set(DELETE_OPERATION, UPDATE_OPERATION, REINSERT_OPERATION) - private final val DELTA_OPERATIONS_WITH_ROW_ID = + private final val OPERATIONS_WITH_ROW = + Set(UPDATE_OPERATION, REINSERT_OPERATION, INSERT_OPERATION, COPY_OPERATION) + private final val OPERATIONS_WITH_METADATA = + Set(DELETE_OPERATION, UPDATE_OPERATION, REINSERT_OPERATION, COPY_OPERATION) + private final val OPERATIONS_WITH_ROW_ID = Set(DELETE_OPERATION, UPDATE_OPERATION) protected def groupFilterEnabled: Boolean = conf.runtimeRowLevelOperationGroupFilterEnabled @@ -181,7 +181,11 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { } protected def addOperationColumn(operation: Int, plan: LogicalPlan): LogicalPlan = { - val operationType = Alias(Literal(operation, IntegerType), OPERATION_COLUMN)() + addOperationColumn(Literal(operation, IntegerType), plan) + } + + protected def addOperationColumn(operation: Expression, plan: LogicalPlan): LogicalPlan = { + val operationType = Alias(operation, OPERATION_COLUMN)() Project(operationType +: plan.output, plan) } @@ -191,12 +195,11 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { metadataAttrs: Seq[Attribute]): ReplaceDataProjections = { val outputs = extractOutputs(plan) - val outputsWithRow = filterOutputs(outputs, - Set(WRITE_OPERATION, WRITE_WITHOUT_METADATA_OPERATION)) + val outputsWithRow = filterOutputs(outputs, OPERATIONS_WITH_ROW) val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs) val metadataProjection = if (metadataAttrs.nonEmpty) { - val outputsWithMetadata = filterOutputs(outputs, Set(WRITE_OPERATION)) + val outputsWithMetadata = filterOutputs(outputs, OPERATIONS_WITH_METADATA) Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) } else { None @@ -213,17 +216,17 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { val outputs = extractOutputs(plan) val rowProjection = if (rowAttrs.nonEmpty) { - val outputsWithRow = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW) + val outputsWithRow = filterOutputs(outputs, OPERATIONS_WITH_ROW) Some(newLazyProjection(plan, outputsWithRow, rowAttrs)) } else { None } - val outputsWithRowId = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW_ID) + val outputsWithRowId = filterOutputs(outputs, OPERATIONS_WITH_ROW_ID) val rowIdProjection = newLazyRowIdProjection(plan, outputsWithRowId, rowIdAttrs) val metadataProjection = if (metadataAttrs.nonEmpty) { - val outputsWithMetadata = filterOutputs(outputs, DELTA_OPERATIONS_WITH_METADATA) + val outputsWithMetadata = filterOutputs(outputs, OPERATIONS_WITH_METADATA) Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) } else { None @@ -237,6 +240,7 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { case p: Project => Seq(p.projectList) case e: Expand => e.projections case m: MergeRows => m.outputs + case u: Union => u.children.flatMap(extractOutputs) case _ => throw SparkException.internalError("Can't extract outputs from plan: " + plan) } } @@ -244,11 +248,13 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { private def filterOutputs( outputs: Seq[Seq[Expression]], operations: Set[Int]): Seq[Seq[Expression]] = { - outputs.filter { - case Literal(operation: Integer, _) +: _ => operations.contains(operation) - case Alias(Literal(operation: Integer, _), _) +: _ => operations.contains(operation) + def matches(expr: Expression): Boolean = expr match { + case Literal(operation: Integer, _) => operations.contains(operation) + case Alias(child, _) => matches(child) + case If(_, trueValue, falseValue) => matches(trueValue) && matches(falseValue) case other => throw SparkException.internalError("Can't determine operation: " + other) } + outputs.filter(output => matches(output.head)) } private def newLazyProjection( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index 96d1e8e9cd06f..05095e686f597 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -72,12 +72,13 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) // build a plan with updated and copied over records - val updatedAndRemainingRowsPlan = buildReplaceDataUpdateProjection( - readRelation, assignments, cond) + // the conditional operation column needs to be added in the same Projection as cond is + // referencing attributes before the update + val writeOp = If(cond, Literal(UPDATE_OPERATION), Literal(COPY_OPERATION)) + val query = buildReplaceDataUpdateProjection(readRelation, assignments, writeOp, cond) // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val query = addOperationColumn(WRITE_OPERATION, updatedAndRemainingRowsPlan) val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) @@ -101,18 +102,19 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // build a plan for updated records that match the condition val matchedRowsPlan = Filter(cond, readRelation) - val updatedRowsPlan = buildReplaceDataUpdateProjection(matchedRowsPlan, assignments) + val updatedRowsPlan = buildReplaceDataUpdateProjection( + matchedRowsPlan, assignments, Literal(UPDATE_OPERATION)) // build a plan that contains unmatched rows in matched groups that must be copied over val remainingRowFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral)) - val remainingRowsPlan = Filter(remainingRowFilter, readRelation) + val remainingRowsPlan = addOperationColumn(COPY_OPERATION, + Filter(remainingRowFilter, readRelation)) // the new state is a union of updated and copied over records - val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan) + val query = Union(updatedRowsPlan, remainingRowsPlan) // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val query = addOperationColumn(WRITE_OPERATION, updatedAndRemainingRowsPlan) val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) @@ -122,6 +124,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { private def buildReplaceDataUpdateProjection( plan: LogicalPlan, assignments: Seq[Assignment], + operation: Expression, cond: Expression = TrueLiteral): LogicalPlan = { // the plan output may include metadata columns at the end @@ -143,7 +146,8 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { } } - Project(updatedValues, plan) + val operationCol = Alias(operation, OPERATION_COLUMN)() + Project(operationCol +: updatedValues, plan) } // build a rewrite plan for sources that support row deltas diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala index 9bdbcf899f87f..8b86d530550b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RowDeltaUtils.scala @@ -26,7 +26,6 @@ object RowDeltaUtils { final val UPDATE_OPERATION: Int = 2 final val INSERT_OPERATION: Int = 3 final val REINSERT_OPERATION: Int = 4 - final val WRITE_WITHOUT_METADATA_OPERATION: Int = 5 - final val WRITE_OPERATION: Int = 6 + final val COPY_OPERATION: Int = 5 final val ORIGINAL_ROW_ID_VALUE_PREFIX: String = "__original_row_id_" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 082e2ccfb4819..98e0c6f66deaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, TableSpec, UnaryNode} import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections} -import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION, WRITE_OPERATION, WRITE_WITHOUT_METADATA_OPERATION} +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{COPY_OPERATION, DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric @@ -625,12 +625,12 @@ case class DataAndMetadataWritingSparkTask( val operation = row.getInt(0) operation match { - case WRITE_OPERATION => + case UPDATE_OPERATION | COPY_OPERATION => dataProj.project(row) metadataProj.project(row) writer.write(metadataProj, dataProj) - case WRITE_WITHOUT_METADATA_OPERATION => + case INSERT_OPERATION => dataProj.project(row) writer.write(dataProj) @@ -651,7 +651,7 @@ case class DataWithProjectionWritingSparkTask( val operation = row.getInt(0) operation match { - case WRITE_OPERATION | WRITE_WITHOUT_METADATA_OPERATION => + case UPDATE_OPERATION | COPY_OPERATION | INSERT_OPERATION => dataProj.project(row) writer.write(dataProj) From 268c29c8df2d463b8414a0f4571907d3e8dd190b Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Fri, 17 Apr 2026 13:56:16 +0000 Subject: [PATCH 16/20] Resolve conflicts --- .../sql/catalyst/analysis/RewriteRowLevelCommand.scala | 9 --------- .../connector/DeltaBasedNoMetadataUpdateTableSuite.scala | 2 ++ 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index dd5cc25eb6d3b..f0592eac82844 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -189,14 +189,6 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { Project(operationType +: plan.output, plan) } - private final val REPLACE_DATA_OPERATIONS_WITH_ROW = Set( - WRITE_WITH_METADATA_OPERATION, WRITE_OPERATION, - WRITE_UPDATED_OPERATION, WRITE_COPIED_OPERATION) - - private final val REPLACE_DATA_OPERATIONS_WITH_METADATA = Set( - WRITE_WITH_METADATA_OPERATION, - WRITE_UPDATED_OPERATION, WRITE_COPIED_OPERATION) - protected def buildReplaceDataProjections( plan: LogicalPlan, rowAttrs: Seq[Attribute], @@ -246,7 +238,6 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { private def extractOutputs(plan: LogicalPlan): Seq[Seq[Expression]] = { plan match { case p: Project => Seq(p.projectList) - case u: Union => extractOutputs(u.children.head) case e: Expand => e.projections case m: MergeRows => m.outputs case u: Union => u.children.flatMap(extractOutputs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataUpdateTableSuite.scala index 15ff7688a26b1..bf657806a30de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataUpdateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataUpdateTableSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.connector class DeltaBasedNoMetadataUpdateTableSuite extends UpdateTableSuiteBase { + override protected def deltaUpdate: Boolean = true + override protected def extraTableProps: java.util.Map[String, String] = { val props = new java.util.HashMap[String, String]() props.put("supports-deltas", "true") From 69b2c42d8705059b44c070dd61ee2be404d29d0e Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Sat, 18 Apr 2026 08:18:13 +0000 Subject: [PATCH 17/20] Address comments --- .../catalyst/analysis/RewriteRowLevelCommand.scala | 6 +----- .../sql/catalyst/analysis/RewriteUpdateTable.scala | 12 ++++-------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index f0592eac82844..48c48eb323bd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -181,11 +181,7 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { } protected def addOperationColumn(operation: Int, plan: LogicalPlan): LogicalPlan = { - addOperationColumn(Literal(operation, IntegerType), plan) - } - - protected def addOperationColumn(operation: Expression, plan: LogicalPlan): LogicalPlan = { - val operationType = Alias(operation, OPERATION_COLUMN)() + val operationType = Alias(Literal(operation, IntegerType), OPERATION_COLUMN)() Project(operationType +: plan.output, plan) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index 05095e686f597..3c41b6bfa5683 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -72,10 +72,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) // build a plan with updated and copied over records - // the conditional operation column needs to be added in the same Projection as cond is - // referencing attributes before the update - val writeOp = If(cond, Literal(UPDATE_OPERATION), Literal(COPY_OPERATION)) - val query = buildReplaceDataUpdateProjection(readRelation, assignments, writeOp, cond) + val query = buildReplaceDataUpdateProjection(readRelation, assignments, cond) // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) @@ -102,8 +99,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { // build a plan for updated records that match the condition val matchedRowsPlan = Filter(cond, readRelation) - val updatedRowsPlan = buildReplaceDataUpdateProjection( - matchedRowsPlan, assignments, Literal(UPDATE_OPERATION)) + val updatedRowsPlan = buildReplaceDataUpdateProjection(matchedRowsPlan, assignments) // build a plan that contains unmatched rows in matched groups that must be copied over val remainingRowFilter = Not(EqualNullSafe(cond, Literal.TrueLiteral)) @@ -124,7 +120,6 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { private def buildReplaceDataUpdateProjection( plan: LogicalPlan, assignments: Seq[Assignment], - operation: Expression, cond: Expression = TrueLiteral): LogicalPlan = { // the plan output may include metadata columns at the end @@ -146,7 +141,8 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { } } - val operationCol = Alias(operation, OPERATION_COLUMN)() + val writeOp = If(cond, Literal(UPDATE_OPERATION), Literal(COPY_OPERATION)) + val operationCol = Alias(writeOp, OPERATION_COLUMN)() Project(operationCol +: updatedValues, plan) } From 551a701f6203a4a62abf5cf87ac0b101671b3b68 Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Fri, 17 Apr 2026 15:56:53 +0000 Subject: [PATCH 18/20] DELETE metrics for WriteDelta --- .../sql/connector/write/DeleteSummary.java | 39 +++++++++++++++++++ .../connector/write/DeleteSummaryImpl.scala | 27 +++++++++++++ .../v2/WriteToDataSourceV2Exec.scala | 14 ++++++- .../connector/DeleteFromTableSuiteBase.scala | 29 ++++++++++++++ .../DeltaBasedDeleteFromTableSuite.scala | 2 + ...aBasedNoMetadataDeleteFromTableSuite.scala | 2 + 6 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeleteSummary.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/DeleteSummaryImpl.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeleteSummary.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeleteSummary.java new file mode 100644 index 0000000000000..76ece79b09a96 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DeleteSummary.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write; + +import org.apache.spark.annotation.Evolving; + +/** + * Provides an informational summary of the DELETE operation producing write. + * + * @since 4.2.0 + */ +@Evolving +public interface DeleteSummary extends WriteSummary { + + /** + * Returns the number of rows deleted, or -1 if not found. + */ + long numDeletedRows(); + + /** + * Returns the number of rows copied unmodified, or -1 if not found. + */ + long numCopiedRows(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/DeleteSummaryImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/DeleteSummaryImpl.scala new file mode 100644 index 0000000000000..b96bf86a57681 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/DeleteSummaryImpl.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write + +/** + * Implementation of [[DeleteSummary]] that provides DELETE operation summary. + */ +private[sql] case class DeleteSummaryImpl( + numDeletedRows: Long, + numCopiedRows: Long) + extends DeleteSummary { +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 2be3a6ee246c0..2ccae98efc948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{COPY_OPERATION, DELETE_ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric -import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperation, UpdateSummaryImpl, Write, WriterCommitMessage, WriteSummary} +import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeleteSummaryImpl, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperation, UpdateSummaryImpl, Write, WriterCommitMessage, WriteSummary} import org.apache.spark.sql.connector.write.RowLevelOperation.Command._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution, UnaryExecNode} @@ -426,6 +426,10 @@ trait RowLevelWriteExec extends V2ExistingTableWriteExec { Map( "numUpdatedRows" -> SQLMetrics.createMetric(sparkContext, "number of updated rows"), "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows")) + case DELETE => + Map( + "numDeletedRows" -> SQLMetrics.createMetric(sparkContext, "number of deleted rows"), + "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows")) case _ => Map.empty } @@ -456,7 +460,9 @@ trait RowLevelWriteExec extends V2ExistingTableWriteExec { getMetricValue(operationMetrics, "numUpdatedRows"), getMetricValue(operationMetrics, "numCopiedRows"))) case DELETE => - None + Some(DeleteSummaryImpl( + getMetricValue(operationMetrics, "numDeletedRows"), + getMetricValue(operationMetrics, "numCopiedRows"))) } } } @@ -735,6 +741,7 @@ case class DeltaWritingSparkTask( private lazy val rowProjection = projections.rowProjection.orNull private lazy val rowIdProjection = projections.rowIdProjection private lazy val numUpdatedRows = operationMetrics.get("numUpdatedRows") + private lazy val numDeletedRows = operationMetrics.get("numDeletedRows") override protected def write( writer: DeltaWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { @@ -744,6 +751,7 @@ case class DeltaWritingSparkTask( operation match { case DELETE_OPERATION => + numDeletedRows.foreach(_.add(1L)) rowIdProjection.project(row) writer.delete(null, rowIdProjection) @@ -780,6 +788,7 @@ case class DeltaWithMetadataWritingSparkTask( private lazy val rowIdProjection = projections.rowIdProjection private lazy val metadataProjection = projections.metadataProjection.orNull private lazy val numUpdatedRows = operationMetrics.get("numUpdatedRows") + private lazy val numDeletedRows = operationMetrics.get("numDeletedRows") override protected def write( writer: DeltaWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = { @@ -789,6 +798,7 @@ case class DeltaWithMetadataWritingSparkTask( operation match { case DELETE_OPERATION => + numDeletedRows.foreach(_.add(1L)) rowIdProjection.project(row) metadataProjection.project(row) writer.delete(metadataProjection, rowIdProjection) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index 0f7f4cefe2feb..938f0c095f836 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.CheckInvariant import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.connector.catalog.InMemoryTable +import org.apache.spark.sql.connector.write.DeleteSummary import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec, WriteDeltaExec} abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { @@ -28,6 +30,25 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { protected def enforceCheckConstraintOnDelete: Boolean = true + protected def deltaDelete: Boolean = false + + protected def getDeleteSummary(): DeleteSummary = { + val t = catalog.loadTable(ident).asInstanceOf[InMemoryTable] + t.commits.last.writeSummary.get.asInstanceOf[DeleteSummary] + } + + protected def checkDeleteMetrics( + numDeletedRows: Long, + numCopiedRows: Long): Unit = { + val summary = getDeleteSummary() + val expectedDeleted = if (deltaDelete) numDeletedRows else 0L + assert(summary.numDeletedRows() === expectedDeleted, + s"Expected numDeletedRows=$expectedDeleted, got ${summary.numDeletedRows()}") + val expectedCopied = if (deltaDelete) 0L else numCopiedRows + assert(summary.numCopiedRows() === expectedCopied, + s"Expected numCopiedRows=$expectedCopied, got ${summary.numCopiedRows()}") + } + test("delete from table containing added column with default value") { createAndInitTable("pk INT NOT NULL, dep STRING", """{ "pk": 1, "dep": "hr" }""") @@ -151,6 +172,8 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { sql(s"DELETE FROM $tableNameAsString WHERE id <= 1") checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) + + checkDeleteMetrics(numDeletedRows = 0, numCopiedRows = 0) } test("delete with basic filters") { @@ -165,6 +188,8 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) } test("delete with aliases") { @@ -177,6 +202,8 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { sql(s"DELETE FROM $tableNameAsString AS t WHERE t.id <= 1 OR t.dep = 'hr'") checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, "software") :: Nil) + + checkDeleteMetrics(numDeletedRows = 2, numCopiedRows = 0) } test("delete with IN predicates") { @@ -191,6 +218,8 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, "software") :: Row(3, null, "hr") :: Nil) + + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) } test("delete with NOT IN predicates") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala index 9046123ddbd3f..3d3b37705d0b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala @@ -25,6 +25,8 @@ class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { import testImplicits._ + override protected def deltaDelete: Boolean = true + override protected lazy val extraTableProps: java.util.Map[String, String] = { val props = new java.util.HashMap[String, String]() props.put("supports-deltas", "true") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataDeleteFromTableSuite.scala index 73407d640923a..b33cf87402b3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedNoMetadataDeleteFromTableSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.connector class DeltaBasedNoMetadataDeleteFromTableSuite extends DeleteFromTableSuiteBase { + override protected def deltaDelete: Boolean = true + override protected def extraTableProps: java.util.Map[String, String] = { val props = new java.util.HashMap[String, String]() props.put("supports-deltas", "true") From 0d6c5f42e00be4dd4597e100706d8a8573c9ac26 Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Fri, 17 Apr 2026 16:33:21 +0000 Subject: [PATCH 19/20] DELETE metrics for ReplaceData --- .../v2/WriteToDataSourceV2Exec.scala | 24 ++++++++++++-- .../connector/DeleteFromTableSuiteBase.scala | 31 +++++++++++++++++-- .../DeltaBasedDeleteFromTableSuite.scala | 4 +++ .../GroupBasedDeleteFromTableSuite.scala | 6 ++++ 4 files changed, 60 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 2ccae98efc948..4577712a3bc6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{COPY_OPERATION, DELETE_ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric -import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeleteSummaryImpl, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperation, UpdateSummaryImpl, Write, WriterCommitMessage, WriteSummary} +import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeleteSummaryImpl, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperation, RowLevelOperationTable, UpdateSummaryImpl, Write, WriterCommitMessage, WriteSummary} import org.apache.spark.sql.connector.write.RowLevelOperation.Command._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution, UnaryExecNode} @@ -334,6 +334,26 @@ case class ReplaceDataExec( override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = { copy(query = newChild) } + + override protected def getWriteSummary(query: SparkPlan): Option[WriteSummary] = { + if (rowLevelCommand == DELETE) { + // DELETE ReplaceData plans filter out the deleted rows early in the plan, and they don't + // reach this node. We need to calculate this value as numScannedRows - numCopiedRows. + val numScannedRows = collectFirst(query) { + case b: BatchScanExec if b.table.isInstanceOf[RowLevelOperationTable] => + getMetricValue(b.metrics, "numOutputRows") + } + val numCopiedRows = getMetricValue(metrics, "numCopiedRows") + val numDeletedRows = if (numScannedRows.exists(_ >= 0) && numCopiedRows >= 0) { + numScannedRows.get - numCopiedRows + } else { + // One of the metrics couldn't be found, also mark numDeletedRows as not found. + -1L + } + metrics("numDeletedRows").set(numDeletedRows) + } + super.getWriteSummary(query) + } } /** @@ -436,7 +456,7 @@ trait RowLevelWriteExec extends V2ExistingTableWriteExec { /** * Returns the value of the named metric, or -1 if the metric is not found. */ - private def getMetricValue(metrics: Map[String, SQLMetric], name: String): Long = { + protected def getMetricValue(metrics: Map[String, SQLMetric], name: String): Long = { metrics.get(name).map(_.value).getOrElse(-1L) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index 938f0c095f836..2682487e51ba0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -41,9 +41,8 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { numDeletedRows: Long, numCopiedRows: Long): Unit = { val summary = getDeleteSummary() - val expectedDeleted = if (deltaDelete) numDeletedRows else 0L - assert(summary.numDeletedRows() === expectedDeleted, - s"Expected numDeletedRows=$expectedDeleted, got ${summary.numDeletedRows()}") + assert(summary.numDeletedRows() === numDeletedRows, + s"Expected numDeletedRows=$numDeletedRows, got ${summary.numDeletedRows()}") val expectedCopied = if (deltaDelete) 0L else numCopiedRows assert(summary.numCopiedRows() === expectedCopied, s"Expected numCopiedRows=$expectedCopied, got ${summary.numCopiedRows()}") @@ -86,6 +85,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { Row(3, "software", "initial-text"), Row(4, "hr", "initial-text"), Row(6, "hr", "new-text"))) + checkDeleteMetrics(numDeletedRows = 2, numCopiedRows = 3) } test("delete from table with table constraints") { @@ -111,10 +111,13 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq(Row(2, 4, "eng"), Row(3, 6, "eng"))) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 0) + sql(s"DELETE FROM $tableNameAsString WHERE pk >=3") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq(Row(2, 4, "eng"))) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) } test("delete from table containing struct column with default value") { @@ -234,12 +237,14 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "hr") :: Row(2, 2, "software") :: Row(3, null, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 0, numCopiedRows = 0) sql(s"DELETE FROM $tableNameAsString WHERE id NOT IN (1, 10)") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "hr") :: Row(3, null, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 0) } test("delete with conditions on nested columns") { @@ -253,10 +258,12 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, Row(2, "v2"), "software") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 0) sql(s"DELETE FROM $tableNameAsString t WHERE t.complex.c1 = id") checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 0) } test("delete with IN subqueries") { @@ -284,6 +291,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) append("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 4, "id": 1, "dep": "hr" } @@ -305,6 +313,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(5, -1, "hr") :: Row(4, 1, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 2, numCopiedRows = 2) append("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 6, "id": null, "dep": "hr" } @@ -326,6 +335,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(5, -1, "hr") :: Row(4, 1, "hr") :: Row(6, null, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 3) } } @@ -349,6 +359,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) } } @@ -375,6 +386,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "hr") :: Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 0, numCopiedRows = 0) sql( s"""DELETE FROM $tableNameAsString @@ -383,6 +395,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { |""".stripMargin) checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(3, null, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 2, numCopiedRows = 1) append("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 4, "id": 1, "dep": "hr" } @@ -403,6 +416,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { |""".stripMargin) checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(5, 2, "hardware") :: Nil) + checkDeleteMetrics(numDeletedRows = 3, numCopiedRows = 0) sql( s"""DELETE FROM $tableNameAsString @@ -449,6 +463,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "hr") :: Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 0, numCopiedRows = 0) sql( s"""DELETE FROM $tableNameAsString t @@ -459,6 +474,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) sql( s"""DELETE FROM $tableNameAsString t @@ -469,6 +485,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, "hardware") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 0) sql( s"""DELETE FROM $tableNameAsString t @@ -481,6 +498,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, "hardware") :: Nil) + checkDeleteMetrics(numDeletedRows = 0, numCopiedRows = 0) } } @@ -509,6 +527,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "hr") :: Row(3, null, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 0) sql( s"""DELETE FROM $tableNameAsString t @@ -517,6 +536,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { |""".stripMargin) checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Row(1, 1, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) sql( s"""DELETE FROM $tableNameAsString t @@ -527,6 +547,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { |""".stripMargin) checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 0) } } @@ -550,6 +571,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, "hardware") :: Row(3, null, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) } } @@ -582,6 +604,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(3, 2, "hardware") :: Row(4, 3, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 2, numCopiedRows = 2) // verify the view reflects the changes in the table checkAnswer(sql("SELECT * FROM temp"), Nil) @@ -627,6 +650,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, 200) :: Nil) + checkDeleteMetrics(numDeletedRows = 2, numCopiedRows = 0) } test("delete with subquery cannot be converted into delete with filters") { @@ -646,6 +670,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, 200) :: Row(3, 3, 100) :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala index 3d3b37705d0b3..15d259d44a4fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala @@ -54,6 +54,7 @@ class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) checkLastWriteLog(deleteWriteLogEntry(id = 1, metadata = Row("hr", null))) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 0) } test("delete with subquery handles metadata columns correctly") { @@ -85,6 +86,7 @@ class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) checkLastWriteLog(deleteWriteLogEntry(id = 1, metadata = Row("hr", null))) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 0) } } @@ -138,6 +140,7 @@ class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, "us", "software") :: Row(3, 3, "canada", "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 0) } test("delete does not double plan table") { @@ -164,5 +167,6 @@ class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Row(2, 2, 150, "software") :: Row(3, 3, 120, "hr") :: Nil) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala index 2f922295010ff..3889b0d172adc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala @@ -47,6 +47,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { checkLastWriteLog( writeWithMetadataLogEntry(metadata = Row("hr", 1), data = Row(3, 3, "hr")), writeWithMetadataLogEntry(metadata = Row("hr", 2), data = Row(4, 4, "hr"))) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 2) } test("delete with nondeterministic conditions") { @@ -86,6 +87,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) checkReplacedPartitions(Seq("hr")) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) } test("delete with subqueries and runtime group filtering") { @@ -118,6 +120,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { Row(1, 300, "hr") :: Row(3, 120, "hr") :: Row(4, 150, "software") :: Nil) checkReplacedPartitions(Seq("software")) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) } } @@ -166,6 +169,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) checkReplacedPartitions(Seq("hr")) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) } } @@ -199,6 +203,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) checkReplacedPartitions(Seq("hr")) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 1) } } @@ -229,6 +234,7 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) checkReplacedPartitions(Seq("software", "hr")) + checkDeleteMetrics(numDeletedRows = 1, numCopiedRows = 2) } } } From 670a0c40a0abbbbe4eaee491bd057bf0e796ab7e Mon Sep 17 00:00:00 2001 From: Ziya Mukhtarov Date: Wed, 22 Apr 2026 07:29:53 +0000 Subject: [PATCH 20/20] Resolve conflicts --- .../sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index e93cdb7756270..6bb1eb6f4b6d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -436,7 +436,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { } /** - * A trait for row-level write operations (UPDATE, DELETE, MERGE) that carry the command. + * A trait for row-level write operations (UPDATE, DELETE, MERGE). */ trait RowLevelWriteExec extends V2ExistingTableWriteExec { def rowLevelCommand: RowLevelOperation.Command