diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala index 4fb9a48a3e00..d0165cb6ffcc 100644 --- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.AlignRowLevelCommandAssignments import org.apache.spark.sql.catalyst.analysis.CheckMergeIntoTableConditions import org.apache.spark.sql.catalyst.analysis.MergeIntoIcebergTableResolutionCheck import org.apache.spark.sql.catalyst.analysis.ProcedureArgumentCoercion +import org.apache.spark.sql.catalyst.analysis.RemoveUnusedMetadataColumns import org.apache.spark.sql.catalyst.analysis.ResolveMergeIntoTableReferences import org.apache.spark.sql.catalyst.analysis.ResolveProcedures import org.apache.spark.sql.catalyst.analysis.RewriteDeleteFromIcebergTable @@ -55,6 +56,7 @@ class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) { extensions.injectResolutionRule { _ => RewriteDeleteFromIcebergTable } extensions.injectResolutionRule { _ => RewriteUpdateTable } extensions.injectResolutionRule { _ => RewriteMergeIntoTable } + extensions.injectResolutionRule { _ => RemoveUnusedMetadataColumns } extensions.injectCheckRule { _ => MergeIntoIcebergTableResolutionCheck } extensions.injectCheckRule { _ => AlignedRowLevelIcebergCommandCheck } diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RemoveUnusedMetadataColumns.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RemoveUnusedMetadataColumns.scala new file mode 100644 index 000000000000..5113f6112cf0 --- /dev/null +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RemoveUnusedMetadataColumns.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +/** + * A rule to add a projection on top of the source table of MergeRows to remove unnecessary + * metadata columns reading. + */ +object RemoveUnusedMetadataColumns extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case mergeRows: MergeRows if mergeRows.resolved => + mergeRows.child match { + case join @ Join(left, right, _, _, _) if isIcebergSourcePlan(left, conf.resolver) => + val projected = projectWithoutMetadata(left) + mergeRows.withNewChildren(Seq(join.withNewChildren(Seq(projected, right)))) + + case join @ Join(left, right, _, _, _) if isIcebergSourcePlan(right, conf.resolver) => + val projected = projectWithoutMetadata(right) + mergeRows.withNewChildren(Seq(join.withNewChildren(Seq(left, projected)))) + + case _ => mergeRows + } + } + + private def projectWithoutMetadata(plan: LogicalPlan): LogicalPlan = { + + import org.apache.spark.sql.catalyst.util.MetadataColumnHelper + + val outputWithoutMetadata = plan.output.filterNot(col => col.isMetadataCol) + if (outputWithoutMetadata.nonEmpty && outputWithoutMetadata.size != plan.output.size) { + Project(outputWithoutMetadata, plan) + } else { + plan + } + } + + private def isIcebergSourcePlan(plan: LogicalPlan, resolve: Resolver): Boolean = { + plan.output.exists(col => resolve(col.name, RewriteMergeIntoTable.ROW_FROM_SOURCE)) && + onlyHasIcebergRelation(plan) + } + + private def onlyHasIcebergRelation(plan: LogicalPlan): Boolean = { + val icebergRelations = plan.collect { + case node: LeafNode => + node + } + + icebergRelations.forall { + case r: DataSourceV2Relation if r.table.isInstanceOf[SparkTable] => + true + case _ => false + } + } +} diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index c367b07c701a..18acc7ed6c23 100644 --- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -78,7 +78,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with PredicateHelper { - private final val ROW_FROM_SOURCE = "__row_from_source" + private[spark] final val ROW_FROM_SOURCE = "__row_from_source" private final val ROW_FROM_TARGET = "__row_from_target" private final val ROW_ID = "__row_id" diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java index dc1e96be48a1..30d642057469 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java @@ -62,6 +62,10 @@ import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.plans.logical.CommandResult; +import org.apache.spark.sql.catalyst.plans.logical.Join; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.util.package$; import org.apache.spark.sql.execution.SparkPlan; import org.apache.spark.sql.internal.SQLConf; import org.assertj.core.api.Assertions; @@ -70,6 +74,8 @@ import org.junit.Assume; import org.junit.BeforeClass; import org.junit.Test; +import scala.PartialFunction; +import scala.collection.Seq; public abstract class TestMerge extends SparkRowLevelOperationsTestBase { @@ -2540,6 +2546,61 @@ public void testMergeToWapBranchWithTableBranchIdentifier() { branch))); } + @Test + public void testRemoveUnusedMetadataColumns() { + createAndInitTable( + "id INT, v STRING", "{ \"id\": 1, \"v\": \"v1\" }\n" + "{ \"id\": 2, \"v\": \"v2\" }"); + + String mergeSql = + String.format( + "MERGE INTO %s t USING %s s " + + "ON t.id == s.id " + + "WHEN MATCHED AND t.id = 1 THEN " + + " UPDATE SET v = 'x' " + + "WHEN NOT MATCHED THEN " + + " INSERT *", + commitTarget(), commitTarget()); + + LogicalPlan optimized = spark.sql(mergeSql).queryExecution().optimizedPlan(); + CommandResult commandResult = (CommandResult) optimized; + Seq sourcePlans = + commandResult + .commandLogicalPlan() + .collect( + new PartialFunction() { + @Override + public LogicalPlan apply(LogicalPlan plan) { + Join join = (Join) plan; + if (isSourcePlan(join.left())) { + return join.left(); + } else { + return join.right(); + } + } + + @Override + public boolean isDefinedAt(LogicalPlan plan) { + if (!(plan instanceof Join)) { + return false; + } + + Join join = (Join) plan; + return isSourcePlan(join.left()) || isSourcePlan(join.right()); + } + + private boolean isSourcePlan(LogicalPlan plan) { + return plan.output().exists(col -> col.name().equals("__row_from_source")); + } + }); + + Assertions.assertThat(sourcePlans.size()).isEqualTo(1); + LogicalPlan sourcePlan = sourcePlans.head(); + String metadataKey = package$.MODULE$.METADATA_COL_ATTR_KEY(); + boolean containsMetadataColumns = + sourcePlan.output().find(col -> col.metadata().contains(metadataKey)).nonEmpty(); + Assertions.assertThat(containsMetadataColumns).isFalse(); + } + private void checkJoinAndFilterConditions(String query, String join, String icebergFilters) { // disable runtime filtering for easier validation withSQLConf(