From d0489346d2ff927b58fb42f44cf92feba53bd27a Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 29 Apr 2026 06:28:53 +0000 Subject: [PATCH 01/33] V1 --- .../catalog/TransactionalCatalogPlugin.java | 39 ++ .../catalog/transactions/Transaction.java | 78 ++++ .../catalog/transactions/TransactionInfo.java | 30 ++ .../sql/catalyst/analysis/Analyzer.scala | 29 +- .../analysis/RelationResolution.scala | 15 +- .../UnresolveTransactionRelations.scala | 56 +++ .../catalyst/analysis/V2TableReference.scala | 28 +- .../catalyst/plans/logical/statements.scala | 2 +- .../catalyst/plans/logical/v2Commands.scala | 23 +- .../transactions/TransactionUtils.scala | 55 +++ .../connector/catalog/CatalogManager.scala | 8 +- .../sql/connector/catalog/LookupCatalog.scala | 13 + .../TransactionAwareCatalogManager.scala | 57 +++ .../transactions/TransactionInfoImpl.scala | 20 + .../transactions/TransactionUtilsSuite.scala | 124 +++++ .../connector/catalog/InMemoryBaseTable.scala | 3 + ...nMemoryRowLevelOperationTableCatalog.scala | 15 +- .../sql/connector/catalog/InMemoryTable.scala | 15 +- .../catalog/InMemoryTableCatalog.scala | 4 + .../spark/sql/connector/catalog/txns.scala | 147 ++++++ .../apache/spark/sql/classic/Catalog.scala | 17 + .../spark/sql/execution/CacheManager.scala | 3 + .../spark/sql/execution/QueryExecution.scala | 131 +++++- .../datasources/v2/DataSourceV2Strategy.scala | 3 - .../datasources/v2/DeleteFromTableExec.scala | 9 +- .../execution/datasources/v2/V2Writes.scala | 2 +- .../v2/WriteToDataSourceV2Exec.scala | 36 +- .../AppendDataTransactionSuite.scala | 228 ++++++++++ .../connector/DeleteFromTableSuiteBase.scala | 184 +++++++- .../DeltaBasedDeleteFromTableSuite.scala | 2 + .../connector/MergeIntoDataFrameSuite.scala | 71 ++- .../connector/MergeIntoTableSuiteBase.scala | 427 +++++++++++++++++- .../RowLevelOperationSuiteBase.scala | 45 +- .../sql/connector/UpdateTableSuiteBase.scala | 321 ++++++++++++- .../benchmark/AnalyzerBenchmark.scala | 118 +++++ 35 files changed, 2291 insertions(+), 67 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java new file mode 100644 index 0000000000000..34a4fc68e9649 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.sql.connector.catalog.transactions.Transaction; +import org.apache.spark.sql.connector.catalog.transactions.TransactionInfo; + +/** + * A {@link CatalogPlugin} that supports transactions. + *

+ * Catalogs that implement this interface opt in to transactional query execution. A catalog + * implementing this interface is responsible for starting transactions. + * + * @since 4.2.0 + */ +public interface TransactionalCatalogPlugin extends CatalogPlugin { + + /** + * Begins a new transaction and returns a {@link Transaction} representing it. + * + * @param info metadata about the transaction being started. + */ + Transaction beginTransaction(TransactionInfo info); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java new file mode 100644 index 0000000000000..80513aff31506 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions; + +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin; + +import java.io.Closeable; + +/** + * Represents a transaction. + *

+ * Spark begins a transaction with {@link TransactionalCatalogPlugin#beginTransaction} and + * executes read/write operations against the transaction's catalog. On success, Spark + * calls {@link #commit()}; on failure, Spark calls {@link #abort()}. In both cases Spark + * subsequently calls {@link #close()} to release resources. + * + * @since 4.2.0 + */ +public interface Transaction extends Closeable { + + /** + * Returns the catalog associated with this transaction. This catalog is responsible for tracking + * read/write operations that occur within the boundaries of a transaction. This allows + * connectors to perform conflict resolution at commit time. + */ + CatalogPlugin catalog(); + + /** + * Commits the transaction. All writes performed under it become visible to other readers. + *

+ * The connector is responsible for detecting and resolving conflicting commits or throwing + * an exception if resolution is not possible. + *

+ * This method will be called exactly once per transaction. Spark calls {@link #close()} + * immediately after this method returns. + * + * @throws IllegalStateException if the transaction has already been committed or aborted. + */ + void commit(); + + /** + * Aborts the transaction, discarding any staged changes. + *

+ * This method must be idempotent. If the transaction has already been committed or aborted, + * invoking it must have no effect. + *

+ * Spark calls {@link #close()} immediately after this method returns. + */ + void abort(); + + /** + * Releases any resources held by this transaction. + *

+ * Spark always calls this method after {@link #commit()} or {@link #abort()}, regardless of + * whether those methods succeed or not. + *

+ * This method must be idempotent. If the transaction has already been closed, + * invoking it must have no effect. + */ + @Override + void close(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java new file mode 100644 index 0000000000000..a9c17d4b88274 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions; + +/** + * Metadata about a transaction. + * + * @since 4.2.0 + */ +public interface TransactionInfo { + /** + * Returns a unique identifier for this transaction. + */ + String id(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 323a7db9c7ad7..333ac8817be07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -351,6 +351,31 @@ class Analyzer( } } + /** + * Returns a copy of this analyzer that uses the given [[CatalogManager]] for all catalog + * lookups. All other configuration (extended rules, checks, etc.) is preserved. Used by + * [[QueryExecution]] to create a per-query analyzer for transactional queries so that + * transaction-aware catalog resolution is an instance-level property rather than thread-local + * state. + */ + def withCatalogManager(newCatalogManager: CatalogManager): Analyzer = { + val self = this + new Analyzer(newCatalogManager, sharedRelationCache) { + override val hintResolutionRules: Seq[Rule[LogicalPlan]] = self.hintResolutionRules + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = self.extendedResolutionRules + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = self.postHocResolutionRules + override val extendedCheckRules: Seq[LogicalPlan => Unit] = self.extendedCheckRules + override val singlePassResolverExtensions: Seq[ResolverExtension] = + self.singlePassResolverExtensions + override val singlePassMetadataResolverExtensions: Seq[ResolverExtension] = + self.singlePassMetadataResolverExtensions + override val singlePassPostHocResolutionRules: Seq[Rule[LogicalPlan]] = + self.singlePassPostHocResolutionRules + override val singlePassExtendedResolutionChecks: Seq[LogicalPlan => Unit] = + self.singlePassExtendedResolutionChecks + } + } + override def execute(plan: LogicalPlan): LogicalPlan = { AnalysisContext.withNewAnalysisContext { executeSameContext(plan) @@ -458,7 +483,9 @@ class Analyzer( Batch("Simple Sanity Check", Once, LookupFunctions), Batch("Keep Legacy Outputs", Once, - KeepLegacyOutputs) + KeepLegacyOutputs), + Batch("Unresolve Relations", Once, + new UnresolveTransactionRelations(catalogManager)) ) override def batches: Seq[Batch] = earlyBatches ++ Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index 7a5077a8a3e11..8bcfb041d1a83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -478,6 +478,8 @@ class RelationResolution( } } + // TODO: how to validate the output is compatible? + // TODO: what shall we do if the output mismatches (schema changes?) def resolveReference(ref: V2TableReference): LogicalPlan = { val relation = getOrLoadRelation(ref) val planId = ref.getTagValue(LogicalPlan.PLAN_ID_TAG) @@ -485,6 +487,11 @@ class RelationResolution( } private def getOrLoadRelation(ref: V2TableReference): LogicalPlan = { + // Skip cache when a transaction is active. + if (catalogManager.transaction.isDefined) { + return loadRelation(ref) + } + val key = toCacheKey(ref.catalog, ref.identifier) relationCache.get(key) match { case Some(cached) => @@ -497,9 +504,13 @@ class RelationResolution( } private def loadRelation(ref: V2TableReference): LogicalPlan = { - val table = ref.catalog.loadTable(ref.identifier) + val resolvedCatalog = catalogManager.catalog(ref.catalog.name).asTableCatalog + val table = resolvedCatalog.loadTable(ref.identifier) + // val table = ref.catalog.loadTable(ref.identifier) V2TableReferenceUtils.validateLoadedTable(table, ref) - ref.toRelation(table) + // ref.toRelation(table) + DataSourceV2Relation( + table, ref.output, Some(resolvedCatalog), Some(ref.identifier), ref.options) } private def adaptCachedRelation(cached: LogicalPlan, ref: V2TableReference): LogicalPlan = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala new file mode 100644 index 0000000000000..4b175dd44ef08 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TransactionalWrite} +import org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.allowInvokingTransformsInAnalyzer +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +class UnresolveTransactionRelations(val catalogManager: CatalogManager) + extends Rule[LogicalPlan] with LookupCatalog { + + override def apply(plan: LogicalPlan): LogicalPlan = + catalogManager.transaction match { + case Some(transaction) => + allowInvokingTransformsInAnalyzer { + plan.transform { + case tw: TransactionalWrite => + unresolveRelations(tw, transaction.catalog) + } + } + case _ => plan + } + + private def unresolveRelations( + plan: LogicalPlan, + catalog: CatalogPlugin): LogicalPlan = { + plan transform { + case r: DataSourceV2Relation if isLoadedFromCatalog(r, catalog) => + V2TableReference.createForRelation(r, Seq.empty) + } + } + + private def isLoadedFromCatalog( + relation: DataSourceV2Relation, + catalog: CatalogPlugin): Boolean = { + // relation.catalog.exists(_ eq catalog) && relation.identifier.isDefined + relation.catalog.exists(_.name == catalog.name) && relation.identifier.isDefined + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index 85c36d452b309..a2379f33e14ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.V2TableReference.Context import org.apache.spark.sql.catalyst.analysis.V2TableReference.TableInfo import org.apache.spark.sql.catalyst.analysis.V2TableReference.TemporaryViewContext +import org.apache.spark.sql.catalyst.analysis.V2TableReference.TestContext import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics @@ -37,7 +38,7 @@ import org.apache.spark.sql.connector.catalog.V2TableUtil import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.sql.util.SchemaValidationMode.ALLOW_NEW_TOP_LEVEL_FIELDS +import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_TOP_LEVEL_FIELDS, PROHIBIT_CHANGES} import org.apache.spark.util.ArrayImplicits._ /** @@ -84,11 +85,19 @@ private[sql] object V2TableReference { sealed trait Context case class TemporaryViewContext(viewName: Seq[String]) extends Context + // TODO(achatzis): Fix naming and complete implementation. + case class TestContext(tableName: Seq[String]) extends Context def createForTempView(relation: DataSourceV2Relation, viewName: Seq[String]): V2TableReference = { create(relation, TemporaryViewContext(viewName)) } + def createForRelation( + relation: DataSourceV2Relation, + relationName: Seq[String]): V2TableReference = { + create(relation, TestContext(relationName)) + } + private def create(relation: DataSourceV2Relation, context: Context): V2TableReference = { val ref = V2TableReference( relation.catalog.get.asTableCatalog, @@ -110,11 +119,28 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { ref.context match { case ctx: TemporaryViewContext => validateLoadedTableInTempView(table, ref, ctx) + case _: TestContext => + validateLoadedTableInTransaction(table, ref) case ctx => throw SparkException.internalError(s"Unknown table ref context: ${ctx.getClass.getName}") } } + private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = { + val dataErrors = V2TableUtil.validateCapturedColumns( + table, + ref.info.columns, + mode = PROHIBIT_CHANGES) + if (dataErrors.nonEmpty) { + throw QueryCompilationErrors.columnsChangedAfterAnalysis(ref.name, dataErrors) + } + + val metaErrors = V2TableUtil.validateCapturedMetadataColumns(table, ref.info.metadataColumns) + if (metaErrors.nonEmpty) { + throw QueryCompilationErrors.metadataColumnsChangedAfterAnalysis(ref.name, metaErrors) + } + } + private def validateLoadedTableInTempView( table: Table, ref: V2TableReference, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index c38377582c156..fb54af2344d1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -188,7 +188,7 @@ case class InsertIntoStatement( byName: Boolean = false, replaceCriteriaOpt: Option[InsertReplaceCriteria] = None, withSchemaEvolution: Boolean = false) - extends UnaryParsedStatement { + extends UnaryParsedStatement with TransactionalWrite { require(overwrite || !ifPartitionNotExists, "IF NOT EXISTS is only valid in INSERT OVERWRITE") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 0eded2d9dbdf9..9d419ad668462 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -157,7 +157,7 @@ case class AppendData( isByName: Boolean, withSchemaEvolution: Boolean, write: Option[Write] = None, - analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand { + analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT) override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery) override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable) @@ -205,7 +205,7 @@ case class OverwriteByExpression( isByName: Boolean, withSchemaEvolution: Boolean, write: Option[Write] = None, - analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand { + analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE) override lazy val resolved: Boolean = { @@ -265,7 +265,7 @@ case class OverwritePartitionsDynamic( writeOptions: Map[String, String], isByName: Boolean, withSchemaEvolution: Boolean, - write: Option[Write] = None) extends V2WriteCommand { + write: Option[Write] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE) override def withNewQuery(newQuery: LogicalPlan): OverwritePartitionsDynamic = { @@ -956,7 +956,8 @@ object DescribeColumn { */ case class DeleteFromTable( table: LogicalPlan, - condition: Expression) extends UnaryCommand with SupportsSubquery { + condition: Expression) + extends UnaryCommand with TransactionalWrite with SupportsSubquery { override def child: LogicalPlan = table override protected def withNewChildInternal(newChild: LogicalPlan): DeleteFromTable = copy(table = newChild) @@ -978,7 +979,8 @@ case class DeleteFromTableWithFilters( case class UpdateTable( table: LogicalPlan, assignments: Seq[Assignment], - condition: Option[Expression]) extends UnaryCommand with SupportsSubquery { + condition: Option[Expression]) + extends UnaryCommand with TransactionalWrite with SupportsSubquery { lazy val aligned: Boolean = AssignmentUtils.aligned(table.output, assignments) @@ -1011,8 +1013,13 @@ case class MergeIntoTable( notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction], withSchemaEvolution: Boolean) - extends BinaryCommand with WriteWithSchemaEvolution with SupportsSubquery { + extends BinaryCommand + with WriteWithSchemaEvolution + with TransactionalWrite + with SupportsSubquery { + // Implements SupportsSchemaEvolution.table. + // Implements TransactionalWrite.table, identifying the MERGE target as the table being written. override val table: LogicalPlan = EliminateSubqueryAliases(targetTable) override def withNewTable(newTable: NamedRelation): MergeIntoTable = { @@ -1272,6 +1279,10 @@ case class Assignment(key: Expression, value: Expression) extends Expression newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight) } +trait TransactionalWrite extends LogicalPlan { + def table: LogicalPlan +} + /** * The logical plan of the DROP TABLE command. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala new file mode 100644 index 0000000000000..d160aafdea34e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.transactions + +import java.util.UUID + +import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfoImpl} +import org.apache.spark.util.Utils + +object TransactionUtils { + def commit(transaction: Transaction): Unit = { + Utils.tryWithSafeFinally { + transaction.commit() + } { + transaction.close() + } + } + + def abort(transaction: Transaction): Unit = { + Utils.tryWithSafeFinally { + transaction.abort() + } { + transaction.close() + } + } + + def beginTransaction(catalog: TransactionalCatalogPlugin): Transaction = { + val info = TransactionInfoImpl(id = UUID.randomUUID.toString) + val transaction = catalog.beginTransaction(info) + if (transaction.catalog.name != catalog.name) { + abort(transaction) + throw new IllegalStateException( + s"""Transaction catalog name (${transaction.catalog.name}) + |must match original catalog name (${catalog.name}). + |""".stripMargin) + } + transaction + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index 3f5afd9ce0de7..c851e931aad4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, TempVariableManager} import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -39,7 +40,7 @@ import org.apache.spark.sql.internal.SQLConf // need to track current database at all. private[sql] class CatalogManager( - defaultSessionCatalog: CatalogPlugin, + val defaultSessionCatalog: CatalogPlugin, val v1SessionCatalog: SessionCatalog) extends SQLConfHelper with Logging { import CatalogManager.SESSION_CATALOG_NAME import CatalogV2Util._ @@ -57,6 +58,11 @@ class CatalogManager( } } + def transaction: Option[Transaction] = None + + def withTransaction(transaction: Transaction): CatalogManager = + new TransactionAwareCatalogManager(this, transaction) + def isCatalogRegistered(name: String): Boolean = { try { catalog(name) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index 203cfc23452a8..fbb2938fd3da2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.TransactionalWrite import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -163,4 +165,15 @@ private[sql] trait LookupCatalog extends Logging { } } } + + object TransactionalWrite { + def unapply(write: TransactionalWrite): Option[TransactionalCatalogPlugin] = { + EliminateSubqueryAliases(write.table) match { + case UnresolvedRelation(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _, _) => + Some(c) + case _ => + None + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala new file mode 100644 index 0000000000000..9403219f596da --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import org.apache.spark.sql.connector.catalog.transactions.Transaction + +/** + * A [[CatalogManager]] decorator that redirects catalog lookups to the transaction's catalog + * instance when names match, ensuring table loads during analysis are scoped to the transaction. + * All mutable state (current catalog, current namespace, loaded catalogs) is delegated to the + * wrapped [[CatalogManager]]. + */ +// TODO: Consider extracting a CatalogManager trait that both the real +// implementation and the decorator implement +private[sql] class TransactionAwareCatalogManager( + delegate: CatalogManager, + txn: Transaction) + extends CatalogManager(delegate.defaultSessionCatalog, delegate.v1SessionCatalog) { + + override def transaction: Option[Transaction] = Some(txn) + + override def catalog(name: String): CatalogPlugin = { + val resolved = delegate.catalog(name) + if (txn.catalog.name() == resolved.name()) txn.catalog else resolved + } + + override def currentCatalog: CatalogPlugin = { + val c = delegate.currentCatalog + if (txn.catalog.name() == c.name()) txn.catalog else c + } + + override def currentNamespace: Array[String] = delegate.currentNamespace + + override def setCurrentNamespace(namespace: Array[String]): Unit = + delegate.setCurrentNamespace(namespace) + + override def setCurrentCatalog(catalogName: String): Unit = + delegate.setCurrentCatalog(catalogName) + + override def listCatalogs(pattern: Option[String]): Seq[String] = + delegate.listCatalogs(pattern) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala new file mode 100644 index 0000000000000..4cb53da0a59e2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions + +case class TransactionInfoImpl(id: String) extends TransactionInfo diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala new file mode 100644 index 0000000000000..d409316e667b1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.transactions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, TransactionalCatalogPlugin} +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class TransactionUtilsSuite extends SparkFunSuite { + val testCatalogName = "test_catalog" + + // --- Helpers --------------------------------------------------------------- + private def mockCatalog(catalogName: String): CatalogPlugin = new CatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = catalogName + } + + private val emptyFunction = () => () + private class TestTransaction( + catalogName: String, + onCommit: () => Unit = emptyFunction, + onAbort: () => Unit = emptyFunction, + onClose: () => Unit = emptyFunction) extends Transaction { + var committed = false + var aborted = false + var closed = false + + override def catalog(): CatalogPlugin = mockCatalog(catalogName) + override def commit(): Unit = { committed = true; onCommit() } + override def abort(): Unit = { aborted = true; onAbort() } + override def close(): Unit = { closed = true; onClose() } + } + + private def mockTransactionalCatalog( + catalogName: String, + txnCatalogName: String = null): TransactionalCatalogPlugin = { + val resolvedTxnCatalogName = Option(txnCatalogName).getOrElse(catalogName) + new TransactionalCatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = catalogName + override def beginTransaction(info: TransactionInfo): Transaction = + new TestTransaction(resolvedTxnCatalogName) + } + } + + // --- Commit ---------------------------------------------------------------- + test("commit: calls commit then close") { + val txn = new TestTransaction(testCatalogName) + TransactionUtils.commit(txn) + assert(txn.committed) + assert(txn.closed) + } + + test("commit: close is called even if commit fails") { + val txn = new TestTransaction( + testCatalogName, onCommit = () => throw new RuntimeException("commit failed")) + intercept[RuntimeException] { TransactionUtils.commit(txn) } + assert(txn.closed) + } + + // --- Abort ----------------------------------------------------------------- + test("abort: calls abort then close") { + val txn = new TestTransaction(testCatalogName) + TransactionUtils.abort(txn) + assert(txn.aborted) + assert(txn.closed) + } + + test("abort: close is called even if abort fails") { + val txn = new TestTransaction(testCatalogName, + onAbort = () => throw new RuntimeException("abort failed")) + intercept[RuntimeException] { TransactionUtils.abort(txn) } + assert(txn.closed) + } + + // --- Begin Transaction ----------------------------------------------------- + test("beginTransaction: returns transaction when catalog names match") { + val catalog = mockTransactionalCatalog(testCatalogName) + val txn = TransactionUtils.beginTransaction(catalog) + assert(txn.catalog().name() == testCatalogName) + } + + test("beginTransaction: fails when transaction catalog name does not match") { + val catalog = mockTransactionalCatalog(catalogName = testCatalogName, txnCatalogName = "other") + val e = intercept[IllegalStateException] { + TransactionUtils.beginTransaction(catalog) + } + assert(e.getMessage.contains("other")) + assert(e.getMessage.contains(testCatalogName)) + } + + test("beginTransaction: aborts and closes transaction on catalog name mismatch") { + var aborted = false + var closed = false + val catalog = new TransactionalCatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = testCatalogName + override def beginTransaction(info: TransactionInfo): Transaction = + new TestTransaction( + "other", + onAbort = () => { aborted = true }, + onClose = () => { closed = true }) + } + intercept[IllegalStateException] { TransactionUtils.beginTransaction(catalog) } + assert(aborted) + assert(closed) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index fd2c0f6e9c2ec..af0860664312c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -94,6 +94,8 @@ abstract class InMemoryBaseTable( validatedTableVersion = version } + protected def recordScanEvent(filters: Array[Filter]): Unit = {} + protected object PartitionKeyColumn extends MetadataColumn { override def name: String = "_partition" override def dataType: DataType = StringType @@ -455,6 +457,7 @@ abstract class InMemoryBaseTable( if (evaluableFilters.nonEmpty) { scan.filter(evaluableFilters) } + recordScanEvent(_pushedFilters) scan } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index bbb9041bab37c..27231447a1273 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -18,10 +18,23 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} -class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog { +class InMemoryRowLevelOperationTableCatalog + extends InMemoryTableCatalog with TransactionalCatalogPlugin { import CatalogV2Implicits._ + var transaction: Txn = _ + // Tracks the last completed transaction for test assertions; cleared when a new one begins. + var lastTransaction: Txn = _ + + override def beginTransaction(info: TransactionInfo): Transaction = { + assert(transaction == null || transaction.currentState != Active) + this.transaction = new Txn(new TxnTableCatalog(this)) + this.lastTransaction = transaction + transaction + } + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { if (tables.containsKey(ident)) { throw new TableAlreadyExistsException(ident.asMultipartIdentifier) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index d5738475031dc..2f3c65924d6a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ /** @@ -215,6 +216,16 @@ class InMemoryTable( object InMemoryTable { + // V1 filter values (from PredicateUtils.toV1) are Scala types (e.g. String), but partition + // keys stored in dataMap are Catalyst internal types (e.g. UTF8String). Normalize both sides + // before comparing so that string partitions work correctly. + private def valuesEqual(filterValue: Any, partitionValue: Any): Boolean = + (filterValue, partitionValue) match { + case (s: String, u: UTF8String) => u.toString == s + case (u: UTF8String, s: String) => u.toString == s + case _ => filterValue == partitionValue + } + def filtersToKeys( keys: Iterable[Seq[Any]], partitionNames: Seq[String], @@ -222,7 +233,7 @@ object InMemoryTable { keys.filter { partValues => filters.flatMap(splitAnd).forall { case EqualTo(attr, value) => - value == InMemoryBaseTable.extractValue(attr, partitionNames, partValues) + valuesEqual(value, InMemoryBaseTable.extractValue(attr, partitionNames, partValues)) case EqualNullSafe(attr, value) => val attrVal = InMemoryBaseTable.extractValue(attr, partitionNames, partValues) if (attrVal == null && value == null) { @@ -230,7 +241,7 @@ object InMemoryTable { } else if (attrVal == null || value == null) { false } else { - value == attrVal + valuesEqual(value, attrVal) } case IsNull(attr) => null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index ff7995ad6697e..c7195b512b8d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -59,6 +59,10 @@ class BasicInMemoryTableCatalog extends TableCatalog { tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray } + def loadTableAs[T <: Table](ident: Identifier): T = { + loadTable(ident).asInstanceOf[T] + } + // load table for scans override def loadTable(ident: Identifier): Table = { Option(tables.get(ident)) match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala new file mode 100644 index 0000000000000..4feb89c78f56c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.connector.catalog.transactions.Transaction +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +sealed trait TransactionState +case object Active extends TransactionState +case object Committed extends TransactionState +case object Aborted extends TransactionState + +class Txn(override val catalog: TxnTableCatalog) extends Transaction { + + private[this] var state: TransactionState = Active + private[this] var closed: Boolean = false + + def currentState: TransactionState = state + + def isClosed: Boolean = closed + + override def commit(): Unit = { + if (closed) throw new IllegalStateException("Can't commit, already closed") + catalog.commit() + this.state = Committed + } + + override def abort(): Unit = { + if (state == Committed || state == Aborted) return + // if (closed) throw new IllegalStateException("Can't abort, already closed") + this.state = Aborted + } + + override def close(): Unit = { + catalog.clearActiveTransaction() + this.closed = true + } +} + +// a special table used in row-level operation transactions +// it inherits data from the base table upon construction and +// propagates staged transaction state back after an explicit commit +class TxnTable(val delegate: InMemoryRowLevelOperationTable) + extends InMemoryRowLevelOperationTable( + delegate.name, + delegate.schema, + delegate.partitioning, + delegate.properties, + delegate.constraints) { + + // TODO(achatzis): Rethink how schema evolution works on top of transactions. + alterTableWithData(delegate.data, schema) + + // a tracker of filters used in each scan + // achatzis: Non-deterministic filters? + val scanEvents = new ArrayBuffer[Array[Filter]]() + + override protected def recordScanEvent(filters: Array[Filter]): Unit = { + scanEvents += filters + } + + def commit(): Unit = { + delegate.dataMap.clear() + // TODO(achatzis): Rethink how schema evolution works on top of transactions. + delegate.alterTableWithData(data, delegate.schema) + delegate.replacedPartitions = replacedPartitions + delegate.lastWriteInfo = lastWriteInfo + delegate.lastWriteLog = lastWriteLog + delegate.commits ++= commits + delegate.increaseVersion() + } +} + +// a special table catalog used in row-level operation transactions +// table changes are initially staged in memory and propagated only after an explicit commit +class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends TableCatalog { + + private val tables: util.Map[Identifier, TxnTable] = new ConcurrentHashMap[Identifier, TxnTable]() + + override def name: String = delegate.name + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} + + override def listTables(namespace: Array[String]): Array[Identifier] = { + throw new UnsupportedOperationException() + } + + override def loadTable(ident: Identifier): Table = { + tables.computeIfAbsent(ident, _ => { + val table = delegate.loadTableAs[InMemoryRowLevelOperationTable](ident) + new TxnTable(table) + }) + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + val newDelegateTable = delegate.alterTable(ident, changes: _*) + // Compute again if absent. + tables.remove(ident) + newDelegateTable + } + + override def dropTable(ident: Identifier): Boolean = { + throw new UnsupportedOperationException() + } + + override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { + throw new UnsupportedOperationException() + } + + def commit(): Unit = { + tables.values.forEach(table => table.commit()) + } + + def clearActiveTransaction(): Unit = { + delegate.transaction = null + } + + override def equals(obj: Any): Boolean = { + obj match { + case that: CatalogPlugin => this.name == that.name + case _ => false + } + } + + override def hashCode(): Int = name.hashCode() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala index 9a5aed333a4c2..5bd5d20692aed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala @@ -933,6 +933,23 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog with Logging { // caches referencing this relation. If this relation is cached as an InMemoryRelation, // this will clear the relation cache and caches of all its dependents. CommandUtils.recacheTableOrView(sparkSession, relation) + /* + EliminateSubqueryAliases(relation) match { + case r @ ExtractV2CatalogAndIdentifier(catalog, ident) if r.timeTravelSpec.isEmpty => + val nameParts = ident.toQualifiedNameParts(catalog) + sparkSession.sharedState.cacheManager.recacheTableOrView(sparkSession, nameParts) + case _ => + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation) + */ + /* + relation match { + case r: DataSourceV2Relation if r.catalog.isDefined && r.identifier.isDefined => + val nameParts = r.identifier.get.toQualifiedNameParts(r.catalog.get) + sparkSession.sharedState.cacheManager.recacheTableOrView(sparkSession, nameParts) + case _ => + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation) + } + */ } private def resolveRelation(tableName: String): LogicalPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 3f92f24156d3c..66f406d39f263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -260,6 +260,9 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { val nameInCache = v2Ident.toQualifiedNameParts(catalog) isSameName(name, nameInCache, resolver) && (includeTimeTravel || timeTravelSpec.isEmpty) + // case r: TableReference => + // isSameName(name, r.identifier.toQualifiedNameParts(r.catalog), resolver) + case v: View => isSameName(name, v.desc.identifier.nameParts, resolver) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index c0ab906de4841..9413354907678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -32,18 +32,21 @@ import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} -import org.apache.spark.sql.catalyst.analysis.{LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker} import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, UnresolvedWith, WithCTE} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.connector.catalog.LookupCatalog +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.execution.SQLExecution.EXECUTION_ROOT_ID_KEY import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan} import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} -import org.apache.spark.sql.execution.datasources.v2.V2TableRefreshUtil +import org.apache.spark.sql.execution.datasources.v2.{TransactionalExec, V2TableRefreshUtil} import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery @@ -69,7 +72,8 @@ class QueryExecution( val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL, val shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None, val refreshPhaseEnabled: Boolean = true, - val queryId: UUID = UUIDv7Generator.generate()) extends Logging { + val queryId: UUID = UUIDv7Generator.generate(), + val analyzerOpt: Option[Analyzer] = None) extends LookupCatalog { val id: Long = QueryExecution.nextExecutionId @@ -79,6 +83,8 @@ class QueryExecution( // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner + protected val catalogManager = sparkSession.sessionState.catalogManager + /** * Check whether the query represented by this QueryExecution is a SQL script. * @return True if the query is a SQL script, False otherwise. @@ -90,6 +96,46 @@ class QueryExecution( logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyExpression]))) } + + // 1. At the pre-Analyzed plan we look for nodes that implement the TransactionalWrite trait. + // When a plan contains such a node we initiate a transaction. Note, we should never start + // a transaction for operations that are not executed, e.g. EXPLAIN. + // 2. Create an analyzer clone with a transaction aware Catalog Manager. The latter is the single + // choke point of all catalog access, and it is also the transaction context carrier. + // This is then passed to all rules during analysis that need to check the catalog. Rules + // that are specifically interested in transactionality can access the transaction directly + // from the Catalog Manager. The transaction catalog, is potentially the place where connectors + // should keep state about the reads (tables+predicates) that occurred during the transaction. + // 3. The analyzer instance is passed to nested Query Execution instances. These need to respect + // the open transaction instead of creating their own. + private lazy val transactionOpt: Option[Transaction] = + // Always inherit an active transaction from the outer analyzer, regardless of mode. + analyzerOpt.flatMap(_.catalogManager.transaction).orElse { + // Only begin a new transaction for outer QEs that lead to execution. + if (mode != CommandExecutionMode.SKIP) { + val catalog = logical match { + case UnresolvedWith(TransactionalWrite(c), _, _) => Some(c) + case TransactionalWrite(c) => Some(c) + case _ => None + } + catalog.map(TransactionUtils.beginTransaction) + } else { + None + } + } + + // Per-query analyzer: uses a transaction-aware CatalogManager when a transaction is active, + // so that all catalog lookups and rule applications during analysis see the correct state + // without relying on thread-local context. + private lazy val analyzer: Analyzer = analyzerOpt.getOrElse { + transactionOpt match { + case Some(txn) => + sparkSession.sessionState.analyzer.withCatalogManager(catalogManager.withTransaction(txn)) + case None => + sparkSession.sessionState.analyzer + } + } + def assertAnalyzed(): Unit = { try { analyzed @@ -102,7 +148,7 @@ class QueryExecution( } } - def assertSupported(): Unit = { + def assertSupported(): Unit = executeWithTransactionContext { if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForBatch(analyzed) } @@ -141,7 +187,7 @@ class QueryExecution( try { val plan = executePhase(QueryPlanningTracker.ANALYSIS) { // We can't clone `logical` here, which will reset the `_analyzed` flag. - sparkSession.sessionState.analyzer.executeAndCheck(sqlScriptExecuted, tracker) + analyzer.executeAndCheck(sqlScriptExecuted, tracker) } tracker.setAnalyzed(plan) plan @@ -152,7 +198,9 @@ class QueryExecution( } } - def analyzed: LogicalPlan = lazyAnalyzed.get + def analyzed: LogicalPlan = executeWithTransactionContext { + lazyAnalyzed.get + } private val lazyCommandExecuted = LazyTry { mode match { @@ -162,7 +210,9 @@ class QueryExecution( } } - def commandExecuted: LogicalPlan = lazyCommandExecuted.get + def commandExecuted: LogicalPlan = executeWithTransactionContext { + lazyCommandExecuted.get + } private def commandExecutionName(command: Command): String = command match { case _: CreateTableAsSelect => "create" @@ -184,7 +234,8 @@ class QueryExecution( // for eagerly executed commands we mark this place as beginning of execution. tracker.setReadyForExecution() val (qe, result) = QueryExecution.runCommand( - sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode)) + sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode), + analyzerOpt = Some(analyzer)) CommandResult( qe.analyzed.output, qe.commandExecuted, @@ -222,7 +273,9 @@ class QueryExecution( } // The plan that has been normalized by custom rules, so that it's more likely to hit cache. - def normalized: LogicalPlan = lazyNormalized.get + def normalized: LogicalPlan = executeWithTransactionContext { + lazyNormalized.get + } private val lazyWithCachedData = LazyTry { sparkSession.withActive { @@ -230,11 +283,19 @@ class QueryExecution( assertSupported() // clone the plan to avoid sharing the plan instance between different stages like analyzing, // optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + val plan = normalized.clone() + // During a transaction, skip cache substitution. useCachedData replaces DataSourceV2Relation + // nodes (loaded via the transaction catalog) with InMemoryRelation, which bypasses read + // tracking in the transaction catalog and may serve stale data. + // if (transactionOpt.isDefined) plan + // else sparkSession.sharedState.cacheManager.useCachedData(plan) + sparkSession.sharedState.cacheManager.useCachedData(plan) } } - def withCachedData: LogicalPlan = lazyWithCachedData.get + def withCachedData: LogicalPlan = executeWithTransactionContext { + lazyWithCachedData.get + } def assertCommandExecuted(): Unit = commandExecuted @@ -256,7 +317,9 @@ class QueryExecution( } } - def optimizedPlan: LogicalPlan = lazyOptimizedPlan.get + def optimizedPlan: LogicalPlan = executeWithTransactionContext { + lazyOptimizedPlan.get + } def assertOptimized(): Unit = optimizedPlan @@ -264,14 +327,21 @@ class QueryExecution( // We need to materialize the optimizedPlan here because sparkPlan is also tracked under // the planning phase assertOptimized() - executePhase(QueryPlanningTracker.PLANNING) { + val plan = executePhase(QueryPlanningTracker.PLANNING) { // Clone the logical plan here, in case the planner rules change the states of the logical // plan. QueryExecution.createSparkPlan(planner, optimizedPlan.clone()) } + transactionOpt match { + case Some(txn) => + plan.transformDown { case w: TransactionalExec => w.withTransaction(Some(txn)) } + case None => plan + } } - def sparkPlan: SparkPlan = lazySparkPlan.get + def sparkPlan: SparkPlan = executeWithTransactionContext { + lazySparkPlan.get + } def assertSparkPlanPrepared(): Unit = sparkPlan @@ -292,7 +362,9 @@ class QueryExecution( // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - def executedPlan: SparkPlan = lazyExecutedPlan.get + def executedPlan: SparkPlan = executeWithTransactionContext { + lazyExecutedPlan.get + } def assertExecutedPlanPrepared(): Unit = executedPlan @@ -310,7 +382,9 @@ class QueryExecution( * Given QueryExecution is not a public class, end users are discouraged to use this: please * use `Dataset.rdd` instead where conversion will be applied. */ - def toRdd: RDD[InternalRow] = lazyToRdd.get + def toRdd: RDD[InternalRow] = executeWithTransactionContext { + lazyToRdd.get + } private val observedMetricsLock = new Object @@ -535,6 +609,23 @@ class QueryExecution( } } + /** + * Execute the given block with the transaction context if exists. If there is an exception thrown + * during the execution, the transaction will be aborted. + * + * Note 1: The transaction is not committed in this method. The caller should commit the + * transaction if the execution is successful. + * + * Note 2: In some cases, post commit execution might generate an exception. The abort operation + * should be no-op in this case. + */ + private def executeWithTransactionContext[T](block: => T): T = transactionOpt match { + case Some(transaction) => + try block + catch { case e: Throwable => TransactionUtils.abort(transaction); throw e } + case None => block + } + /** A special namespace for commands that can be used to debug query execution. */ // scalastyle:off object debug { @@ -819,14 +910,16 @@ object QueryExecution { name: String, refreshPhaseEnabled: Boolean = true, mode: CommandExecutionMode.Value = CommandExecutionMode.SKIP, - shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None) + shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None, + analyzerOpt: Option[Analyzer] = None) : (QueryExecution, Array[InternalRow]) = { val qe = new QueryExecution( sparkSession, command, mode = mode, shuffleCleanupModeOpt = shuffleCleanupModeOpt, - refreshPhaseEnabled = refreshPhaseEnabled) + refreshPhaseEnabled = refreshPhaseEnabled, + analyzerOpt = analyzerOpt) val result = QueryExecution.withInternalError(s"Executed $name failed.") { SQLExecution.withNewExecutionId(qe, Some(name)) { qe.executedPlan.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index e03928867e24d..c73f4ad9eded9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -63,9 +63,6 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat private def hadoopConf = session.sessionState.newHadoopConf() - // recaches all cache entries without time travel for the given table - // after a write operation that moves the state of the table forward (e.g. append, overwrite) - // this method accounts for V2 tables loaded via TableProvider (no catalog/identifier) private def refreshCache(r: DataSourceV2Relation)(): Unit = r match { case ExtractV2CatalogAndIdentifier(catalog, ident) => val nameParts = ident.toQualifiedNameParts(catalog) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala index 8d5ee6038e80f..c6b1bae89b156 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala @@ -19,16 +19,23 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.connector.catalog.SupportsDeleteV2 +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.connector.expressions.filter.Predicate case class DeleteFromTableExec( table: SupportsDeleteV2, condition: Array[Predicate], - refreshCache: () => Unit) extends LeafV2CommandExec { + refreshCache: () => Unit, + transaction: Option[Transaction] = None) extends LeafV2CommandExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): DeleteFromTableExec = + copy(transaction = txn) override protected def run(): Seq[InternalRow] = { table.deleteWhere(condition) + transaction.foreach(TransactionUtils.commit) refreshCache() Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index d8e871bcf4824..0249e5b49c9bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -96,7 +96,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { options, relationOpt.map(r => r.options.asCaseSensitiveMap.asScala.toMap).getOrElse(Map.empty)) val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId = queryId) - val write = buildWriteForMicroBatch(table, writeBuilder, outputMode) + val write = buildWriteForMicroBatch(tableDataSourceV2Strategy, writeBuilder, outputMode) val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming) val customMetrics = write.supportedCustomMetrics.toImmutableArraySeq val funCatalogOpt = relationOpt.flatMap(_.funCatalog) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 2071024c5b7e5..e2851c3187f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,9 +26,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, ProjectingInternalRow} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, TableSpec, UnaryNode} +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{COPY_OPERATION, DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege} +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeleteSummaryImpl, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperation, RowLevelOperationTable, UpdateSummaryImpl, Write, WriterCommitMessage, WriteSummary} @@ -273,7 +275,9 @@ case class AppendDataExec( query: SparkPlan, refreshCache: () => Unit, write: Write, - tableName: String) extends V2ExistingTableWriteExec { + tableName: String, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): AppendDataExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): AppendDataExec = copy(query = newChild) } @@ -292,7 +296,10 @@ case class OverwriteByExpressionExec( query: SparkPlan, refreshCache: () => Unit, write: Write, - tableName: String) extends V2ExistingTableWriteExec { + tableName: String, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): OverwriteByExpressionExec = + copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): OverwriteByExpressionExec = copy(query = newChild) } @@ -310,7 +317,10 @@ case class OverwritePartitionsDynamicExec( query: SparkPlan, refreshCache: () => Unit, write: Write, - tableName: String) extends V2ExistingTableWriteExec { + tableName: String, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): OverwritePartitionsDynamicExec = + copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): OverwritePartitionsDynamicExec = copy(query = newChild) } @@ -324,7 +334,8 @@ case class ReplaceDataExec( projections: ReplaceDataProjections, write: Write, rowLevelCommand: RowLevelOperation.Command, - tableName: String) extends RowLevelWriteExec { + tableName: String, + transaction: Option[Transaction] = None) extends RowLevelWriteExec { override def writingTask: WritingSparkTask[_] = { projections.metadataProjection match { @@ -335,6 +346,7 @@ case class ReplaceDataExec( } } + override def withTransaction(txn: Option[Transaction]): ReplaceDataExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = { copy(query = newChild) } @@ -369,7 +381,8 @@ case class WriteDeltaExec( projections: WriteDeltaProjections, write: DeltaWrite, rowLevelCommand: RowLevelOperation.Command, - tableName: String) extends RowLevelWriteExec { + tableName: String, + transaction: Option[Transaction] = None) extends RowLevelWriteExec { override lazy val writingTask: WritingSparkTask[_] = { if (projections.metadataProjection.isDefined) { @@ -379,6 +392,7 @@ case class WriteDeltaExec( } } + override def withTransaction(txn: Option[Transaction]): WriteDeltaExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): WriteDeltaExec = { copy(query = newChild) } @@ -406,7 +420,16 @@ case class WriteToDataSourceV2Exec( copy(query = newChild) } -trait V2ExistingTableWriteExec extends V2TableWriteExec { +/** + * Trait for physical plan nodes that write to an existing table as part of a transaction. + * The [[transaction]] is injected post-planning by [[QueryExecution]]. + */ +trait TransactionalExec extends SparkPlan { + def transaction: Option[Transaction] + def withTransaction(txn: Option[Transaction]): SparkPlan +} + +trait V2ExistingTableWriteExec extends V2TableWriteExec with TransactionalExec { def refreshCache: () => Unit def write: Write def tableName: String @@ -426,6 +449,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { } finally { postDriverMetrics() } + transaction.foreach(TransactionUtils.commit) refreshCache() writtenRows } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala new file mode 100644 index 0000000000000..1c8e7fc5a0fd4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.Committed +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf + +class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { + + test("writeTo append with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // create a source on top of itself that will be fully resolved and analyzed + val sourceDF = spark.table(tableNameAsString) + .where("pk == 1") + .select(col("pk") + 10 as "pk", col("salary"), col("dep")) + sourceDF.queryExecution.assertAnalyzed() + + // append data using the DataFrame API + val (txn, txnTables) = executeTransaction { + sourceDF.writeTo(tableNameAsString).append() + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + // check the source scan was tracked via the transaction catalog + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size >= 1) + + // check data was appended correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 100, "hr"))) // appended + } + + test("SQL INSERT INTO with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // SQL INSERT INTO using VALUES + val (txn, _) = executeTransaction { + sql(s"INSERT INTO $tableNameAsString VALUES (3, 300, 'hr'), (4, 400, 'finance')") + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + + // check data was inserted correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(4, 400, "finance"))) + } + + test("SQL INSERT OVERWRITE with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // INSERT OVERWRITE with static partition predicate -> OverwriteByExpression + val (txn, _) = executeTransaction { + sql(s"""INSERT OVERWRITE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT pk + 10, salary FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(11, 100, "hr"), // overwritten + Row(13, 300, "hr"))) // overwritten + } + + test("SQL INSERT OVERWRITE dynamic partition with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // INSERT OVERWRITE with dynamic partitioning -> OverwritePartitionsDynamic + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> "dynamic") { + val (txn, _) = executeTransaction { + sql(s"""INSERT OVERWRITE $tableNameAsString + |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged (different partition) + Row(11, 100, "hr"), // overwrote hr partition + Row(13, 300, "hr"))) // overwrote hr partition + } + } + + test("writeTo overwrite with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // overwrite using a condition that covers the hr partition -> OverwriteByExpression + val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). + toDF("pk", "salary", "dep") + + val (txn, _) = executeTransaction { + sourceDF.writeTo(tableNameAsString).overwrite(col("dep") === "hr") + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged (different partition) + Row(11, 999, "hr"), // overwrote hr partition + Row(12, 888, "hr"))) // overwrote hr partition + } + + test("writeTo overwritePartitions with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // overwrite partitions dynamically -> OverwritePartitionsDynamic + val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). + toDF("pk", "salary", "dep") + + val (txn, _) = executeTransaction { + sourceDF.writeTo(tableNameAsString).overwritePartitions() + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged (different partition) + Row(11, 999, "hr"), // overwrote hr partition + Row(12, 888, "hr"))) // overwrote hr partition + } + + test("SQL INSERT INTO SELECT with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // SQL INSERT INTO using SELECT from the same table (self-insert) + val (txn, txnTables) = executeTransaction { + sql(s"""INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + // check data was inserted correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(11, 100, "hr"), // inserted from pk=1 + Row(13, 300, "hr"))) // inserted from pk=3 + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index adc88f5a54a07..1e18928c10856 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.connector -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.CheckInvariant import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} import org.apache.spark.sql.connector.catalog.InMemoryTable import org.apache.spark.sql.connector.write.DeleteSummary import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec, WriteDeltaExec} +import org.apache.spark.sql.sources abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { @@ -773,6 +775,186 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { } } + test("delete with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val exception = intercept[AnalysisException] { + sql(s"DELETE FROM $tableNameAsString WHERE invalid_column = 1") + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("delete with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // simple predicate delete: goes through SupportsDelete.deleteWhere (no Spark-side scan) + val (txn, _) = executeTransaction { + sql(s"DELETE FROM $tableNameAsString WHERE dep = 'hr'") + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(2, 200, "software"))) + } + + test("delete with subquery on source table and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + val (txn, txnTables) = executeTransaction { + sql( + s"""DELETE FROM $tableNameAsString + |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numSubquerySourceScans == expectedNumSourceScans) + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result) + } + + test("delete with CTE and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk FROM $sourceNameAsString WHERE dep = 'hr' + |) + |DELETE FROM $tableNameAsString + |WHERE pk IN (SELECT pk FROM cte) + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in source) + } + + test("delete using view with transactional checks") { + withView("temp_view") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + sql( + s"""CREATE VIEW temp_view AS + |SELECT pk FROM $sourceNameAsString WHERE dep = 'hr' + |""".stripMargin) + + val (txn, txnTables) = executeTransaction { + sql(s"DELETE FROM $tableNameAsString WHERE pk IN (SELECT pk FROM temp_view)") + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in source) + } + } + + test("EXPLAIN DELETE SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"EXPLAIN DELETE FROM $tableNameAsString WHERE dep = 'hr'") + + // EXPLAIN should not start a new transaction + assert(catalog.transaction === null) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } + private def executeDeleteWithFilters(query: String): Unit = { val executedPlan = executeAndKeepPlan { sql(query) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala index 15d259d44a4fd..9b630b25f658e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala @@ -35,6 +35,8 @@ class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { override def enforceCheckConstraintOnDelete: Boolean = false + override protected def deltaDelete: Boolean = true + test("delete handles metadata columns correctly") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala index e1c574ec7ba65..687aae91438da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.{sources, Column, Row} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.Row import org.apache.spark.sql.classic.MergeIntoWriter -import org.apache.spark.sql.connector.catalog.Column +import org.apache.spark.sql.connector.catalog.Committed import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableInfo import org.apache.spark.sql.functions._ @@ -31,6 +31,71 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { import testImplicits._ + private def targetTableCol(colName: String): Column = { + col(tableNameAsString + "." + colName) + } + + test("self merge with transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create a source on top of itself that will be fully resolved and analyzed + val sourceDF = spark.table(tableNameAsString) + .where("salary == 100") + .as("source") + sourceDF.queryExecution.assertAnalyzed() + + // merge into table using the source on top of itself + val (txn, txnTables) = executeTransaction { + sourceDF + .mergeInto( + tableNameAsString, + $"source.pk" === targetTableCol("pk") && targetTableCol("dep") === "hr") + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .whenNotMatched() + .insertAll() + .merge() + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + // check all table scans + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 4) + + // check table scans as MERGE target + val numTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numTargetScans == 2) + + // check table scans as MERGE source + val numSourceScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("salary", 100) => true + case _ => false + } + assert(numSourceScans == 2) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged + + // TODO Achatzis check version. + } + test("merge into empty table with NOT MATCHED clause") { withTempView("source") { createTable("pk INT NOT NULL, salary INT, dep STRING") @@ -979,6 +1044,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { } test("SPARK-54157: version is refreshed when source is V2 table") { + import org.apache.spark.sql.connector.catalog.Column val sourceTable = "cat.ns1.source_table" withTable(sourceTable) { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -1026,6 +1092,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { } test("SPARK-54444: any schema changes after analysis are prohibited") { + import org.apache.spark.sql.connector.catalog.Column val sourceTable = "cat.ns1.source_table" withTable(sourceTable) { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 069781e40d8c2..e14d0a2571bb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, In, Not} import org.apache.spark.sql.catalyst.optimizer.BuildLeft -import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, InMemoryTable, TableInfo} +import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, TableInfo} import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.write.MergeSummary import org.apache.spark.sql.execution.SparkPlan @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.MergeRowsExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources import org.apache.spark.sql.types.{IntegerType, StringType} abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase @@ -38,6 +39,305 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase protected def deltaMerge: Boolean = false + test("self merge with transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // merge into table using a subquery on top of itself + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING (SELECT * FROM $tableNameAsString WHERE salary = 100) s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + // check all table scans + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumScans) + + // check table scans as MERGE target + val numTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumTargetScans = if (deltaMerge) 1 else 2 + assert(numTargetScans == expectedNumTargetScans) + + // check table scans as MERGE source + val numSourceScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("salary", 100) => true + case _ => false + } + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(numSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged + } + + test("merge into table with analysis failure and transactional checks") { + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')") + + val exception = intercept[AnalysisException] { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = s.invalid_column + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("merge into table using view with transactional checks") { + withView("temp_view") { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create view on top of source and target tables + sql( + s"""CREATE VIEW temp_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // merge into target table using view + val (txn, txnTables) = executeTransaction { + sql(s"""MERGE INTO $tableNameAsString t + |USING temp_view s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2 + assert(numMergeTargetScans == expectedNumMergeTargetScans) + + // check target table scans in view as MERGE source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaMerge) 1 else 2 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view as MERGE source (no predicate) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 400, "pending"))) // new + } + } + + test("merge into table using nested view with transactional checks") { + withView("base_view", "nested_view") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create base view + sql( + s"""CREATE VIEW base_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // create nested view on top of base view + sql( + s"""CREATE VIEW nested_view AS + |SELECT * FROM base_view + |""".stripMargin) + + // merge into target table using nested view + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING nested_view s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2 + assert(numMergeTargetScans == expectedNumMergeTargetScans) + + // check target table scans in view as MERGE source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaMerge) 1 else 2 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view as MERGE source (no predicate) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 400, "pending"))) // new + } + } + } + + test("merge into table rewritten as INSERT with transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT, value STRING, dep STRING", + """{ "pk": 1, "value": "a", "dep": "hr" } + |{ "pk": 2, "value": "b", "dep": "finance" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT, value STRING, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (3, 'c', 'hr'), (4, 'd', 'software')") + + // merge into target with only WHEN NOT MATCHED clauses (rewritten as insert) + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk + |WHEN NOT MATCHED AND s.pk < 4 THEN + | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_low'), s.dep) + |WHEN NOT MATCHED AND s.pk >= 4 THEN + | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_high'), s.dep) + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 1) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size == 1) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, "a", "hr"), // unchanged + Row(2, "b", "finance"), // unchanged + Row(3, "c_low", "hr"), // inserted via first NOT MATCHED clause + Row(4, "d_high", "software"))) // inserted via second NOT MATCHED clause + } + } + test("merge into table with expression-based default values") { val columns = Array( Column.create("pk", IntegerType), @@ -770,6 +1070,129 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("merge with CTE with transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // merge into target table using CTE + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk, salary + 50 AS salary, dep + | FROM $sourceNameAsString + | WHERE salary > 100 + |) + |MERGE INTO $tableNameAsString t + |USING cte s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 1 else 2 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numMergeTargetScans == expectedNumTargetScans) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check source table scans in CTE (salary > 100) + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.GreaterThan("salary", 100) => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 200, "hr"), // updated (150 + 50) + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 450, "pending"))) // inserted (400 + 50) + } + } + + test("merge with cached source and transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create and populate source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')") + + // Cache source table before the transaction. Make sure when the transation is active the + // catalog still creates a transaction table. + spark.table(sourceNameAsString).cache() + + try { + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + // both target and source must have been read through the transaction catalog + assert(txnTables.size == 2) + assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + assert(txnTables(tableNameAsString).scanEvents.nonEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "support"), // matched and updated + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged (no match in source) + Row(4, 400, "pending"))) // not matched, inserted + } finally { + spark.catalog.uncacheTable(sourceNameAsString) + } + } + } + test("merge with subquery as source") { withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -2322,6 +2745,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(query) } assert(e.getMessage.contains("ON search condition of the MERGE statement")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) } private def assertMetric( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 79387821bf087..eb449cdaa449f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -28,16 +28,16 @@ import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expr import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReplaceData, WriteDelta} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, TableInfo, Update, Write} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, Table, TableInfo, Txn, TxnTable, Update, Write} import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.write.RowLevelOperationTable import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, StructField, StructType} -import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -82,6 +82,7 @@ abstract class RowLevelOperationSuiteBase protected val namespace: Array[String] = Array("ns1") protected val ident: Identifier = Identifier.of(namespace, "test_table") protected val tableNameAsString: String = "cat." + ident.toString + protected val sourceNameAsString: String = "cat.ns1.source_table" protected def extraTableProps: java.util.Map[String, String] = { Collections.emptyMap[String, String] @@ -133,24 +134,36 @@ abstract class RowLevelOperationSuiteBase } } - // executes an operation and keeps the executed plan - protected def executeAndKeepPlan(func: => Unit): SparkPlan = { - var executedPlan: SparkPlan = null - - val listener = new QueryExecutionListener { - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - executedPlan = qe.executedPlan - } - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { - } + protected def executeTransaction(func: => Unit): (Txn, Map[String, TxnTable]) = { + val qe = execute(func) + val tables = collectWithSubqueries(qe.executedPlan) { + case BatchScanExec(_, _, _, _, table: TxnTable, _) => + table + case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => + table } - spark.listenerManager.register(listener) + (catalog.lastTransaction, indexByName(tables)) + } - func + private def indexByName[T <: Table](tables: Seq[T]): Map[String, T] = { + tables.groupBy(_.name).map { + case (name, sameNameTables) => + val Seq(table) = sameNameTables.distinct + name -> table + } + } - sparkContext.listenerBus.waitUntilEmpty() + // executes an operation and keeps the executed plan + protected def executeAndKeepPlan(func: => Unit): SparkPlan = { + val qe = execute(func) + stripAQEPlan(qe.executedPlan) + } - stripAQEPlan(executedPlan) + private def execute(func: => Unit): QueryExecution = { + withQueryExecutionsCaptured(spark)(func) match { + case Seq(qe) => qe + case other => fail(s"expected only one query execution, but got ${other.size}") + } } // executes an operation and extracts conditions from ReplaceData or WriteDelta diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index d32a1e5c7f561..34cc6efb6db9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.connector import org.apache.spark.SparkRuntimeException -import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, InMemoryTable, TableChange, TableInfo} +import org.apache.spark.sql.{sources, AnalysisException, Row} +import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, TableChange, TableInfo} import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.write.UpdateSummary import org.apache.spark.sql.internal.SQLConf @@ -867,4 +867,321 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { Row(5))) } } + + test("update with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val exception = intercept[AnalysisException] { + sql(s"UPDATE $tableNameAsString SET invalid_column = -1") + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("update with CTE and transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using CTE + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk, salary + 50 AS adjusted_salary, dep + | FROM $sourceNameAsString + | WHERE salary > 100 + |) + |UPDATE $tableNameAsString t + |SET salary = -1 + |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM cte WHERE cte.pk = t.pk) + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans for UPDATE condition (dep = 'hr') + val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numUpdateTargetScans == expectedNumTargetScans) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check source table scans in CTE (salary > 100) + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.GreaterThan("salary", 100) => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (no matching pk in source) + } + + test("update with subquery on source table and transactional checks") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using an uncorrelated IN subquery that reads from a transactional catalog table + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET salary = -1 + |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check source table was scanned correctly (dep = 'hr' filter in the subquery) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numSubquerySourceScans == expectedNumSourceScans) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated (pk 1 is in subquery result) + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result) + } + + test("update with uncorrelated scalar subquery on source table and transactional checks") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using an uncorrelated scalar subquery in the SET clause that reads from a + // transactional catalog table; scalar subqueries are executed as SubqueryExec at runtime + // and cannot be rewritten as joins + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET salary = (SELECT max(salary) FROM $sourceNameAsString WHERE dep = 'hr') + |WHERE dep = 'hr' + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check source table was scanned via the transaction catalog + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.nonEmpty) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // check target table was scanned via the transaction catalog + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.nonEmpty) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // updated (max salary in source for 'hr' is 150) + Row(2, 200, "software"), // unchanged + Row(3, 150, "hr"))) // updated + } + + test("update with constraint violation and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val exception = intercept[SparkRuntimeException] { + executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET pk = NULL + |WHERE dep = 'hr' + |""".stripMargin) // NULL violates NOT NULL constraint + } + } + + assert(exception.getMessage.contains("NOT_NULL_ASSERT_VIOLATION")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("update using view with transactional checks") { + withView("temp_view") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create view on top of source and target tables + sql( + s"""CREATE VIEW temp_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // update target table using view + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString t + |SET salary = -1 + |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM temp_view v WHERE v.pk = t.pk) + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 2 else 7 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as UPDATE target (dep = 'hr') + val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumUpdateTargetScans = if (deltaUpdate) 1 else 3 + assert(numUpdateTargetScans == expectedNumUpdateTargetScans) + + // check target table scans in view as source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaUpdate) 1 else 4 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated from view + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (no matching pk in source) + } + } + + test("df.explain() on update with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // NOTE: df.explain() on a DML command actually executes the write. + // TODO(achatzis): This is existing behavior but check why this is OK. Shouldn't sql() be lazy? + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'").explain() + + assert(catalog.lastTransaction != null) + assert(catalog.lastTransaction.currentState == Committed) + assert(catalog.lastTransaction.isClosed) + + // the UPDATE was actually executed, not just planned + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated + Row(2, 200, "software"))) // unchanged + } + + test("EXPLAIN UPDATE SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // EXPLAIN UPDATE only plans the command, it does not execute the write. + sql(s"EXPLAIN UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + // A transaction should not have started at all. + assert(catalog.transaction === null) + + // The UPDATE was not executed. Data is unchanged. + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala new file mode 100644 index 0000000000000..141d5966b4b6c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.concurrent.duration._ + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.classic +import org.apache.spark.sql.execution.QueryExecution + +/** + * Benchmark to measure the overhead of cloning the analyzer for transactional query execution. + * Each transactional query creates a new [[Analyzer]] instance via + * [[Analyzer.withCatalogManager]], which shares all rules with the original but carries a + * per-query [[org.apache.spark.sql.connector.catalog.CatalogManager]]. This benchmark checks + * whether the cloning introduces measurable overhead. + * + * To run this benchmark: + * {{{ + * build/sbt "sql/Test/runMain " + * }}} + */ +object AnalyzerBenchmark extends SqlBasedBenchmark { + + private val numRows = 100 + private val queries = Seq( + "simple select" -> "SELECT id, val FROM t1", + "join" -> "SELECT t1.id, t2.val FROM t1 JOIN t2 ON t1.id = t2.id", + "wide schema" -> s"SELECT ${(1 to 100).map(i => s"col_$i").mkString(", ")} FROM wide_t" + ) + + private def setupTables(): Unit = { + spark.range(numRows).selectExpr("id", "id * 2 as val").createOrReplaceTempView("t1") + spark.range(numRows).selectExpr("id", "id * 3 as val").createOrReplaceTempView("t2") + spark.range(numRows) + .selectExpr((1 to numRows).map(i => s"id as col_$i"): _*) + .createOrReplaceTempView("wide_t") + } + + /** + * Measures analysis time for a pre-parsed plan, comparing the session analyzer against a + * cloned analyzer created via [[Analyzer.withCatalogManager]]. + * + * Two cases: + * - "session analyzer" : baseline, uses the session analyzer directly. + * - "cloned analyzer (per query)": analyzer cloned every iteration; reflects the full + * per-transactional-query cost (clone + analysis). + */ + def analysisBenchmark(): Unit = { + for ((name, sql) <- queries) { + runBenchmark(s"analysis overhead $name") { + val benchmark = new Benchmark( + name = s"analysis overhead $name", + // Per row measurements are not meaningful here. + valuesPerIteration = numRows, + minTime = 10.seconds, + output = output) + val catalogManager = spark.sessionState.catalogManager + + benchmark.addCase("session analyzer") { _ => + val plan = spark.sessionState.sqlParser.parsePlan(sql) + new QueryExecution(spark.asInstanceOf[classic.SparkSession], plan).analyzed + } + + benchmark.addCase("cloned analyzer (per query)") { _ => + val cloned = spark.sessionState.analyzer.withCatalogManager(catalogManager) + val plan = spark.sessionState.sqlParser.parsePlan(sql) + new QueryExecution(spark.asInstanceOf[classic.SparkSession], + plan, analyzerOpt = Some(cloned)).analyzed + } + + benchmark.run() + } + } + } + + /** + * Micro-benchmark for [[Analyzer.withCatalogManager]] in isolation: measures the cost of + * instantiating the anonymous [[Analyzer]] subclass, independent of analysis work. + */ + def cloneCostBenchmark(): Unit = { + runBenchmark("analyzer clone cost") { + val numRows = 1 // Per row measurements are not meaningful here. + val benchmark = new Benchmark( + name = "analyzer clone cost", + valuesPerIteration = numRows, + output = output) + val catalogManager = spark.sessionState.catalogManager + + benchmark.addCase("withCatalogManager") { _ => + spark.sessionState.analyzer.withCatalogManager(catalogManager) + } + + benchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + setupTables() + cloneCostBenchmark() + analysisBenchmark() + } +} From f44c3c0f21371ee445b3f0c1420392455fc2fe7b Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Thu, 2 Apr 2026 13:52:01 +0000 Subject: [PATCH 02/33] Fix delete failures --- .../sql/connector/DeleteFromTableSuiteBase.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index 1e18928c10856..0bca12b315515 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -832,7 +832,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txnTables.size == 2) val sourceTxnTable = txnTables(sourceNameAsString) - val expectedNumSourceScans = if (deltaDelete) 1 else 4 + val expectedNumSourceScans = if (deltaDelete) 1 else 2 assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count { @@ -842,7 +842,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(numSubquerySourceScans == expectedNumSourceScans) val targetTxnTable = txnTables(tableNameAsString) - val expectedNumTargetScans = if (deltaDelete) 1 else 3 + val expectedNumTargetScans = if (deltaDelete) 1 else 2 assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) checkAnswer( @@ -877,11 +877,11 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txnTables.size == 2) val targetTxnTable = txnTables(tableNameAsString) - val expectedNumTargetScans = if (deltaDelete) 1 else 3 + val expectedNumTargetScans = if (deltaDelete) 1 else 2 assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) val sourceTxnTable = txnTables(sourceNameAsString) - val expectedNumSourceScans = if (deltaDelete) 1 else 4 + val expectedNumSourceScans = if (deltaDelete) 1 else 2 assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { @@ -922,11 +922,11 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txnTables.size == 2) val targetTxnTable = txnTables(tableNameAsString) - val expectedNumTargetScans = if (deltaDelete) 1 else 3 + val expectedNumTargetScans = if (deltaDelete) 1 else 2 assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) val sourceTxnTable = txnTables(sourceNameAsString) - val expectedNumSourceScans = if (deltaDelete) 1 else 4 + val expectedNumSourceScans = if (deltaDelete) 1 else 2 assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) checkAnswer( From 934f20f39f4b5bb2300dbf15a75daf29265e77a9 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Thu, 9 Apr 2026 19:05:49 +0000 Subject: [PATCH 03/33] Cleaning pass 1 --- .../sql/catalyst/analysis/Analyzer.scala | 12 ++++---- .../analysis/RelationResolution.scala | 11 +++++-- .../UnresolveTransactionRelations.scala | 12 ++++++-- .../catalyst/analysis/V2TableReference.scala | 18 +++++------- .../catalyst/plans/logical/statements.scala | 4 +++ .../catalyst/plans/logical/v2Commands.scala | 20 +++++++++---- .../sql/connector/catalog/LookupCatalog.scala | 4 ++- ...nMemoryRowLevelOperationTableCatalog.scala | 7 +++-- .../sql/connector/catalog/InMemoryTable.scala | 5 ++-- .../catalog/InMemoryTableCatalog.scala | 4 --- .../spark/sql/connector/catalog/txns.scala | 29 +++++++++++-------- .../apache/spark/sql/classic/Catalog.scala | 17 ----------- .../spark/sql/execution/CacheManager.scala | 3 -- 13 files changed, 76 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 333ac8817be07..fdd7d09356d8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -354,9 +354,8 @@ class Analyzer( /** * Returns a copy of this analyzer that uses the given [[CatalogManager]] for all catalog * lookups. All other configuration (extended rules, checks, etc.) is preserved. Used by - * [[QueryExecution]] to create a per-query analyzer for transactional queries so that - * transaction-aware catalog resolution is an instance-level property rather than thread-local - * state. + * [[QueryExecution]] to create a per-query analyzer for transactional operations for + * transaction-aware catalog resolution. */ def withCatalogManager(newCatalogManager: CatalogManager): Analyzer = { val self = this @@ -1056,9 +1055,10 @@ class Analyzer( } } - // Resolve V2TableReference nodes in a plan. V2TableReference is only created for temp views - // (via V2TableReference.createForTempView), so we only need to resolve it when returning - // the plan of temp views (in resolveViews and unwrapRelationPlan). + // Resolve V2TableReference nodes created for: + // 1 Temp views (via createForTempView). + // 2. Transaction references (via createForTransaction). These are resolved by a + // separate analysis batch in the transaction-aware analyzer instance. private def resolveTableReferences(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsUp { case r: V2TableReference => relationResolution.resolveReference(r) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index 8bcfb041d1a83..05d5394c9dfa9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -504,13 +504,18 @@ class RelationResolution( } private def loadRelation(ref: V2TableReference): LogicalPlan = { + // Resolve catalog. When a transaction is active we return the transaction + // aware catalog instance. val resolvedCatalog = catalogManager.catalog(ref.catalog.name).asTableCatalog val table = resolvedCatalog.loadTable(ref.identifier) - // val table = ref.catalog.loadTable(ref.identifier) V2TableReferenceUtils.validateLoadedTable(table, ref) - // ref.toRelation(table) + // Create relation with resolved Catalog. DataSourceV2Relation( - table, ref.output, Some(resolvedCatalog), Some(ref.identifier), ref.options) + table = table, + output = ref.output, + catalog = Some(resolvedCatalog), + identifier = Some(ref.identifier), + options = ref.options) } private def adaptCachedRelation(cached: LogicalPlan, ref: V2TableReference): LogicalPlan = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala index 4b175dd44ef08..0e344173d7892 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala @@ -23,6 +23,15 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +/** + * When a transaction is active, converts resolved [[DataSourceV2Relation]] nodes back to + * [[V2TableReference]] placeholders for all relations loaded by a catalog with the same + * name as the transaction catalog. + * + * This forces re-resolution of those relations against the transaction's catalog, which + * intercepts [[TableCatalog#loadTable]] calls to track which tables are read as part of + * the transaction. + */ class UnresolveTransactionRelations(val catalogManager: CatalogManager) extends Rule[LogicalPlan] with LookupCatalog { @@ -43,14 +52,13 @@ class UnresolveTransactionRelations(val catalogManager: CatalogManager) catalog: CatalogPlugin): LogicalPlan = { plan transform { case r: DataSourceV2Relation if isLoadedFromCatalog(r, catalog) => - V2TableReference.createForRelation(r, Seq.empty) + V2TableReference.createForTransaction(r) } } private def isLoadedFromCatalog( relation: DataSourceV2Relation, catalog: CatalogPlugin): Boolean = { - // relation.catalog.exists(_ eq catalog) && relation.identifier.isDefined relation.catalog.exists(_.name == catalog.name) && relation.identifier.isDefined } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index a2379f33e14ee..76226056ffe65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.V2TableReference.Context import org.apache.spark.sql.catalyst.analysis.V2TableReference.TableInfo import org.apache.spark.sql.catalyst.analysis.V2TableReference.TemporaryViewContext -import org.apache.spark.sql.catalyst.analysis.V2TableReference.TestContext +import org.apache.spark.sql.catalyst.analysis.V2TableReference.TransactionContext import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics @@ -85,17 +85,15 @@ private[sql] object V2TableReference { sealed trait Context case class TemporaryViewContext(viewName: Seq[String]) extends Context - // TODO(achatzis): Fix naming and complete implementation. - case class TestContext(tableName: Seq[String]) extends Context + /** Context for relations that are re-resolved through a transaction catalog. */ + case object TransactionContext extends Context def createForTempView(relation: DataSourceV2Relation, viewName: Seq[String]): V2TableReference = { create(relation, TemporaryViewContext(viewName)) } - def createForRelation( - relation: DataSourceV2Relation, - relationName: Seq[String]): V2TableReference = { - create(relation, TestContext(relationName)) + def createForTransaction(relation: DataSourceV2Relation): V2TableReference = { + create(relation, TransactionContext) } private def create(relation: DataSourceV2Relation, context: Context): V2TableReference = { @@ -119,7 +117,7 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { ref.context match { case ctx: TemporaryViewContext => validateLoadedTableInTempView(table, ref, ctx) - case _: TestContext => + case TransactionContext => validateLoadedTableInTransaction(table, ref) case ctx => throw SparkException.internalError(s"Unknown table ref context: ${ctx.getClass.getName}") @@ -128,8 +126,8 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = { val dataErrors = V2TableUtil.validateCapturedColumns( - table, - ref.info.columns, + table = table, + originCols = ref.info.columns, mode = PROHIBIT_CHANGES) if (dataErrors.nonEmpty) { throw QueryCompilationErrors.columnsChangedAfterAnalysis(ref.name, dataErrors) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index fb54af2344d1b..774c783ecf8a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -188,6 +188,10 @@ case class InsertIntoStatement( byName: Boolean = false, replaceCriteriaOpt: Option[InsertReplaceCriteria] = None, withSchemaEvolution: Boolean = false) + // Extends TransactionalWrite so that QueryExecution can detect a potential transaction on the + // unresolved logical plan before analysis runs. InsertIntoStatement is shared between V1 and V2 + // inserts, but the LookupCatalog.TransactionalWrite extractor only matches when the target + // catalog implements TransactionalCatalogPlugin, so V1 inserts are never assigned a transaction. extends UnaryParsedStatement with TransactionalWrite { require(overwrite || !ifPartitionNotExists, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 9d419ad668462..c16087bdf9bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -521,8 +521,10 @@ case class WriteDelta( trait V2CreateTableAsSelectPlan extends V2CreateTablePlan with AnalysisOnlyCommand - with CTEInChildren { + with CTEInChildren + with TransactionalWrite { def query: LogicalPlan + override def table: LogicalPlan = name override def withCTEDefs(cteDefs: Seq[CTERelationDef]): LogicalPlan = { withNameAndQuery(newName = name, newQuery = WithCTE(query, cteDefs)) @@ -1013,13 +1015,13 @@ case class MergeIntoTable( notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction], withSchemaEvolution: Boolean) - extends BinaryCommand - with WriteWithSchemaEvolution - with TransactionalWrite - with SupportsSubquery { + extends BinaryCommand + with WriteWithSchemaEvolution + with SupportsSubquery + with TransactionalWrite { // Implements SupportsSchemaEvolution.table. - // Implements TransactionalWrite.table, identifying the MERGE target as the table being written. + // Implements TransactionalWrite.table. override val table: LogicalPlan = EliminateSubqueryAliases(targetTable) override def withNewTable(newTable: NamedRelation): MergeIntoTable = { @@ -1279,6 +1281,12 @@ case class Assignment(key: Expression, value: Expression) extends Expression newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight) } +/** + * Marker trait for write operations that participate in a DSv2 transaction. + * + * Implementations are expected to target a DSv2 catalog backed by a + * [[org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin]]. + */ trait TransactionalWrite extends LogicalPlan { def table: LogicalPlan } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index fbb2938fd3da2..dd5be45bfc5f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.plans.logical.TransactionalWrite import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -171,6 +171,8 @@ private[sql] trait LookupCatalog extends Logging { EliminateSubqueryAliases(write.table) match { case UnresolvedRelation(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _, _) => Some(c) + case UnresolvedIdentifier(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _) => + Some(c) case _ => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index 27231447a1273..7ba1e9747f52e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -21,17 +21,18 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} class InMemoryRowLevelOperationTableCatalog - extends InMemoryTableCatalog with TransactionalCatalogPlugin { + extends InMemoryTableCatalog + with TransactionalCatalogPlugin { import CatalogV2Implicits._ + // The current active transaction. var transaction: Txn = _ - // Tracks the last completed transaction for test assertions; cleared when a new one begins. + // The last completed transaction. var lastTransaction: Txn = _ override def beginTransaction(info: TransactionInfo): Transaction = { assert(transaction == null || transaction.currentState != Active) this.transaction = new Txn(new TxnTableCatalog(this)) - this.lastTransaction = transaction transaction } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 2f3c65924d6a7..15ed4136dbda8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -216,9 +216,8 @@ class InMemoryTable( object InMemoryTable { - // V1 filter values (from PredicateUtils.toV1) are Scala types (e.g. String), but partition - // keys stored in dataMap are Catalyst internal types (e.g. UTF8String). Normalize both sides - // before comparing so that string partitions work correctly. + // Convert UTF8String to string to make sure equality checks between filters and partitions + // work correctly. private def valuesEqual(filterValue: Any, partitionValue: Any): Boolean = (filterValue, partitionValue) match { case (s: String, u: UTF8String) => u.toString == s diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index c7195b512b8d6..ff7995ad6697e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -59,10 +59,6 @@ class BasicInMemoryTableCatalog extends TableCatalog { tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray } - def loadTableAs[T <: Table](ident: Identifier): T = { - loadTable(ident).asInstanceOf[T] - } - // load table for scans override def loadTable(ident: Identifier): Table = { Option(tables.get(ident)) match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 4feb89c78f56c..f4f56d59f7851 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -42,13 +42,13 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { override def commit(): Unit = { if (closed) throw new IllegalStateException("Can't commit, already closed") + if (state == Aborted) throw new IllegalStateException("Can't commit, already aborted") catalog.commit() this.state = Committed } override def abort(): Unit = { if (state == Committed || state == Aborted) return - // if (closed) throw new IllegalStateException("Can't abort, already closed") this.state = Aborted } @@ -58,9 +58,9 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { } } -// a special table used in row-level operation transactions -// it inherits data from the base table upon construction and -// propagates staged transaction state back after an explicit commit +// A special table used in row-level operation transactions. It inherits data +// from the base table upon construction and propagates staged transaction state +// back after an explicit commit. class TxnTable(val delegate: InMemoryRowLevelOperationTable) extends InMemoryRowLevelOperationTable( delegate.name, @@ -72,8 +72,7 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) // TODO(achatzis): Rethink how schema evolution works on top of transactions. alterTableWithData(delegate.data, schema) - // a tracker of filters used in each scan - // achatzis: Non-deterministic filters? + // A tracker of filters used in each scan. val scanEvents = new ArrayBuffer[Array[Filter]]() override protected def recordScanEvent(filters: Array[Filter]): Unit = { @@ -92,8 +91,8 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) } } -// a special table catalog used in row-level operation transactions -// table changes are initially staged in memory and propagated only after an explicit commit +// A special table catalog used in row-level operation transactions. +// Table changes are initially staged in memory and propagated only after an explicit commit. class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends TableCatalog { private val tables: util.Map[Identifier, TxnTable] = new ConcurrentHashMap[Identifier, TxnTable]() @@ -108,20 +107,25 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T override def loadTable(ident: Identifier): Table = { tables.computeIfAbsent(ident, _ => { - val table = delegate.loadTableAs[InMemoryRowLevelOperationTable](ident) + val table = delegate.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] new TxnTable(table) }) } + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { + delegate.createTable(ident, tableInfo) + loadTable(ident) + } + override def alterTable(ident: Identifier, changes: TableChange*): Table = { val newDelegateTable = delegate.alterTable(ident, changes: _*) - // Compute again if absent. - tables.remove(ident) + tables.remove(ident) // Load again. newDelegateTable } override def dropTable(ident: Identifier): Boolean = { - throw new UnsupportedOperationException() + tables.remove(ident) + delegate.dropTable(ident) } override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { @@ -133,6 +137,7 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T } def clearActiveTransaction(): Unit = { + delegate.lastTransaction = delegate.transaction delegate.transaction = null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala index 5bd5d20692aed..9a5aed333a4c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala @@ -933,23 +933,6 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog with Logging { // caches referencing this relation. If this relation is cached as an InMemoryRelation, // this will clear the relation cache and caches of all its dependents. CommandUtils.recacheTableOrView(sparkSession, relation) - /* - EliminateSubqueryAliases(relation) match { - case r @ ExtractV2CatalogAndIdentifier(catalog, ident) if r.timeTravelSpec.isEmpty => - val nameParts = ident.toQualifiedNameParts(catalog) - sparkSession.sharedState.cacheManager.recacheTableOrView(sparkSession, nameParts) - case _ => - sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation) - */ - /* - relation match { - case r: DataSourceV2Relation if r.catalog.isDefined && r.identifier.isDefined => - val nameParts = r.identifier.get.toQualifiedNameParts(r.catalog.get) - sparkSession.sharedState.cacheManager.recacheTableOrView(sparkSession, nameParts) - case _ => - sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation) - } - */ } private def resolveRelation(tableName: String): LogicalPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 66f406d39f263..3f92f24156d3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -260,9 +260,6 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { val nameInCache = v2Ident.toQualifiedNameParts(catalog) isSameName(name, nameInCache, resolver) && (includeTimeTravel || timeTravelSpec.isEmpty) - // case r: TableReference => - // isSameName(name, r.identifier.toQualifiedNameParts(r.catalog), resolver) - case v: View => isSameName(name, v.desc.identifier.nameParts, resolver) From f20be5208601b408bd5540da7d941a744611a90d Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 22 Apr 2026 18:27:37 +0000 Subject: [PATCH 04/33] CTAS/RTAS support plus more cleaning --- .../analysis/RelationResolution.scala | 2 - .../TransactionAwareCatalogManager.scala | 7 +- ...nMemoryRowLevelOperationTableCatalog.scala | 4 +- .../spark/sql/connector/catalog/txns.scala | 4 +- .../spark/sql/execution/QueryExecution.scala | 49 +++--- .../datasources/v2/DataSourceV2Strategy.scala | 3 + .../v2/WriteToDataSourceV2Exec.scala | 47 ++++-- .../sql/connector/CTASTransactionSuite.scala | 140 ++++++++++++++++++ .../RowLevelOperationSuiteBase.scala | 13 +- 9 files changed, 228 insertions(+), 41 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index 05d5394c9dfa9..ac685f984ce4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -478,8 +478,6 @@ class RelationResolution( } } - // TODO: how to validate the output is compatible? - // TODO: what shall we do if the output mismatches (schema changes?) def resolveReference(ref: V2TableReference): LogicalPlan = { val relation = getOrLoadRelation(ref) val planId = ref.getTagValue(LogicalPlan.PLAN_ID_TAG) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala index 9403219f596da..aaeef4c2dea76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.catalog +import org.apache.spark.sql.catalyst.catalog.TempVariableManager import org.apache.spark.sql.connector.catalog.transactions.Transaction /** @@ -25,13 +26,15 @@ import org.apache.spark.sql.connector.catalog.transactions.Transaction * All mutable state (current catalog, current namespace, loaded catalogs) is delegated to the * wrapped [[CatalogManager]]. */ -// TODO: Consider extracting a CatalogManager trait that both the real -// implementation and the decorator implement +// TODO: Extracting a CatalogManager trait (so this class can implement it instead of extending +// CatalogManager) would eliminate the inherited mutable state that this decorator doesn't use. private[sql] class TransactionAwareCatalogManager( delegate: CatalogManager, txn: Transaction) extends CatalogManager(delegate.defaultSessionCatalog, delegate.v1SessionCatalog) { + override val tempVariableManager: TempVariableManager = delegate.tempVariableManager + override def transaction: Option[Transaction] = Some(txn) override def catalog(name: String): CatalogPlugin = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index 7ba1e9747f52e..78c350b5145a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} class InMemoryRowLevelOperationTableCatalog - extends InMemoryTableCatalog - with TransactionalCatalogPlugin { + extends InMemoryTableCatalog + with TransactionalCatalogPlugin { import CatalogV2Implicits._ // The current active transaction. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index f4f56d59f7851..e76fdf97c6888 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -69,7 +69,7 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) delegate.properties, delegate.constraints) { - // TODO(achatzis): Rethink how schema evolution works on top of transactions. + // TODO: Revise schema evolution. alterTableWithData(delegate.data, schema) // A tracker of filters used in each scan. @@ -81,7 +81,7 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) def commit(): Unit = { delegate.dataMap.clear() - // TODO(achatzis): Rethink how schema evolution works on top of transactions. + // TODO: Revise schema evolution. delegate.alterTableWithData(data, delegate.schema) delegate.replacedPartitions = replacedPartitions delegate.lastWriteInfo = lastWriteInfo diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9413354907678..ff7e03c124420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -100,8 +100,8 @@ class QueryExecution( // 1. At the pre-Analyzed plan we look for nodes that implement the TransactionalWrite trait. // When a plan contains such a node we initiate a transaction. Note, we should never start // a transaction for operations that are not executed, e.g. EXPLAIN. - // 2. Create an analyzer clone with a transaction aware Catalog Manager. The latter is the single - // choke point of all catalog access, and it is also the transaction context carrier. + // 2. Create an analyzer clone with a transaction aware Catalog Manager. The latter is the + // narrow waist of all catalog accesses, and it is also the transaction context carrier. // This is then passed to all rules during analysis that need to check the catalog. Rules // that are specifically interested in transactionality can access the transaction directly // from the Catalog Manager. The transaction catalog, is potentially the place where connectors @@ -281,15 +281,18 @@ class QueryExecution( sparkSession.withActive { assertAnalyzed() assertSupported() - // clone the plan to avoid sharing the plan instance between different stages like analyzing, - // optimizing and planning. - val plan = normalized.clone() - // During a transaction, skip cache substitution. useCachedData replaces DataSourceV2Relation - // nodes (loaded via the transaction catalog) with InMemoryRelation, which bypasses read - // tracking in the transaction catalog and may serve stale data. - // if (transactionOpt.isDefined) plan - // else sparkSession.sharedState.cacheManager.useCachedData(plan) - sparkSession.sharedState.cacheManager.useCachedData(plan) + + // During a transaction, skip cache substitution. This is to avoid replacing relations + // loaded by the transactional catalog with potentially stale relations cached before + // the transaction was active. + val plan = if (transactionOpt.isDefined) { + plan + } + else { + // clone the plan to avoid sharing the plan instance between different stages like + // analyzing, optimizing and planning. + sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + } } } @@ -332,11 +335,7 @@ class QueryExecution( // plan. QueryExecution.createSparkPlan(planner, optimizedPlan.clone()) } - transactionOpt match { - case Some(txn) => - plan.transformDown { case w: TransactionalExec => w.withTransaction(Some(txn)) } - case None => plan - } + attachTransaction(plan) } def sparkPlan: SparkPlan = executeWithTransactionContext { @@ -610,14 +609,11 @@ class QueryExecution( } /** - * Execute the given block with the transaction context if exists. If there is an exception thrown - * during the execution, the transaction will be aborted. + * Executes the given block with the transaction context if exists. If there is an exception + * thrown during the execution, the transaction will be aborted. * - * Note 1: The transaction is not committed in this method. The caller should commit the + * Note: The transaction is not committed in this method. The caller should commit the * transaction if the execution is successful. - * - * Note 2: In some cases, post commit execution might generate an exception. The abort operation - * should be no-op in this case. */ private def executeWithTransactionContext[T](block: => T): T = transactionOpt match { case Some(transaction) => @@ -626,6 +622,15 @@ class QueryExecution( case None => block } + + /** Attaches a transaction to the given SparkPlan to the transactional execution nodes. */ + private def attachTransaction(plan: SparkPlan): SparkPlan = transactionOpt match { + case Some(txn) => plan.transformDown { + case w: TransactionalExec => w.withTransaction(Some(txn)) + } + case None => plan + } + /** A special namespace for commands that can be used to debug query execution. */ // scalastyle:off object debug { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c73f4ad9eded9..e03928867e24d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -63,6 +63,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat private def hadoopConf = session.sessionState.newHadoopConf() + // recaches all cache entries without time travel for the given table + // after a write operation that moves the state of the table forward (e.g. append, overwrite) + // this method accounts for V2 tables loaded via TableProvider (no catalog/identifier) private def refreshCache(r: DataSourceV2Relation)(): Unit = r match { case ExtractV2CatalogAndIdentifier(catalog, ident) => val nameParts = ident.toQualifiedNameParts(catalog) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index e2851c3187f42..5580d33c9aaeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -76,7 +76,12 @@ case class CreateTableAsSelectExec( query: LogicalPlan, tableSpec: TableSpec, writeOptions: Map[String, String], - ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec { + ifNotExists: Boolean, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): CreateTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -94,7 +99,9 @@ case class CreateTableAsSelectExec( .build() val table = Option(catalog.createTable(ident, tableInfo)) .getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable(catalog, table, writeOptions, ident, query, overwrite = false) + val result = writeToTable(catalog, table, writeOptions, ident, query, overwrite = false) + transaction.foreach(TransactionUtils.commit) + result } } @@ -114,7 +121,13 @@ case class AtomicCreateTableAsSelectExec( query: LogicalPlan, tableSpec: TableSpec, writeOptions: Map[String, String], - ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec { + ifNotExists: Boolean, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec + with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): AtomicCreateTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -135,7 +148,9 @@ case class AtomicCreateTableAsSelectExec( .build() val stagedTable = Option(catalog.stageCreate(ident, tableInfo) ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable(catalog, stagedTable, writeOptions, ident, query, overwrite = false) + val result = writeToTable(catalog, stagedTable, writeOptions, ident, query, overwrite = false) + transaction.foreach(TransactionUtils.commit) + result } } @@ -157,8 +172,12 @@ case class ReplaceTableAsSelectExec( tableSpec: TableSpec, writeOptions: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Identifier) => Unit) - extends V2CreateTableAsSelectBaseExec { + invalidateCache: (TableCatalog, Identifier) => Unit, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): ReplaceTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -194,9 +213,11 @@ case class ReplaceTableAsSelectExec( .build() val table = Option(catalog.createTable(ident, tableInfo)) .getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable( + val result = writeToTable( catalog, table, writeOptions, ident, refreshedQuery, overwrite = true, refreshPhaseEnabled = false) + transaction.foreach(TransactionUtils.commit) + result } } @@ -220,8 +241,12 @@ case class AtomicReplaceTableAsSelectExec( tableSpec: TableSpec, writeOptions: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Identifier) => Unit) - extends V2CreateTableAsSelectBaseExec { + invalidateCache: (TableCatalog, Identifier) => Unit, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): AtomicReplaceTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -262,7 +287,9 @@ case class AtomicReplaceTableAsSelectExec( } val table = Option(staged).getOrElse( catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable(catalog, table, writeOptions, ident, query, overwrite = true) + val result = writeToTable(catalog, table, writeOptions, ident, query, overwrite = true) + transaction.foreach(TransactionUtils.commit) + result } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala new file mode 100644 index 0000000000000..1643fa6879525 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.Committed + +class CTASTransactionSuite extends RowLevelOperationSuiteBase { + + private val newTableNameAsString = "cat.ns1.new_table" + + test("CTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val (txn, txnTables) = executeTransactionMultiQE { + sql(s"""CREATE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size >= 1) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq(Row(1, 100, "hr"))) + } + + test("CTAS with cached source and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // cache the source table before running CTAS + spark.catalog.cacheTable(tableNameAsString) + + try { + val (txn, txnTables) = executeTransactionMultiQE { + sql(s"""CREATE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + // cache miss: TxnTable-based relation is not structurally equal to the cached one, + // so the scan goes through the transaction catalog and scan events are captured + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size >= 1) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq(Row(1, 100, "hr"))) + } finally { + spark.catalog.uncacheTable(tableNameAsString) + } + } + + test("RTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // pre-create the target so REPLACE TABLE (not CREATE OR REPLACE) is valid + sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + + val (txn, txnTables) = executeTransactionMultiQE { + sql(s"""REPLACE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size >= 1) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(3, 300, "hr"))) + } + + test("RTAS self-reference with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // source and target are the same table: reads the old snapshot via TxnTable, + // replaces the table with a filtered version + val (txn, txnTables) = executeTransactionMultiQE { + sql(s"""CREATE OR REPLACE TABLE $tableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size >= 1) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(3, 300, "hr"))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index eb449cdaa449f..95a6720628162 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -145,7 +145,18 @@ abstract class RowLevelOperationSuiteBase (catalog.lastTransaction, indexByName(tables)) } - private def indexByName[T <: Table](tables: Seq[T]): Map[String, T] = { + protected def executeTransactionMultiQE(func: => Unit): (Txn, Map[String, TxnTable]) = { + val qes = withQueryExecutionsCaptured(spark)(func) + val tables = qes.flatMap { qe => + collectWithSubqueries(qe.executedPlan) { + case BatchScanExec(_, _, _, _, table: TxnTable, _) => table + case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => table + } + } + (catalog.lastTransaction, indexByName(tables)) + } + + protected def indexByName[T <: Table](tables: Seq[T]): Map[String, T] = { tables.groupBy(_.name).map { case (name, sameNameTables) => val Seq(table) = sameNameTables.distinct From 6cf1365643a5ea6eb20e75d42ccee0d9866312df Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Tue, 7 Apr 2026 12:29:28 +0000 Subject: [PATCH 05/33] Fix comp error + refactor executeTransaction --- .../spark/sql/execution/QueryExecution.scala | 4 ++-- .../sql/connector/CTASTransactionSuite.scala | 8 +++---- .../RowLevelOperationSuiteBase.scala | 23 +++---------------- 3 files changed, 9 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index ff7e03c124420..b1567cfdd1c59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -285,8 +285,8 @@ class QueryExecution( // During a transaction, skip cache substitution. This is to avoid replacing relations // loaded by the transactional catalog with potentially stale relations cached before // the transaction was active. - val plan = if (transactionOpt.isDefined) { - plan + if (transactionOpt.isDefined) { + normalized } else { // clone the plan to avoid sharing the plan instance between different stages like diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala index 1643fa6879525..ff055039173c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala @@ -30,7 +30,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { |{ "pk": 2, "salary": 200, "dep": "software" } |""".stripMargin) - val (txn, txnTables) = executeTransactionMultiQE { + val (txn, txnTables) = executeTransaction { sql(s"""CREATE TABLE $newTableNameAsString |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' |""".stripMargin) @@ -58,7 +58,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { spark.catalog.cacheTable(tableNameAsString) try { - val (txn, txnTables) = executeTransactionMultiQE { + val (txn, txnTables) = executeTransaction { sql(s"""CREATE TABLE $newTableNameAsString |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' |""".stripMargin) @@ -90,7 +90,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { // pre-create the target so REPLACE TABLE (not CREATE OR REPLACE) is valid sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") - val (txn, txnTables) = executeTransactionMultiQE { + val (txn, txnTables) = executeTransaction { sql(s"""REPLACE TABLE $newTableNameAsString |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' |""".stripMargin) @@ -119,7 +119,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { // source and target are the same table: reads the old snapshot via TxnTable, // replaces the table with a filtered version - val (txn, txnTables) = executeTransactionMultiQE { + val (txn, txnTables) = executeTransaction { sql(s"""CREATE OR REPLACE TABLE $tableNameAsString |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' |""".stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 95a6720628162..7c2d7463c3100 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Id import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.write.RowLevelOperationTable -import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.internal.SQLConf @@ -135,19 +135,7 @@ abstract class RowLevelOperationSuiteBase } protected def executeTransaction(func: => Unit): (Txn, Map[String, TxnTable]) = { - val qe = execute(func) - val tables = collectWithSubqueries(qe.executedPlan) { - case BatchScanExec(_, _, _, _, table: TxnTable, _) => - table - case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => - table - } - (catalog.lastTransaction, indexByName(tables)) - } - - protected def executeTransactionMultiQE(func: => Unit): (Txn, Map[String, TxnTable]) = { - val qes = withQueryExecutionsCaptured(spark)(func) - val tables = qes.flatMap { qe => + val tables = withQueryExecutionsCaptured(spark)(func).flatMap { qe => collectWithSubqueries(qe.executedPlan) { case BatchScanExec(_, _, _, _, table: TxnTable, _) => table case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => table @@ -166,13 +154,8 @@ abstract class RowLevelOperationSuiteBase // executes an operation and keeps the executed plan protected def executeAndKeepPlan(func: => Unit): SparkPlan = { - val qe = execute(func) - stripAQEPlan(qe.executedPlan) - } - - private def execute(func: => Unit): QueryExecution = { withQueryExecutionsCaptured(spark)(func) match { - case Seq(qe) => qe + case Seq(qe) => stripAQEPlan(qe.executedPlan) case other => fail(s"expected only one query execution, but got ${other.size}") } } From 0ea7e0576fe374f3826e35bae0bcb08feaf3684f Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 8 Apr 2026 08:22:35 +0000 Subject: [PATCH 06/33] Test improvements --- .../spark/sql/connector/catalog/txns.scala | 6 ++-- .../AppendDataTransactionSuite.scala | 7 ++++ ...e.scala => CTASRTASTransactionSuite.scala} | 5 ++- .../connector/DeleteFromTableSuiteBase.scala | 4 +++ .../connector/MergeIntoDataFrameSuite.scala | 36 +++++++++++++++++-- .../connector/MergeIntoTableSuiteBase.scala | 6 ++++ .../RowLevelOperationSuiteBase.scala | 16 ++++----- .../sql/connector/UpdateTableSuiteBase.scala | 10 ++++-- 8 files changed, 74 insertions(+), 16 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/connector/{CTASTransactionSuite.scala => CTASRTASTransactionSuite.scala} (96%) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index e76fdf97c6888..c6339d2099b63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -53,8 +53,10 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { } override def close(): Unit = { - catalog.clearActiveTransaction() - this.closed = true + if (!closed) { + catalog.clearActiveTransaction() + this.closed = true + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala index 1c8e7fc5a0fd4..379cf6df0f739 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -46,6 +46,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") // check the source scan was tracked via the transaction catalog val targetTxnTable = txnTables(tableNameAsString) @@ -75,6 +76,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { // check txn was properly committed and closed assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") // check data was inserted correctly checkAnswer( @@ -104,6 +106,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -131,6 +134,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -159,6 +163,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -186,6 +191,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -214,6 +220,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") // check data was inserted correctly checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala index ff055039173c8..c58a78498f9f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.Row import org.apache.spark.sql.connector.catalog.Committed -class CTASTransactionSuite extends RowLevelOperationSuiteBase { +class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { private val newTableNameAsString = "cat.ns1.new_table" @@ -39,6 +39,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") val sourceTxnTable = txnTables(tableNameAsString) assert(sourceTxnTable.scanEvents.size >= 1) @@ -66,6 +67,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") // cache miss: TxnTable-based relation is not structurally equal to the cached one, // so the scan goes through the transaction catalog and scan events are captured @@ -99,6 +101,7 @@ class CTASTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") val sourceTxnTable = txnTables(tableNameAsString) assert(sourceTxnTable.scanEvents.size >= 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index 0bca12b315515..f8d81ee086911 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -804,6 +804,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) + assert(table.version() == "2") checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -830,6 +831,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") val sourceTxnTable = txnTables(sourceNameAsString) val expectedNumSourceScans = if (deltaDelete) 1 else 2 @@ -875,6 +877,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") val targetTxnTable = txnTables(tableNameAsString) val expectedNumTargetScans = if (deltaDelete) 1 else 2 @@ -920,6 +923,7 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") val targetTxnTable = txnTables(tableNameAsString) val expectedNumTargetScans = if (deltaDelete) 1 else 2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala index 687aae91438da..d58e22e63d71e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.{sources, Column, Row} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.classic.MergeIntoWriter -import org.apache.spark.sql.connector.catalog.Committed +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableInfo import org.apache.spark.sql.functions._ @@ -66,6 +66,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") // check all table scans val targetTxnTable = txnTables(tableNameAsString) @@ -92,8 +93,39 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { Row(1, 101, "hr"), // update Row(2, 200, "software"), // unchanged Row(3, 300, "hr"))) // unchanged + } + + for (alterClause <- Seq( + "ADD COLUMN new_col INT", + "DROP COLUMN salary", + "ALTER COLUMN salary TYPE BIGINT", + "ALTER COLUMN pk DROP NOT NULL")) + test(s"self merge fails when source schema changes after analysis - DDL: $alterClause" ) { + withTable(tableNameAsString) { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = spark.table(tableNameAsString).where("salary == 100").as("source") + sourceDF.queryExecution.assertAnalyzed() - // TODO Achatzis check version. + sql(s"ALTER TABLE $tableNameAsString $alterClause") + + val e = intercept[AnalysisException] { + sourceDF + .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk")) + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .merge() + } + + assert( + e.getCondition == "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", + alterClause) + assert(catalog.lastTransaction.currentState == Aborted, alterClause) + assert(catalog.lastTransaction.isClosed, alterClause) + } } test("merge into empty table with NOT MATCHED clause") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index e14d0a2571bb7..91f20885beb45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -64,6 +64,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 1) + assert(table.version() == "2") // check all table scans val targetTxnTable = txnTables(tableNameAsString) @@ -163,6 +164,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -249,6 +251,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -318,6 +321,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -1106,6 +1110,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -1177,6 +1182,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.isClosed) // both target and source must have been read through the transaction catalog assert(txnTables.size == 2) + assert(table.version() == "2") assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) assert(txnTables(tableNameAsString).scanEvents.nonEmpty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 7c2d7463c3100..d0209c97cf93c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -134,6 +134,14 @@ abstract class RowLevelOperationSuiteBase } } + // executes an operation and keeps the executed plan + protected def executeAndKeepPlan(func: => Unit): SparkPlan = { + withQueryExecutionsCaptured(spark)(func) match { + case Seq(qe) => stripAQEPlan(qe.executedPlan) + case other => fail(s"expected only one query execution, but got ${other.size}") + } + } + protected def executeTransaction(func: => Unit): (Txn, Map[String, TxnTable]) = { val tables = withQueryExecutionsCaptured(spark)(func).flatMap { qe => collectWithSubqueries(qe.executedPlan) { @@ -152,14 +160,6 @@ abstract class RowLevelOperationSuiteBase } } - // executes an operation and keeps the executed plan - protected def executeAndKeepPlan(func: => Unit): SparkPlan = { - withQueryExecutionsCaptured(spark)(func) match { - case Seq(qe) => stripAQEPlan(qe.executedPlan) - case other => fail(s"expected only one query execution, but got ${other.size}") - } - } - // executes an operation and extracts conditions from ReplaceData or WriteDelta protected def executeAndKeepConditions(func: => Unit): (Expression, Option[Expression]) = { val Seq(qe) = withQueryExecutionsCaptured(spark)(func) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index 34cc6efb6db9a..65c3b68fa8cb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -914,6 +914,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -973,6 +974,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check source table was scanned correctly (dep = 'hr' filter in the subquery) val sourceTxnTable = txnTables(sourceNameAsString) @@ -1026,6 +1028,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check source table was scanned via the transaction catalog val sourceTxnTable = txnTables(sourceNameAsString) @@ -1106,6 +1109,7 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { assert(txn.currentState == Committed) assert(txn.isClosed) assert(txnTables.size == 2) + assert(table.version() == "2") // check target table was scanned correctly val targetTxnTable = txnTables(tableNameAsString) @@ -1149,15 +1153,15 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { |{ "pk": 2, "salary": 200, "dep": "software" } |""".stripMargin) - // NOTE: df.explain() on a DML command actually executes the write. - // TODO(achatzis): This is existing behavior but check why this is OK. Shouldn't sql() be lazy? + // sql() is lazy, but explain() forces executedPlan. sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'").explain() assert(catalog.lastTransaction != null) assert(catalog.lastTransaction.currentState == Committed) assert(catalog.lastTransaction.isClosed) + assert(table.version() == "2") - // the UPDATE was actually executed, not just planned + // The UPDATE was actually executed, not just planned. checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq( From 30e887b626d617f7acee9c92e56af30e37b24f48 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 8 Apr 2026 09:32:58 +0000 Subject: [PATCH 07/33] Append suite improvements pass 1 --- .../AppendDataTransactionSuite.scala | 211 ++++++++++++++---- 1 file changed, 172 insertions(+), 39 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala index 379cf6df0f739..22cdcd1cdfad2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.Committed +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.sources class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { @@ -88,7 +91,8 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { Row(4, 400, "finance"))) } - test("SQL INSERT OVERWRITE with transactional checks") { + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE with transactional checks - isDynamic: $isDynamic") { // create table with initial data; table is partitioned by dep createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } @@ -96,12 +100,22 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { |{ "pk": 3, "salary": 300, "dep": "hr" } |""".stripMargin) - // INSERT OVERWRITE with static partition predicate -> OverwriteByExpression - val (txn, _) = executeTransaction { - sql(s"""INSERT OVERWRITE $tableNameAsString - |PARTITION (dep = 'hr') - |SELECT pk + 10, salary FROM $tableNameAsString WHERE dep = 'hr' - |""".stripMargin) + val insertOverwrite = if (isDynamic) { + // OverwritePartitionsDynamic + s"""INSERT OVERWRITE $tableNameAsString + |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } else { + // OverwriteByExpression + s"""INSERT OVERWRITE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT pk + 10, salary FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val (txn, _) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + executeTransaction { sql(insertOverwrite) } } assert(txn.currentState == Committed) @@ -116,35 +130,6 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { Row(13, 300, "hr"))) // overwritten } - test("SQL INSERT OVERWRITE dynamic partition with transactional checks") { - // create table with initial data; table is partitioned by dep - createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", - """{ "pk": 1, "salary": 100, "dep": "hr" } - |{ "pk": 2, "salary": 200, "dep": "software" } - |{ "pk": 3, "salary": 300, "dep": "hr" } - |""".stripMargin) - - // INSERT OVERWRITE with dynamic partitioning -> OverwritePartitionsDynamic - withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> "dynamic") { - val (txn, _) = executeTransaction { - sql(s"""INSERT OVERWRITE $tableNameAsString - |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' - |""".stripMargin) - } - - assert(txn.currentState == Committed) - assert(txn.isClosed) - assert(table.version() == "2") - - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(2, 200, "software"), // unchanged (different partition) - Row(11, 100, "hr"), // overwrote hr partition - Row(13, 300, "hr"))) // overwrote hr partition - } - } - test("writeTo overwrite with transactional checks") { // create table with initial data; table is partitioned by dep createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -168,7 +153,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq( - Row(2, 200, "software"), // unchanged (different partition) + Row(2, 200, "software"), // unchanged Row(11, 999, "hr"), // overwrote hr partition Row(12, 888, "hr"))) // overwrote hr partition } @@ -196,7 +181,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq( - Row(2, 200, "software"), // unchanged (different partition) + Row(2, 200, "software"), // unchanged Row(11, 999, "hr"), // overwrote hr partition Row(12, 888, "hr"))) // overwrote hr partition } @@ -232,4 +217,152 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { Row(11, 100, "hr"), // inserted from pk=1 Row(13, 300, "hr"))) // inserted from pk=3 } + + test("SQL INSERT INTO SELECT with subquery on source table and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')") + + // INSERT using a subquery that reads from the target to filter source rows + // both tables are scanned through the transaction catalog + val (txn, txnTables) = executeTransaction { + sql( + s"""INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $sourceNameAsString + |WHERE pk IN (SELECT pk FROM $tableNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // target was scanned via the transaction catalog (IN subquery) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // source was scanned via the transaction catalog + assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row + } + + test("SQL INSERT INTO SELECT with CTE and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')") + + // CTE reads from target; INSERT selects from source filtered by the CTE result + // both tables are scanned through the transaction catalog + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH hr_pks AS (SELECT pk FROM $tableNameAsString WHERE dep = 'hr') + |INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $sourceNameAsString + |WHERE pk IN (SELECT pk FROM hr_pks) + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // target was scanned via the transaction catalog (CTE) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // source was scanned via the transaction catalog + assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row via CTE + } + + test("SQL INSERT with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val e = intercept[AnalysisException] { + sql(s"INSERT INTO $tableNameAsString SELECT nonexistent_col FROM $tableNameAsString") + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE with analysis failure and transactional checks" + + s"isDynamic: $isDynamic") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val insertOverwrite = if (isDynamic) { + s"""INSERT OVERWRITE $tableNameAsString + |SELECT nonexistent_col, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } else { + s"""INSERT OVERWRITE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT nonexistent_col FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val e = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + intercept[AnalysisException] { sql(insertOverwrite) } + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("EXPLAIN INSERT SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"EXPLAIN INSERT INTO $tableNameAsString VALUES (3, 300, 'hr')") + + // EXPLAIN should not start a transaction + assert(catalog.transaction === null) + + // INSERT was not executed; data is unchanged + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } } From 0fc3049905fc2cce4aaaa9896cbe2b3a2deb3336 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 8 Apr 2026 13:35:25 +0000 Subject: [PATCH 08/33] Append suite improvements pass 2 --- .../AppendDataTransactionSuite.scala | 100 ++++++++++++------ 1 file changed, 67 insertions(+), 33 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala index 22cdcd1cdfad2..93da0a95de956 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -46,14 +46,18 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { } // check txn was properly committed and closed - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 1) - assert(table.version() == "2") + assert(txnTables.size === 1) + assert(table.version() === "2") // check the source scan was tracked via the transaction catalog val targetTxnTable = txnTables(tableNameAsString) - assert(targetTxnTable.scanEvents.size >= 1) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("pk", 1) => true + case _ => false + }) // check data was appended correctly checkAnswer( @@ -72,14 +76,17 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) // SQL INSERT INTO using VALUES - val (txn, _) = executeTransaction { + val (txn, txnTables) = executeTransaction { sql(s"INSERT INTO $tableNameAsString VALUES (3, 300, 'hr'), (4, 400, 'finance')") } // check txn was properly committed and closed - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(table.version() == "2") + assert(table.version() === "2") + + // VALUES literal - No catalog tables were scanned + assert(txnTables.isEmpty) // check data was inserted correctly checkAnswer( @@ -114,13 +121,22 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { } val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC - val (txn, _) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + val (txn, txnTables) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { executeTransaction { sql(insertOverwrite) } } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(table.version() == "2") + assert(table.version() === "2") + + // the SELECT reads from the target table once with a dep='hr' filter + assert(txnTables.size == 1) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -142,13 +158,16 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). toDF("pk", "salary", "dep") - val (txn, _) = executeTransaction { + val (txn, txnTables) = executeTransaction { sourceDF.writeTo(tableNameAsString).overwrite(col("dep") === "hr") } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(table.version() == "2") + assert(table.version() === "2") + + // literal DataFrame source - no catalog tables were scanned + assert(txnTables.isEmpty) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -170,13 +189,16 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). toDF("pk", "salary", "dep") - val (txn, _) = executeTransaction { + val (txn, txnTables) = executeTransaction { sourceDF.writeTo(tableNameAsString).overwritePartitions() } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(table.version() == "2") + assert(table.version() === "2") + + // literal DataFrame source - no catalog tables were scanned + assert(txnTables.isEmpty) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -202,10 +224,18 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { } // check txn was properly committed and closed - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 1) - assert(table.version() == "2") + assert(table.version() === "2") + + // the SELECT reads from the target table once with a dep='hr' filter + assert(txnTables.size === 1) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) // check data was inserted correctly checkAnswer( @@ -237,20 +267,22 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 2) - assert(table.version() == "2") + assert(txnTables.size === 2) + assert(table.version() === "2") - // target was scanned via the transaction catalog (IN subquery) + // target was scanned via the transaction catalog (IN subquery) once with dep='hr' filter val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) assert(targetTxnTable.scanEvents.flatten.exists { case sources.EqualTo("dep", "hr") => true case _ => false }) - // source was scanned via the transaction catalog - assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + // source was scanned via the transaction catalog exactly once (no filter) + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -280,20 +312,22 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 2) - assert(table.version() == "2") + assert(txnTables.size === 2) + assert(table.version() === "2") - // target was scanned via the transaction catalog (CTE) + // target was scanned via the transaction catalog (CTE) once with dep='hr' filter val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) assert(targetTxnTable.scanEvents.flatten.exists { case sources.EqualTo("dep", "hr") => true case _ => false }) - // source was scanned via the transaction catalog - assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + // source was scanned via the transaction catalog exactly once (no filter) + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -314,7 +348,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { } assert(e.getMessage.contains("nonexistent_col")) - assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.currentState === Aborted) assert(catalog.lastTransaction.isClosed) } @@ -343,7 +377,7 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { } assert(e.getMessage.contains("nonexistent_col")) - assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.currentState === Aborted) assert(catalog.lastTransaction.isClosed) } From dc9992c3c005e762168a09c4acaa8ec965f87742 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Thu, 9 Apr 2026 08:52:04 +0000 Subject: [PATCH 09/33] RTAS/CTAS improvements --- .../spark/sql/connector/catalog/txns.scala | 20 ++- .../v2/WriteToDataSourceV2Exec.scala | 25 +--- .../connector/CTASRTASTransactionSuite.scala | 138 +++++++++++++----- 3 files changed, 120 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index c6339d2099b63..0881daeb7635d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -74,6 +74,8 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) // TODO: Revise schema evolution. alterTableWithData(delegate.data, schema) + private val initialVersion: String = version() + // A tracker of filters used in each scan. val scanEvents = new ArrayBuffer[Array[Filter]]() @@ -82,14 +84,16 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) } def commit(): Unit = { - delegate.dataMap.clear() - // TODO: Revise schema evolution. - delegate.alterTableWithData(data, delegate.schema) - delegate.replacedPartitions = replacedPartitions - delegate.lastWriteInfo = lastWriteInfo - delegate.lastWriteLog = lastWriteLog - delegate.commits ++= commits - delegate.increaseVersion() + if (version() != initialVersion) { + delegate.dataMap.clear() + // TODO: Revise schema evolution. + delegate.alterTableWithData(data, delegate.schema) + delegate.replacedPartitions = replacedPartitions + delegate.lastWriteInfo = lastWriteInfo + delegate.lastWriteLog = lastWriteLog + delegate.commits ++= commits + delegate.increaseVersion() + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 5580d33c9aaeb..308f4bdc5042b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -121,13 +121,8 @@ case class AtomicCreateTableAsSelectExec( query: LogicalPlan, tableSpec: TableSpec, writeOptions: Map[String, String], - ifNotExists: Boolean, - transaction: Option[Transaction] = None) - extends V2CreateTableAsSelectBaseExec - with TransactionalExec { - - override def withTransaction(txn: Option[Transaction]): AtomicCreateTableAsSelectExec = - copy(transaction = txn) + ifNotExists: Boolean) + extends V2CreateTableAsSelectBaseExec { val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -148,9 +143,7 @@ case class AtomicCreateTableAsSelectExec( .build() val stagedTable = Option(catalog.stageCreate(ident, tableInfo) ).getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - val result = writeToTable(catalog, stagedTable, writeOptions, ident, query, overwrite = false) - transaction.foreach(TransactionUtils.commit) - result + writeToTable(catalog, stagedTable, writeOptions, ident, query, overwrite = false) } } @@ -241,12 +234,8 @@ case class AtomicReplaceTableAsSelectExec( tableSpec: TableSpec, writeOptions: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Identifier) => Unit, - transaction: Option[Transaction] = None) - extends V2CreateTableAsSelectBaseExec with TransactionalExec { - - override def withTransaction(txn: Option[Transaction]): AtomicReplaceTableAsSelectExec = - copy(transaction = txn) + invalidateCache: (TableCatalog, Identifier) => Unit) + extends V2CreateTableAsSelectBaseExec { val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -287,9 +276,7 @@ case class AtomicReplaceTableAsSelectExec( } val table = Option(staged).getOrElse( catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - val result = writeToTable(catalog, table, writeOptions, ident, query, overwrite = true) - transaction.foreach(TransactionUtils.commit) - result + writeToTable(catalog, table, writeOptions, ident, query, overwrite = true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala index c58a78498f9f4..8acdd8242ef1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala @@ -17,13 +17,19 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.Committed +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTable} +import org.apache.spark.sql.sources class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { private val newTableNameAsString = "cat.ns1.new_table" + private def newTable: InMemoryRowLevelOperationTable = + catalog.loadTable(Identifier.of(Array("ns1"), "new_table")) + .asInstanceOf[InMemoryRowLevelOperationTable] + test("CTAS with transactional checks") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } @@ -36,50 +42,56 @@ class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 1) - assert(table.version() == "2") + assert(txnTables.size === 1) + assert(table.version() === "1") // source table: read-only, version unchanged + assert(newTable.version() === "1") // target table: newly created and written + // the source table was scanned once through the transaction catalog with a dep='hr' filter val sourceTxnTable = txnTables(tableNameAsString) - assert(sourceTxnTable.scanEvents.size >= 1) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq(Row(1, 100, "hr"))) + } + + test("CTAS from literal source with transactional checks") { + // no source catalog table involved — the query is a pure literal SELECT + val (txn, txnTables) = executeTransaction { + sql(s"CREATE TABLE $newTableNameAsString AS SELECT 1 AS pk, 100 AS salary, 'hr' AS dep") + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // literal SELECT - no catalog tables were scanned + assert(txnTables.isEmpty) + assert(newTable.version() === "1") // target table: newly created and written checkAnswer( sql(s"SELECT * FROM $newTableNameAsString"), Seq(Row(1, 100, "hr"))) } - test("CTAS with cached source and transactional checks") { + test("CTAS with analysis failure and transactional checks") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } |{ "pk": 2, "salary": 200, "dep": "software" } |""".stripMargin) - // cache the source table before running CTAS - spark.catalog.cacheTable(tableNameAsString) - - try { - val (txn, txnTables) = executeTransaction { - sql(s"""CREATE TABLE $newTableNameAsString - |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' - |""".stripMargin) - } - - assert(txn.currentState == Committed) - assert(txn.isClosed) - assert(table.version() == "2") - - // cache miss: TxnTable-based relation is not structurally equal to the cached one, - // so the scan goes through the transaction catalog and scan events are captured - val sourceTxnTable = txnTables(tableNameAsString) - assert(sourceTxnTable.scanEvents.size >= 1) - - checkAnswer( - sql(s"SELECT * FROM $newTableNameAsString"), - Seq(Row(1, 100, "hr"))) - } finally { - spark.catalog.uncacheTable(tableNameAsString) + val e = intercept[AnalysisException] { + sql(s"CREATE TABLE $newTableNameAsString AS SELECT nonexistent_col FROM $tableNameAsString") } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) } test("RTAS with transactional checks") { @@ -98,13 +110,19 @@ class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) - assert(txnTables.size == 1) - assert(table.version() == "2") + assert(txnTables.size === 1) + assert(table.version() === "1") // source table: read-only, version unchanged + assert(newTable.version() === "1") // target table: replaced and written + // the source table was scanned once through the transaction catalog with a dep='hr' filter val sourceTxnTable = txnTables(tableNameAsString) - assert(sourceTxnTable.scanEvents.size >= 1) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) checkAnswer( sql(s"SELECT * FROM $newTableNameAsString"), @@ -128,11 +146,18 @@ class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { |""".stripMargin) } - assert(txn.currentState == Committed) + assert(txn.currentState === Committed) assert(txn.isClosed) + assert(txnTables.size === 1) + assert(table.version() === "1") // source/target table: replaced in place, version reset to 1 + // the source/target table was scanned once with a dep='hr' filter val sourceTxnTable = txnTables(tableNameAsString) - assert(sourceTxnTable.scanEvents.size >= 1) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -140,4 +165,45 @@ class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { Row(1, 100, "hr"), Row(3, 300, "hr"))) } + + test("RTAS with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val e = intercept[AnalysisException] { + sql(s"""CREATE OR REPLACE TABLE $tableNameAsString + |AS SELECT nonexistent_col FROM $tableNameAsString + |""".stripMargin) + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("simple CREATE TABLE and DROP TABLE do not create transactions") { + sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + assert(catalog.transaction === null) + assert(catalog.lastTransaction === null) + + sql(s"DROP TABLE $newTableNameAsString") + assert(catalog.transaction === null) + assert(catalog.lastTransaction === null) + } + + test("EXPLAIN CTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"""EXPLAIN CREATE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + + // EXPLAIN should not start a transaction + assert(catalog.transaction === null) + } } From 3ed327b5727b2681c2f901faedf62e5715c6d247 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Thu, 9 Apr 2026 14:07:14 +0000 Subject: [PATCH 10/33] Schema evolution --- .../connector/catalog/InMemoryBaseTable.scala | 4 + .../spark/sql/connector/catalog/txns.scala | 12 ++- .../AppendDataTransactionSuite.scala | 98 +++++++++++++++++++ 3 files changed, 110 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index af0860664312c..f49838f10c904 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -76,6 +76,10 @@ abstract class InMemoryBaseTable( override def columns(): Array[Column] = tableColumns + private[catalog] def updateColumns(newColumns: Array[Column]): Unit = { + tableColumns = newColumns + } + override def version(): String = tableVersion.toString def setVersion(version: String): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 0881daeb7635d..6a6897c896165 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -71,9 +71,9 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) delegate.properties, delegate.constraints) { - // TODO: Revise schema evolution. - alterTableWithData(delegate.data, schema) + withData(delegate.data) + // Keep initial version to detect any changes during the transaction. private val initialVersion: String = version() // A tracker of filters used in each scan. @@ -86,8 +86,8 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) def commit(): Unit = { if (version() != initialVersion) { delegate.dataMap.clear() - // TODO: Revise schema evolution. - delegate.alterTableWithData(data, delegate.schema) + delegate.alterTableWithData(data, schema) + delegate.updateColumns(columns()) // Evolve schema if needed. delegate.replacedPartitions = replacedPartitions delegate.lastWriteInfo = lastWriteInfo delegate.lastWriteLog = lastWriteLog @@ -124,6 +124,10 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T } override def alterTable(ident: Identifier, changes: TableChange*): Table = { + // TODO: This evicts the staged TxnTable, losing any in-flight DML changes. The correct + // approach is to apply only the schema change to the existing TxnTable so that the ongoing + // DML can observe the new schema and reconcile at commit time. Concurrent DDL + DML is not + // supported in this test catalog for now. val newDelegateTable = delegate.alterTable(ident, changes: _*) tables.remove(ident) // Load again. newDelegateTable diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala index 93da0a95de956..aef9c65550fc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -399,4 +399,102 @@ class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { Row(1, 100, "hr"), Row(2, 200, "software"))) } + + test("SQL INSERT WITH SCHEMA EVOLUTION adds new column with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + sql(s"INSERT INTO $sourceNameAsString VALUES (3, 300, 'hr', true), (4, 400, 'software', false)") + + val (txn, txnTables) = executeTransaction { + sql(s"INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString SELECT * FROM $sourceNameAsString") + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // the new column must be visible in the committed delegate's schema + assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep", "active")) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), // pre-existing rows: active is null + Row(2, 200, "software", null), + Row(3, 300, "hr", true), // inserted with active + Row(4, 400, "software", false))) + } + + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE WITH SCHEMA EVOLUTION adds new column with transactional checks " + + s"isDynamic: $isDynamic") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + sql(s"INSERT INTO $sourceNameAsString VALUES (11, 999, 'hr', true), (12, 888, 'hr', false)") + + val insertOverwrite = if (isDynamic) { + s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString + |SELECT * FROM $sourceNameAsString + |""".stripMargin + } else { + s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT pk, salary, active FROM $sourceNameAsString + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val (txn, _) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + executeTransaction { sql(insertOverwrite) } + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // the new column must be visible in the committed delegate's schema + assert(table.schema.fieldNames.contains("active")) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software", null), // unchanged (different partition) + Row(11, 999, "hr", true), // overwrote hr partition + Row(12, 888, "hr", false))) + } + + test("SQL INSERT WITH SCHEMA EVOLUTION analysis failure aborts transaction") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + + val e = intercept[AnalysisException] { + sql( + s"""INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString + |SELECT nonexistent_col FROM $sourceNameAsString + |""".stripMargin) + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + // schema must be unchanged after the aborted transaction + assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep")) + } } From 38b6e318107f28552d74785cd0e4bb045a4402ed Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 10 Apr 2026 08:23:39 +0000 Subject: [PATCH 11/33] Fix schema evolution --- .../scala/org/apache/spark/sql/connector/catalog/txns.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 6a6897c896165..ba6de6be9130a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -71,7 +71,7 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) delegate.properties, delegate.constraints) { - withData(delegate.data) + alterTableWithData(delegate.data, delegate.schema) // Keep initial version to detect any changes during the transaction. private val initialVersion: String = version() From 6e33fa0a8ae452e979ce489865ee7997e8608c0a Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 10 Apr 2026 14:48:29 +0000 Subject: [PATCH 12/33] Add schema evolution fixme --- .../org/apache/spark/sql/connector/catalog/txns.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index ba6de6be9130a..94c3bda826b32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -124,10 +124,10 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T } override def alterTable(ident: Identifier, changes: TableChange*): Table = { - // TODO: This evicts the staged TxnTable, losing any in-flight DML changes. The correct - // approach is to apply only the schema change to the existing TxnTable so that the ongoing - // DML can observe the new schema and reconcile at commit time. Concurrent DDL + DML is not - // supported in this test catalog for now. + // FIXME: This is not transactional. The schema changes are applied directly to the delegate. + // The correct behavior is to apply the schema changes to the TxnTable and propagate them + // to the delegate only after commit. + // Furthermore, this also evicts the staged TxnTable, losing any in-flight DML changes. val newDelegateTable = delegate.alterTable(ident, changes: _*) tables.remove(ident) // Load again. newDelegateTable From 9f163a9338e92e22885cda51423c3f1b77656eac Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 10 Apr 2026 15:27:09 +0000 Subject: [PATCH 13/33] Schema evolution fix 2 --- .../spark/sql/connector/catalog/txns.scala | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 94c3bda826b32..4d11945a06580 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap sealed trait TransactionState @@ -63,15 +64,15 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { // A special table used in row-level operation transactions. It inherits data // from the base table upon construction and propagates staged transaction state // back after an explicit commit. -class TxnTable(val delegate: InMemoryRowLevelOperationTable) +class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) extends InMemoryRowLevelOperationTable( delegate.name, - delegate.schema, + schema, delegate.partitioning, delegate.properties, delegate.constraints) { - alterTableWithData(delegate.data, delegate.schema) + alterTableWithData(delegate.data, schema) // Keep initial version to detect any changes during the transaction. private val initialVersion: String = version() @@ -86,8 +87,8 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable) def commit(): Unit = { if (version() != initialVersion) { delegate.dataMap.clear() - delegate.alterTableWithData(data, schema) delegate.updateColumns(columns()) // Evolve schema if needed. + delegate.alterTableWithData(data, schema) delegate.replacedPartitions = replacedPartitions delegate.lastWriteInfo = lastWriteInfo delegate.lastWriteLog = lastWriteLog @@ -114,7 +115,7 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T override def loadTable(ident: Identifier): Table = { tables.computeIfAbsent(ident, _ => { val table = delegate.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] - new TxnTable(table) + new TxnTable(table, table.schema()) }) } @@ -124,13 +125,20 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T } override def alterTable(ident: Identifier, changes: TableChange*): Table = { - // FIXME: This is not transactional. The schema changes are applied directly to the delegate. - // The correct behavior is to apply the schema changes to the TxnTable and propagate them - // to the delegate only after commit. - // Furthermore, this also evicts the staged TxnTable, losing any in-flight DML changes. - val newDelegateTable = delegate.alterTable(ident, changes: _*) - tables.remove(ident) // Load again. - newDelegateTable + // AlterTable may be called by ResolveSchemaEvolution when schema evolution is enabled. Thus, + // it needs to be transactional. The schema changes are only propagated to the delegate at + // commit time. + val txnTable = tables.get(ident) + val schema = CatalogV2Util.applySchemaChanges( + txnTable.schema, changes, tableProvider = Some("in-memory"), statementType = "ALTER TABLE") + + if (schema.fields.isEmpty) { + throw new IllegalArgumentException(s"Cannot drop all fields") + } + + val newTxnTable = new TxnTable(txnTable.delegate, schema) + tables.put(ident, newTxnTable) + newTxnTable } override def dropTable(ident: Identifier): Boolean = { From a46d78f914a94fba91315b3927ad9576dfb5e45b Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Mon, 13 Apr 2026 12:54:00 +0000 Subject: [PATCH 14/33] Delegate schema computation changes to the underlying catalog --- ...nMemoryRowLevelOperationTableCatalog.scala | 25 +++++++++++++++---- .../spark/sql/connector/catalog/txns.scala | 7 ++++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index 78c350b5145a9..bdf19e0e9d355 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} +import org.apache.spark.sql.types.StructType class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog @@ -55,11 +56,7 @@ class InMemoryRowLevelOperationTableCatalog override def alterTable(ident: Identifier, changes: TableChange*): Table = { val table = loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) - val schema = CatalogV2Util.applySchemaChanges( - table.schema, - changes, - tableProvider = Some("in-memory"), - statementType = "ALTER TABLE") + val schema = computeAlterTableSchema(table.schema, changes.toSeq) val partitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes) val constraints = CatalogV2Util.collectConstraintChanges(table, changes) @@ -80,6 +77,16 @@ class InMemoryRowLevelOperationTableCatalog newTable } + + /** + * Computes the schema that would result from applying `changes` to `currentSchema`. + * Overriding this allows subclasses to simulate catalogs that selectively ignore some changes + * (e.g. [[PartialSchemaEvolutionCatalog]]). + */ + def computeAlterTableSchema(currentSchema: StructType, changes: Seq[TableChange]): StructType = { + CatalogV2Util.applySchemaChanges( + currentSchema, changes, tableProvider = Some("in-memory"), statementType = "ALTER TABLE") + } } /** @@ -108,4 +115,12 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo tables.put(ident, newTable) newTable } + + // When used inside a transaction, TxnTableCatalog.alterTable uses this method to compute + // the resulting schema instead of calling CatalogV2Util.applySchemaChanges directly. + // Returning the current schema unchanged mirrors the behaviour of alterTable above (silently + // ignore all column changes), so ResolveSchemaEvolution can still detect pending changes. + override def computeAlterTableSchema( + currentSchema: StructType, + changes: Seq[TableChange]): StructType = currentSchema } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 4d11945a06580..1f0d36f925625 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -128,9 +128,12 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T // AlterTable may be called by ResolveSchemaEvolution when schema evolution is enabled. Thus, // it needs to be transactional. The schema changes are only propagated to the delegate at // commit time. + // + // We delegate schema computation to the underlying catalog so that catalogs that selectively + // ignore some changes (e.g. PartialSchemaEvolutionCatalog) have the same behaviour inside a + // transaction. This lets ResolveSchemaEvolution detect pending changes correctly. val txnTable = tables.get(ident) - val schema = CatalogV2Util.applySchemaChanges( - txnTable.schema, changes, tableProvider = Some("in-memory"), statementType = "ALTER TABLE") + val schema = delegate.computeAlterTableSchema(txnTable.schema, changes.toSeq) if (schema.fields.isEmpty) { throw new IllegalArgumentException(s"Cannot drop all fields") From bd1efb8d1f724d9280cd709ecc4a582cafe25bd6 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Tue, 14 Apr 2026 07:39:27 +0000 Subject: [PATCH 15/33] Improve comments in schema evolution --- .../spark/sql/catalyst/analysis/V2TableReference.scala | 6 ++++++ .../InMemoryRowLevelOperationTableCatalog.scala | 10 ++++------ .../org/apache/spark/sql/connector/catalog/txns.scala | 6 +++--- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index 76226056ffe65..f459706a690bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -92,6 +92,8 @@ private[sql] object V2TableReference { create(relation, TemporaryViewContext(viewName)) } + // V2TableReference nodes in the transaction context are produced by + // UnresolveTransactionRelations which unresolves already resolved relations. def createForTransaction(relation: DataSourceV2Relation): V2TableReference = { create(relation, TransactionContext) } @@ -125,6 +127,10 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { } private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = { + // Do not allow schema evolution to pre-analysed dataframes that are later used in + // transactional writes. This is because the entire plans was built based on the original schema + // and any schema change would make the plan structurally invalid. This is inline with the + // semantics of SPARK-54444. val dataErrors = V2TableUtil.validateCapturedColumns( table = table, originCols = ref.info.columns, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index bdf19e0e9d355..4a38285b685e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -80,7 +80,7 @@ class InMemoryRowLevelOperationTableCatalog /** * Computes the schema that would result from applying `changes` to `currentSchema`. - * Overriding this allows subclasses to simulate catalogs that selectively ignore some changes + * Can be overridden by subclasses to simulate catalogs that selectively ignore changes * (e.g. [[PartialSchemaEvolutionCatalog]]). */ def computeAlterTableSchema(currentSchema: StructType, changes: Seq[TableChange]): StructType = { @@ -105,9 +105,10 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo case _ => false } val properties = CatalogV2Util.applyPropertiesChanges(table.properties, propertyChanges) + val schema = computeAlterTableSchema(table.schema, changes.toSeq) val newTable = new InMemoryRowLevelOperationTable( name = table.name, - schema = table.schema, + schema = schema, partitioning = table.partitioning, properties = properties, constraints = table.constraints) @@ -116,10 +117,7 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo newTable } - // When used inside a transaction, TxnTableCatalog.alterTable uses this method to compute - // the resulting schema instead of calling CatalogV2Util.applySchemaChanges directly. - // Returning the current schema unchanged mirrors the behaviour of alterTable above (silently - // ignore all column changes), so ResolveSchemaEvolution can still detect pending changes. + // Ignores all schema changes and returns the current schema unchanged. override def computeAlterTableSchema( currentSchema: StructType, changes: Seq[TableChange]): StructType = currentSchema diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 1f0d36f925625..f19e360047d26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -129,9 +129,9 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T // it needs to be transactional. The schema changes are only propagated to the delegate at // commit time. // - // We delegate schema computation to the underlying catalog so that catalogs that selectively - // ignore some changes (e.g. PartialSchemaEvolutionCatalog) have the same behaviour inside a - // transaction. This lets ResolveSchemaEvolution detect pending changes correctly. + // We delegate schema computation to the underlying catalog so that catalogs with special + // handling (e.g. PartialSchemaEvolutionCatalog) have the same behaviour inside a + // transaction. val txnTable = tables.get(ident) val schema = delegate.computeAlterTableSchema(txnTable.schema, changes.toSeq) From 21989decfd01960a87d0aadb3c107aa9ddbae456 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Tue, 14 Apr 2026 13:46:48 +0000 Subject: [PATCH 16/33] Mark new APIs as evolving + minor cleanup --- .../sql/connector/catalog/TransactionalCatalogPlugin.java | 2 ++ .../sql/connector/catalog/transactions/Transaction.java | 2 ++ .../sql/connector/catalog/transactions/TransactionInfo.java | 3 +++ .../org/apache/spark/sql/execution/QueryExecution.scala | 5 ++--- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java index 34a4fc68e9649..daa3176dcbba5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.catalog; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.catalog.transactions.Transaction; import org.apache.spark.sql.connector.catalog.transactions.TransactionInfo; @@ -28,6 +29,7 @@ * * @since 4.2.0 */ +@Evolving public interface TransactionalCatalogPlugin extends CatalogPlugin { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java index 80513aff31506..77044c6202fbe 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.catalog.transactions; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.catalog.CatalogPlugin; import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin; @@ -32,6 +33,7 @@ * * @since 4.2.0 */ +@Evolving public interface Transaction extends Closeable { /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java index a9c17d4b88274..3e6979cec469f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java @@ -17,11 +17,14 @@ package org.apache.spark.sql.connector.catalog.transactions; +import org.apache.spark.annotation.Evolving; + /** * Metadata about a transaction. * * @since 4.2.0 */ +@Evolving public interface TransactionInfo { /** * Returns a unique identifier for this transaction. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index b1567cfdd1c59..024073f55c6a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -287,9 +287,8 @@ class QueryExecution( // the transaction was active. if (transactionOpt.isDefined) { normalized - } - else { - // clone the plan to avoid sharing the plan instance between different stages like + } else { + // Clone the plan to avoid sharing the plan instance between different stages like // analyzing, optimizing and planning. sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) } From f93daf9dadd42c9d873570793cd9573f2d423334 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 15 Apr 2026 10:53:20 +0000 Subject: [PATCH 17/33] Add TODO plus nit --- .../spark/sql/connector/catalog/txns.scala | 1 + .../spark/sql/execution/QueryExecution.scala | 3 +- .../connector/StreamingTransactionSuite.scala | 213 ++++++++++++++++++ 3 files changed, 215 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index f19e360047d26..157c9a82d6f2f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -139,6 +139,7 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T throw new IllegalArgumentException(s"Cannot drop all fields") } + // TODO: We need to pass all tracked predicates to the new TXN table. val newTxnTable = new TxnTable(txnTable.delegate, schema) tables.put(ident, newTxnTable) newTxnTable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 024073f55c6a0..8ae58b76e8711 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -234,8 +234,7 @@ class QueryExecution( // for eagerly executed commands we mark this place as beginning of execution. tracker.setReadyForExecution() val (qe, result) = QueryExecution.runCommand( - sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode), - analyzerOpt = Some(analyzer)) + sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode), Some(analyzer)) CommandResult( qe.analyzed.output, qe.commandExecuted, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala new file mode 100644 index 0000000000000..bd3a8fc1307b2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{Committed, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, PhysicalWriteInfo, Write, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.sources.PackedRowWriterFactory +import org.apache.spark.sql.internal.connector.SimpleTableProvider +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.SlowSQLTest + +/** + * Tests that structured streaming micro-batch writes participate in the DSv2 transaction API. + * + * The V2 streaming path is: + * WriteToMicroBatchDataSource (logical) + * -> V2Writes rule -> WriteToDataSourceV2(MicroBatchWrite) (logical) + * -> WriteToDataSourceV2Exec (physical, implements TransactionalExec) + * + * Each micro-batch runs in its own IncrementalExecution, so transactionOpt is evaluated + * fresh per batch. The transaction is committed inside WriteToDataSourceV2Exec.run() after + * writeWithV2 completes, and aborted if writeWithV2 throws. + */ +@SlowSQLTest +class StreamingTransactionSuite extends StreamTest with BeforeAndAfter { + import testImplicits._ + + private val tableIdent = Identifier.of(Array("ns1"), "test_table") + private val tableNameAsString = "cat.ns1.test_table" + + before { + spark.conf.set("spark.sql.catalog.cat", + classOf[InMemoryRowLevelOperationTableCatalog].getName) + sql("CREATE NAMESPACE cat.ns1") + sql(s"CREATE TABLE $tableNameAsString (value INT) USING foo") + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.clear() + sqlContext.streams.active.foreach(_.stop()) + } + + private def catalog: InMemoryRowLevelOperationTableCatalog = + spark.sessionState.catalogManager.catalog("cat") + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + + private def delegateTable: InMemoryRowLevelOperationTable = + catalog.loadTable(tableIdent).asInstanceOf[InMemoryRowLevelOperationTable] + + test("streaming micro-batch append commits a transaction") { + val stream = MemoryStream[Int] + + withTempDir { checkpointDir => + val query = stream.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .toTable(tableNameAsString) + + try { + stream.addData(1, 2, 3) + query.processAllAvailable() + + val txn = catalog.lastTransaction + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(delegateTable.version() === "2") + + checkAnswer( + spark.table(tableNameAsString), + Seq(Row(1), Row(2), Row(3))) + } finally { + query.stop() + } + } + } + + test("each micro-batch gets a fresh transaction") { + val stream = MemoryStream[Int] + + withTempDir { checkpointDir => + val query = stream.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .toTable(tableNameAsString) + + try { + stream.addData(1, 2, 3) + query.processAllAvailable() + val txn1 = catalog.lastTransaction + assert(txn1.currentState === Committed) + assert(txn1.isClosed) + assert(delegateTable.version() === "2") + + stream.addData(4, 5) + query.processAllAvailable() + val txn2 = catalog.lastTransaction + assert(txn2.currentState === Committed) + assert(txn2.isClosed) + assert(txn2 ne txn1, "each batch must open a fresh transaction") + assert(delegateTable.version() === "3") + + checkAnswer( + spark.table(tableNameAsString), + Seq(Row(1), Row(2), Row(3), Row(4), Row(5))) + } finally { + query.stop() + } + } + } + + test("no transaction is started when the catalog is not transactional") { + // Writing to a non-transactional catalog (session catalog / parquet) must not + // open a transaction. Verify by confirming catalog.lastTransaction is untouched. + val initialLastTxn = catalog.lastTransaction // null at start + + withTempDir { dir => + val stream = MemoryStream[Int] + val query = stream.toDF() + .writeStream + .format("parquet") + .option("checkpointLocation", dir.getCanonicalPath + "/checkpoint") + .option("path", dir.getCanonicalPath + "/data") + .start() + + try { + stream.addData(1, 2, 3) + query.processAllAvailable() + + // our TxnTableCatalog was not involved - lastTransaction must be unchanged + assert(catalog.lastTransaction === initialLastTxn) + } finally { + query.stop() + } + } + } + + test("no transaction is started for an anonymous V2 sink (catalog = None)") { + // An anonymous V2 sink has DataSourceV2Relation.catalog == None (no catalog/ident). + // UnresolveTransactionRelations skips it since catalog doesn't match any + // TransactionalCatalogPlugin, so transactionOpt returns None and no transaction is opened. + val initialLastTxn = catalog.lastTransaction // null at start + + withTempDir { dir => + val stream = MemoryStream[Int] + val query = stream.toDF() + .writeStream + .format(classOf[NoOpV2SinkProvider].getName) + .option("checkpointLocation", dir.getCanonicalPath + "/checkpoint") + .start() + + try { + stream.addData(1, 2, 3) + query.processAllAvailable() + + // Anonymous V2 sink: no catalog involved, no transaction must be opened. + assert(catalog.lastTransaction === initialLastTxn) + } finally { + query.stop() + } + } + } +} + +/** + * A no-op V2 streaming sink with no catalog or identifier (anonymous sink). + * Used to verify that anonymous V2 sinks do not open transactions. + */ +class NoOpV2SinkProvider extends SimpleTableProvider { + override def getTable(options: CaseInsensitiveStringMap): Table = { + new Table with SupportsWrite { + override def name(): String = "noop-v2-sink" + override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = + util.EnumSet.of(TableCapability.STREAMING_WRITE) + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = + new WriteBuilder { + override def build(): Write = new Write { + override def toStreaming: StreamingWrite = new StreamingWrite { + override def createStreamingWriterFactory( + info2: PhysicalWriteInfo): StreamingDataWriterFactory = PackedRowWriterFactory + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + } + } + } + } + } +} From 391a1f3bbee4b95a028f3d57557ab63192e4f6ca Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 15 Apr 2026 11:00:26 +0000 Subject: [PATCH 18/33] Remove StreamingTransactionSuite --- .../connector/StreamingTransactionSuite.scala | 213 ------------------ 1 file changed, 213 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala deleted file mode 100644 index bd3a8fc1307b2..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector - -import java.util - -import org.scalatest.BeforeAndAfter - -import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.{Committed, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, SupportsWrite, Table, TableCapability} -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, PhysicalWriteInfo, Write, WriteBuilder, WriterCommitMessage} -import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} -import org.apache.spark.sql.execution.streaming.runtime.MemoryStream -import org.apache.spark.sql.execution.streaming.sources.PackedRowWriterFactory -import org.apache.spark.sql.internal.connector.SimpleTableProvider -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.tags.SlowSQLTest - -/** - * Tests that structured streaming micro-batch writes participate in the DSv2 transaction API. - * - * The V2 streaming path is: - * WriteToMicroBatchDataSource (logical) - * -> V2Writes rule -> WriteToDataSourceV2(MicroBatchWrite) (logical) - * -> WriteToDataSourceV2Exec (physical, implements TransactionalExec) - * - * Each micro-batch runs in its own IncrementalExecution, so transactionOpt is evaluated - * fresh per batch. The transaction is committed inside WriteToDataSourceV2Exec.run() after - * writeWithV2 completes, and aborted if writeWithV2 throws. - */ -@SlowSQLTest -class StreamingTransactionSuite extends StreamTest with BeforeAndAfter { - import testImplicits._ - - private val tableIdent = Identifier.of(Array("ns1"), "test_table") - private val tableNameAsString = "cat.ns1.test_table" - - before { - spark.conf.set("spark.sql.catalog.cat", - classOf[InMemoryRowLevelOperationTableCatalog].getName) - sql("CREATE NAMESPACE cat.ns1") - sql(s"CREATE TABLE $tableNameAsString (value INT) USING foo") - } - - after { - spark.sessionState.catalogManager.reset() - spark.sessionState.conf.clear() - sqlContext.streams.active.foreach(_.stop()) - } - - private def catalog: InMemoryRowLevelOperationTableCatalog = - spark.sessionState.catalogManager.catalog("cat") - .asInstanceOf[InMemoryRowLevelOperationTableCatalog] - - private def delegateTable: InMemoryRowLevelOperationTable = - catalog.loadTable(tableIdent).asInstanceOf[InMemoryRowLevelOperationTable] - - test("streaming micro-batch append commits a transaction") { - val stream = MemoryStream[Int] - - withTempDir { checkpointDir => - val query = stream.toDF() - .writeStream - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .toTable(tableNameAsString) - - try { - stream.addData(1, 2, 3) - query.processAllAvailable() - - val txn = catalog.lastTransaction - assert(txn.currentState === Committed) - assert(txn.isClosed) - assert(delegateTable.version() === "2") - - checkAnswer( - spark.table(tableNameAsString), - Seq(Row(1), Row(2), Row(3))) - } finally { - query.stop() - } - } - } - - test("each micro-batch gets a fresh transaction") { - val stream = MemoryStream[Int] - - withTempDir { checkpointDir => - val query = stream.toDF() - .writeStream - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .toTable(tableNameAsString) - - try { - stream.addData(1, 2, 3) - query.processAllAvailable() - val txn1 = catalog.lastTransaction - assert(txn1.currentState === Committed) - assert(txn1.isClosed) - assert(delegateTable.version() === "2") - - stream.addData(4, 5) - query.processAllAvailable() - val txn2 = catalog.lastTransaction - assert(txn2.currentState === Committed) - assert(txn2.isClosed) - assert(txn2 ne txn1, "each batch must open a fresh transaction") - assert(delegateTable.version() === "3") - - checkAnswer( - spark.table(tableNameAsString), - Seq(Row(1), Row(2), Row(3), Row(4), Row(5))) - } finally { - query.stop() - } - } - } - - test("no transaction is started when the catalog is not transactional") { - // Writing to a non-transactional catalog (session catalog / parquet) must not - // open a transaction. Verify by confirming catalog.lastTransaction is untouched. - val initialLastTxn = catalog.lastTransaction // null at start - - withTempDir { dir => - val stream = MemoryStream[Int] - val query = stream.toDF() - .writeStream - .format("parquet") - .option("checkpointLocation", dir.getCanonicalPath + "/checkpoint") - .option("path", dir.getCanonicalPath + "/data") - .start() - - try { - stream.addData(1, 2, 3) - query.processAllAvailable() - - // our TxnTableCatalog was not involved - lastTransaction must be unchanged - assert(catalog.lastTransaction === initialLastTxn) - } finally { - query.stop() - } - } - } - - test("no transaction is started for an anonymous V2 sink (catalog = None)") { - // An anonymous V2 sink has DataSourceV2Relation.catalog == None (no catalog/ident). - // UnresolveTransactionRelations skips it since catalog doesn't match any - // TransactionalCatalogPlugin, so transactionOpt returns None and no transaction is opened. - val initialLastTxn = catalog.lastTransaction // null at start - - withTempDir { dir => - val stream = MemoryStream[Int] - val query = stream.toDF() - .writeStream - .format(classOf[NoOpV2SinkProvider].getName) - .option("checkpointLocation", dir.getCanonicalPath + "/checkpoint") - .start() - - try { - stream.addData(1, 2, 3) - query.processAllAvailable() - - // Anonymous V2 sink: no catalog involved, no transaction must be opened. - assert(catalog.lastTransaction === initialLastTxn) - } finally { - query.stop() - } - } - } -} - -/** - * A no-op V2 streaming sink with no catalog or identifier (anonymous sink). - * Used to verify that anonymous V2 sinks do not open transactions. - */ -class NoOpV2SinkProvider extends SimpleTableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - new Table with SupportsWrite { - override def name(): String = "noop-v2-sink" - override def schema(): StructType = StructType(Nil) - override def capabilities(): util.Set[TableCapability] = - util.EnumSet.of(TableCapability.STREAMING_WRITE) - override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = - new WriteBuilder { - override def build(): Write = new Write { - override def toStreaming: StreamingWrite = new StreamingWrite { - override def createStreamingWriterFactory( - info2: PhysicalWriteInfo): StreamingDataWriterFactory = PackedRowWriterFactory - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - } - } - } - } - } -} From 60029299b73831c434bfea08316526a823ae43a5 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 15 Apr 2026 13:16:24 +0000 Subject: [PATCH 19/33] More comments and renames --- .../spark/sql/connector/catalog/txns.scala | 33 +++++++++++++++---- .../spark/sql/execution/QueryExecution.scala | 27 +++++++-------- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 157c9a82d6f2f..7b55e20c61676 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -48,11 +48,13 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { this.state = Committed } + // This is idempotent since nested QEs can cause multiple aborts. override def abort(): Unit = { if (state == Committed || state == Aborted) return this.state = Aborted } + // This is idempotent since nested QEs can cause multiple aborts. override def close(): Unit = { if (!closed) { catalog.clearActiveTransaction() @@ -64,6 +66,9 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { // A special table used in row-level operation transactions. It inherits data // from the base table upon construction and propagates staged transaction state // back after an explicit commit. +// Note, the in-memory data store does not handle concurrency at the moment. The assumes that the +// underlying delegate table cannot change from concurrent transactions. Data sources need to +// implement isolation semantics and make sure they are enforced. class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) extends InMemoryRowLevelOperationTable( delegate.name, @@ -80,10 +85,13 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) // A tracker of filters used in each scan. val scanEvents = new ArrayBuffer[Array[Filter]]() + // Record scan events. This is invoked when building a scan for the particular table. override protected def recordScanEvent(filters: Array[Filter]): Unit = { scanEvents += filters } + // Perform commit if there are any changes. This push metadata and data changes to the + // delegate table. def commit(): Unit = { if (version() != initialVersion) { delegate.dataMap.clear() @@ -98,8 +106,11 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) } } -// A special table catalog used in row-level operation transactions. -// Table changes are initially staged in memory and propagated only after an explicit commit. +// A special table catalog used in row-level operation transactions. The lifecycle of this catalog +// is tied to the transaction. A new catalog instance is created at the beginning of a transaction +// and discarded at the end. The catalog is responsible for pinning all tables involved in the +// transaction. Table changes are initially staged in memory and propagated only after an explicit +// commit. class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends TableCatalog { private val tables: util.Map[Identifier, TxnTable] = new ConcurrentHashMap[Identifier, TxnTable]() @@ -112,6 +123,9 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T throw new UnsupportedOperationException() } + // This is where the table pinning logic should occur. In this implementation, a tables is loaded + // (pinned) the first time is accessed. All subsequent accesses should return the same pinned + // table. override def loadTable(ident: Identifier): Table = { tables.computeIfAbsent(ident, _ => { val table = delegate.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] @@ -119,11 +133,6 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T }) } - override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { - delegate.createTable(ident, tableInfo) - loadTable(ident) - } - override def alterTable(ident: Identifier, changes: TableChange*): Table = { // AlterTable may be called by ResolveSchemaEvolution when schema evolution is enabled. Thus, // it needs to be transactional. The schema changes are only propagated to the delegate at @@ -145,6 +154,13 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T newTxnTable } + // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented. + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { + delegate.createTable(ident, tableInfo) + loadTable(ident) + } + + // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented. override def dropTable(ident: Identifier): Boolean = { tables.remove(ident) delegate.dropTable(ident) @@ -154,10 +170,13 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T throw new UnsupportedOperationException() } + // Invoke commit for all tables participated in the transaction. If a table is read-only + // this is a no-op. def commit(): Unit = { tables.values.forEach(table => table.commit()) } + // Clear transaction context. def clearActiveTransaction(): Unit = { delegate.lastTransaction = delegate.transaction delegate.transaction = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 8ae58b76e8711..42d72fb6d53a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -148,7 +148,7 @@ class QueryExecution( } } - def assertSupported(): Unit = executeWithTransactionContext { + def assertSupported(): Unit = withAbortTransactionOnFailure { if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForBatch(analyzed) } @@ -198,7 +198,7 @@ class QueryExecution( } } - def analyzed: LogicalPlan = executeWithTransactionContext { + def analyzed: LogicalPlan = withAbortTransactionOnFailure { lazyAnalyzed.get } @@ -210,7 +210,7 @@ class QueryExecution( } } - def commandExecuted: LogicalPlan = executeWithTransactionContext { + def commandExecuted: LogicalPlan = withAbortTransactionOnFailure { lazyCommandExecuted.get } @@ -272,7 +272,7 @@ class QueryExecution( } // The plan that has been normalized by custom rules, so that it's more likely to hit cache. - def normalized: LogicalPlan = executeWithTransactionContext { + def normalized: LogicalPlan = withAbortTransactionOnFailure { lazyNormalized.get } @@ -294,7 +294,7 @@ class QueryExecution( } } - def withCachedData: LogicalPlan = executeWithTransactionContext { + def withCachedData: LogicalPlan = withAbortTransactionOnFailure { lazyWithCachedData.get } @@ -318,7 +318,7 @@ class QueryExecution( } } - def optimizedPlan: LogicalPlan = executeWithTransactionContext { + def optimizedPlan: LogicalPlan = withAbortTransactionOnFailure { lazyOptimizedPlan.get } @@ -336,7 +336,7 @@ class QueryExecution( attachTransaction(plan) } - def sparkPlan: SparkPlan = executeWithTransactionContext { + def sparkPlan: SparkPlan = withAbortTransactionOnFailure { lazySparkPlan.get } @@ -359,7 +359,7 @@ class QueryExecution( // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - def executedPlan: SparkPlan = executeWithTransactionContext { + def executedPlan: SparkPlan = withAbortTransactionOnFailure { lazyExecutedPlan.get } @@ -379,7 +379,7 @@ class QueryExecution( * Given QueryExecution is not a public class, end users are discouraged to use this: please * use `Dataset.rdd` instead where conversion will be applied. */ - def toRdd: RDD[InternalRow] = executeWithTransactionContext { + def toRdd: RDD[InternalRow] = withAbortTransactionOnFailure { lazyToRdd.get } @@ -607,13 +607,10 @@ class QueryExecution( } /** - * Executes the given block with the transaction context if exists. If there is an exception - * thrown during the execution, the transaction will be aborted. - * - * Note: The transaction is not committed in this method. The caller should commit the - * transaction if the execution is successful. + * Runs the given block, aborting the active transaction if an exception is thrown. + * If no transaction is active, the block is executed as-is. */ - private def executeWithTransactionContext[T](block: => T): T = transactionOpt match { + private def withAbortTransactionOnFailure[T](block: => T): T = transactionOpt match { case Some(transaction) => try block catch { case e: Throwable => TransactionUtils.abort(transaction); throw e } From dfbbc0e87659de8ae346dd6d524b931f095ac79e Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 17 Apr 2026 12:02:34 +0000 Subject: [PATCH 20/33] Test coverage for SQL scripting --- ...nMemoryRowLevelOperationTableCatalog.scala | 5 + .../spark/sql/connector/catalog/txns.scala | 6 +- .../sql/scripting/SqlScriptingE2eSuite.scala | 145 ++++++++++++++++++ 3 files changed, 155 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index 4a38285b685e8..4e5e1e7c8c6e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.catalog +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} import org.apache.spark.sql.types.StructType @@ -30,6 +32,9 @@ class InMemoryRowLevelOperationTableCatalog var transaction: Txn = _ // The last completed transaction. var lastTransaction: Txn = _ + // All transactions in order (committed and aborted), allowing per-statement + // validation in SQL scripting tests. + val seenTransactions: ArrayBuffer[Txn] = new ArrayBuffer[Txn]() override def beginTransaction(info: TransactionInfo): Transaction = { assert(transaction == null || transaction.currentState != Active) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 7b55e20c61676..49ddeb2c7c809 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -22,6 +22,8 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.util.QuotingUtils import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -178,7 +180,9 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T // Clear transaction context. def clearActiveTransaction(): Unit = { - delegate.lastTransaction = delegate.transaction + val txn = delegate.transaction + delegate.lastTransaction = txn + delegate.seenTransactions += txn delegate.transaction = null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala index a36570467a9df..f9d939b632c40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala @@ -21,8 +21,10 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.plans.logical.CompoundBody import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLConf +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTableCatalog, Txn, TxnTable, TxnTableCatalog} import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -47,6 +49,27 @@ class SqlScriptingE2eSuite extends SharedSparkSession { } // Helpers + private def withCatalog( + name: String)( + f: InMemoryRowLevelOperationTableCatalog => Unit): Unit = { + withSQLConf(s"spark.sql.catalog.$name" -> + classOf[InMemoryRowLevelOperationTableCatalog].getName) { + val catalog = spark.sessionState.catalogManager + .catalog(name) + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + f(catalog) + } + } + + private def loadTxnTable( + txn: Txn, + tableName: String, + namespace: Array[String] = Array("ns1")): TxnTable = + txn.catalog + .asInstanceOf[TxnTableCatalog] + .loadTable(Identifier.of(namespace, tableName)) + .asInstanceOf[TxnTable] + private def verifySqlScriptResult( sqlText: String, expected: Seq[Row], @@ -174,6 +197,128 @@ class SqlScriptingE2eSuite extends SharedSparkSession { } } + test("multi statement with transactional checks - insert then delete") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'); + | DELETE FROM cat.ns1.t + | WHERE pk IN (SELECT pk FROM cat.ns1.t WHERE dep = 'hr'); + | SELECT * FROM cat.ns1.t; + |END + |""".stripMargin + + verifySqlScriptResult(sqlScript, Seq(Row(2, 200, "software"))) + + // Each DML statement in a script runs in its own independent QE and transaction. + assert(catalog.seenTransactions.size === 2) + assert(catalog.seenTransactions.forall(t => + t.currentState === Committed && t.isClosed)) + + // The DELETE subquery scans the table with a dep='hr' predicate; verify it was tracked. + val deleteTxnTable = loadTxnTable(catalog.seenTransactions(1), "t") + assert(deleteTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + } + } + } + + test("multi statement with transactional checks - second statement fails") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'); + | DELETE FROM cat.ns1.t WHERE nonexistent_column = 1; + |END + |""".stripMargin + + checkError( + exception = intercept[AnalysisException] { + spark.sql(sqlScript).collect() + }, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map( + "objectName" -> "`nonexistent_column`", + "proposal" -> ".*"), + matchPVals = true, + queryContext = Array(ExpectedContext("nonexistent_column"))) + + // INSERT committed; DELETE was aborted because analysis failed on the bad column. + assert(catalog.seenTransactions.size === 2) + assert(catalog.seenTransactions(0).currentState === Committed) + assert(catalog.seenTransactions(0).isClosed) + assert(catalog.seenTransactions(1).currentState === Aborted) + assert(catalog.seenTransactions(1).isClosed) + assert(catalog.lastTransaction.currentState === Aborted) + } + } + } + + test("multi statement with transactional checks - insert, merge, update") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t", "cat.ns1.src") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | CREATE TABLE cat.ns1.src (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'), (3, 300, 'hr'); + | INSERT INTO cat.ns1.src VALUES (1, 150, 'hr'), (4, 400, 'finance'); + | MERGE INTO cat.ns1.t AS t + | USING cat.ns1.src AS s + | ON t.pk = s.pk + | WHEN MATCHED THEN UPDATE SET salary = s.salary + | WHEN NOT MATCHED THEN INSERT (pk, salary, dep) + | VALUES (s.pk, s.salary, s.dep); + | UPDATE cat.ns1.t SET salary = salary + 50 WHERE dep = 'software'; + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq( + Row(1, 150, "hr"), + Row(2, 250, "software"), + Row(3, 300, "hr"), + Row(4, 400, "finance"))) + + // INSERT (x2), MERGE, and UPDATE each run in their own independent QE and transaction. + assert(catalog.seenTransactions.size === 4) + assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + + def txnTable(txnIdx: Int): TxnTable = + loadTxnTable(catalog.seenTransactions(txnIdx), "t") + + // Both inserts are pure writes - no scan. + assert(txnTable(0).scanEvents.isEmpty) + assert(txnTable(1).scanEvents.isEmpty) + + // MERGE scans the full target table. The join is on pk (not the partition column). + assert(txnTable(2).scanEvents.nonEmpty) + assert(txnTable(2).scanEvents.flatten.isEmpty) + + // UPDATE with WHERE dep='software' pushes an equality predicate on the partition column. + assert(txnTable(3).scanEvents.flatten.exists { + case sources.EqualTo("dep", "software") => true + case _ => false + }) + } + } + } + test("script without result statement") { val sqlScript = """ From 32a591e00a67202cb6cb098609b647444f1d8c74 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 17 Apr 2026 12:32:13 +0000 Subject: [PATCH 21/33] Extra SQL scripting tests --- .../sql/scripting/SqlScriptingE2eSuite.scala | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala index f9d939b632c40..13ee57d1a5a50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala @@ -319,6 +319,66 @@ class SqlScriptingE2eSuite extends SharedSparkSession { } } + test("loop with transactional checks - each iteration runs in its own transaction") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | DECLARE i INT = 1; + | CREATE TABLE + | cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | WHILE i <= 3 DO + | INSERT INTO cat.ns1.t VALUES (i, i * 100, 'hr'); + | SET i = i + 1; + | END WHILE; + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq(Row(1, 100, "hr"), Row(2, 200, "hr"), Row(3, 300, "hr"))) + + // Each loop iteration's INSERT runs in its own independent transaction. + assert(catalog.seenTransactions.size === 3) + assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + } + } + } + + test("continue handler with transactional checks - handler DML runs in its own transaction") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | DECLARE CONTINUE HANDLER FOR DIVIDE_BY_ZERO + | BEGIN + | INSERT INTO cat.ns1.t VALUES (-1, -1, 'error'); + | END; + | CREATE TABLE + | cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'); + | SELECT 1/0; + | INSERT INTO cat.ns1.t VALUES (2, 200, 'software'); + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq(Row(-1, -1, "error"), Row(1, 100, "hr"), Row(2, 200, "software"))) + + // INSERT(1), handler INSERT(-1), INSERT(2) - each in its own transaction. + assert(catalog.seenTransactions.size === 3) + assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + } + } + } + test("script without result statement") { val sqlScript = """ From 6094689b3201d677cef2d2a5d3a145f8fd669c50 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Fri, 17 Apr 2026 13:58:11 +0000 Subject: [PATCH 22/33] Fix lint --- .../scala/org/apache/spark/sql/connector/catalog/txns.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 49ddeb2c7c809..232b623174996 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -22,8 +22,6 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException -import org.apache.spark.sql.catalyst.util.QuotingUtils import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType From 2faf2376e2aeaca0d8da7e735a90da275d54941d Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Tue, 21 Apr 2026 06:15:33 +0000 Subject: [PATCH 23/33] Path based table support v1 --- .../spark/sql/execution/QueryExecution.scala | 44 +++++++++++++++++-- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 42d72fb6d53a0..00a22be1cd7ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -22,6 +22,7 @@ import java.util.UUID import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import javax.annotation.concurrent.GuardedBy +import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -32,20 +33,21 @@ import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} -import org.apache.spark.sql.catalyst.analysis.{Analyzer, LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubqueryAliases, LazyExpression, NameParameterizedQuery, UnresolvedRelation, UnsupportedOperationChecker} import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, UnresolvedWith, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, TransactionalWrite => TransactionalWritePlan, Union, UnresolvedWith, WithCTE} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.classic.SparkSession -import org.apache.spark.sql.connector.catalog.LookupCatalog +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, LookupCatalog, SupportsCatalogOptions, TransactionalCatalogPlugin} import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.execution.SQLExecution.EXECUTION_ROOT_ID_KEY import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan} import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.v2.{TransactionalExec, V2TableRefreshUtil} import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters import org.apache.spark.sql.execution.exchange.EnsureRequirements @@ -55,6 +57,7 @@ import org.apache.spark.sql.execution.streaming.runtime.{IncrementalExecution, W import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.scripting.SqlScriptingExecution import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LazyTry, Utils, UUIDv7Generator} import org.apache.spark.util.ArrayImplicits._ @@ -116,7 +119,8 @@ class QueryExecution( val catalog = logical match { case UnresolvedWith(TransactionalWrite(c), _, _) => Some(c) case TransactionalWrite(c) => Some(c) - case _ => None + case UnresolvedWith(inner, _, _) => pathBasedTransactionalCatalog(inner) + case other => pathBasedTransactionalCatalog(other) } catalog.map(TransactionUtils.beginTransaction) } else { @@ -124,6 +128,38 @@ class QueryExecution( } } + // For path-based tables (e.g. `delta.`/path/to/table``) the first identifier part is a + // connector name, not a catalog. SupportsCatalogOptions on the connector tells us which + // catalog actually owns the table. We can only do this lookup in sql/core where DataSource + // is available; the catalyst-side TransactionalWrite extractor only handles catalog tables. + private def pathBasedTransactionalCatalog( + plan: LogicalPlan): Option[TransactionalCatalogPlugin] = { + EliminateSubqueryAliases(plan) match { + case write: TransactionalWritePlan => + EliminateSubqueryAliases(write.table) match { + case UnresolvedRelation(parts, _, _) if parts.length >= 2 => + // Only proceed if parts.head is not a registered catalog; if it were, the + // catalyst-side extractor would have already matched it above. + Try(catalogManager.catalog(parts.head)).toOption match { + case None => + DataSource.lookupDataSourceV2(parts.head, sparkSession.sessionState.conf).flatMap { + case sco: SupportsCatalogOptions => + val options = new CaseInsensitiveStringMap( + java.util.Collections.singletonMap("path", parts.last)) + CatalogV2Util.getTableProviderCatalog(sco, catalogManager, options) match { + case c: TransactionalCatalogPlugin => Some(c) + case _ => None + } + case _ => None + } + case _ => None + } + case _ => None + } + case _ => None + } + } + // Per-query analyzer: uses a transaction-aware CatalogManager when a transaction is active, // so that all catalog lookups and rule applications during analysis see the correct state // without relying on thread-local context. From 5854cd06a30acb3ff3b4261886f51c73e29cb031 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Tue, 21 Apr 2026 09:28:54 +0000 Subject: [PATCH 24/33] Path based tables support improvements --- .../spark/sql/execution/QueryExecution.scala | 69 ++++--- ...pache.spark.sql.sources.DataSourceRegister | 2 + .../PathBasedTableTransactionSuite.scala | 180 ++++++++++++++++++ 3 files changed, 216 insertions(+), 35 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 00a22be1cd7ca..ac8388ef72349 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -22,6 +22,7 @@ import java.util.UUID import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import javax.annotation.concurrent.GuardedBy +import scala.jdk.CollectionConverters._ import scala.util.Try import scala.util.control.NonFatal @@ -42,13 +43,13 @@ import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.classic.SparkSession -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, LookupCatalog, SupportsCatalogOptions, TransactionalCatalogPlugin} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, LookupCatalog, SupportsCatalogOptions, TableCatalog, TransactionalCatalogPlugin} import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.execution.SQLExecution.EXECUTION_ROOT_ID_KEY import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan} import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.v2.{TransactionalExec, V2TableRefreshUtil} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, TransactionalExec, V2TableRefreshUtil} import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery @@ -116,11 +117,17 @@ class QueryExecution( analyzerOpt.flatMap(_.catalogManager.transaction).orElse { // Only begin a new transaction for outer QEs that lead to execution. if (mode != CommandExecutionMode.SKIP) { + def resolve(w: TransactionalWritePlan): Option[TransactionalCatalogPlugin] = + pathBased(w) match { + case Some(c: TransactionalCatalogPlugin) => Some(c) + case Some(_) => None + // If the path is not data source based, fallback to catalog based resolution. + case None => TransactionalWrite.unapply(w) + } val catalog = logical match { - case UnresolvedWith(TransactionalWrite(c), _, _) => Some(c) - case TransactionalWrite(c) => Some(c) - case UnresolvedWith(inner, _, _) => pathBasedTransactionalCatalog(inner) - case other => pathBasedTransactionalCatalog(other) + case UnresolvedWith(w: TransactionalWritePlan, _, _) => resolve(w) + case w: TransactionalWritePlan => resolve(w) + case _ => None } catalog.map(TransactionUtils.beginTransaction) } else { @@ -128,37 +135,29 @@ class QueryExecution( } } - // For path-based tables (e.g. `delta.`/path/to/table``) the first identifier part is a - // connector name, not a catalog. SupportsCatalogOptions on the connector tells us which - // catalog actually owns the table. We can only do this lookup in sql/core where DataSource - // is available; the catalyst-side TransactionalWrite extractor only handles catalog tables. - private def pathBasedTransactionalCatalog( - plan: LogicalPlan): Option[TransactionalCatalogPlugin] = { - EliminateSubqueryAliases(plan) match { - case write: TransactionalWritePlan => - EliminateSubqueryAliases(write.table) match { - case UnresolvedRelation(parts, _, _) if parts.length >= 2 => - // Only proceed if parts.head is not a registered catalog; if it were, the - // catalyst-side extractor would have already matched it above. - Try(catalogManager.catalog(parts.head)).toOption match { - case None => - DataSource.lookupDataSourceV2(parts.head, sparkSession.sessionState.conf).flatMap { - case sco: SupportsCatalogOptions => - val options = new CaseInsensitiveStringMap( - java.util.Collections.singletonMap("path", parts.last)) - CatalogV2Util.getTableProviderCatalog(sco, catalogManager, options) match { - case c: TransactionalCatalogPlugin => Some(c) - case _ => None - } - case _ => None - } - case _ => None - } - case _ => None - } + // For path-based tables (e.g. `format.`/path/to/table``) the first identifier part is a + // connector name. SupportsCatalogOptions on the connector tells us which catalog actually + // owns the table. Returns Some(catalog) if parts.head is a recognized SupportsCatalogOptions + // data source (caller decides whether the catalog is transactional), or None to fall through + // to the catalog-based extractor. + private def pathBased(write: TransactionalWritePlan): Option[TableCatalog] = + EliminateSubqueryAliases(write.table) match { + case UnresolvedRelation(parts, _, _) if parts.length > 1 => + Try(DataSource.lookupDataSourceV2(parts.head, sparkSession.sessionState.conf)) + .toOption + .flatten + .collect { case sco: SupportsCatalogOptions => sco } + .map { sco => + val sessionConfigs = DataSourceV2Utils.extractSessionConfigs( + sco, sparkSession.sessionState.conf) + // Pass the entire identifier as option. The connector can decide how to parse it + // if needed. + val options = sessionConfigs + ("identifier" -> parts.mkString(".")) + CatalogV2Util.getTableProviderCatalog( + sco, catalogManager, new CaseInsensitiveStringMap(options.asJava)) + } case _ => None } - } // Per-query analyzer: uses a transaction-aware CatalogManager when a transaction is active, // so that all catalog lookups and rule applications during analysis see the correct state diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index c1fc7234d7c19..0354e545aa903 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -18,6 +18,8 @@ org.apache.spark.sql.sources.FakeSourceOne org.apache.spark.sql.sources.FakeSourceTwo org.apache.spark.sql.sources.FakeSourceThree +org.apache.spark.sql.connector.FakePathBasedSource +org.apache.spark.sql.connector.FakePathBasedSourceWithSessionConfig org.apache.spark.sql.sources.FakeSourceFour org.apache.fakesource.FakeExternalSourceOne org.apache.fakesource.FakeExternalSourceTwo diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala new file mode 100644 index 0000000000000..c6b2f33c25fe0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTableCatalog, InMemoryTableCatalog, SessionConfigSupport, SupportsCatalogOptions} +import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Tests for transactional writes to path-based tables, where the table is identified by a + * bare path with no catalog prefix (e.g. `/path/to/t`), or a connector-prefixed path + * (e.g. `pathformat.`/path/to/t``). The transactional catalog is registered as the session + * catalog (`spark_catalog`). + */ +class PathBasedTableTransactionSuite extends RowLevelOperationSuiteBase { + + private val tablePath = "`/path/to/t`" + private val tablePathWithFormat = "pathformat.`/path/to/t`" + + override def beforeEach(): Unit = { + super.beforeEach() + spark.conf.set( + V2_SESSION_CATALOG_IMPLEMENTATION.key, + classOf[InMemoryRowLevelOperationTableCatalog].getName) + } + + override def afterEach(): Unit = { + spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) + super.afterEach() + } + + override protected def catalog: InMemoryRowLevelOperationTableCatalog = { + spark.sessionState.catalogManager.v2SessionCatalog + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + } + + private def createPathTable(name: String): Unit = { + sql(s"CREATE TABLE $name (id INT, data STRING)") + } + + test("SQL insert into bare path-based table participates in transaction") { + createPathTable(tablePath) + val (txn, _) = executeTransaction { + sql(s"INSERT INTO $tablePath VALUES (1, 'a'), (2, 'b')") + } + assert(txn.currentState === Committed) + assert(txn.isClosed) + checkAnswer(spark.table(tablePath), Row(1, "a") :: Row(2, "b") :: Nil) + } + + test("SQL insert with connector-prefixed path participates in transaction") { + createPathTable(tablePathWithFormat) + val (txn, _) = executeTransaction { + sql(s"INSERT INTO $tablePathWithFormat VALUES (1, 'a'), (2, 'b')") + } + assert(txn.currentState === Committed) + assert(txn.isClosed) + checkAnswer(spark.table(tablePathWithFormat), Row(1, "a") :: Row(2, "b") :: Nil) + } + + test("SQL insert with CTE into connector-prefixed path participates in transaction") { + createPathTable(tablePathWithFormat) + val (txn, _) = executeTransaction { + sql(s""" + |WITH cte AS (SELECT 1 AS id, 'a' AS data) + |INSERT INTO $tablePathWithFormat SELECT * FROM cte + |""".stripMargin) + } + assert(txn.currentState === Committed) + assert(txn.isClosed) + checkAnswer(spark.table(tablePathWithFormat), Row(1, "a") :: Nil) + } + + test("session-config catalog controls which catalog is enrolled in transaction") { + withSQLConf( + "spark.sql.catalog.txncat" -> classOf[InMemoryRowLevelOperationTableCatalog].getName, + "spark.sql.catalog.nontxncat" -> classOf[InMemoryTableCatalog].getName) { + val txnCat = spark.sessionState.catalogManager.catalog("txncat") + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + + // Non-transactional catalog configured. + withSQLConf("spark.datasource.pathformat2.catalog" -> "nontxncat") { + createPathTable("pathformat2.`/path/to/t1`") + sql("INSERT INTO pathformat2.`/path/to/t1` VALUES (1, 'a')") + // The transaction was not routed to any of the transactional catalogs. + assert(catalog.lastTransaction == null) + assert(txnCat.lastTransaction == null) + } + + // Transactional catalog configured: pathBased resolves txncat as a + // TransactionalCatalogPlugin and opens the transaction there instead. + withSQLConf("spark.datasource.pathformat2.catalog" -> "txncat") { + createPathTable("pathformat2.`/path/to/t2`") + sql("INSERT INTO pathformat2.`/path/to/t2` VALUES (1, 'a')") + assert(txnCat.lastTransaction.currentState === Committed) + assert(txnCat.lastTransaction.isClosed) + } + } + } + + test("SQL insert with unregistered format produces analysis error and aborts transaction") { + createPathTable(tablePathWithFormat) + // "Unregistered" is not a known catalog and not registered data source. + // So Spark falls back to treating it as a namespace in spark_catalog. The table + // does not exist, causing an AnalysisException. The transaction is started (because + // spark_catalog IS a TransactionalCatalogPlugin) and then aborted on failure. + checkError( + exception = intercept[AnalysisException] { + sql("INSERT INTO unregistered.`/path/to/t` VALUES (1, 'a'), (2, 'b')") + }, + condition = "TABLE_OR_VIEW_NOT_FOUND", + parameters = Map("relationName" -> "`unregistered`.`/path/to/t`"), + context = ExpectedContext( + fragment = "unregistered.`/path/to/t`", + start = -1, + stop = -1)) + val txn = catalog.lastTransaction + assert(txn.currentState === Aborted) + assert(txn.isClosed) + } +} + +/** + * Simulates a path-based connector (e.g. Delta) that implements [[SupportsCatalogOptions]] + * to route `pathformat.\`/path/to/t\`` SQL identifiers to the session catalog. Returning + * null from [[extractCatalog]] signals that the session catalog (`spark_catalog`) owns the + * table, matching Delta's behavior where DeltaCatalog is registered as spark_catalog. + */ +class FakePathBasedSource + extends FakeV2ProviderWithCustomSchema + with SupportsCatalogOptions + with DataSourceRegister { + + override def shortName(): String = "pathformat" + + // Use the session catalog. + override def extractCatalog(options: CaseInsensitiveStringMap): String = null + + // Not used in the transactional path. + override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = null +} + +/** + * Like [[FakePathBasedSource]] but resolves the owning catalog from the session config + * `spark.datasource.pathformat2.catalog` instead of always returning null. This simulates + * a connector that lets users configure the target catalog. + */ +class FakePathBasedSourceWithSessionConfig + extends FakeV2ProviderWithCustomSchema + with SupportsCatalogOptions + with SessionConfigSupport + with DataSourceRegister { + + override def shortName(): String = "pathformat2" + + override def keyPrefix: String = "pathformat2" + + override def extractCatalog(options: CaseInsensitiveStringMap): String = options.get("catalog") + + // Not used in the transactional path. + override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = null +} From 097fa7c041d65a8178ce82d423afaca9f581996b Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Tue, 21 Apr 2026 17:27:02 +0000 Subject: [PATCH 25/33] Catalog reset fix --- .../org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala index 13ee57d1a5a50..2a3bd915eccd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala @@ -57,7 +57,7 @@ class SqlScriptingE2eSuite extends SharedSparkSession { val catalog = spark.sessionState.catalogManager .catalog(name) .asInstanceOf[InMemoryRowLevelOperationTableCatalog] - f(catalog) + try f(catalog) finally spark.sessionState.catalogManager.reset() } } From 76c567041fef128d2387ffeb9cdceb59ac55b2b4 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Mon, 27 Apr 2026 14:40:57 +0000 Subject: [PATCH 26/33] Transactional Streaming v1 --- .../sql/catalyst/analysis/Analyzer.scala | 42 +-- .../catalyst/plans/logical/v2Commands.scala | 6 + .../connector/catalog/InMemoryBaseTable.scala | 39 +++ .../catalog/InMemoryTableCatalog.scala | 2 +- .../spark/sql/connector/catalog/txns.scala | 32 +++ .../spark/sql/execution/QueryExecution.scala | 27 +- .../execution/datasources/v2/V2Writes.scala | 19 +- .../v2/WriteToDataSourceV2Exec.scala | 7 +- .../runtime/IncrementalExecution.scala | 3 + .../runtime/MicroBatchExecution.scala | 19 +- .../sources/WriteToMicroBatchDataSource.scala | 28 +- .../DeltaBasedDeleteFromTableSuite.scala | 2 - .../connector/StreamingTransactionSuite.scala | 253 ++++++++++++++++++ .../test/DataStreamTableAPISuite.scala | 46 +++- 14 files changed, 461 insertions(+), 64 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fdd7d09356d8d..bf81f3740e4a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1055,6 +1055,28 @@ class Analyzer( } } + // Resolve the write target of a V2 write command (batch or streaming). + private def resolveWriteTarget( + write: LogicalPlan, + table: NamedRelation, + withNewTable: NamedRelation => LogicalPlan): LogicalPlan = { + table match { + case u: UnresolvedRelation if !u.isStreaming => + resolveRelation(u).map(unwrapRelationPlan).map { + case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError( + v.desc.identifier, write) + case u: UnresolvedCatalogRelation => + throw QueryCompilationErrors.writeIntoV1TableNotAllowedError( + u.tableMeta.identifier, write) + case r: DataSourceV2Relation => withNewTable(r) + case _ => + throw QueryCompilationErrors.writeIntoTempViewNotAllowedError( + u.multipartIdentifier.quoted) + }.getOrElse(write) + case _ => write + } + } + // Resolve V2TableReference nodes created for: // 1 Temp views (via createForTempView). // 2. Transaction references (via createForTransaction). These are resolved by a @@ -1084,23 +1106,11 @@ class Analyzer( case other => i.copy(table = other) } - // TODO (SPARK-27484): handle streaming write commands when we have them. + case write: StreamingV2WriteCommand => + resolveWriteTarget(write, write.table, write.withNewTable) + case write: V2WriteCommand => - write.table match { - case u: UnresolvedRelation if !u.isStreaming => - resolveRelation(u).map(unwrapRelationPlan).map { - case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError( - v.desc.identifier, write) - case u: UnresolvedCatalogRelation => - throw QueryCompilationErrors.writeIntoV1TableNotAllowedError( - u.tableMeta.identifier, write) - case r: DataSourceV2Relation => write.withNewTable(r) - case _ => - throw QueryCompilationErrors.writeIntoTempViewNotAllowedError( - u.multipartIdentifier.quoted) - }.getOrElse(write) - case _ => write - } + resolveWriteTarget(write, write.table, write.withNewTable) case u: UnresolvedRelation => resolveRelation(u).map(resolveViews(_, u.options)).getOrElse(u) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index c16087bdf9bb7..d47abe284e6f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -1291,6 +1291,12 @@ trait TransactionalWrite extends LogicalPlan { def table: LogicalPlan } +/** Trait for streaming write commands that participate in DSv2 transactions. */ +trait StreamingV2WriteCommand extends TransactionalWrite { + override def table: NamedRelation + def withNewTable(newTable: NamedRelation): StreamingV2WriteCommand +} + /** * The logical plan of the DROP TABLE command. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index f49838f10c904..30e046bcd700e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, Cus import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram, HistogramBin} import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.internal.SQLConf @@ -412,6 +413,7 @@ abstract class InMemoryBaseTable( def baseCapabiilities: Set[TableCapability] = Set( TableCapability.BATCH_READ, + TableCapability.MICRO_BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.STREAMING_WRITE, TableCapability.OVERWRITE_BY_FILTER, @@ -501,6 +503,33 @@ abstract class InMemoryBaseTable( case class InMemoryHistogram(height: Double, bins: Array[HistogramBin]) extends Histogram + private class InMemoryTableOffset(val rowCount: Long) extends Offset { + override def json(): String = rowCount.toString + } + + class InMemoryMicroBatchStream extends MicroBatchStream { + override def initialOffset(): Offset = new InMemoryTableOffset(0) + override def latestOffset(): Offset = + new InMemoryTableOffset(InMemoryBaseTable.this.rows.size.toLong) + override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = { + val s = start.asInstanceOf[InMemoryTableOffset].rowCount.toInt + val e = end.asInstanceOf[InMemoryTableOffset].rowCount.toInt + Array(InMemoryMicroBatchPartition(InMemoryBaseTable.this.rows.slice(s, e))) + } + override def createReaderFactory(): PartitionReaderFactory = { partition => + val rows = partition.asInstanceOf[InMemoryMicroBatchPartition].rows + new PartitionReader[InternalRow] { + private var idx = -1 + override def next(): Boolean = { idx += 1; idx < rows.size } + override def get(): InternalRow = rows(idx) + override def close(): Unit = {} + } + } + override def deserializeOffset(json: String): Offset = new InMemoryTableOffset(json.toLong) + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} + } + abstract class BatchScanBaseClass( var data: Seq[InputPartition], readSchema: StructType, @@ -586,6 +615,9 @@ abstract class InMemoryBaseTable( override def supportedCustomMetrics(): Array[CustomMetric] = { Array(new RowsReadCustomMetric) } + + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = + new InMemoryMicroBatchStream } case class InMemoryBatchScan( @@ -806,6 +838,13 @@ object InMemoryBaseTable { } } +/** + * A partition for [[InMemoryBaseTable]] micro-batch streaming reads, holding a slice of rows. + * Defined at the top level (not as an inner class) so that Java serialization to executors + * does not attempt to serialize the enclosing [[InMemoryBaseTable]] instance. + */ +case class InMemoryMicroBatchPartition(rows: Seq[InternalRow]) extends InputPartition + /** * Represent a set of rows buffered in memory for a given partition key. * @param key partition key diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index ff7995ad6697e..5ce1804c11a7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -40,7 +40,7 @@ class BasicInMemoryTableCatalog extends TableCatalog { protected val namespaces: util.Map[List[String], Map[String, String]] = new ConcurrentHashMap[List[String], Map[String, String]]() - protected val tables: util.Map[Identifier, Table] = + protected var tables: util.Map[Identifier, Table] = new ConcurrentHashMap[Identifier, Table]() private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 232b623174996..99b1e4c8f9870 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -21,6 +21,7 @@ import java.util import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.sources.Filter @@ -170,6 +171,9 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T throw new UnsupportedOperationException() } + // Returns all tables that participated in this transaction, keyed by identifier. + def txnTables: scala.collection.Map[Identifier, TxnTable] = tables.asScala + // Invoke commit for all tables participated in the transaction. If a table is read-only // this is a no-op. def commit(): Unit = { @@ -193,3 +197,31 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T override def hashCode(): Int = name.hashCode() } + +/** + * An InMemoryRowLevelOperationTableCatalog that utilizes tables backed by a shared map. This + * simulates the behavior of real catalogs (Delta, Iceberg, etc.) where multiple instances + * of the catalog share the same underlying persistent storage, thus, they see the same tables. + * + * This is needed for testing execution that spans multiple Spark sessions. In particular, + * streaming queries execute micro-batches in cloned Spark sessions. Without this, the cloned + * spark session catalog will not see any tables created in the original session. + * + * Tests that use this catalog must call + * [[SharedTablesInMemoryRowLevelOperationTableCatalog.reset()]] in `afterEach` to clear the + * shared state between test cases. + */ +class SharedTablesInMemoryRowLevelOperationTableCatalog + extends InMemoryRowLevelOperationTableCatalog { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { + super.initialize(name, options) + tables = SharedTablesInMemoryRowLevelOperationTableCatalog.sharedTables + } +} + +object SharedTablesInMemoryRowLevelOperationTableCatalog { + private[catalog] val sharedTables: ConcurrentHashMap[Identifier, Table] = + new ConcurrentHashMap[Identifier, Table]() + + def reset(): Unit = sharedTables.clear() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index ac8388ef72349..bbe14db33ca86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -496,17 +496,24 @@ class QueryExecution( } } + /** + * Returns the QueryExecution to use when generating an explain string. + * Overridden by IncrementalExecution to reuse `this` so that the already-open transaction and + * cached executedPlan are not duplicated. + */ + protected def queryExecutionForExplain: QueryExecution = if (logical.isStreaming) { + // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the + // output mode does not matter since there is no `Sink`. + new IncrementalExecution( + sparkSession, logical, OutputMode.Append(), "", + UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0), + WatermarkPropagator.noop(), false, mode = this.mode) + } else { + this + } + private def explainString(mode: ExplainMode, maxFields: Int, append: String => Unit): Unit = { - val queryExecution = if (logical.isStreaming) { - // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the - // output mode does not matter since there is no `Sink`. - new IncrementalExecution( - sparkSession, logical, OutputMode.Append(), "", - UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0), - WatermarkPropagator.noop(), false, mode = this.mode) - } else { - this - } + val queryExecution = queryExecutionForExplain mode match { case SimpleMode => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index 0249e5b49c9bc..581027f951938 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -91,17 +91,18 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { o.copy(write = Some(write), query = newQuery) case WriteToMicroBatchDataSource( - relationOpt, table, query, queryId, options, outputMode, Some(batchId)) => - val writeOptions = mergeOptions( - options, - relationOpt.map(r => r.options.asCaseSensitiveMap.asScala.toMap).getOrElse(Map.empty)) - val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId = queryId) - val write = buildWriteForMicroBatch(tableDataSourceV2Strategy, writeBuilder, outputMode) + relation, query, queryId, options, outputMode, Some(batchId)) => + val v2Relation = relation.asInstanceOf[DataSourceV2Relation] + val writeOptions = mergeOptions(options, v2Relation.options.asCaseSensitiveMap.asScala.toMap) + // Guaranteed to support writes since it is a strict requirement to construct + // WriteToMicroBatchDataSource. + val writeTable = v2Relation.table.asInstanceOf[SupportsWrite] + val writeBuilder = newWriteBuilder(writeTable, writeOptions, query.schema, queryId = queryId) + val write = buildWriteForMicroBatch(writeTable, writeBuilder, outputMode) val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming) val customMetrics = write.supportedCustomMetrics.toImmutableArraySeq - val funCatalogOpt = relationOpt.flatMap(_.funCatalog) - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt) - WriteToDataSourceV2(relationOpt, microBatchWrite, newQuery, customMetrics) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, v2Relation.funCatalog) + WriteToDataSourceV2(Some(v2Relation), microBatchWrite, newQuery, customMetrics) case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, projections, _, None) => val rowSchema = projections.rowProjection.schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 308f4bdc5042b..9e579ae779f3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -416,7 +416,11 @@ case class WriteToDataSourceV2Exec( batchWrite: BatchWrite, refreshCache: () => Unit, query: SparkPlan, - writeMetrics: Seq[CustomMetric]) extends V2TableWriteExec { + writeMetrics: Seq[CustomMetric], + transaction: Option[Transaction] = None) extends V2TableWriteExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): WriteToDataSourceV2Exec = + copy(transaction = txn) override def stringArgs: Iterator[Any] = Iterator(batchWrite, query) @@ -426,6 +430,7 @@ case class WriteToDataSourceV2Exec( override protected def run(): Seq[InternalRow] = { val writtenRows = writeWithV2(batchWrite) + transaction.foreach(TransactionUtils.commit) refreshCache() writtenRows } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala index 1587fd4786a35..9fc72241e83b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala @@ -143,6 +143,9 @@ class IncrementalExecution( } } + // Use `this` for explain so the already-open transaction and executedPlan are reused. + override protected def queryExecutionForExplain: QueryExecution = this + private val allowMultipleStatefulOperators: Boolean = sparkSession.sessionState.conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala index 973af04e04307..e6d0666aca259 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException} import org.apache.spark.internal.LogKeys import org.apache.spark.internal.LogKeys._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Deduplicate, DeduplicateWithinWatermark, Distinct, FlatMapGroupsInPandasWithState, FlatMapGroupsWithState, GlobalLimit, Join, LeafNode, LocalRelation, LogicalPlan, Project, StreamSourceAwareLogicalPlan, TransformWithState, TransformWithStateInPySpark} @@ -37,7 +38,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.classic.{Dataset, SparkSession} import org.apache.spark.sql.classic.ClassicConversions.castToImpl -import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability} +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability, TransactionalCatalogPlugin} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsRealTimeMode, SupportsTriggerAvailableNow} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} @@ -46,7 +47,6 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, Real import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, Offset, OneTimeTrigger, ProcessingTimeTrigger, RealTimeModeAllowlist, RealTimeTrigger, Sink, Source, StreamingQueryPlanTraverseHelper} import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, CommitMetadata, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataV2} import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, StateStoreWriter} -import org.apache.spark.sql.execution.streaming.runtime.AcceptsLatestSeenOffsetHandler import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchSink, WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1} import org.apache.spark.sql.execution.streaming.state.{OfflineStateRepartitionUtils, StateSchemaBroadcast, StateStoreErrors} @@ -346,15 +346,20 @@ class MicroBatchExecution( ) } - // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. sink match { case s: SupportsWrite => - val relationOpt = plan.catalogAndIdent.map { - case (catalog, ident) => DataSourceV2Relation.create(s, Some(catalog), Some(ident)) + val relation = plan.catalogAndIdent match { + // When the catalog is transactional, instead of eagerly creating the relation, we + // delegate resolution to ResolveRelations. This allows to resolve the relation against + // a transactional catalo which keeps track of all tables loaded within the transaction. + case Some((catalog: TransactionalCatalogPlugin, ident)) => + UnresolvedRelation(catalog.name +: ident.namespace().toSeq :+ ident.name()) + case Some((catalog, ident)) => + DataSourceV2Relation.create(s, Some(catalog), Some(ident)) + case None => DataSourceV2Relation.create(s, None, None) } WriteToMicroBatchDataSource( - relationOpt, - table = s, + relation, query = _logicalPlan, queryId = id.toString, extraOptions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala index 0a33093dcbcea..7aa7a31bb085e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.execution.streaming.sources +import org.apache.spark.sql.catalyst.analysis.NamedRelation import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} -import org.apache.spark.sql.connector.catalog.SupportsWrite -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, StreamingV2WriteCommand, UnaryNode} import org.apache.spark.sql.streaming.OutputMode /** @@ -29,19 +28,36 @@ import org.apache.spark.sql.streaming.OutputMode * Note that this logical plan does not have a corresponding physical plan, as it will be converted * to [[org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 WriteToDataSourceV2]] * with [[MicroBatchWrite]] before execution. + * + * [[relation]] starts as [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation]] when the + * sink has a catalog+identifier (transactional catalogs), or as a resolved + * [[org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation]] for non-transactional + * catalog-backed sinks and format-based sinks. + * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveRelations]] + * resolves it to [[org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation]] during + * each micro-batch analysis, going through the transaction-aware catalog when a transaction is + * active. */ case class WriteToMicroBatchDataSource( - relation: Option[DataSourceV2Relation], - table: SupportsWrite, + relation: NamedRelation, query: LogicalPlan, queryId: String, writeOptions: Map[String, String], outputMode: OutputMode, batchId: Option[Long] = None) - extends UnaryNode { + extends UnaryNode with StreamingV2WriteCommand { + override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil + override def simpleString(maxFields: Int): String = + s"WriteToMicroBatchDataSource ${relation.name}" + + override def table: NamedRelation = relation + + override def withNewTable(newTable: NamedRelation): WriteToMicroBatchDataSource = + copy(relation = newTable) + def withNewBatchId(batchId: Long): WriteToMicroBatchDataSource = { copy(batchId = Some(batchId)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala index 9b630b25f658e..15d259d44a4fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedDeleteFromTableSuite.scala @@ -35,8 +35,6 @@ class DeltaBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { override def enforceCheckConstraintOnDelete: Boolean = false - override protected def deltaDelete: Boolean = true - test("delete handles metadata columns correctly") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala new file mode 100644 index 0000000000000..10551f29eb1b7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util.Collections + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{Aborted, CatalogV2Util, Committed, Identifier, InMemoryBaseTable, InMemoryRowLevelOperationTableCatalog, InMemoryTableCatalog, SharedTablesInMemoryRowLevelOperationTableCatalog, TableInfo} +import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.types.StructType + +class StreamingTransactionSuite extends RowLevelOperationSuiteBase { + + import testImplicits._ + + override def beforeEach(): Unit = { + super.beforeEach() + spark.conf.set( + "spark.sql.catalog.cat", + classOf[SharedTablesInMemoryRowLevelOperationTableCatalog].getName) + } + + override def afterEach(): Unit = { + SharedTablesInMemoryRowLevelOperationTableCatalog.reset() + super.afterEach() + } + + private def createSimpleTable(schemaString: String): Unit = { + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) + val tableInfo = new TableInfo.Builder().withColumns(columns).build() + catalog.createTable(ident, tableInfo) + } + + private def streamCatalog(query: StreamingQuery): InMemoryRowLevelOperationTableCatalog = { + val session = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sparkSessionForStream + session.sessionState.catalogManager.catalog("cat") + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + } + + test("streaming write commits a transaction") { + createSimpleTable("value INT") + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + + val query = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tableNameAsString) + + assert(table.version() === "0") + + inputData.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + val txn = streamCatalog(query).lastTransaction + assert(txn != null, "expected a transaction to have been committed") + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // Pure streaming append: the write target is not read, source is not a TxnTable. + val targetTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(tableNameAsString) + assert(txn.catalog.txnTables.size === 1) + assert(targetTxnTable.scanEvents.isEmpty) + assert(table.version() === "1") + + // Transaction must be scoped to the streaming session; main session catalog is untouched. + assert(catalog.seenTransactions.isEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1), Row(2), Row(3))) + } + } + + test("each micro-batch is an independent transaction") { + createSimpleTable("value INT") + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + + val query = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tableNameAsString) + + assert(table.version() === "0") + + inputData.addData(1, 2, 3) + query.processAllAvailable() + + inputData.addData(4, 5, 6) + query.processAllAvailable() + + query.stop() + + val sc = streamCatalog(query) + assert(sc.seenTransactions.size === 2) + assert(sc.seenTransactions.forall(_.currentState === Committed)) + // Pure streaming append: write target is not read in any micro-batch. + assert(sc.seenTransactions.forall { t => + indexByName(t.catalog.txnTables.values.toSeq)(tableNameAsString).scanEvents.isEmpty + }) + // Each committed micro-batch increments the delegate version exactly once. + assert(table.version() === "2") + + // Transaction must be scoped to the streaming session; main session catalog is untouched. + assert(catalog.seenTransactions.isEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1), Row(2), Row(3), Row(4), Row(5), Row(6))) + } + } + + test("batch read from catalog-backed table inside streaming query is tracked as a scan event") { + // Target table for the stream. + createSimpleTable("value INT") + + // Catalog-backed static table used as a batch (non-streaming) source. + val sourceIdent = Identifier.of(namespace, "source_table") + val srcColumns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT")) + catalog.createTable(sourceIdent, new TableInfo.Builder().withColumns(srcColumns).build()) + sql(s"INSERT INTO $sourceNameAsString VALUES (1), (2), (3)") + // The INSERT above runs a transaction on the main session catalog; capture the count now + // so we can assert the streaming query does not add more. + val mainTxnsBefore = catalog.seenTransactions.size + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + + // spark.read produces a DataSourceV2Relation (batch), not a streaming source. + // UnresolveTransactionRelations converts it to V2TableReference each micro-batch so + // the transaction-aware catalog can record the scan event. + val staticData = spark.read.table(sourceNameAsString) + + val query = inputData.toDF() + .join(staticData, "value") + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tableNameAsString) + + inputData.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + val txn = streamCatalog(query).lastTransaction + assert(txn != null, "expected a transaction to have been committed") + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // Both the write target and the batch source participate in the transaction. + assert(txn.catalog.txnTables.size === 2) + + val targetTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(tableNameAsString) + assert(targetTxnTable.scanEvents.isEmpty) + + // The static source was read exactly once and its scan event was captured. + val sourceTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + + // Streaming must not add transactions to the main session catalog beyond the pre-existing + // INSERT transaction. + assert(catalog.seenTransactions.size === mainTxnsBefore) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1), Row(2), Row(3))) + } + } + + test("transaction is aborted when micro-batch write fails and no data is written") { + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT")) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(Collections.singletonMap( + InMemoryBaseTable.SIMULATE_FAILED_WRITE_OPTION, "true")) + .build() + catalog.createTable(ident, tableInfo) + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tableNameAsString) + + inputData.addData(1, 2, 3) + intercept[Exception] { query.processAllAvailable() } + query.stop() + + val txn = streamCatalog(query).lastTransaction + assert(txn != null, "expected a transaction to have been recorded") + assert(txn.currentState === Aborted) + assert(txn.isClosed) + // Aborted transaction must not advance the delegate version. + assert(table.version() === "0") + + // Transaction must be scoped to the streaming session; main session catalog is untouched. + assert(catalog.seenTransactions.isEmpty) + + // Writes must not be visible after an aborted transaction. + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Seq.empty) + } + } + + test("streaming write to non-transactional catalog does not start a transaction") { + withSQLConf("spark.sql.catalog.nonTxnCat" -> classOf[InMemoryTableCatalog].getName) { + val nonTxnCat = spark + .sessionState + .catalogManager + .catalog("nonTxnCat") + .asInstanceOf[InMemoryTableCatalog] + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT")) + nonTxnCat.createTable( + Identifier.of(Array("ns"), "tbl"), + new TableInfo.Builder().withColumns(columns).build()) + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable("nonTxnCat.ns.tbl") + + inputData.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + assert(catalog.seenTransactions.isEmpty, + "no transaction expected for non-transactional catalog") + checkAnswer(spark.table("nonTxnCat.ns.tbl"), Seq(Row(1), Row(2), Row(3))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala index 89f6556229527..dab6677310192 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.{FakeV2Provider, FakeV2ProviderWithCustomSchema, InMemoryTableSessionCatalog} -import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTableCatalog, MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability, TableInfo, V2TableWithV1Fallback} +import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTable, InMemoryTableCatalog, MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability, TableInfo, V2TableWithV1Fallback} import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, Transform} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, MemoryStreamScanBuilder, StreamingQueryWrapper} @@ -109,20 +109,23 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { } test("read: read table without streaming capability support") { - val tableIdentifier = "testcat.table_name" + withSQLConf("spark.sql.catalog.testcat" -> + classOf[DataStreamTableAPISuite.NonStreamingInMemoryTableCatalog].getName) { + val tableIdentifier = "testcat.table_name" - spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING foo") + spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING foo") - checkError( - exception = intercept[AnalysisException] { - spark.readStream.table(tableIdentifier) - }, - condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", - parameters = Map( - "tableName" -> "`testcat`.`table_name`", - "operation" -> "either micro-batch or continuous scan" + checkError( + exception = intercept[AnalysisException] { + spark.readStream.table(tableIdentifier) + }, + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + parameters = Map( + "tableName" -> "`testcat`.`table_name`", + "operation" -> "either micro-batch or continuous scan" + ) ) - ) + } } test("read: read table with custom catalog") { @@ -638,6 +641,25 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { object DataStreamTableAPISuite { val V1FallbackTestTableName = "fallbackV1Test" + + class NonStreamingInMemoryTableCatalog extends InMemoryTableCatalog { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident.asMultipartIdentifier) + } + val tableName = s"$name.${ident.quoted}" + val table = new InMemoryTable(tableName, tableInfo.columns(), tableInfo.partitions(), + tableInfo.properties, tableInfo.constraints()) { + override def baseCapabiilities: Set[TableCapability] = + super.baseCapabiilities - TableCapability.MICRO_BATCH_READ + } + tables.put(ident, table) + namespaces.putIfAbsent(ident.namespace.toList, Map()) + table + } + } } class InMemoryStreamTable(override val name: String) From 951ad430855800de4c8a2871fe83151b0069b353 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 29 Apr 2026 06:22:47 +0000 Subject: [PATCH 27/33] Remove comment --- .../apache/spark/sql/connector/catalog/InMemoryBaseTable.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 30e046bcd700e..5a1e3a30b150d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -840,8 +840,6 @@ object InMemoryBaseTable { /** * A partition for [[InMemoryBaseTable]] micro-batch streaming reads, holding a slice of rows. - * Defined at the top level (not as an inner class) so that Java serialization to executors - * does not attempt to serialize the enclosing [[InMemoryBaseTable]] instance. */ case class InMemoryMicroBatchPartition(rows: Seq[InternalRow]) extends InputPartition From dfd5f2453a3757aa5c64ea30e5c2ea17a25fe10a Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 29 Apr 2026 06:39:56 +0000 Subject: [PATCH 28/33] Fix import --- .../scala/org/apache/spark/sql/execution/QueryExecution.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index bbe14db33ca86..a0297cfcf8cc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -29,7 +29,6 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row} From 3262808df310432692b621f2558d1fa9a3dee507 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 29 Apr 2026 07:38:57 +0000 Subject: [PATCH 29/33] Fix compilation error and addressed some comments --- .../sql/catalyst/analysis/Analyzer.scala | 21 +++--- .../analysis/RelationResolution.scala | 12 +++- .../UnresolveTransactionRelations.scala | 64 ------------------- .../catalyst/analysis/V2TableReference.scala | 2 +- .../connector/DataSourceV2OptionSuite.scala | 16 ++--- .../connector/StreamingTransactionSuite.scala | 2 +- 6 files changed, 31 insertions(+), 86 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bf81f3740e4a5..ed8b3f6ee018f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -484,7 +484,7 @@ class Analyzer( Batch("Keep Legacy Outputs", Once, KeepLegacyOutputs), Batch("Unresolve Relations", Once, - new UnresolveTransactionRelations(catalogManager)) + new UnresolveRelationsInTransaction(catalogManager)) ) override def batches: Seq[Batch] = earlyBatches ++ Seq( @@ -1041,7 +1041,7 @@ class Analyzer( // DataSourceV2Relation on each view access. Only dataframe temp view may contain it // as it stores resolved plans directly. case view: View if view.isTempViewStoringAnalyzedPlan => - view.copy(child = resolveTableReferences(view.child)) + view.copy(child = resolveTableReferencesInTempView(view.child)) case p @ SubqueryAlias(_, view: View) => p.copy(child = resolveViews(view, options)) case _ => plan @@ -1050,7 +1050,7 @@ class Analyzer( // Unwrap temp views storing analyzed plans and resolve V2TableReference nodes in the child. private def unwrapRelationPlan(plan: LogicalPlan): LogicalPlan = { EliminateSubqueryAliases(plan) match { - case v: View if v.isTempViewStoringAnalyzedPlan => resolveTableReferences(v.child) + case v: View if v.isTempViewStoringAnalyzedPlan => resolveTableReferencesInTempView(v.child) case other => other } } @@ -1077,13 +1077,16 @@ class Analyzer( } } - // Resolve V2TableReference nodes created for: - // 1 Temp views (via createForTempView). - // 2. Transaction references (via createForTransaction). These are resolved by a - // separate analysis batch in the transaction-aware analyzer instance. - private def resolveTableReferences(plan: LogicalPlan): LogicalPlan = { + // Resolve V2TableReference nodes inside temp view plans. These are created by + // V2TableReference.createForTempView. We only need to resolve it when returning + // the plan of temp views (in resolveViews and unwrapRelationPlan). + private def resolveTableReferencesInTempView(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsUp { - case r: V2TableReference => relationResolution.resolveReference(r) + case r: V2TableReference => + assert(r.context.isInstanceOf[V2TableReference.TemporaryViewContext], + s"""Expected TemporaryViewContext in temp view but got + |${r.context.getClass.getSimpleName}""".stripMargin) + relationResolution.resolveReference(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index ac685f984ce4f..efff6b1fd1f43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -501,13 +501,19 @@ class RelationResolution( } } + /** + * Loads the table for a [[V2TableReference]] and returns a resolved [[DataSourceV2Relation]]. + * + * The catalog is re-resolved by name through the [[CatalogManager]] rather than reusing + * [[V2TableReference#catalog]] directly. When a transaction is active, the + * [[TransactionAwareCatalogManager]] redirects catalog lookups to the transaction's catalog + * instance, so the [[TableCatalog#loadTable]] call is intercepted by the transaction catalog, + * which uses it to track which tables are read as part of the transaction. + */ private def loadRelation(ref: V2TableReference): LogicalPlan = { - // Resolve catalog. When a transaction is active we return the transaction - // aware catalog instance. val resolvedCatalog = catalogManager.catalog(ref.catalog.name).asTableCatalog val table = resolvedCatalog.loadTable(ref.identifier) V2TableReferenceUtils.validateLoadedTable(table, ref) - // Create relation with resolved Catalog. DataSourceV2Relation( table = table, output = ref.output, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala deleted file mode 100644 index 0e344173d7892..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveTransactionRelations.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TransactionalWrite} -import org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.allowInvokingTransformsInAnalyzer -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation - -/** - * When a transaction is active, converts resolved [[DataSourceV2Relation]] nodes back to - * [[V2TableReference]] placeholders for all relations loaded by a catalog with the same - * name as the transaction catalog. - * - * This forces re-resolution of those relations against the transaction's catalog, which - * intercepts [[TableCatalog#loadTable]] calls to track which tables are read as part of - * the transaction. - */ -class UnresolveTransactionRelations(val catalogManager: CatalogManager) - extends Rule[LogicalPlan] with LookupCatalog { - - override def apply(plan: LogicalPlan): LogicalPlan = - catalogManager.transaction match { - case Some(transaction) => - allowInvokingTransformsInAnalyzer { - plan.transform { - case tw: TransactionalWrite => - unresolveRelations(tw, transaction.catalog) - } - } - case _ => plan - } - - private def unresolveRelations( - plan: LogicalPlan, - catalog: CatalogPlugin): LogicalPlan = { - plan transform { - case r: DataSourceV2Relation if isLoadedFromCatalog(r, catalog) => - V2TableReference.createForTransaction(r) - } - } - - private def isLoadedFromCatalog( - relation: DataSourceV2Relation, - catalog: CatalogPlugin): Boolean = { - relation.catalog.exists(_.name == catalog.name) && relation.identifier.isDefined - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index f459706a690bb..f2020d39e8d05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -93,7 +93,7 @@ private[sql] object V2TableReference { } // V2TableReference nodes in the transaction context are produced by - // UnresolveTransactionRelations which unresolves already resolved relations. + // UnresolveRelationsInTransaction which unresolves already resolved relations. def createForTransaction(relation: DataSourceV2Relation): V2TableReference = { create(relation, TransactionContext) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala index fbcfdfb20c6ec..803dd35513f45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala @@ -109,7 +109,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, AppendDataExec(_, _, write, _), + _, AppendDataExec(_, _, write, _, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") @@ -141,7 +141,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case AppendDataExec(_, _, write, _) => + case AppendDataExec(_, _, write, _, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") } @@ -168,7 +168,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case AppendDataExec(_, _, write, _) => + case AppendDataExec(_, _, write, _, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") } @@ -194,7 +194,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, OverwriteByExpressionExec(_, _, write, _), + _, OverwriteByExpressionExec(_, _, write, _, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") @@ -227,7 +227,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwritePartitionsDynamicExec(_, _, write, _) => + case OverwritePartitionsDynamicExec(_, _, write, _, _) => val dynOverwrite = write.toBatch.asInstanceOf[InMemoryBaseTable#DynamicOverwrite] assert(dynOverwrite.info.options.get("write.split-size") === "10") } @@ -254,7 +254,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, OverwriteByExpressionExec(_, _, write, _), + _, OverwriteByExpressionExec(_, _, write, _, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") @@ -287,7 +287,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwriteByExpressionExec(_, _, write, _) => + case OverwriteByExpressionExec(_, _, write, _, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") } @@ -317,7 +317,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwriteByExpressionExec(_, _, write, _) => + case OverwriteByExpressionExec(_, _, write, _, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala index 10551f29eb1b7..af9d50f167145 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala @@ -147,7 +147,7 @@ class StreamingTransactionSuite extends RowLevelOperationSuiteBase { val inputData = MemoryStream[Int] // spark.read produces a DataSourceV2Relation (batch), not a streaming source. - // UnresolveTransactionRelations converts it to V2TableReference each micro-batch so + // UnresolveRelationsInTransaction converts it to V2TableReference each micro-batch so // the transaction-aware catalog can record the scan event. val staticData = spark.read.table(sourceNameAsString) From 0cffb4bdb279e73ead4ee00bb44a2237f5d64bd4 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 29 Apr 2026 07:39:32 +0000 Subject: [PATCH 30/33] Rename unresolve relations rule --- .../UnresolveRelationsInTransaction.scala | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala new file mode 100644 index 0000000000000..23e13a6078f64 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TransactionalWrite} +import org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.allowInvokingTransformsInAnalyzer +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +/** + * When a transaction is active, converts resolved [[DataSourceV2Relation]] nodes back to + * [[V2TableReference]] placeholders for all relations loaded by a catalog with the same + * name as the transaction catalog. + * + * This forces re-resolution of those relations against the transaction's catalog, which + * intercepts [[TableCatalog#loadTable]] calls to track which tables are read as part of + * the transaction. + */ +class UnresolveRelationsInTransaction(val catalogManager: CatalogManager) + extends Rule[LogicalPlan] with LookupCatalog { + + override def apply(plan: LogicalPlan): LogicalPlan = + catalogManager.transaction match { + case Some(transaction) => + allowInvokingTransformsInAnalyzer { + plan.transform { + case tw: TransactionalWrite => + unresolveRelations(tw, transaction.catalog) + } + } + case _ => plan + } + + private def unresolveRelations( + plan: LogicalPlan, + catalog: CatalogPlugin): LogicalPlan = { + plan transform { + case r: DataSourceV2Relation if isLoadedFromCatalog(r, catalog) => + V2TableReference.createForTransaction(r) + } + } + + private def isLoadedFromCatalog( + relation: DataSourceV2Relation, + catalog: CatalogPlugin): Boolean = { + relation.catalog.exists(_.name == catalog.name) && relation.identifier.isDefined + } +} From cb6c361ddbfc9fa9e80a6cdb22513256cd37ae82 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 29 Apr 2026 12:53:18 +0000 Subject: [PATCH 31/33] Address rest comments --- .../sql/catalyst/analysis/Analyzer.scala | 3 + .../analysis/RelationResolution.scala | 5 - .../UnresolveRelationsInTransaction.scala | 6 +- .../catalyst/analysis/V2TableReference.scala | 6 + .../catalyst/plans/logical/v2Commands.scala | 3 +- .../transactions/TransactionUtils.scala | 28 ++--- .../TransactionAwareCatalogManager.scala | 4 + .../sql/connector/catalog/V2TableUtil.scala | 15 +++ .../AnalyzerExtensionPropagationSuite.scala | 105 ++++++++++++++++++ .../InMemoryRowLevelOperationTable.scala | 6 +- ...nMemoryRowLevelOperationTableCatalog.scala | 3 +- .../spark/sql/connector/catalog/txns.scala | 4 + .../spark/sql/execution/QueryExecution.scala | 38 ++++--- .../datasources/v2/V2TableRefreshUtil.scala | 7 +- .../connector/MergeIntoDataFrameSuite.scala | 31 ++++++ 15 files changed, 218 insertions(+), 46 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ed8b3f6ee018f..50dc055a98a0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -356,6 +356,9 @@ class Analyzer( * lookups. All other configuration (extended rules, checks, etc.) is preserved. Used by * [[QueryExecution]] to create a per-query analyzer for transactional operations for * transaction-aware catalog resolution. + * + * IMPORTANT: any new extension point added to Analyzer must also be copied here, otherwise + * transaction-aware analyzer clones (created by QueryExecution) will silently miss those rules. */ def withCatalogManager(newCatalogManager: CatalogManager): Analyzer = { val self = this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index efff6b1fd1f43..4546c1dad0011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -485,11 +485,6 @@ class RelationResolution( } private def getOrLoadRelation(ref: V2TableReference): LogicalPlan = { - // Skip cache when a transaction is active. - if (catalogManager.transaction.isDefined) { - return loadRelation(ref) - } - val key = toCacheKey(ref.catalog, ref.identifier) relationCache.get(key) match { case Some(cached) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala index 23e13a6078f64..8ee64e32376fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala @@ -38,6 +38,10 @@ class UnresolveRelationsInTransaction(val catalogManager: CatalogManager) override def apply(plan: LogicalPlan): LogicalPlan = catalogManager.transaction match { case Some(transaction) => + // We use plain transform rather than resolveOperators* because the latter skips subtrees + // that have already been analyzed. Furthermore, allowInvokingTransformsInAnalyzer + // allows to suppress the assertNotAnalysisRule safety check, which forbids calling + // transform directly inside the analyzer when not within a resolveOperators call. allowInvokingTransformsInAnalyzer { plan.transform { case tw: TransactionalWrite => @@ -50,7 +54,7 @@ class UnresolveRelationsInTransaction(val catalogManager: CatalogManager) private def unresolveRelations( plan: LogicalPlan, catalog: CatalogPlugin): LogicalPlan = { - plan transform { + plan.transform { case r: DataSourceV2Relation if isLoadedFromCatalog(r, catalog) => V2TableReference.createForTransaction(r) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index f2020d39e8d05..5545141640a39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -80,10 +80,12 @@ private[sql] case class V2TableReference private( private[sql] object V2TableReference { case class TableInfo( + tableId: Option[String], columns: Seq[Column], metadataColumns: Seq[MetadataColumn]) sealed trait Context + /** Context for relations that are re-resolved on access of a dataframe temp view. */ case class TemporaryViewContext(viewName: Seq[String]) extends Context /** Context for relations that are re-resolved through a transaction catalog. */ case object TransactionContext extends Context @@ -104,6 +106,7 @@ private[sql] object V2TableReference { relation.identifier.get, relation.options, TableInfo( + tableId = Option(relation.table.id()), columns = relation.table.columns.toImmutableArraySeq, metadataColumns = V2TableUtil.extractMetadataColumns(relation)), relation.output, @@ -127,6 +130,9 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { } private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = { + // Make sure the table was not dropped and recreated. + ref.info.tableId.foreach(V2TableUtil.validateTableId(ref.name, _, table)) + // Do not allow schema evolution to pre-analysed dataframes that are later used in // transactional writes. This is because the entire plans was built based on the original schema // and any schema change would make the plan structurally invalid. This is inline with the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index d47abe284e6f4..15573b157d5c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -1020,8 +1020,7 @@ case class MergeIntoTable( with SupportsSubquery with TransactionalWrite { - // Implements SupportsSchemaEvolution.table. - // Implements TransactionalWrite.table. + // Implements WriteWithSchemaEvolution.table and TransactionalWrite.table. override val table: LogicalPlan = EliminateSubqueryAliases(targetTable) override def withNewTable(newTable: NamedRelation): MergeIntoTable = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala index d160aafdea34e..a5f8afddf01c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala @@ -19,37 +19,37 @@ package org.apache.spark.sql.catalyst.transactions import java.util.UUID +import org.apache.spark.SparkException import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfoImpl} import org.apache.spark.util.Utils object TransactionUtils { - def commit(transaction: Transaction): Unit = { + def commit(txn: Transaction): Unit = { Utils.tryWithSafeFinally { - transaction.commit() + txn.commit() } { - transaction.close() + txn.close() } } - def abort(transaction: Transaction): Unit = { + def abort(txn: Transaction): Unit = { Utils.tryWithSafeFinally { - transaction.abort() + txn.abort() } { - transaction.close() + txn.close() } } def beginTransaction(catalog: TransactionalCatalogPlugin): Transaction = { val info = TransactionInfoImpl(id = UUID.randomUUID.toString) - val transaction = catalog.beginTransaction(info) - if (transaction.catalog.name != catalog.name) { - abort(transaction) - throw new IllegalStateException( - s"""Transaction catalog name (${transaction.catalog.name}) - |must match original catalog name (${catalog.name}). - |""".stripMargin) + val txn = catalog.beginTransaction(info) + if (txn.catalog.name != catalog.name) { + abort(txn) + throw SparkException.internalError( + s"""Transaction catalog name (${txn.catalog.name}) + |must match original catalog name (${catalog.name}).""".stripMargin) } - transaction + txn } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala index aaeef4c2dea76..70079357b6dde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector.catalog +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.catalog.TempVariableManager import org.apache.spark.sql.connector.catalog.transactions.Transaction @@ -37,6 +38,9 @@ private[sql] class TransactionAwareCatalogManager( override def transaction: Option[Transaction] = Some(txn) + override def withTransaction(newTxn: Transaction): CatalogManager = + throw SparkException.internalError("Cannot nest transactions: a transaction is already active.") + override def catalog(name: String): CatalogPlugin = { val resolved = delegate.catalog(name) if (txn.catalog.name() == resolved.name()) txn.catalog else resolved diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala index c7f7b17a58430..af7edce47427a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, MetadataColumnHelper} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.sql.util.SchemaValidationMode @@ -131,6 +132,20 @@ private[sql] object V2TableUtil extends SQLConfHelper { case _ => Seq.empty } + /** + * Validates that the identity of a loaded table matches a previously captured table id. + * Throws if the table was dropped and recreated under the same name (which changes the id). + * No-op if the connector does not support table ids (capturedId is null). + */ + def validateTableId(name: String, capturedId: String, currentTable: Table): Unit = { + if (capturedId != null && capturedId != currentTable.id) { + throw QueryCompilationErrors.tableIdChangedAfterAnalysis( + name, + capturedTableId = capturedId, + currentTableId = currentTable.id) + } + } + private def normalize(name: String): String = { if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala new file mode 100644 index 0000000000000..02cfe6b4eb7e8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.resolver.{ResolverExtension, TreeNodeResolver} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogManager + +/** + * Verifies that [[Analyzer.withCatalogManager]] propagates all extension points. + * + * If this suite fails with an unexpected method count, a new extension point was added to + * [[Analyzer.withCatalogManager]] without being verified here. Add the corresponding assertion + * and update the expected count. + * + * If [[Analyzer]] gains a new extension point that is NOT yet in [[Analyzer.withCatalogManager]], + * add it there first, then update this suite. + */ +class AnalyzerExtensionPropagationSuite extends SparkFunSuite { + + private val dummyRule: Rule[LogicalPlan] = new Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan + } + + private val dummyCheck: LogicalPlan => Unit = (_: LogicalPlan) => () + + private val dummyExtension: ResolverExtension = new ResolverExtension { + override def resolveOperator( + operator: LogicalPlan, + resolver: TreeNodeResolver[LogicalPlan, LogicalPlan]): Option[LogicalPlan] = None + } + + private def newCatalogManager(): CatalogManager = + new CatalogManager( + FakeV2SessionCatalog, + new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry)) + + test("withCatalogManager propagates all extension points") { + val analyzer = new Analyzer(newCatalogManager()) { + override val hintResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule) + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule) + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule) + override val extendedCheckRules: Seq[LogicalPlan => Unit] = Seq(dummyCheck) + override val singlePassResolverExtensions: Seq[ResolverExtension] = Seq(dummyExtension) + override val singlePassMetadataResolverExtensions: Seq[ResolverExtension] = + Seq(dummyExtension) + override val singlePassPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule) + override val singlePassExtendedResolutionChecks: Seq[LogicalPlan => Unit] = Seq(dummyCheck) + } + + val clone = analyzer.withCatalogManager(newCatalogManager()) + + assert(clone.hintResolutionRules eq analyzer.hintResolutionRules) + assert(clone.extendedResolutionRules eq analyzer.extendedResolutionRules) + assert(clone.postHocResolutionRules eq analyzer.postHocResolutionRules) + assert(clone.extendedCheckRules eq analyzer.extendedCheckRules) + assert(clone.singlePassResolverExtensions eq analyzer.singlePassResolverExtensions) + assert(clone.singlePassMetadataResolverExtensions eq + analyzer.singlePassMetadataResolverExtensions) + assert(clone.singlePassPostHocResolutionRules eq analyzer.singlePassPostHocResolutionRules) + assert(clone.singlePassExtendedResolutionChecks eq analyzer.singlePassExtendedResolutionChecks) + + // Verify the clone's anonymous class overrides exactly the expected extension points. + // If this assertion fails, withCatalogManager was updated but this test was not. + // Add the corresponding assert above and update the expected set. + val overriddenMethods = clone.getClass.getDeclaredMethods + .filterNot(m => m.isSynthetic || m.isBridge || m.getName.contains("$")) + .map(_.getName) + .toSet + + val expectedExtensions = Set( + "hintResolutionRules", + "extendedResolutionRules", + "postHocResolutionRules", + "extendedCheckRules", + "singlePassResolverExtensions", + "singlePassMetadataResolverExtensions", + "singlePassPostHocResolutionRules", + "singlePassExtendedResolutionChecks" + ) + + assert(overriddenMethods == expectedExtensions, + s"withCatalogManager does not copy the expected set of extension points. " + + s"Missing from withCatalogManager: ${expectedExtensions -- overriddenMethods}. " + + s"Unexpected overrides: ${overriddenMethods -- expectedExtensions}.") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index 91e899bc1169e..406d83aa86abb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -38,13 +38,15 @@ class InMemoryRowLevelOperationTable( schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String], - constraints: Array[Constraint] = Array.empty) + constraints: Array[Constraint] = Array.empty, + tableId: String = java.util.UUID.randomUUID().toString) extends InMemoryTable( name, CatalogV2Util.structTypeToV2Columns(schema), partitioning, properties, - constraints) + constraints, + id = tableId) with SupportsRowLevelOperations { private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index 4e5e1e7c8c6e9..ff3b2673dbf12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -75,7 +75,8 @@ class InMemoryRowLevelOperationTableCatalog schema = schema, partitioning = partitioning, properties = properties, - constraints = constraints) + constraints = constraints, + tableId = table.id) newTable.alterTableWithData(table.data, schema) tables.put(ident, newTable) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 99b1e4c8f9870..ea2fd23fad911 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -78,6 +78,10 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) delegate.properties, delegate.constraints) { + // Expose the same id as the delegate so that identity checks during transaction re-resolution + // don't false-positive on the TxnTable wrapper having a different UUID. + override val id: String = delegate.id + alterTableWithData(delegate.data, schema) // Keep initial version to detect any changes during the transaction. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index a0297cfcf8cc4..5bfb73e6b8875 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import javax.annotation.concurrent.GuardedBy import scala.jdk.CollectionConverters._ -import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -76,6 +75,10 @@ class QueryExecution( val shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None, val refreshPhaseEnabled: Boolean = true, val queryId: UUID = UUIDv7Generator.generate(), + // When a transaction is active, callers creating nested QueryExecution instances MUST pass + // the enclosing QueryExecution's analyzer here to propagate the transaction context. + // Omitting it causes the nested QE to use sessionState.analyzer, which has no knowledge + // of the transaction and will load tables outside the transaction's catalog scope. val analyzerOpt: Option[Analyzer] = None) extends LookupCatalog { val id: Long = QueryExecution.nextExecutionId @@ -142,25 +145,30 @@ class QueryExecution( private def pathBased(write: TransactionalWritePlan): Option[TableCatalog] = EliminateSubqueryAliases(write.table) match { case UnresolvedRelation(parts, _, _) if parts.length > 1 => - Try(DataSource.lookupDataSourceV2(parts.head, sparkSession.sessionState.conf)) - .toOption - .flatten - .collect { case sco: SupportsCatalogOptions => sco } - .map { sco => - val sessionConfigs = DataSourceV2Utils.extractSessionConfigs( - sco, sparkSession.sessionState.conf) - // Pass the entire identifier as option. The connector can decide how to parse it - // if needed. - val options = sessionConfigs + ("identifier" -> parts.mkString(".")) - CatalogV2Util.getTableProviderCatalog( - sco, catalogManager, new CaseInsensitiveStringMap(options.asJava)) - } + try { + DataSource.lookupDataSourceV2(parts.head, sparkSession.sessionState.conf) + .collect { case sco: SupportsCatalogOptions => sco } + .map { sco => + val sessionConfigs = DataSourceV2Utils.extractSessionConfigs( + sco, sparkSession.sessionState.conf) + // Pass the entire identifier as option. The connector can decide how to parse it + // if needed. + val options = sessionConfigs + ("identifier" -> parts.mkString(".")) + CatalogV2Util.getTableProviderCatalog( + sco, catalogManager, new CaseInsensitiveStringMap(options.asJava)) + } + } catch { + // The head of the multipart identifier is not a registered data source. + // Fallback to catalog-based detection. + case _: ClassNotFoundException => None + } case _ => None } // Per-query analyzer: uses a transaction-aware CatalogManager when a transaction is active, // so that all catalog lookups and rule applications during analysis see the correct state - // without relying on thread-local context. + // without relying on thread-local context. Any nested QueryExecution that is created during + // analysis or execution of a transactional plan must receive this analyzer via analyzerOpt. private lazy val analyzer: Analyzer = analyzerOpt.getOrElse { transactionOpt match { case Some(txn) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala index 151329de9e6f2..60965453d9ee5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala @@ -119,12 +119,7 @@ private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging { } private def validateTableIdentity(currentTable: Table, relation: DataSourceV2Relation): Unit = { - if (relation.table.id != null && relation.table.id != currentTable.id) { - throw QueryCompilationErrors.tableIdChangedAfterAnalysis( - relation.name, - capturedTableId = relation.table.id, - currentTableId = currentTable.id) - } + V2TableUtil.validateTableId(relation.name, relation.table.id, currentTable) } private def validateDataColumns( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala index d58e22e63d71e..c5d82ec3c1d65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala @@ -128,6 +128,37 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { } } + test("self merge fails when source table is dropped and recreated after analysis") { + withTable(tableNameAsString) { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = spark.table(tableNameAsString).where("salary == 100").as("source") + sourceDF.queryExecution.assertAnalyzed() + + val originalId = catalog.loadTable(ident).id + + sql(s"DROP TABLE $tableNameAsString") + sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + val newId = catalog.loadTable(ident).id + assert(originalId != newId) + + val e = intercept[AnalysisException] { + sourceDF + .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk")) + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .merge() + } + + assert(e.getCondition == "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.TABLE_ID_MISMATCH") + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + } + test("merge into empty table with NOT MATCHED clause") { withTempView("source") { createTable("pk INT NOT NULL, salary INT, dep STRING") From 009a604449547b4dfe0d7876e4d3b199bb300611 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 29 Apr 2026 12:59:17 +0000 Subject: [PATCH 32/33] Fix utils suite --- .../sql/catalyst/transactions/TransactionUtilsSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala index d409316e667b1..ee771fe3f2460 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.transactions -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.connector.catalog.{CatalogPlugin, TransactionalCatalogPlugin} import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -98,7 +98,7 @@ class TransactionUtilsSuite extends SparkFunSuite { test("beginTransaction: fails when transaction catalog name does not match") { val catalog = mockTransactionalCatalog(catalogName = testCatalogName, txnCatalogName = "other") - val e = intercept[IllegalStateException] { + val e = intercept[SparkException] { TransactionUtils.beginTransaction(catalog) } assert(e.getMessage.contains("other")) @@ -117,7 +117,7 @@ class TransactionUtilsSuite extends SparkFunSuite { onAbort = () => { aborted = true }, onClose = () => { closed = true }) } - intercept[IllegalStateException] { TransactionUtils.beginTransaction(catalog) } + intercept[SparkException] { TransactionUtils.beginTransaction(catalog) } assert(aborted) assert(closed) } From 367c66a5430317da9f4ce0ca70c66a4d335a5710 Mon Sep 17 00:00:00 2001 From: Andreas Chatzistergiou Date: Wed, 29 Apr 2026 15:27:40 +0000 Subject: [PATCH 33/33] Improvements --- .../spark/sql/connector/write/BatchWrite.java | 6 ++ .../write/streaming/StreamingWrite.java | 6 ++ ...nMemoryRowLevelOperationTableCatalog.scala | 2 +- .../spark/sql/connector/catalog/txns.scala | 59 ++++++++++++------- .../connector/StreamingTransactionSuite.scala | 18 +++--- .../sql/scripting/SqlScriptingE2eSuite.scala | 30 +++++----- 6 files changed, 74 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java index 44fc5f9d794bf..75816349af38d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java @@ -85,6 +85,12 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple * tasks may have committed successfully and one successful commit message per task will be * passed to this commit method. The remaining commit messages are ignored by Spark. + *

+ * Note: this method signals that all data for this write operation has been successfully written. + * It is NOT a transactional commit. When this write is part of a + * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction}, the transaction is + * committed separately via + * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction#commit()}. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java index ab98bc01b3aed..764ed0a35a3f2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java @@ -80,6 +80,12 @@ default boolean useCommitCoordinator() { * The execution engine may call {@code commit} multiple times for the same epoch in some * circumstances. To support exactly-once data semantics, implementations must ensure that * multiple commits for the same epoch are idempotent. + *

+ * Note: this method signals that all data for this write operation has been successfully written. + * It is NOT a transactional commit. When this write is part of a + * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction}, the transaction is + * committed separately via + * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction#commit()}. */ void commit(long epochId, WriterCommitMessage[] messages); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index ff3b2673dbf12..95d5975d269ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -34,7 +34,7 @@ class InMemoryRowLevelOperationTableCatalog var lastTransaction: Txn = _ // All transactions in order (committed and aborted), allowing per-statement // validation in SQL scripting tests. - val seenTransactions: ArrayBuffer[Txn] = new ArrayBuffer[Txn]() + val observedTransactions: ArrayBuffer[Txn] = new ArrayBuffer[Txn]() override def beginTransaction(info: TransactionInfo): Transaction = { assert(transaction == null || transaction.currentState != Active) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index ea2fd23fad911..bd7d6d689f515 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import org.apache.spark.sql.connector.catalog.transactions.Transaction +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, RowLevelOperationBuilder, RowLevelOperationInfo, WriteBuilder} import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -70,7 +71,10 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { // Note, the in-memory data store does not handle concurrency at the moment. The assumes that the // underlying delegate table cannot change from concurrent transactions. Data sources need to // implement isolation semantics and make sure they are enforced. -class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) +class TxnTable( + val delegate: InMemoryRowLevelOperationTable, + schema: StructType, + catalog: TxnTableCatalog) extends InMemoryRowLevelOperationTable( delegate.name, schema, @@ -84,9 +88,6 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) alterTableWithData(delegate.data, schema) - // Keep initial version to detect any changes during the transaction. - private val initialVersion: String = version() - // A tracker of filters used in each scan. val scanEvents = new ArrayBuffer[Array[Filter]]() @@ -95,19 +96,32 @@ class TxnTable(val delegate: InMemoryRowLevelOperationTable, schema: StructType) scanEvents += filters } - // Perform commit if there are any changes. This push metadata and data changes to the - // delegate table. + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + catalog.writeTarget = this + super.newWriteBuilder(info) + } + + override def newRowLevelOperationBuilder( + info: RowLevelOperationInfo): RowLevelOperationBuilder = { + catalog.writeTarget = this + super.newRowLevelOperationBuilder(info) + } + + override def deleteWhere(filters: Array[Filter]): Unit = { + catalog.writeTarget = this + super.deleteWhere(filters) + } + + // Propagates staged data and metadata changes to the delegate table. def commit(): Unit = { - if (version() != initialVersion) { - delegate.dataMap.clear() - delegate.updateColumns(columns()) // Evolve schema if needed. - delegate.alterTableWithData(data, schema) - delegate.replacedPartitions = replacedPartitions - delegate.lastWriteInfo = lastWriteInfo - delegate.lastWriteLog = lastWriteLog - delegate.commits ++= commits - delegate.increaseVersion() - } + delegate.dataMap.clear() + delegate.updateColumns(columns()) // Evolve schema if needed. + delegate.alterTableWithData(data, schema) + delegate.replacedPartitions = replacedPartitions + delegate.lastWriteInfo = lastWriteInfo + delegate.lastWriteLog = lastWriteLog + delegate.commits ++= commits + delegate.increaseVersion() } } @@ -120,6 +134,8 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T private val tables: util.Map[Identifier, TxnTable] = new ConcurrentHashMap[Identifier, TxnTable]() + var writeTarget: TxnTable = _ + override def name: String = delegate.name override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} @@ -134,7 +150,7 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T override def loadTable(ident: Identifier): Table = { tables.computeIfAbsent(ident, _ => { val table = delegate.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] - new TxnTable(table, table.schema()) + new TxnTable(table, table.schema(), this) }) } @@ -154,7 +170,7 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T } // TODO: We need to pass all tracked predicates to the new TXN table. - val newTxnTable = new TxnTable(txnTable.delegate, schema) + val newTxnTable = new TxnTable(txnTable.delegate, schema, this) tables.put(ident, newTxnTable) newTxnTable } @@ -178,17 +194,16 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T // Returns all tables that participated in this transaction, keyed by identifier. def txnTables: scala.collection.Map[Identifier, TxnTable] = tables.asScala - // Invoke commit for all tables participated in the transaction. If a table is read-only - // this is a no-op. + // Commit the write target table, propagating staged changes to the delegate. def commit(): Unit = { - tables.values.forEach(table => table.commit()) + if (writeTarget != null) writeTarget.commit() } // Clear transaction context. def clearActiveTransaction(): Unit = { val txn = delegate.transaction delegate.lastTransaction = txn - delegate.seenTransactions += txn + delegate.observedTransactions += txn delegate.transaction = null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala index af9d50f167145..13b6267a28ff8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala @@ -82,7 +82,7 @@ class StreamingTransactionSuite extends RowLevelOperationSuiteBase { assert(table.version() === "1") // Transaction must be scoped to the streaming session; main session catalog is untouched. - assert(catalog.seenTransactions.isEmpty) + assert(catalog.observedTransactions.isEmpty) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -112,17 +112,17 @@ class StreamingTransactionSuite extends RowLevelOperationSuiteBase { query.stop() val sc = streamCatalog(query) - assert(sc.seenTransactions.size === 2) - assert(sc.seenTransactions.forall(_.currentState === Committed)) + assert(sc.observedTransactions.size === 2) + assert(sc.observedTransactions.forall(_.currentState === Committed)) // Pure streaming append: write target is not read in any micro-batch. - assert(sc.seenTransactions.forall { t => + assert(sc.observedTransactions.forall { t => indexByName(t.catalog.txnTables.values.toSeq)(tableNameAsString).scanEvents.isEmpty }) // Each committed micro-batch increments the delegate version exactly once. assert(table.version() === "2") // Transaction must be scoped to the streaming session; main session catalog is untouched. - assert(catalog.seenTransactions.isEmpty) + assert(catalog.observedTransactions.isEmpty) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -141,7 +141,7 @@ class StreamingTransactionSuite extends RowLevelOperationSuiteBase { sql(s"INSERT INTO $sourceNameAsString VALUES (1), (2), (3)") // The INSERT above runs a transaction on the main session catalog; capture the count now // so we can assert the streaming query does not add more. - val mainTxnsBefore = catalog.seenTransactions.size + val mainTxnsBefore = catalog.observedTransactions.size withTempDir { checkpointDir => val inputData = MemoryStream[Int] @@ -178,7 +178,7 @@ class StreamingTransactionSuite extends RowLevelOperationSuiteBase { // Streaming must not add transactions to the main session catalog beyond the pre-existing // INSERT transaction. - assert(catalog.seenTransactions.size === mainTxnsBefore) + assert(catalog.observedTransactions.size === mainTxnsBefore) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -214,7 +214,7 @@ class StreamingTransactionSuite extends RowLevelOperationSuiteBase { assert(table.version() === "0") // Transaction must be scoped to the streaming session; main session catalog is untouched. - assert(catalog.seenTransactions.isEmpty) + assert(catalog.observedTransactions.isEmpty) // Writes must not be visible after an aborted transaction. checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Seq.empty) @@ -244,7 +244,7 @@ class StreamingTransactionSuite extends RowLevelOperationSuiteBase { query.processAllAvailable() query.stop() - assert(catalog.seenTransactions.isEmpty, + assert(catalog.observedTransactions.isEmpty, "no transaction expected for non-transactional catalog") checkAnswer(spark.table("nonTxnCat.ns.tbl"), Seq(Row(1), Row(2), Row(3))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala index 2a3bd915eccd1..da697847874dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala @@ -215,12 +215,12 @@ class SqlScriptingE2eSuite extends SharedSparkSession { verifySqlScriptResult(sqlScript, Seq(Row(2, 200, "software"))) // Each DML statement in a script runs in its own independent QE and transaction. - assert(catalog.seenTransactions.size === 2) - assert(catalog.seenTransactions.forall(t => + assert(catalog.observedTransactions.size === 2) + assert(catalog.observedTransactions.forall(t => t.currentState === Committed && t.isClosed)) // The DELETE subquery scans the table with a dep='hr' predicate; verify it was tracked. - val deleteTxnTable = loadTxnTable(catalog.seenTransactions(1), "t") + val deleteTxnTable = loadTxnTable(catalog.observedTransactions(1), "t") assert(deleteTxnTable.scanEvents.flatten.exists { case sources.EqualTo("dep", "hr") => true case _ => false @@ -254,11 +254,11 @@ class SqlScriptingE2eSuite extends SharedSparkSession { queryContext = Array(ExpectedContext("nonexistent_column"))) // INSERT committed; DELETE was aborted because analysis failed on the bad column. - assert(catalog.seenTransactions.size === 2) - assert(catalog.seenTransactions(0).currentState === Committed) - assert(catalog.seenTransactions(0).isClosed) - assert(catalog.seenTransactions(1).currentState === Aborted) - assert(catalog.seenTransactions(1).isClosed) + assert(catalog.observedTransactions.size === 2) + assert(catalog.observedTransactions(0).currentState === Committed) + assert(catalog.observedTransactions(0).isClosed) + assert(catalog.observedTransactions(1).currentState === Aborted) + assert(catalog.observedTransactions(1).isClosed) assert(catalog.lastTransaction.currentState === Aborted) } } @@ -296,11 +296,11 @@ class SqlScriptingE2eSuite extends SharedSparkSession { Row(4, 400, "finance"))) // INSERT (x2), MERGE, and UPDATE each run in their own independent QE and transaction. - assert(catalog.seenTransactions.size === 4) - assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + assert(catalog.observedTransactions.size === 4) + assert(catalog.observedTransactions.forall(t => t.currentState === Committed && t.isClosed)) def txnTable(txnIdx: Int): TxnTable = - loadTxnTable(catalog.seenTransactions(txnIdx), "t") + loadTxnTable(catalog.observedTransactions(txnIdx), "t") // Both inserts are pure writes - no scan. assert(txnTable(0).scanEvents.isEmpty) @@ -342,8 +342,8 @@ class SqlScriptingE2eSuite extends SharedSparkSession { Seq(Row(1, 100, "hr"), Row(2, 200, "hr"), Row(3, 300, "hr"))) // Each loop iteration's INSERT runs in its own independent transaction. - assert(catalog.seenTransactions.size === 3) - assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + assert(catalog.observedTransactions.size === 3) + assert(catalog.observedTransactions.forall(t => t.currentState === Committed && t.isClosed)) } } } @@ -373,8 +373,8 @@ class SqlScriptingE2eSuite extends SharedSparkSession { Seq(Row(-1, -1, "error"), Row(1, 100, "hr"), Row(2, 200, "software"))) // INSERT(1), handler INSERT(-1), INSERT(2) - each in its own transaction. - assert(catalog.seenTransactions.size === 3) - assert(catalog.seenTransactions.forall(t => t.currentState === Committed && t.isClosed)) + assert(catalog.observedTransactions.size === 3) + assert(catalog.observedTransactions.forall(t => t.currentState === Committed && t.isClosed)) } } }