diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java
new file mode 100644
index 000000000000..daa3176dcbba
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.catalog.transactions.Transaction;
+import org.apache.spark.sql.connector.catalog.transactions.TransactionInfo;
+
+/**
+ * A {@link CatalogPlugin} that supports transactions.
+ *
+ * Catalogs that implement this interface opt in to transactional query execution. A catalog
+ * implementing this interface is responsible for starting transactions.
+ *
+ * @since 4.2.0
+ */
+@Evolving
+public interface TransactionalCatalogPlugin extends CatalogPlugin {
+
+ /**
+ * Begins a new transaction and returns a {@link Transaction} representing it.
+ *
+ * @param info metadata about the transaction being started.
+ */
+ Transaction beginTransaction(TransactionInfo info);
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java
new file mode 100644
index 000000000000..77044c6202fb
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.transactions;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.catalog.CatalogPlugin;
+import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin;
+
+import java.io.Closeable;
+
+/**
+ * Represents a transaction.
+ *
+ * Spark begins a transaction with {@link TransactionalCatalogPlugin#beginTransaction} and
+ * executes read/write operations against the transaction's catalog. On success, Spark
+ * calls {@link #commit()}; on failure, Spark calls {@link #abort()}. In both cases Spark
+ * subsequently calls {@link #close()} to release resources.
+ *
+ * @since 4.2.0
+ */
+@Evolving
+public interface Transaction extends Closeable {
+
+ /**
+ * Returns the catalog associated with this transaction. This catalog is responsible for tracking
+ * read/write operations that occur within the boundaries of a transaction. This allows
+ * connectors to perform conflict resolution at commit time.
+ */
+ CatalogPlugin catalog();
+
+ /**
+ * Commits the transaction. All writes performed under it become visible to other readers.
+ *
+ * The connector is responsible for detecting and resolving conflicting commits or throwing
+ * an exception if resolution is not possible.
+ *
+ * This method will be called exactly once per transaction. Spark calls {@link #close()}
+ * immediately after this method returns.
+ *
+ * @throws IllegalStateException if the transaction has already been committed or aborted.
+ */
+ void commit();
+
+ /**
+ * Aborts the transaction, discarding any staged changes.
+ *
+ * This method must be idempotent. If the transaction has already been committed or aborted,
+ * invoking it must have no effect.
+ *
+ * Spark calls {@link #close()} immediately after this method returns.
+ */
+ void abort();
+
+ /**
+ * Releases any resources held by this transaction.
+ *
+ * Spark always calls this method after {@link #commit()} or {@link #abort()}, regardless of
+ * whether those methods succeed or not.
+ *
+ * This method must be idempotent. If the transaction has already been closed,
+ * invoking it must have no effect.
+ */
+ @Override
+ void close();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java
new file mode 100644
index 000000000000..3e6979cec469
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.transactions;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * Metadata about a transaction.
+ *
+ * @since 4.2.0
+ */
+@Evolving
+public interface TransactionInfo {
+ /**
+ * Returns a unique identifier for this transaction.
+ */
+ String id();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java
index 44fc5f9d794b..75816349af38 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java
@@ -85,6 +85,12 @@ default void onDataWriterCommit(WriterCommitMessage message) {}
* disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple
* tasks may have committed successfully and one successful commit message per task will be
* passed to this commit method. The remaining commit messages are ignored by Spark.
+ *
+ * Note: this method signals that all data for this write operation has been successfully written.
+ * It is NOT a transactional commit. When this write is part of a
+ * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction}, the transaction is
+ * committed separately via
+ * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction#commit()}.
*/
void commit(WriterCommitMessage[] messages);
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java
index ab98bc01b3ae..764ed0a35a3f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java
@@ -80,6 +80,12 @@ default boolean useCommitCoordinator() {
* The execution engine may call {@code commit} multiple times for the same epoch in some
* circumstances. To support exactly-once data semantics, implementations must ensure that
* multiple commits for the same epoch are idempotent.
+ *
+ * Note: this method signals that all data for this write operation has been successfully written.
+ * It is NOT a transactional commit. When this write is part of a
+ * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction}, the transaction is
+ * committed separately via
+ * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction#commit()}.
*/
void commit(long epochId, WriterCommitMessage[] messages);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 323a7db9c7ad..50dc055a98a0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -351,6 +351,33 @@ class Analyzer(
}
}
+ /**
+ * Returns a copy of this analyzer that uses the given [[CatalogManager]] for all catalog
+ * lookups. All other configuration (extended rules, checks, etc.) is preserved. Used by
+ * [[QueryExecution]] to create a per-query analyzer for transactional operations for
+ * transaction-aware catalog resolution.
+ *
+ * IMPORTANT: any new extension point added to Analyzer must also be copied here, otherwise
+ * transaction-aware analyzer clones (created by QueryExecution) will silently miss those rules.
+ */
+ def withCatalogManager(newCatalogManager: CatalogManager): Analyzer = {
+ val self = this
+ new Analyzer(newCatalogManager, sharedRelationCache) {
+ override val hintResolutionRules: Seq[Rule[LogicalPlan]] = self.hintResolutionRules
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = self.extendedResolutionRules
+ override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = self.postHocResolutionRules
+ override val extendedCheckRules: Seq[LogicalPlan => Unit] = self.extendedCheckRules
+ override val singlePassResolverExtensions: Seq[ResolverExtension] =
+ self.singlePassResolverExtensions
+ override val singlePassMetadataResolverExtensions: Seq[ResolverExtension] =
+ self.singlePassMetadataResolverExtensions
+ override val singlePassPostHocResolutionRules: Seq[Rule[LogicalPlan]] =
+ self.singlePassPostHocResolutionRules
+ override val singlePassExtendedResolutionChecks: Seq[LogicalPlan => Unit] =
+ self.singlePassExtendedResolutionChecks
+ }
+ }
+
override def execute(plan: LogicalPlan): LogicalPlan = {
AnalysisContext.withNewAnalysisContext {
executeSameContext(plan)
@@ -458,7 +485,9 @@ class Analyzer(
Batch("Simple Sanity Check", Once,
LookupFunctions),
Batch("Keep Legacy Outputs", Once,
- KeepLegacyOutputs)
+ KeepLegacyOutputs),
+ Batch("Unresolve Relations", Once,
+ new UnresolveRelationsInTransaction(catalogManager))
)
override def batches: Seq[Batch] = earlyBatches ++ Seq(
@@ -1015,7 +1044,7 @@ class Analyzer(
// DataSourceV2Relation on each view access. Only dataframe temp view may contain it
// as it stores resolved plans directly.
case view: View if view.isTempViewStoringAnalyzedPlan =>
- view.copy(child = resolveTableReferences(view.child))
+ view.copy(child = resolveTableReferencesInTempView(view.child))
case p @ SubqueryAlias(_, view: View) =>
p.copy(child = resolveViews(view, options))
case _ => plan
@@ -1024,17 +1053,43 @@ class Analyzer(
// Unwrap temp views storing analyzed plans and resolve V2TableReference nodes in the child.
private def unwrapRelationPlan(plan: LogicalPlan): LogicalPlan = {
EliminateSubqueryAliases(plan) match {
- case v: View if v.isTempViewStoringAnalyzedPlan => resolveTableReferences(v.child)
+ case v: View if v.isTempViewStoringAnalyzedPlan => resolveTableReferencesInTempView(v.child)
case other => other
}
}
- // Resolve V2TableReference nodes in a plan. V2TableReference is only created for temp views
- // (via V2TableReference.createForTempView), so we only need to resolve it when returning
+ // Resolve the write target of a V2 write command (batch or streaming).
+ private def resolveWriteTarget(
+ write: LogicalPlan,
+ table: NamedRelation,
+ withNewTable: NamedRelation => LogicalPlan): LogicalPlan = {
+ table match {
+ case u: UnresolvedRelation if !u.isStreaming =>
+ resolveRelation(u).map(unwrapRelationPlan).map {
+ case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError(
+ v.desc.identifier, write)
+ case u: UnresolvedCatalogRelation =>
+ throw QueryCompilationErrors.writeIntoV1TableNotAllowedError(
+ u.tableMeta.identifier, write)
+ case r: DataSourceV2Relation => withNewTable(r)
+ case _ =>
+ throw QueryCompilationErrors.writeIntoTempViewNotAllowedError(
+ u.multipartIdentifier.quoted)
+ }.getOrElse(write)
+ case _ => write
+ }
+ }
+
+ // Resolve V2TableReference nodes inside temp view plans. These are created by
+ // V2TableReference.createForTempView. We only need to resolve it when returning
// the plan of temp views (in resolveViews and unwrapRelationPlan).
- private def resolveTableReferences(plan: LogicalPlan): LogicalPlan = {
+ private def resolveTableReferencesInTempView(plan: LogicalPlan): LogicalPlan = {
plan.resolveOperatorsUp {
- case r: V2TableReference => relationResolution.resolveReference(r)
+ case r: V2TableReference =>
+ assert(r.context.isInstanceOf[V2TableReference.TemporaryViewContext],
+ s"""Expected TemporaryViewContext in temp view but got
+ |${r.context.getClass.getSimpleName}""".stripMargin)
+ relationResolution.resolveReference(r)
}
}
@@ -1057,23 +1112,11 @@ class Analyzer(
case other => i.copy(table = other)
}
- // TODO (SPARK-27484): handle streaming write commands when we have them.
+ case write: StreamingV2WriteCommand =>
+ resolveWriteTarget(write, write.table, write.withNewTable)
+
case write: V2WriteCommand =>
- write.table match {
- case u: UnresolvedRelation if !u.isStreaming =>
- resolveRelation(u).map(unwrapRelationPlan).map {
- case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError(
- v.desc.identifier, write)
- case u: UnresolvedCatalogRelation =>
- throw QueryCompilationErrors.writeIntoV1TableNotAllowedError(
- u.tableMeta.identifier, write)
- case r: DataSourceV2Relation => write.withNewTable(r)
- case _ =>
- throw QueryCompilationErrors.writeIntoTempViewNotAllowedError(
- u.multipartIdentifier.quoted)
- }.getOrElse(write)
- case _ => write
- }
+ resolveWriteTarget(write, write.table, write.withNewTable)
case u: UnresolvedRelation =>
resolveRelation(u).map(resolveViews(_, u.options)).getOrElse(u)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala
index 7a5077a8a3e1..4546c1dad001 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala
@@ -496,10 +496,25 @@ class RelationResolution(
}
}
+ /**
+ * Loads the table for a [[V2TableReference]] and returns a resolved [[DataSourceV2Relation]].
+ *
+ * The catalog is re-resolved by name through the [[CatalogManager]] rather than reusing
+ * [[V2TableReference#catalog]] directly. When a transaction is active, the
+ * [[TransactionAwareCatalogManager]] redirects catalog lookups to the transaction's catalog
+ * instance, so the [[TableCatalog#loadTable]] call is intercepted by the transaction catalog,
+ * which uses it to track which tables are read as part of the transaction.
+ */
private def loadRelation(ref: V2TableReference): LogicalPlan = {
- val table = ref.catalog.loadTable(ref.identifier)
+ val resolvedCatalog = catalogManager.catalog(ref.catalog.name).asTableCatalog
+ val table = resolvedCatalog.loadTable(ref.identifier)
V2TableReferenceUtils.validateLoadedTable(table, ref)
- ref.toRelation(table)
+ DataSourceV2Relation(
+ table = table,
+ output = ref.output,
+ catalog = Some(resolvedCatalog),
+ identifier = Some(ref.identifier),
+ options = ref.options)
}
private def adaptCachedRelation(cached: LogicalPlan, ref: V2TableReference): LogicalPlan = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala
new file mode 100644
index 000000000000..8ee64e32376f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TransactionalWrite}
+import org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.allowInvokingTransformsInAnalyzer
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog}
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+
+/**
+ * When a transaction is active, converts resolved [[DataSourceV2Relation]] nodes back to
+ * [[V2TableReference]] placeholders for all relations loaded by a catalog with the same
+ * name as the transaction catalog.
+ *
+ * This forces re-resolution of those relations against the transaction's catalog, which
+ * intercepts [[TableCatalog#loadTable]] calls to track which tables are read as part of
+ * the transaction.
+ */
+class UnresolveRelationsInTransaction(val catalogManager: CatalogManager)
+ extends Rule[LogicalPlan] with LookupCatalog {
+
+ override def apply(plan: LogicalPlan): LogicalPlan =
+ catalogManager.transaction match {
+ case Some(transaction) =>
+ // We use plain transform rather than resolveOperators* because the latter skips subtrees
+ // that have already been analyzed. Furthermore, allowInvokingTransformsInAnalyzer
+ // allows to suppress the assertNotAnalysisRule safety check, which forbids calling
+ // transform directly inside the analyzer when not within a resolveOperators call.
+ allowInvokingTransformsInAnalyzer {
+ plan.transform {
+ case tw: TransactionalWrite =>
+ unresolveRelations(tw, transaction.catalog)
+ }
+ }
+ case _ => plan
+ }
+
+ private def unresolveRelations(
+ plan: LogicalPlan,
+ catalog: CatalogPlugin): LogicalPlan = {
+ plan.transform {
+ case r: DataSourceV2Relation if isLoadedFromCatalog(r, catalog) =>
+ V2TableReference.createForTransaction(r)
+ }
+ }
+
+ private def isLoadedFromCatalog(
+ relation: DataSourceV2Relation,
+ catalog: CatalogPlugin): Boolean = {
+ relation.catalog.exists(_.name == catalog.name) && relation.identifier.isDefined
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala
index 85c36d452b30..5545141640a3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.V2TableReference.Context
import org.apache.spark.sql.catalyst.analysis.V2TableReference.TableInfo
import org.apache.spark.sql.catalyst.analysis.V2TableReference.TemporaryViewContext
+import org.apache.spark.sql.catalyst.analysis.V2TableReference.TransactionContext
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.plans.logical.Statistics
@@ -37,7 +38,7 @@ import org.apache.spark.sql.connector.catalog.V2TableUtil
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.util.CaseInsensitiveStringMap
-import org.apache.spark.sql.util.SchemaValidationMode.ALLOW_NEW_TOP_LEVEL_FIELDS
+import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_TOP_LEVEL_FIELDS, PROHIBIT_CHANGES}
import org.apache.spark.util.ArrayImplicits._
/**
@@ -79,22 +80,33 @@ private[sql] case class V2TableReference private(
private[sql] object V2TableReference {
case class TableInfo(
+ tableId: Option[String],
columns: Seq[Column],
metadataColumns: Seq[MetadataColumn])
sealed trait Context
+ /** Context for relations that are re-resolved on access of a dataframe temp view. */
case class TemporaryViewContext(viewName: Seq[String]) extends Context
+ /** Context for relations that are re-resolved through a transaction catalog. */
+ case object TransactionContext extends Context
def createForTempView(relation: DataSourceV2Relation, viewName: Seq[String]): V2TableReference = {
create(relation, TemporaryViewContext(viewName))
}
+ // V2TableReference nodes in the transaction context are produced by
+ // UnresolveRelationsInTransaction which unresolves already resolved relations.
+ def createForTransaction(relation: DataSourceV2Relation): V2TableReference = {
+ create(relation, TransactionContext)
+ }
+
private def create(relation: DataSourceV2Relation, context: Context): V2TableReference = {
val ref = V2TableReference(
relation.catalog.get.asTableCatalog,
relation.identifier.get,
relation.options,
TableInfo(
+ tableId = Option(relation.table.id()),
columns = relation.table.columns.toImmutableArraySeq,
metadataColumns = V2TableUtil.extractMetadataColumns(relation)),
relation.output,
@@ -110,11 +122,35 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper {
ref.context match {
case ctx: TemporaryViewContext =>
validateLoadedTableInTempView(table, ref, ctx)
+ case TransactionContext =>
+ validateLoadedTableInTransaction(table, ref)
case ctx =>
throw SparkException.internalError(s"Unknown table ref context: ${ctx.getClass.getName}")
}
}
+ private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = {
+ // Make sure the table was not dropped and recreated.
+ ref.info.tableId.foreach(V2TableUtil.validateTableId(ref.name, _, table))
+
+ // Do not allow schema evolution to pre-analysed dataframes that are later used in
+ // transactional writes. This is because the entire plans was built based on the original schema
+ // and any schema change would make the plan structurally invalid. This is inline with the
+ // semantics of SPARK-54444.
+ val dataErrors = V2TableUtil.validateCapturedColumns(
+ table = table,
+ originCols = ref.info.columns,
+ mode = PROHIBIT_CHANGES)
+ if (dataErrors.nonEmpty) {
+ throw QueryCompilationErrors.columnsChangedAfterAnalysis(ref.name, dataErrors)
+ }
+
+ val metaErrors = V2TableUtil.validateCapturedMetadataColumns(table, ref.info.metadataColumns)
+ if (metaErrors.nonEmpty) {
+ throw QueryCompilationErrors.metadataColumnsChangedAfterAnalysis(ref.name, metaErrors)
+ }
+ }
+
private def validateLoadedTableInTempView(
table: Table,
ref: V2TableReference,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
index c38377582c15..774c783ecf8a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
@@ -188,7 +188,11 @@ case class InsertIntoStatement(
byName: Boolean = false,
replaceCriteriaOpt: Option[InsertReplaceCriteria] = None,
withSchemaEvolution: Boolean = false)
- extends UnaryParsedStatement {
+ // Extends TransactionalWrite so that QueryExecution can detect a potential transaction on the
+ // unresolved logical plan before analysis runs. InsertIntoStatement is shared between V1 and V2
+ // inserts, but the LookupCatalog.TransactionalWrite extractor only matches when the target
+ // catalog implements TransactionalCatalogPlugin, so V1 inserts are never assigned a transaction.
+ extends UnaryParsedStatement with TransactionalWrite {
require(overwrite || !ifPartitionNotExists,
"IF NOT EXISTS is only valid in INSERT OVERWRITE")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 0eded2d9dbdf..15573b157d5c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -157,7 +157,7 @@ case class AppendData(
isByName: Boolean,
withSchemaEvolution: Boolean,
write: Option[Write] = None,
- analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand {
+ analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite {
override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT)
override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery)
override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable)
@@ -205,7 +205,7 @@ case class OverwriteByExpression(
isByName: Boolean,
withSchemaEvolution: Boolean,
write: Option[Write] = None,
- analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand {
+ analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite {
override val writePrivileges: Set[TableWritePrivilege] =
Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)
override lazy val resolved: Boolean = {
@@ -265,7 +265,7 @@ case class OverwritePartitionsDynamic(
writeOptions: Map[String, String],
isByName: Boolean,
withSchemaEvolution: Boolean,
- write: Option[Write] = None) extends V2WriteCommand {
+ write: Option[Write] = None) extends V2WriteCommand with TransactionalWrite {
override val writePrivileges: Set[TableWritePrivilege] =
Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)
override def withNewQuery(newQuery: LogicalPlan): OverwritePartitionsDynamic = {
@@ -521,8 +521,10 @@ case class WriteDelta(
trait V2CreateTableAsSelectPlan
extends V2CreateTablePlan
with AnalysisOnlyCommand
- with CTEInChildren {
+ with CTEInChildren
+ with TransactionalWrite {
def query: LogicalPlan
+ override def table: LogicalPlan = name
override def withCTEDefs(cteDefs: Seq[CTERelationDef]): LogicalPlan = {
withNameAndQuery(newName = name, newQuery = WithCTE(query, cteDefs))
@@ -956,7 +958,8 @@ object DescribeColumn {
*/
case class DeleteFromTable(
table: LogicalPlan,
- condition: Expression) extends UnaryCommand with SupportsSubquery {
+ condition: Expression)
+ extends UnaryCommand with TransactionalWrite with SupportsSubquery {
override def child: LogicalPlan = table
override protected def withNewChildInternal(newChild: LogicalPlan): DeleteFromTable =
copy(table = newChild)
@@ -978,7 +981,8 @@ case class DeleteFromTableWithFilters(
case class UpdateTable(
table: LogicalPlan,
assignments: Seq[Assignment],
- condition: Option[Expression]) extends UnaryCommand with SupportsSubquery {
+ condition: Option[Expression])
+ extends UnaryCommand with TransactionalWrite with SupportsSubquery {
lazy val aligned: Boolean = AssignmentUtils.aligned(table.output, assignments)
@@ -1011,8 +1015,12 @@ case class MergeIntoTable(
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction],
withSchemaEvolution: Boolean)
- extends BinaryCommand with WriteWithSchemaEvolution with SupportsSubquery {
+ extends BinaryCommand
+ with WriteWithSchemaEvolution
+ with SupportsSubquery
+ with TransactionalWrite {
+ // Implements WriteWithSchemaEvolution.table and TransactionalWrite.table.
override val table: LogicalPlan = EliminateSubqueryAliases(targetTable)
override def withNewTable(newTable: NamedRelation): MergeIntoTable = {
@@ -1272,6 +1280,22 @@ case class Assignment(key: Expression, value: Expression) extends Expression
newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight)
}
+/**
+ * Marker trait for write operations that participate in a DSv2 transaction.
+ *
+ * Implementations are expected to target a DSv2 catalog backed by a
+ * [[org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin]].
+ */
+trait TransactionalWrite extends LogicalPlan {
+ def table: LogicalPlan
+}
+
+/** Trait for streaming write commands that participate in DSv2 transactions. */
+trait StreamingV2WriteCommand extends TransactionalWrite {
+ override def table: NamedRelation
+ def withNewTable(newTable: NamedRelation): StreamingV2WriteCommand
+}
+
/**
* The logical plan of the DROP TABLE command.
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala
new file mode 100644
index 000000000000..a5f8afddf01c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.transactions
+
+import java.util.UUID
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin
+import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfoImpl}
+import org.apache.spark.util.Utils
+
+object TransactionUtils {
+ def commit(txn: Transaction): Unit = {
+ Utils.tryWithSafeFinally {
+ txn.commit()
+ } {
+ txn.close()
+ }
+ }
+
+ def abort(txn: Transaction): Unit = {
+ Utils.tryWithSafeFinally {
+ txn.abort()
+ } {
+ txn.close()
+ }
+ }
+
+ def beginTransaction(catalog: TransactionalCatalogPlugin): Transaction = {
+ val info = TransactionInfoImpl(id = UUID.randomUUID.toString)
+ val txn = catalog.beginTransaction(info)
+ if (txn.catalog.name != catalog.name) {
+ abort(txn)
+ throw SparkException.internalError(
+ s"""Transaction catalog name (${txn.catalog.name})
+ |must match original catalog name (${catalog.name}).""".stripMargin)
+ }
+ txn
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
index 3f5afd9ce0de..c851e931aad4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
@@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, TempVariableManager}
import org.apache.spark.sql.catalyst.util.StringUtils
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
@@ -39,7 +40,7 @@ import org.apache.spark.sql.internal.SQLConf
// need to track current database at all.
private[sql]
class CatalogManager(
- defaultSessionCatalog: CatalogPlugin,
+ val defaultSessionCatalog: CatalogPlugin,
val v1SessionCatalog: SessionCatalog) extends SQLConfHelper with Logging {
import CatalogManager.SESSION_CATALOG_NAME
import CatalogV2Util._
@@ -57,6 +58,11 @@ class CatalogManager(
}
}
+ def transaction: Option[Transaction] = None
+
+ def withTransaction(transaction: Transaction): CatalogManager =
+ new TransactionAwareCatalogManager(this, transaction)
+
def isCatalogRegistered(name: String): Boolean = {
try {
catalog(name)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
index 203cfc23452a..dd5be45bfc5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.connector.catalog
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedIdentifier, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.plans.logical.TransactionalWrite
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
@@ -163,4 +165,17 @@ private[sql] trait LookupCatalog extends Logging {
}
}
}
+
+ object TransactionalWrite {
+ def unapply(write: TransactionalWrite): Option[TransactionalCatalogPlugin] = {
+ EliminateSubqueryAliases(write.table) match {
+ case UnresolvedRelation(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _, _) =>
+ Some(c)
+ case UnresolvedIdentifier(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _) =>
+ Some(c)
+ case _ =>
+ None
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala
new file mode 100644
index 000000000000..70079357b6dd
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.catalog.TempVariableManager
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
+
+/**
+ * A [[CatalogManager]] decorator that redirects catalog lookups to the transaction's catalog
+ * instance when names match, ensuring table loads during analysis are scoped to the transaction.
+ * All mutable state (current catalog, current namespace, loaded catalogs) is delegated to the
+ * wrapped [[CatalogManager]].
+ */
+// TODO: Extracting a CatalogManager trait (so this class can implement it instead of extending
+// CatalogManager) would eliminate the inherited mutable state that this decorator doesn't use.
+private[sql] class TransactionAwareCatalogManager(
+ delegate: CatalogManager,
+ txn: Transaction)
+ extends CatalogManager(delegate.defaultSessionCatalog, delegate.v1SessionCatalog) {
+
+ override val tempVariableManager: TempVariableManager = delegate.tempVariableManager
+
+ override def transaction: Option[Transaction] = Some(txn)
+
+ override def withTransaction(newTxn: Transaction): CatalogManager =
+ throw SparkException.internalError("Cannot nest transactions: a transaction is already active.")
+
+ override def catalog(name: String): CatalogPlugin = {
+ val resolved = delegate.catalog(name)
+ if (txn.catalog.name() == resolved.name()) txn.catalog else resolved
+ }
+
+ override def currentCatalog: CatalogPlugin = {
+ val c = delegate.currentCatalog
+ if (txn.catalog.name() == c.name()) txn.catalog else c
+ }
+
+ override def currentNamespace: Array[String] = delegate.currentNamespace
+
+ override def setCurrentNamespace(namespace: Array[String]): Unit =
+ delegate.setCurrentNamespace(namespace)
+
+ override def setCurrentCatalog(catalogName: String): Unit =
+ delegate.setCurrentCatalog(catalogName)
+
+ override def listCatalogs(pattern: Option[String]): Seq[String] =
+ delegate.listCatalogs(pattern)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala
index c7f7b17a5843..af7edce47427 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, MetadataColumnHelper}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.sql.util.SchemaValidationMode
@@ -131,6 +132,20 @@ private[sql] object V2TableUtil extends SQLConfHelper {
case _ => Seq.empty
}
+ /**
+ * Validates that the identity of a loaded table matches a previously captured table id.
+ * Throws if the table was dropped and recreated under the same name (which changes the id).
+ * No-op if the connector does not support table ids (capturedId is null).
+ */
+ def validateTableId(name: String, capturedId: String, currentTable: Table): Unit = {
+ if (capturedId != null && capturedId != currentTable.id) {
+ throw QueryCompilationErrors.tableIdChangedAfterAnalysis(
+ name,
+ capturedTableId = capturedId,
+ currentTableId = currentTable.id)
+ }
+ }
+
private def normalize(name: String): String = {
if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala
new file mode 100644
index 000000000000..4cb53da0a59e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.transactions
+
+case class TransactionInfoImpl(id: String) extends TransactionInfo
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala
new file mode 100644
index 000000000000..02cfe6b4eb7e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.analysis.resolver.{ResolverExtension, TreeNodeResolver}
+import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.CatalogManager
+
+/**
+ * Verifies that [[Analyzer.withCatalogManager]] propagates all extension points.
+ *
+ * If this suite fails with an unexpected method count, a new extension point was added to
+ * [[Analyzer.withCatalogManager]] without being verified here. Add the corresponding assertion
+ * and update the expected count.
+ *
+ * If [[Analyzer]] gains a new extension point that is NOT yet in [[Analyzer.withCatalogManager]],
+ * add it there first, then update this suite.
+ */
+class AnalyzerExtensionPropagationSuite extends SparkFunSuite {
+
+ private val dummyRule: Rule[LogicalPlan] = new Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan
+ }
+
+ private val dummyCheck: LogicalPlan => Unit = (_: LogicalPlan) => ()
+
+ private val dummyExtension: ResolverExtension = new ResolverExtension {
+ override def resolveOperator(
+ operator: LogicalPlan,
+ resolver: TreeNodeResolver[LogicalPlan, LogicalPlan]): Option[LogicalPlan] = None
+ }
+
+ private def newCatalogManager(): CatalogManager =
+ new CatalogManager(
+ FakeV2SessionCatalog,
+ new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry))
+
+ test("withCatalogManager propagates all extension points") {
+ val analyzer = new Analyzer(newCatalogManager()) {
+ override val hintResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule)
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule)
+ override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule)
+ override val extendedCheckRules: Seq[LogicalPlan => Unit] = Seq(dummyCheck)
+ override val singlePassResolverExtensions: Seq[ResolverExtension] = Seq(dummyExtension)
+ override val singlePassMetadataResolverExtensions: Seq[ResolverExtension] =
+ Seq(dummyExtension)
+ override val singlePassPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule)
+ override val singlePassExtendedResolutionChecks: Seq[LogicalPlan => Unit] = Seq(dummyCheck)
+ }
+
+ val clone = analyzer.withCatalogManager(newCatalogManager())
+
+ assert(clone.hintResolutionRules eq analyzer.hintResolutionRules)
+ assert(clone.extendedResolutionRules eq analyzer.extendedResolutionRules)
+ assert(clone.postHocResolutionRules eq analyzer.postHocResolutionRules)
+ assert(clone.extendedCheckRules eq analyzer.extendedCheckRules)
+ assert(clone.singlePassResolverExtensions eq analyzer.singlePassResolverExtensions)
+ assert(clone.singlePassMetadataResolverExtensions eq
+ analyzer.singlePassMetadataResolverExtensions)
+ assert(clone.singlePassPostHocResolutionRules eq analyzer.singlePassPostHocResolutionRules)
+ assert(clone.singlePassExtendedResolutionChecks eq analyzer.singlePassExtendedResolutionChecks)
+
+ // Verify the clone's anonymous class overrides exactly the expected extension points.
+ // If this assertion fails, withCatalogManager was updated but this test was not.
+ // Add the corresponding assert above and update the expected set.
+ val overriddenMethods = clone.getClass.getDeclaredMethods
+ .filterNot(m => m.isSynthetic || m.isBridge || m.getName.contains("$"))
+ .map(_.getName)
+ .toSet
+
+ val expectedExtensions = Set(
+ "hintResolutionRules",
+ "extendedResolutionRules",
+ "postHocResolutionRules",
+ "extendedCheckRules",
+ "singlePassResolverExtensions",
+ "singlePassMetadataResolverExtensions",
+ "singlePassPostHocResolutionRules",
+ "singlePassExtendedResolutionChecks"
+ )
+
+ assert(overriddenMethods == expectedExtensions,
+ s"withCatalogManager does not copy the expected set of extension points. " +
+ s"Missing from withCatalogManager: ${expectedExtensions -- overriddenMethods}. " +
+ s"Unexpected overrides: ${overriddenMethods -- expectedExtensions}.")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala
new file mode 100644
index 000000000000..ee771fe3f246
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.transactions
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, TransactionalCatalogPlugin}
+import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class TransactionUtilsSuite extends SparkFunSuite {
+ val testCatalogName = "test_catalog"
+
+ // --- Helpers ---------------------------------------------------------------
+ private def mockCatalog(catalogName: String): CatalogPlugin = new CatalogPlugin {
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = ()
+ override def name(): String = catalogName
+ }
+
+ private val emptyFunction = () => ()
+ private class TestTransaction(
+ catalogName: String,
+ onCommit: () => Unit = emptyFunction,
+ onAbort: () => Unit = emptyFunction,
+ onClose: () => Unit = emptyFunction) extends Transaction {
+ var committed = false
+ var aborted = false
+ var closed = false
+
+ override def catalog(): CatalogPlugin = mockCatalog(catalogName)
+ override def commit(): Unit = { committed = true; onCommit() }
+ override def abort(): Unit = { aborted = true; onAbort() }
+ override def close(): Unit = { closed = true; onClose() }
+ }
+
+ private def mockTransactionalCatalog(
+ catalogName: String,
+ txnCatalogName: String = null): TransactionalCatalogPlugin = {
+ val resolvedTxnCatalogName = Option(txnCatalogName).getOrElse(catalogName)
+ new TransactionalCatalogPlugin {
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = ()
+ override def name(): String = catalogName
+ override def beginTransaction(info: TransactionInfo): Transaction =
+ new TestTransaction(resolvedTxnCatalogName)
+ }
+ }
+
+ // --- Commit ----------------------------------------------------------------
+ test("commit: calls commit then close") {
+ val txn = new TestTransaction(testCatalogName)
+ TransactionUtils.commit(txn)
+ assert(txn.committed)
+ assert(txn.closed)
+ }
+
+ test("commit: close is called even if commit fails") {
+ val txn = new TestTransaction(
+ testCatalogName, onCommit = () => throw new RuntimeException("commit failed"))
+ intercept[RuntimeException] { TransactionUtils.commit(txn) }
+ assert(txn.closed)
+ }
+
+ // --- Abort -----------------------------------------------------------------
+ test("abort: calls abort then close") {
+ val txn = new TestTransaction(testCatalogName)
+ TransactionUtils.abort(txn)
+ assert(txn.aborted)
+ assert(txn.closed)
+ }
+
+ test("abort: close is called even if abort fails") {
+ val txn = new TestTransaction(testCatalogName,
+ onAbort = () => throw new RuntimeException("abort failed"))
+ intercept[RuntimeException] { TransactionUtils.abort(txn) }
+ assert(txn.closed)
+ }
+
+ // --- Begin Transaction -----------------------------------------------------
+ test("beginTransaction: returns transaction when catalog names match") {
+ val catalog = mockTransactionalCatalog(testCatalogName)
+ val txn = TransactionUtils.beginTransaction(catalog)
+ assert(txn.catalog().name() == testCatalogName)
+ }
+
+ test("beginTransaction: fails when transaction catalog name does not match") {
+ val catalog = mockTransactionalCatalog(catalogName = testCatalogName, txnCatalogName = "other")
+ val e = intercept[SparkException] {
+ TransactionUtils.beginTransaction(catalog)
+ }
+ assert(e.getMessage.contains("other"))
+ assert(e.getMessage.contains(testCatalogName))
+ }
+
+ test("beginTransaction: aborts and closes transaction on catalog name mismatch") {
+ var aborted = false
+ var closed = false
+ val catalog = new TransactionalCatalogPlugin {
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = ()
+ override def name(): String = testCatalogName
+ override def beginTransaction(info: TransactionInfo): Transaction =
+ new TestTransaction(
+ "other",
+ onAbort = () => { aborted = true },
+ onClose = () => { closed = true })
+ }
+ intercept[SparkException] { TransactionUtils.beginTransaction(catalog) }
+ assert(aborted)
+ assert(closed)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index fd2c0f6e9c2e..5a1e3a30b150 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, Cus
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram, HistogramBin}
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.internal.SQLConf
@@ -76,6 +77,10 @@ abstract class InMemoryBaseTable(
override def columns(): Array[Column] = tableColumns
+ private[catalog] def updateColumns(newColumns: Array[Column]): Unit = {
+ tableColumns = newColumns
+ }
+
override def version(): String = tableVersion.toString
def setVersion(version: String): Unit = {
@@ -94,6 +99,8 @@ abstract class InMemoryBaseTable(
validatedTableVersion = version
}
+ protected def recordScanEvent(filters: Array[Filter]): Unit = {}
+
protected object PartitionKeyColumn extends MetadataColumn {
override def name: String = "_partition"
override def dataType: DataType = StringType
@@ -406,6 +413,7 @@ abstract class InMemoryBaseTable(
def baseCapabiilities: Set[TableCapability] = Set(
TableCapability.BATCH_READ,
+ TableCapability.MICRO_BATCH_READ,
TableCapability.BATCH_WRITE,
TableCapability.STREAMING_WRITE,
TableCapability.OVERWRITE_BY_FILTER,
@@ -455,6 +463,7 @@ abstract class InMemoryBaseTable(
if (evaluableFilters.nonEmpty) {
scan.filter(evaluableFilters)
}
+ recordScanEvent(_pushedFilters)
scan
}
@@ -494,6 +503,33 @@ abstract class InMemoryBaseTable(
case class InMemoryHistogram(height: Double, bins: Array[HistogramBin]) extends Histogram
+ private class InMemoryTableOffset(val rowCount: Long) extends Offset {
+ override def json(): String = rowCount.toString
+ }
+
+ class InMemoryMicroBatchStream extends MicroBatchStream {
+ override def initialOffset(): Offset = new InMemoryTableOffset(0)
+ override def latestOffset(): Offset =
+ new InMemoryTableOffset(InMemoryBaseTable.this.rows.size.toLong)
+ override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = {
+ val s = start.asInstanceOf[InMemoryTableOffset].rowCount.toInt
+ val e = end.asInstanceOf[InMemoryTableOffset].rowCount.toInt
+ Array(InMemoryMicroBatchPartition(InMemoryBaseTable.this.rows.slice(s, e)))
+ }
+ override def createReaderFactory(): PartitionReaderFactory = { partition =>
+ val rows = partition.asInstanceOf[InMemoryMicroBatchPartition].rows
+ new PartitionReader[InternalRow] {
+ private var idx = -1
+ override def next(): Boolean = { idx += 1; idx < rows.size }
+ override def get(): InternalRow = rows(idx)
+ override def close(): Unit = {}
+ }
+ }
+ override def deserializeOffset(json: String): Offset = new InMemoryTableOffset(json.toLong)
+ override def commit(end: Offset): Unit = {}
+ override def stop(): Unit = {}
+ }
+
abstract class BatchScanBaseClass(
var data: Seq[InputPartition],
readSchema: StructType,
@@ -579,6 +615,9 @@ abstract class InMemoryBaseTable(
override def supportedCustomMetrics(): Array[CustomMetric] = {
Array(new RowsReadCustomMetric)
}
+
+ override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream =
+ new InMemoryMicroBatchStream
}
case class InMemoryBatchScan(
@@ -799,6 +838,11 @@ object InMemoryBaseTable {
}
}
+/**
+ * A partition for [[InMemoryBaseTable]] micro-batch streaming reads, holding a slice of rows.
+ */
+case class InMemoryMicroBatchPartition(rows: Seq[InternalRow]) extends InputPartition
+
/**
* Represent a set of rows buffered in memory for a given partition key.
* @param key partition key
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
index 91e899bc1169..406d83aa86ab 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
@@ -38,13 +38,15 @@ class InMemoryRowLevelOperationTable(
schema: StructType,
partitioning: Array[Transform],
properties: util.Map[String, String],
- constraints: Array[Constraint] = Array.empty)
+ constraints: Array[Constraint] = Array.empty,
+ tableId: String = java.util.UUID.randomUUID().toString)
extends InMemoryTable(
name,
CatalogV2Util.structTypeToV2Columns(schema),
partitioning,
properties,
- constraints)
+ constraints,
+ id = tableId)
with SupportsRowLevelOperations {
private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
index bbb9041bab37..95d5975d269f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
@@ -17,11 +17,31 @@
package org.apache.spark.sql.connector.catalog
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo}
+import org.apache.spark.sql.types.StructType
-class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog {
+class InMemoryRowLevelOperationTableCatalog
+ extends InMemoryTableCatalog
+ with TransactionalCatalogPlugin {
import CatalogV2Implicits._
+ // The current active transaction.
+ var transaction: Txn = _
+ // The last completed transaction.
+ var lastTransaction: Txn = _
+ // All transactions in order (committed and aborted), allowing per-statement
+ // validation in SQL scripting tests.
+ val observedTransactions: ArrayBuffer[Txn] = new ArrayBuffer[Txn]()
+
+ override def beginTransaction(info: TransactionInfo): Transaction = {
+ assert(transaction == null || transaction.currentState != Active)
+ this.transaction = new Txn(new TxnTableCatalog(this))
+ transaction
+ }
+
override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
if (tables.containsKey(ident)) {
throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
@@ -41,11 +61,7 @@ class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog {
override def alterTable(ident: Identifier, changes: TableChange*): Table = {
val table = loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable]
val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes)
- val schema = CatalogV2Util.applySchemaChanges(
- table.schema,
- changes,
- tableProvider = Some("in-memory"),
- statementType = "ALTER TABLE")
+ val schema = computeAlterTableSchema(table.schema, changes.toSeq)
val partitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes)
val constraints = CatalogV2Util.collectConstraintChanges(table, changes)
@@ -59,13 +75,24 @@ class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog {
schema = schema,
partitioning = partitioning,
properties = properties,
- constraints = constraints)
+ constraints = constraints,
+ tableId = table.id)
newTable.alterTableWithData(table.data, schema)
tables.put(ident, newTable)
newTable
}
+
+ /**
+ * Computes the schema that would result from applying `changes` to `currentSchema`.
+ * Can be overridden by subclasses to simulate catalogs that selectively ignore changes
+ * (e.g. [[PartialSchemaEvolutionCatalog]]).
+ */
+ def computeAlterTableSchema(currentSchema: StructType, changes: Seq[TableChange]): StructType = {
+ CatalogV2Util.applySchemaChanges(
+ currentSchema, changes, tableProvider = Some("in-memory"), statementType = "ALTER TABLE")
+ }
}
/**
@@ -84,9 +111,10 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo
case _ => false
}
val properties = CatalogV2Util.applyPropertiesChanges(table.properties, propertyChanges)
+ val schema = computeAlterTableSchema(table.schema, changes.toSeq)
val newTable = new InMemoryRowLevelOperationTable(
name = table.name,
- schema = table.schema,
+ schema = schema,
partitioning = table.partitioning,
properties = properties,
constraints = table.constraints)
@@ -94,4 +122,9 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo
tables.put(ident, newTable)
newTable
}
+
+ // Ignores all schema changes and returns the current schema unchanged.
+ override def computeAlterTableSchema(
+ currentSchema: StructType,
+ changes: Seq[TableChange]): StructType = currentSchema
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index d5738475031d..15ed4136dbda 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
/**
@@ -215,6 +216,15 @@ class InMemoryTable(
object InMemoryTable {
+ // Convert UTF8String to string to make sure equality checks between filters and partitions
+ // work correctly.
+ private def valuesEqual(filterValue: Any, partitionValue: Any): Boolean =
+ (filterValue, partitionValue) match {
+ case (s: String, u: UTF8String) => u.toString == s
+ case (u: UTF8String, s: String) => u.toString == s
+ case _ => filterValue == partitionValue
+ }
+
def filtersToKeys(
keys: Iterable[Seq[Any]],
partitionNames: Seq[String],
@@ -222,7 +232,7 @@ object InMemoryTable {
keys.filter { partValues =>
filters.flatMap(splitAnd).forall {
case EqualTo(attr, value) =>
- value == InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
+ valuesEqual(value, InMemoryBaseTable.extractValue(attr, partitionNames, partValues))
case EqualNullSafe(attr, value) =>
val attrVal = InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
if (attrVal == null && value == null) {
@@ -230,7 +240,7 @@ object InMemoryTable {
} else if (attrVal == null || value == null) {
false
} else {
- value == attrVal
+ valuesEqual(value, attrVal)
}
case IsNull(attr) =>
null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
index ff7995ad6697..5ce1804c11a7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
@@ -40,7 +40,7 @@ class BasicInMemoryTableCatalog extends TableCatalog {
protected val namespaces: util.Map[List[String], Map[String, String]] =
new ConcurrentHashMap[List[String], Map[String, String]]()
- protected val tables: util.Map[Identifier, Table] =
+ protected var tables: util.Map[Identifier, Table] =
new ConcurrentHashMap[Identifier, Table]()
private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala
new file mode 100644
index 000000000000..bd7d6d689f51
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, RowLevelOperationBuilder, RowLevelOperationInfo, WriteBuilder}
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+sealed trait TransactionState
+case object Active extends TransactionState
+case object Committed extends TransactionState
+case object Aborted extends TransactionState
+
+class Txn(override val catalog: TxnTableCatalog) extends Transaction {
+
+ private[this] var state: TransactionState = Active
+ private[this] var closed: Boolean = false
+
+ def currentState: TransactionState = state
+
+ def isClosed: Boolean = closed
+
+ override def commit(): Unit = {
+ if (closed) throw new IllegalStateException("Can't commit, already closed")
+ if (state == Aborted) throw new IllegalStateException("Can't commit, already aborted")
+ catalog.commit()
+ this.state = Committed
+ }
+
+ // This is idempotent since nested QEs can cause multiple aborts.
+ override def abort(): Unit = {
+ if (state == Committed || state == Aborted) return
+ this.state = Aborted
+ }
+
+ // This is idempotent since nested QEs can cause multiple aborts.
+ override def close(): Unit = {
+ if (!closed) {
+ catalog.clearActiveTransaction()
+ this.closed = true
+ }
+ }
+}
+
+// A special table used in row-level operation transactions. It inherits data
+// from the base table upon construction and propagates staged transaction state
+// back after an explicit commit.
+// Note, the in-memory data store does not handle concurrency at the moment. The assumes that the
+// underlying delegate table cannot change from concurrent transactions. Data sources need to
+// implement isolation semantics and make sure they are enforced.
+class TxnTable(
+ val delegate: InMemoryRowLevelOperationTable,
+ schema: StructType,
+ catalog: TxnTableCatalog)
+ extends InMemoryRowLevelOperationTable(
+ delegate.name,
+ schema,
+ delegate.partitioning,
+ delegate.properties,
+ delegate.constraints) {
+
+ // Expose the same id as the delegate so that identity checks during transaction re-resolution
+ // don't false-positive on the TxnTable wrapper having a different UUID.
+ override val id: String = delegate.id
+
+ alterTableWithData(delegate.data, schema)
+
+ // A tracker of filters used in each scan.
+ val scanEvents = new ArrayBuffer[Array[Filter]]()
+
+ // Record scan events. This is invoked when building a scan for the particular table.
+ override protected def recordScanEvent(filters: Array[Filter]): Unit = {
+ scanEvents += filters
+ }
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ catalog.writeTarget = this
+ super.newWriteBuilder(info)
+ }
+
+ override def newRowLevelOperationBuilder(
+ info: RowLevelOperationInfo): RowLevelOperationBuilder = {
+ catalog.writeTarget = this
+ super.newRowLevelOperationBuilder(info)
+ }
+
+ override def deleteWhere(filters: Array[Filter]): Unit = {
+ catalog.writeTarget = this
+ super.deleteWhere(filters)
+ }
+
+ // Propagates staged data and metadata changes to the delegate table.
+ def commit(): Unit = {
+ delegate.dataMap.clear()
+ delegate.updateColumns(columns()) // Evolve schema if needed.
+ delegate.alterTableWithData(data, schema)
+ delegate.replacedPartitions = replacedPartitions
+ delegate.lastWriteInfo = lastWriteInfo
+ delegate.lastWriteLog = lastWriteLog
+ delegate.commits ++= commits
+ delegate.increaseVersion()
+ }
+}
+
+// A special table catalog used in row-level operation transactions. The lifecycle of this catalog
+// is tied to the transaction. A new catalog instance is created at the beginning of a transaction
+// and discarded at the end. The catalog is responsible for pinning all tables involved in the
+// transaction. Table changes are initially staged in memory and propagated only after an explicit
+// commit.
+class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends TableCatalog {
+
+ private val tables: util.Map[Identifier, TxnTable] = new ConcurrentHashMap[Identifier, TxnTable]()
+
+ var writeTarget: TxnTable = _
+
+ override def name: String = delegate.name
+
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {}
+
+ override def listTables(namespace: Array[String]): Array[Identifier] = {
+ throw new UnsupportedOperationException()
+ }
+
+ // This is where the table pinning logic should occur. In this implementation, a tables is loaded
+ // (pinned) the first time is accessed. All subsequent accesses should return the same pinned
+ // table.
+ override def loadTable(ident: Identifier): Table = {
+ tables.computeIfAbsent(ident, _ => {
+ val table = delegate.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable]
+ new TxnTable(table, table.schema(), this)
+ })
+ }
+
+ override def alterTable(ident: Identifier, changes: TableChange*): Table = {
+ // AlterTable may be called by ResolveSchemaEvolution when schema evolution is enabled. Thus,
+ // it needs to be transactional. The schema changes are only propagated to the delegate at
+ // commit time.
+ //
+ // We delegate schema computation to the underlying catalog so that catalogs with special
+ // handling (e.g. PartialSchemaEvolutionCatalog) have the same behaviour inside a
+ // transaction.
+ val txnTable = tables.get(ident)
+ val schema = delegate.computeAlterTableSchema(txnTable.schema, changes.toSeq)
+
+ if (schema.fields.isEmpty) {
+ throw new IllegalArgumentException(s"Cannot drop all fields")
+ }
+
+ // TODO: We need to pass all tracked predicates to the new TXN table.
+ val newTxnTable = new TxnTable(txnTable.delegate, schema, this)
+ tables.put(ident, newTxnTable)
+ newTxnTable
+ }
+
+ // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented.
+ override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
+ delegate.createTable(ident, tableInfo)
+ loadTable(ident)
+ }
+
+ // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented.
+ override def dropTable(ident: Identifier): Boolean = {
+ tables.remove(ident)
+ delegate.dropTable(ident)
+ }
+
+ override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = {
+ throw new UnsupportedOperationException()
+ }
+
+ // Returns all tables that participated in this transaction, keyed by identifier.
+ def txnTables: scala.collection.Map[Identifier, TxnTable] = tables.asScala
+
+ // Commit the write target table, propagating staged changes to the delegate.
+ def commit(): Unit = {
+ if (writeTarget != null) writeTarget.commit()
+ }
+
+ // Clear transaction context.
+ def clearActiveTransaction(): Unit = {
+ val txn = delegate.transaction
+ delegate.lastTransaction = txn
+ delegate.observedTransactions += txn
+ delegate.transaction = null
+ }
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case that: CatalogPlugin => this.name == that.name
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = name.hashCode()
+}
+
+/**
+ * An InMemoryRowLevelOperationTableCatalog that utilizes tables backed by a shared map. This
+ * simulates the behavior of real catalogs (Delta, Iceberg, etc.) where multiple instances
+ * of the catalog share the same underlying persistent storage, thus, they see the same tables.
+ *
+ * This is needed for testing execution that spans multiple Spark sessions. In particular,
+ * streaming queries execute micro-batches in cloned Spark sessions. Without this, the cloned
+ * spark session catalog will not see any tables created in the original session.
+ *
+ * Tests that use this catalog must call
+ * [[SharedTablesInMemoryRowLevelOperationTableCatalog.reset()]] in `afterEach` to clear the
+ * shared state between test cases.
+ */
+class SharedTablesInMemoryRowLevelOperationTableCatalog
+ extends InMemoryRowLevelOperationTableCatalog {
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {
+ super.initialize(name, options)
+ tables = SharedTablesInMemoryRowLevelOperationTableCatalog.sharedTables
+ }
+}
+
+object SharedTablesInMemoryRowLevelOperationTableCatalog {
+ private[catalog] val sharedTables: ConcurrentHashMap[Identifier, Table] =
+ new ConcurrentHashMap[Identifier, Table]()
+
+ def reset(): Unit = sharedTables.clear()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index c0ab906de484..5bfb73e6b887 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -22,28 +22,32 @@ import java.util.UUID
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
import javax.annotation.concurrent.GuardedBy
+import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkContext, SparkException}
-import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row}
import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.{LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubqueryAliases, LazyExpression, NameParameterizedQuery, UnresolvedRelation, UnsupportedOperationChecker}
import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, WithCTE}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, TransactionalWrite => TransactionalWritePlan, Union, UnresolvedWith, WithCTE}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
+import org.apache.spark.sql.catalyst.transactions.TransactionUtils
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.classic.SparkSession
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, LookupCatalog, SupportsCatalogOptions, TableCatalog, TransactionalCatalogPlugin}
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
import org.apache.spark.sql.execution.SQLExecution.EXECUTION_ROOT_ID_KEY
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan}
import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan}
-import org.apache.spark.sql.execution.datasources.v2.V2TableRefreshUtil
+import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, TransactionalExec, V2TableRefreshUtil}
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
@@ -52,6 +56,7 @@ import org.apache.spark.sql.execution.streaming.runtime.{IncrementalExecution, W
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.scripting.SqlScriptingExecution
import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.{LazyTry, Utils, UUIDv7Generator}
import org.apache.spark.util.ArrayImplicits._
@@ -69,7 +74,12 @@ class QueryExecution(
val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL,
val shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None,
val refreshPhaseEnabled: Boolean = true,
- val queryId: UUID = UUIDv7Generator.generate()) extends Logging {
+ val queryId: UUID = UUIDv7Generator.generate(),
+ // When a transaction is active, callers creating nested QueryExecution instances MUST pass
+ // the enclosing QueryExecution's analyzer here to propagate the transaction context.
+ // Omitting it causes the nested QE to use sessionState.analyzer, which has no knowledge
+ // of the transaction and will load tables outside the transaction's catalog scope.
+ val analyzerOpt: Option[Analyzer] = None) extends LookupCatalog {
val id: Long = QueryExecution.nextExecutionId
@@ -79,6 +89,8 @@ class QueryExecution(
// TODO: Move the planner an optimizer into here from SessionState.
protected def planner = sparkSession.sessionState.planner
+ protected val catalogManager = sparkSession.sessionState.catalogManager
+
/**
* Check whether the query represented by this QueryExecution is a SQL script.
* @return True if the query is a SQL script, False otherwise.
@@ -90,6 +102,82 @@ class QueryExecution(
logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyExpression])))
}
+
+ // 1. At the pre-Analyzed plan we look for nodes that implement the TransactionalWrite trait.
+ // When a plan contains such a node we initiate a transaction. Note, we should never start
+ // a transaction for operations that are not executed, e.g. EXPLAIN.
+ // 2. Create an analyzer clone with a transaction aware Catalog Manager. The latter is the
+ // narrow waist of all catalog accesses, and it is also the transaction context carrier.
+ // This is then passed to all rules during analysis that need to check the catalog. Rules
+ // that are specifically interested in transactionality can access the transaction directly
+ // from the Catalog Manager. The transaction catalog, is potentially the place where connectors
+ // should keep state about the reads (tables+predicates) that occurred during the transaction.
+ // 3. The analyzer instance is passed to nested Query Execution instances. These need to respect
+ // the open transaction instead of creating their own.
+ private lazy val transactionOpt: Option[Transaction] =
+ // Always inherit an active transaction from the outer analyzer, regardless of mode.
+ analyzerOpt.flatMap(_.catalogManager.transaction).orElse {
+ // Only begin a new transaction for outer QEs that lead to execution.
+ if (mode != CommandExecutionMode.SKIP) {
+ def resolve(w: TransactionalWritePlan): Option[TransactionalCatalogPlugin] =
+ pathBased(w) match {
+ case Some(c: TransactionalCatalogPlugin) => Some(c)
+ case Some(_) => None
+ // If the path is not data source based, fallback to catalog based resolution.
+ case None => TransactionalWrite.unapply(w)
+ }
+ val catalog = logical match {
+ case UnresolvedWith(w: TransactionalWritePlan, _, _) => resolve(w)
+ case w: TransactionalWritePlan => resolve(w)
+ case _ => None
+ }
+ catalog.map(TransactionUtils.beginTransaction)
+ } else {
+ None
+ }
+ }
+
+ // For path-based tables (e.g. `format.`/path/to/table``) the first identifier part is a
+ // connector name. SupportsCatalogOptions on the connector tells us which catalog actually
+ // owns the table. Returns Some(catalog) if parts.head is a recognized SupportsCatalogOptions
+ // data source (caller decides whether the catalog is transactional), or None to fall through
+ // to the catalog-based extractor.
+ private def pathBased(write: TransactionalWritePlan): Option[TableCatalog] =
+ EliminateSubqueryAliases(write.table) match {
+ case UnresolvedRelation(parts, _, _) if parts.length > 1 =>
+ try {
+ DataSource.lookupDataSourceV2(parts.head, sparkSession.sessionState.conf)
+ .collect { case sco: SupportsCatalogOptions => sco }
+ .map { sco =>
+ val sessionConfigs = DataSourceV2Utils.extractSessionConfigs(
+ sco, sparkSession.sessionState.conf)
+ // Pass the entire identifier as option. The connector can decide how to parse it
+ // if needed.
+ val options = sessionConfigs + ("identifier" -> parts.mkString("."))
+ CatalogV2Util.getTableProviderCatalog(
+ sco, catalogManager, new CaseInsensitiveStringMap(options.asJava))
+ }
+ } catch {
+ // The head of the multipart identifier is not a registered data source.
+ // Fallback to catalog-based detection.
+ case _: ClassNotFoundException => None
+ }
+ case _ => None
+ }
+
+ // Per-query analyzer: uses a transaction-aware CatalogManager when a transaction is active,
+ // so that all catalog lookups and rule applications during analysis see the correct state
+ // without relying on thread-local context. Any nested QueryExecution that is created during
+ // analysis or execution of a transactional plan must receive this analyzer via analyzerOpt.
+ private lazy val analyzer: Analyzer = analyzerOpt.getOrElse {
+ transactionOpt match {
+ case Some(txn) =>
+ sparkSession.sessionState.analyzer.withCatalogManager(catalogManager.withTransaction(txn))
+ case None =>
+ sparkSession.sessionState.analyzer
+ }
+ }
+
def assertAnalyzed(): Unit = {
try {
analyzed
@@ -102,7 +190,7 @@ class QueryExecution(
}
}
- def assertSupported(): Unit = {
+ def assertSupported(): Unit = withAbortTransactionOnFailure {
if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) {
UnsupportedOperationChecker.checkForBatch(analyzed)
}
@@ -141,7 +229,7 @@ class QueryExecution(
try {
val plan = executePhase(QueryPlanningTracker.ANALYSIS) {
// We can't clone `logical` here, which will reset the `_analyzed` flag.
- sparkSession.sessionState.analyzer.executeAndCheck(sqlScriptExecuted, tracker)
+ analyzer.executeAndCheck(sqlScriptExecuted, tracker)
}
tracker.setAnalyzed(plan)
plan
@@ -152,7 +240,9 @@ class QueryExecution(
}
}
- def analyzed: LogicalPlan = lazyAnalyzed.get
+ def analyzed: LogicalPlan = withAbortTransactionOnFailure {
+ lazyAnalyzed.get
+ }
private val lazyCommandExecuted = LazyTry {
mode match {
@@ -162,7 +252,9 @@ class QueryExecution(
}
}
- def commandExecuted: LogicalPlan = lazyCommandExecuted.get
+ def commandExecuted: LogicalPlan = withAbortTransactionOnFailure {
+ lazyCommandExecuted.get
+ }
private def commandExecutionName(command: Command): String = command match {
case _: CreateTableAsSelect => "create"
@@ -184,7 +276,7 @@ class QueryExecution(
// for eagerly executed commands we mark this place as beginning of execution.
tracker.setReadyForExecution()
val (qe, result) = QueryExecution.runCommand(
- sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode))
+ sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode), Some(analyzer))
CommandResult(
qe.analyzed.output,
qe.commandExecuted,
@@ -222,19 +314,31 @@ class QueryExecution(
}
// The plan that has been normalized by custom rules, so that it's more likely to hit cache.
- def normalized: LogicalPlan = lazyNormalized.get
+ def normalized: LogicalPlan = withAbortTransactionOnFailure {
+ lazyNormalized.get
+ }
private val lazyWithCachedData = LazyTry {
sparkSession.withActive {
assertAnalyzed()
assertSupported()
- // clone the plan to avoid sharing the plan instance between different stages like analyzing,
- // optimizing and planning.
- sparkSession.sharedState.cacheManager.useCachedData(normalized.clone())
+
+ // During a transaction, skip cache substitution. This is to avoid replacing relations
+ // loaded by the transactional catalog with potentially stale relations cached before
+ // the transaction was active.
+ if (transactionOpt.isDefined) {
+ normalized
+ } else {
+ // Clone the plan to avoid sharing the plan instance between different stages like
+ // analyzing, optimizing and planning.
+ sparkSession.sharedState.cacheManager.useCachedData(normalized.clone())
+ }
}
}
- def withCachedData: LogicalPlan = lazyWithCachedData.get
+ def withCachedData: LogicalPlan = withAbortTransactionOnFailure {
+ lazyWithCachedData.get
+ }
def assertCommandExecuted(): Unit = commandExecuted
@@ -256,7 +360,9 @@ class QueryExecution(
}
}
- def optimizedPlan: LogicalPlan = lazyOptimizedPlan.get
+ def optimizedPlan: LogicalPlan = withAbortTransactionOnFailure {
+ lazyOptimizedPlan.get
+ }
def assertOptimized(): Unit = optimizedPlan
@@ -264,14 +370,17 @@ class QueryExecution(
// We need to materialize the optimizedPlan here because sparkPlan is also tracked under
// the planning phase
assertOptimized()
- executePhase(QueryPlanningTracker.PLANNING) {
+ val plan = executePhase(QueryPlanningTracker.PLANNING) {
// Clone the logical plan here, in case the planner rules change the states of the logical
// plan.
QueryExecution.createSparkPlan(planner, optimizedPlan.clone())
}
+ attachTransaction(plan)
}
- def sparkPlan: SparkPlan = lazySparkPlan.get
+ def sparkPlan: SparkPlan = withAbortTransactionOnFailure {
+ lazySparkPlan.get
+ }
def assertSparkPlanPrepared(): Unit = sparkPlan
@@ -292,7 +401,9 @@ class QueryExecution(
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
- def executedPlan: SparkPlan = lazyExecutedPlan.get
+ def executedPlan: SparkPlan = withAbortTransactionOnFailure {
+ lazyExecutedPlan.get
+ }
def assertExecutedPlanPrepared(): Unit = executedPlan
@@ -310,7 +421,9 @@ class QueryExecution(
* Given QueryExecution is not a public class, end users are discouraged to use this: please
* use `Dataset.rdd` instead where conversion will be applied.
*/
- def toRdd: RDD[InternalRow] = lazyToRdd.get
+ def toRdd: RDD[InternalRow] = withAbortTransactionOnFailure {
+ lazyToRdd.get
+ }
private val observedMetricsLock = new Object
@@ -390,17 +503,24 @@ class QueryExecution(
}
}
+ /**
+ * Returns the QueryExecution to use when generating an explain string.
+ * Overridden by IncrementalExecution to reuse `this` so that the already-open transaction and
+ * cached executedPlan are not duplicated.
+ */
+ protected def queryExecutionForExplain: QueryExecution = if (logical.isStreaming) {
+ // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the
+ // output mode does not matter since there is no `Sink`.
+ new IncrementalExecution(
+ sparkSession, logical, OutputMode.Append(), "",
+ UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0),
+ WatermarkPropagator.noop(), false, mode = this.mode)
+ } else {
+ this
+ }
+
private def explainString(mode: ExplainMode, maxFields: Int, append: String => Unit): Unit = {
- val queryExecution = if (logical.isStreaming) {
- // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the
- // output mode does not matter since there is no `Sink`.
- new IncrementalExecution(
- sparkSession, logical, OutputMode.Append(), "",
- UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0),
- WatermarkPropagator.noop(), false, mode = this.mode)
- } else {
- this
- }
+ val queryExecution = queryExecutionForExplain
mode match {
case SimpleMode =>
@@ -535,6 +655,26 @@ class QueryExecution(
}
}
+ /**
+ * Runs the given block, aborting the active transaction if an exception is thrown.
+ * If no transaction is active, the block is executed as-is.
+ */
+ private def withAbortTransactionOnFailure[T](block: => T): T = transactionOpt match {
+ case Some(transaction) =>
+ try block
+ catch { case e: Throwable => TransactionUtils.abort(transaction); throw e }
+ case None => block
+ }
+
+
+ /** Attaches a transaction to the given SparkPlan to the transactional execution nodes. */
+ private def attachTransaction(plan: SparkPlan): SparkPlan = transactionOpt match {
+ case Some(txn) => plan.transformDown {
+ case w: TransactionalExec => w.withTransaction(Some(txn))
+ }
+ case None => plan
+ }
+
/** A special namespace for commands that can be used to debug query execution. */
// scalastyle:off
object debug {
@@ -819,14 +959,16 @@ object QueryExecution {
name: String,
refreshPhaseEnabled: Boolean = true,
mode: CommandExecutionMode.Value = CommandExecutionMode.SKIP,
- shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None)
+ shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None,
+ analyzerOpt: Option[Analyzer] = None)
: (QueryExecution, Array[InternalRow]) = {
val qe = new QueryExecution(
sparkSession,
command,
mode = mode,
shuffleCleanupModeOpt = shuffleCleanupModeOpt,
- refreshPhaseEnabled = refreshPhaseEnabled)
+ refreshPhaseEnabled = refreshPhaseEnabled,
+ analyzerOpt = analyzerOpt)
val result = QueryExecution.withInternalError(s"Executed $name failed.") {
SQLExecution.withNewExecutionId(qe, Some(name)) {
qe.executedPlan.executeCollect()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala
index 8d5ee6038e80..c6b1bae89b15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala
@@ -19,16 +19,23 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.transactions.TransactionUtils
import org.apache.spark.sql.connector.catalog.SupportsDeleteV2
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
import org.apache.spark.sql.connector.expressions.filter.Predicate
case class DeleteFromTableExec(
table: SupportsDeleteV2,
condition: Array[Predicate],
- refreshCache: () => Unit) extends LeafV2CommandExec {
+ refreshCache: () => Unit,
+ transaction: Option[Transaction] = None) extends LeafV2CommandExec with TransactionalExec {
+
+ override def withTransaction(txn: Option[Transaction]): DeleteFromTableExec =
+ copy(transaction = txn)
override protected def run(): Seq[InternalRow] = {
table.deleteWhere(condition)
+ transaction.foreach(TransactionUtils.commit)
refreshCache()
Seq.empty
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala
index 151329de9e6f..60965453d9ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala
@@ -119,12 +119,7 @@ private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging {
}
private def validateTableIdentity(currentTable: Table, relation: DataSourceV2Relation): Unit = {
- if (relation.table.id != null && relation.table.id != currentTable.id) {
- throw QueryCompilationErrors.tableIdChangedAfterAnalysis(
- relation.name,
- capturedTableId = relation.table.id,
- currentTableId = currentTable.id)
- }
+ V2TableUtil.validateTableId(relation.name, relation.table.id, currentTable)
}
private def validateDataColumns(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
index d8e871bcf482..581027f95193 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
@@ -91,17 +91,18 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
o.copy(write = Some(write), query = newQuery)
case WriteToMicroBatchDataSource(
- relationOpt, table, query, queryId, options, outputMode, Some(batchId)) =>
- val writeOptions = mergeOptions(
- options,
- relationOpt.map(r => r.options.asCaseSensitiveMap.asScala.toMap).getOrElse(Map.empty))
- val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId = queryId)
- val write = buildWriteForMicroBatch(table, writeBuilder, outputMode)
+ relation, query, queryId, options, outputMode, Some(batchId)) =>
+ val v2Relation = relation.asInstanceOf[DataSourceV2Relation]
+ val writeOptions = mergeOptions(options, v2Relation.options.asCaseSensitiveMap.asScala.toMap)
+ // Guaranteed to support writes since it is a strict requirement to construct
+ // WriteToMicroBatchDataSource.
+ val writeTable = v2Relation.table.asInstanceOf[SupportsWrite]
+ val writeBuilder = newWriteBuilder(writeTable, writeOptions, query.schema, queryId = queryId)
+ val write = buildWriteForMicroBatch(writeTable, writeBuilder, outputMode)
val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming)
val customMetrics = write.supportedCustomMetrics.toImmutableArraySeq
- val funCatalogOpt = relationOpt.flatMap(_.funCatalog)
- val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt)
- WriteToDataSourceV2(relationOpt, microBatchWrite, newQuery, customMetrics)
+ val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, v2Relation.funCatalog)
+ WriteToDataSourceV2(Some(v2Relation), microBatchWrite, newQuery, customMetrics)
case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, projections, _, None) =>
val rowSchema = projections.rowProjection.schema
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 2071024c5b7e..9e579ae779f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -26,9 +26,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, ProjectingInternalRow}
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, TableSpec, UnaryNode}
+import org.apache.spark.sql.catalyst.transactions.TransactionUtils
import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{COPY_OPERATION, DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege}
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeleteSummaryImpl, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperation, RowLevelOperationTable, UpdateSummaryImpl, Write, WriterCommitMessage, WriteSummary}
@@ -74,7 +76,12 @@ case class CreateTableAsSelectExec(
query: LogicalPlan,
tableSpec: TableSpec,
writeOptions: Map[String, String],
- ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec {
+ ifNotExists: Boolean,
+ transaction: Option[Transaction] = None)
+ extends V2CreateTableAsSelectBaseExec with TransactionalExec {
+
+ override def withTransaction(txn: Option[Transaction]): CreateTableAsSelectExec =
+ copy(transaction = txn)
val properties = CatalogV2Util.convertTableProperties(tableSpec)
@@ -92,7 +99,9 @@ case class CreateTableAsSelectExec(
.build()
val table = Option(catalog.createTable(ident, tableInfo))
.getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
- writeToTable(catalog, table, writeOptions, ident, query, overwrite = false)
+ val result = writeToTable(catalog, table, writeOptions, ident, query, overwrite = false)
+ transaction.foreach(TransactionUtils.commit)
+ result
}
}
@@ -112,7 +121,8 @@ case class AtomicCreateTableAsSelectExec(
query: LogicalPlan,
tableSpec: TableSpec,
writeOptions: Map[String, String],
- ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec {
+ ifNotExists: Boolean)
+ extends V2CreateTableAsSelectBaseExec {
val properties = CatalogV2Util.convertTableProperties(tableSpec)
@@ -155,8 +165,12 @@ case class ReplaceTableAsSelectExec(
tableSpec: TableSpec,
writeOptions: Map[String, String],
orCreate: Boolean,
- invalidateCache: (TableCatalog, Identifier) => Unit)
- extends V2CreateTableAsSelectBaseExec {
+ invalidateCache: (TableCatalog, Identifier) => Unit,
+ transaction: Option[Transaction] = None)
+ extends V2CreateTableAsSelectBaseExec with TransactionalExec {
+
+ override def withTransaction(txn: Option[Transaction]): ReplaceTableAsSelectExec =
+ copy(transaction = txn)
val properties = CatalogV2Util.convertTableProperties(tableSpec)
@@ -192,9 +206,11 @@ case class ReplaceTableAsSelectExec(
.build()
val table = Option(catalog.createTable(ident, tableInfo))
.getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
- writeToTable(
+ val result = writeToTable(
catalog, table, writeOptions, ident, refreshedQuery,
overwrite = true, refreshPhaseEnabled = false)
+ transaction.foreach(TransactionUtils.commit)
+ result
}
}
@@ -273,7 +289,9 @@ case class AppendDataExec(
query: SparkPlan,
refreshCache: () => Unit,
write: Write,
- tableName: String) extends V2ExistingTableWriteExec {
+ tableName: String,
+ transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec {
+ override def withTransaction(txn: Option[Transaction]): AppendDataExec = copy(transaction = txn)
override protected def withNewChildInternal(newChild: SparkPlan): AppendDataExec =
copy(query = newChild)
}
@@ -292,7 +310,10 @@ case class OverwriteByExpressionExec(
query: SparkPlan,
refreshCache: () => Unit,
write: Write,
- tableName: String) extends V2ExistingTableWriteExec {
+ tableName: String,
+ transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec {
+ override def withTransaction(txn: Option[Transaction]): OverwriteByExpressionExec =
+ copy(transaction = txn)
override protected def withNewChildInternal(newChild: SparkPlan): OverwriteByExpressionExec =
copy(query = newChild)
}
@@ -310,7 +331,10 @@ case class OverwritePartitionsDynamicExec(
query: SparkPlan,
refreshCache: () => Unit,
write: Write,
- tableName: String) extends V2ExistingTableWriteExec {
+ tableName: String,
+ transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec {
+ override def withTransaction(txn: Option[Transaction]): OverwritePartitionsDynamicExec =
+ copy(transaction = txn)
override protected def withNewChildInternal(newChild: SparkPlan): OverwritePartitionsDynamicExec =
copy(query = newChild)
}
@@ -324,7 +348,8 @@ case class ReplaceDataExec(
projections: ReplaceDataProjections,
write: Write,
rowLevelCommand: RowLevelOperation.Command,
- tableName: String) extends RowLevelWriteExec {
+ tableName: String,
+ transaction: Option[Transaction] = None) extends RowLevelWriteExec {
override def writingTask: WritingSparkTask[_] = {
projections.metadataProjection match {
@@ -335,6 +360,7 @@ case class ReplaceDataExec(
}
}
+ override def withTransaction(txn: Option[Transaction]): ReplaceDataExec = copy(transaction = txn)
override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = {
copy(query = newChild)
}
@@ -369,7 +395,8 @@ case class WriteDeltaExec(
projections: WriteDeltaProjections,
write: DeltaWrite,
rowLevelCommand: RowLevelOperation.Command,
- tableName: String) extends RowLevelWriteExec {
+ tableName: String,
+ transaction: Option[Transaction] = None) extends RowLevelWriteExec {
override lazy val writingTask: WritingSparkTask[_] = {
if (projections.metadataProjection.isDefined) {
@@ -379,6 +406,7 @@ case class WriteDeltaExec(
}
}
+ override def withTransaction(txn: Option[Transaction]): WriteDeltaExec = copy(transaction = txn)
override protected def withNewChildInternal(newChild: SparkPlan): WriteDeltaExec = {
copy(query = newChild)
}
@@ -388,7 +416,11 @@ case class WriteToDataSourceV2Exec(
batchWrite: BatchWrite,
refreshCache: () => Unit,
query: SparkPlan,
- writeMetrics: Seq[CustomMetric]) extends V2TableWriteExec {
+ writeMetrics: Seq[CustomMetric],
+ transaction: Option[Transaction] = None) extends V2TableWriteExec with TransactionalExec {
+
+ override def withTransaction(txn: Option[Transaction]): WriteToDataSourceV2Exec =
+ copy(transaction = txn)
override def stringArgs: Iterator[Any] = Iterator(batchWrite, query)
@@ -398,6 +430,7 @@ case class WriteToDataSourceV2Exec(
override protected def run(): Seq[InternalRow] = {
val writtenRows = writeWithV2(batchWrite)
+ transaction.foreach(TransactionUtils.commit)
refreshCache()
writtenRows
}
@@ -406,7 +439,16 @@ case class WriteToDataSourceV2Exec(
copy(query = newChild)
}
-trait V2ExistingTableWriteExec extends V2TableWriteExec {
+/**
+ * Trait for physical plan nodes that write to an existing table as part of a transaction.
+ * The [[transaction]] is injected post-planning by [[QueryExecution]].
+ */
+trait TransactionalExec extends SparkPlan {
+ def transaction: Option[Transaction]
+ def withTransaction(txn: Option[Transaction]): SparkPlan
+}
+
+trait V2ExistingTableWriteExec extends V2TableWriteExec with TransactionalExec {
def refreshCache: () => Unit
def write: Write
def tableName: String
@@ -426,6 +468,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec {
} finally {
postDriverMetrics()
}
+ transaction.foreach(TransactionUtils.commit)
refreshCache()
writtenRows
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
index 1587fd4786a3..9fc72241e83b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
@@ -143,6 +143,9 @@ class IncrementalExecution(
}
}
+ // Use `this` for explain so the already-open transaction and executedPlan are reused.
+ override protected def queryExecutionForExplain: QueryExecution = this
+
private val allowMultipleStatefulOperators: Boolean =
sparkSession.sessionState.conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
index 973af04e0430..e6d0666aca25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
@@ -29,6 +29,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException}
import org.apache.spark.internal.LogKeys
import org.apache.spark.internal.LogKeys._
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Deduplicate, DeduplicateWithinWatermark, Distinct, FlatMapGroupsInPandasWithState, FlatMapGroupsWithState, GlobalLimit, Join, LeafNode, LocalRelation, LogicalPlan, Project, StreamSourceAwareLogicalPlan, TransformWithState, TransformWithStateInPySpark}
@@ -37,7 +38,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.classic.{Dataset, SparkSession}
import org.apache.spark.sql.classic.ClassicConversions.castToImpl
-import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability}
+import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability, TransactionalCatalogPlugin}
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsRealTimeMode, SupportsTriggerAvailableNow}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
@@ -46,7 +47,6 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, Real
import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, Offset, OneTimeTrigger, ProcessingTimeTrigger, RealTimeModeAllowlist, RealTimeTrigger, Sink, Source, StreamingQueryPlanTraverseHelper}
import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, CommitMetadata, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataV2}
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, StateStoreWriter}
-import org.apache.spark.sql.execution.streaming.runtime.AcceptsLatestSeenOffsetHandler
import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE}
import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchSink, WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1}
import org.apache.spark.sql.execution.streaming.state.{OfflineStateRepartitionUtils, StateSchemaBroadcast, StateStoreErrors}
@@ -346,15 +346,20 @@ class MicroBatchExecution(
)
}
- // TODO (SPARK-27484): we should add the writing node before the plan is analyzed.
sink match {
case s: SupportsWrite =>
- val relationOpt = plan.catalogAndIdent.map {
- case (catalog, ident) => DataSourceV2Relation.create(s, Some(catalog), Some(ident))
+ val relation = plan.catalogAndIdent match {
+ // When the catalog is transactional, instead of eagerly creating the relation, we
+ // delegate resolution to ResolveRelations. This allows to resolve the relation against
+ // a transactional catalo which keeps track of all tables loaded within the transaction.
+ case Some((catalog: TransactionalCatalogPlugin, ident)) =>
+ UnresolvedRelation(catalog.name +: ident.namespace().toSeq :+ ident.name())
+ case Some((catalog, ident)) =>
+ DataSourceV2Relation.create(s, Some(catalog), Some(ident))
+ case None => DataSourceV2Relation.create(s, None, None)
}
WriteToMicroBatchDataSource(
- relationOpt,
- table = s,
+ relation,
query = _logicalPlan,
queryId = id.toString,
extraOptions,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala
index 0a33093dcbce..7aa7a31bb085 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala
@@ -17,10 +17,9 @@
package org.apache.spark.sql.execution.streaming.sources
+import org.apache.spark.sql.catalyst.analysis.NamedRelation
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
-import org.apache.spark.sql.connector.catalog.SupportsWrite
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, StreamingV2WriteCommand, UnaryNode}
import org.apache.spark.sql.streaming.OutputMode
/**
@@ -29,19 +28,36 @@ import org.apache.spark.sql.streaming.OutputMode
* Note that this logical plan does not have a corresponding physical plan, as it will be converted
* to [[org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 WriteToDataSourceV2]]
* with [[MicroBatchWrite]] before execution.
+ *
+ * [[relation]] starts as [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation]] when the
+ * sink has a catalog+identifier (transactional catalogs), or as a resolved
+ * [[org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation]] for non-transactional
+ * catalog-backed sinks and format-based sinks.
+ * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveRelations]]
+ * resolves it to [[org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation]] during
+ * each micro-batch analysis, going through the transaction-aware catalog when a transaction is
+ * active.
*/
case class WriteToMicroBatchDataSource(
- relation: Option[DataSourceV2Relation],
- table: SupportsWrite,
+ relation: NamedRelation,
query: LogicalPlan,
queryId: String,
writeOptions: Map[String, String],
outputMode: OutputMode,
batchId: Option[Long] = None)
- extends UnaryNode {
+ extends UnaryNode with StreamingV2WriteCommand {
+
override def child: LogicalPlan = query
override def output: Seq[Attribute] = Nil
+ override def simpleString(maxFields: Int): String =
+ s"WriteToMicroBatchDataSource ${relation.name}"
+
+ override def table: NamedRelation = relation
+
+ override def withNewTable(newTable: NamedRelation): WriteToMicroBatchDataSource =
+ copy(relation = newTable)
+
def withNewBatchId(batchId: Long): WriteToMicroBatchDataSource = {
copy(batchId = Some(batchId))
}
diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index c1fc7234d7c1..0354e545aa90 100644
--- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -18,6 +18,8 @@
org.apache.spark.sql.sources.FakeSourceOne
org.apache.spark.sql.sources.FakeSourceTwo
org.apache.spark.sql.sources.FakeSourceThree
+org.apache.spark.sql.connector.FakePathBasedSource
+org.apache.spark.sql.connector.FakePathBasedSourceWithSessionConfig
org.apache.spark.sql.sources.FakeSourceFour
org.apache.fakesource.FakeExternalSourceOne
org.apache.fakesource.FakeExternalSourceTwo
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala
new file mode 100644
index 000000000000..aef9c65550fc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala
@@ -0,0 +1,500 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.connector.catalog.{Aborted, Committed}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
+import org.apache.spark.sql.sources
+
+class AppendDataTransactionSuite extends RowLevelOperationSuiteBase {
+
+ test("writeTo append with transactional checks") {
+ // create table with initial data
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ // create a source on top of itself that will be fully resolved and analyzed
+ val sourceDF = spark.table(tableNameAsString)
+ .where("pk == 1")
+ .select(col("pk") + 10 as "pk", col("salary"), col("dep"))
+ sourceDF.queryExecution.assertAnalyzed()
+
+ // append data using the DataFrame API
+ val (txn, txnTables) = executeTransaction {
+ sourceDF.writeTo(tableNameAsString).append()
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 1)
+ assert(table.version() === "2")
+
+ // check the source scan was tracked via the transaction catalog
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size === 1)
+ assert(targetTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("pk", 1) => true
+ case _ => false
+ })
+
+ // check data was appended correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software"),
+ Row(11, 100, "hr"))) // appended
+ }
+
+ test("SQL INSERT INTO with transactional checks") {
+ // create table with initial data
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ // SQL INSERT INTO using VALUES
+ val (txn, txnTables) = executeTransaction {
+ sql(s"INSERT INTO $tableNameAsString VALUES (3, 300, 'hr'), (4, 400, 'finance')")
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(table.version() === "2")
+
+ // VALUES literal - No catalog tables were scanned
+ assert(txnTables.isEmpty)
+
+ // check data was inserted correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software"),
+ Row(3, 300, "hr"),
+ Row(4, 400, "finance")))
+ }
+
+ for (isDynamic <- Seq(false, true))
+ test(s"SQL INSERT OVERWRITE with transactional checks - isDynamic: $isDynamic") {
+ // create table with initial data; table is partitioned by dep
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ val insertOverwrite = if (isDynamic) {
+ // OverwritePartitionsDynamic
+ s"""INSERT OVERWRITE $tableNameAsString
+ |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin
+ } else {
+ // OverwriteByExpression
+ s"""INSERT OVERWRITE $tableNameAsString
+ |PARTITION (dep = 'hr')
+ |SELECT pk + 10, salary FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin
+ }
+
+ val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC
+ val (txn, txnTables) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) {
+ executeTransaction { sql(insertOverwrite) }
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(table.version() === "2")
+
+ // the SELECT reads from the target table once with a dep='hr' filter
+ assert(txnTables.size == 1)
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size == 1)
+ assert(targetTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(11, 100, "hr"), // overwritten
+ Row(13, 300, "hr"))) // overwritten
+ }
+
+ test("writeTo overwrite with transactional checks") {
+ // create table with initial data; table is partitioned by dep
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // overwrite using a condition that covers the hr partition -> OverwriteByExpression
+ val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))).
+ toDF("pk", "salary", "dep")
+
+ val (txn, txnTables) = executeTransaction {
+ sourceDF.writeTo(tableNameAsString).overwrite(col("dep") === "hr")
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(table.version() === "2")
+
+ // literal DataFrame source - no catalog tables were scanned
+ assert(txnTables.isEmpty)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(11, 999, "hr"), // overwrote hr partition
+ Row(12, 888, "hr"))) // overwrote hr partition
+ }
+
+ test("writeTo overwritePartitions with transactional checks") {
+ // create table with initial data; table is partitioned by dep
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // overwrite partitions dynamically -> OverwritePartitionsDynamic
+ val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))).
+ toDF("pk", "salary", "dep")
+
+ val (txn, txnTables) = executeTransaction {
+ sourceDF.writeTo(tableNameAsString).overwritePartitions()
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(table.version() === "2")
+
+ // literal DataFrame source - no catalog tables were scanned
+ assert(txnTables.isEmpty)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(11, 999, "hr"), // overwrote hr partition
+ Row(12, 888, "hr"))) // overwrote hr partition
+ }
+
+ test("SQL INSERT INTO SELECT with transactional checks") {
+ // create table with initial data
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // SQL INSERT INTO using SELECT from the same table (self-insert)
+ val (txn, txnTables) = executeTransaction {
+ sql(s"""INSERT INTO $tableNameAsString
+ |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(table.version() === "2")
+
+ // the SELECT reads from the target table once with a dep='hr' filter
+ assert(txnTables.size === 1)
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size === 1)
+ assert(targetTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ // check data was inserted correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software"),
+ Row(3, 300, "hr"),
+ Row(11, 100, "hr"), // inserted from pk=1
+ Row(13, 300, "hr"))) // inserted from pk=3
+ }
+
+ test("SQL INSERT INTO SELECT with subquery on source table and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')")
+
+ // INSERT using a subquery that reads from the target to filter source rows
+ // both tables are scanned through the transaction catalog
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""INSERT INTO $tableNameAsString
+ |SELECT pk + 10, salary, dep FROM $sourceNameAsString
+ |WHERE pk IN (SELECT pk FROM $tableNameAsString WHERE dep = 'hr')
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 2)
+ assert(table.version() === "2")
+
+ // target was scanned via the transaction catalog (IN subquery) once with dep='hr' filter
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size === 1)
+ assert(targetTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ // source was scanned via the transaction catalog exactly once (no filter)
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ assert(sourceTxnTable.scanEvents.size === 1)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software"),
+ Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row
+ }
+
+ test("SQL INSERT INTO SELECT with CTE and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')")
+
+ // CTE reads from target; INSERT selects from source filtered by the CTE result
+ // both tables are scanned through the transaction catalog
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""WITH hr_pks AS (SELECT pk FROM $tableNameAsString WHERE dep = 'hr')
+ |INSERT INTO $tableNameAsString
+ |SELECT pk + 10, salary, dep FROM $sourceNameAsString
+ |WHERE pk IN (SELECT pk FROM hr_pks)
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 2)
+ assert(table.version() === "2")
+
+ // target was scanned via the transaction catalog (CTE) once with dep='hr' filter
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size === 1)
+ assert(targetTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ // source was scanned via the transaction catalog exactly once (no filter)
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ assert(sourceTxnTable.scanEvents.size === 1)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software"),
+ Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row via CTE
+ }
+
+ test("SQL INSERT with analysis failure and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val e = intercept[AnalysisException] {
+ sql(s"INSERT INTO $tableNameAsString SELECT nonexistent_col FROM $tableNameAsString")
+ }
+
+ assert(e.getMessage.contains("nonexistent_col"))
+ assert(catalog.lastTransaction.currentState === Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ for (isDynamic <- Seq(false, true))
+ test(s"SQL INSERT OVERWRITE with analysis failure and transactional checks" +
+ s"isDynamic: $isDynamic") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val insertOverwrite = if (isDynamic) {
+ s"""INSERT OVERWRITE $tableNameAsString
+ |SELECT nonexistent_col, salary, dep FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin
+ } else {
+ s"""INSERT OVERWRITE $tableNameAsString
+ |PARTITION (dep = 'hr')
+ |SELECT nonexistent_col FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin
+ }
+
+ val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC
+ val e = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) {
+ intercept[AnalysisException] { sql(insertOverwrite) }
+ }
+
+ assert(e.getMessage.contains("nonexistent_col"))
+ assert(catalog.lastTransaction.currentState === Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("EXPLAIN INSERT SQL with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"EXPLAIN INSERT INTO $tableNameAsString VALUES (3, 300, 'hr')")
+
+ // EXPLAIN should not start a transaction
+ assert(catalog.transaction === null)
+
+ // INSERT was not executed; data is unchanged
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software")))
+ }
+
+ test("SQL INSERT WITH SCHEMA EVOLUTION adds new column with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(
+ s"""CREATE TABLE $sourceNameAsString
+ |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin)
+ sql(s"INSERT INTO $sourceNameAsString VALUES (3, 300, 'hr', true), (4, 400, 'software', false)")
+
+ val (txn, txnTables) = executeTransaction {
+ sql(s"INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString SELECT * FROM $sourceNameAsString")
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+
+ // the new column must be visible in the committed delegate's schema
+ assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep", "active"))
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr", null), // pre-existing rows: active is null
+ Row(2, 200, "software", null),
+ Row(3, 300, "hr", true), // inserted with active
+ Row(4, 400, "software", false)))
+ }
+
+ for (isDynamic <- Seq(false, true))
+ test(s"SQL INSERT OVERWRITE WITH SCHEMA EVOLUTION adds new column with transactional checks " +
+ s"isDynamic: $isDynamic") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(
+ s"""CREATE TABLE $sourceNameAsString
+ |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin)
+ sql(s"INSERT INTO $sourceNameAsString VALUES (11, 999, 'hr', true), (12, 888, 'hr', false)")
+
+ val insertOverwrite = if (isDynamic) {
+ s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString
+ |SELECT * FROM $sourceNameAsString
+ |""".stripMargin
+ } else {
+ s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString
+ |PARTITION (dep = 'hr')
+ |SELECT pk, salary, active FROM $sourceNameAsString
+ |""".stripMargin
+ }
+
+ val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC
+ val (txn, _) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) {
+ executeTransaction { sql(insertOverwrite) }
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+
+ // the new column must be visible in the committed delegate's schema
+ assert(table.schema.fieldNames.contains("active"))
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software", null), // unchanged (different partition)
+ Row(11, 999, "hr", true), // overwrote hr partition
+ Row(12, 888, "hr", false)))
+ }
+
+ test("SQL INSERT WITH SCHEMA EVOLUTION analysis failure aborts transaction") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(
+ s"""CREATE TABLE $sourceNameAsString
+ |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin)
+
+ val e = intercept[AnalysisException] {
+ sql(
+ s"""INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString
+ |SELECT nonexistent_col FROM $sourceNameAsString
+ |""".stripMargin)
+ }
+
+ assert(e.getMessage.contains("nonexistent_col"))
+ assert(catalog.lastTransaction.currentState === Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ // schema must be unchanged after the aborted transaction
+ assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep"))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala
new file mode 100644
index 000000000000..8acdd8242ef1
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTable}
+import org.apache.spark.sql.sources
+
+class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase {
+
+ private val newTableNameAsString = "cat.ns1.new_table"
+
+ private def newTable: InMemoryRowLevelOperationTable =
+ catalog.loadTable(Identifier.of(Array("ns1"), "new_table"))
+ .asInstanceOf[InMemoryRowLevelOperationTable]
+
+ test("CTAS with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val (txn, txnTables) = executeTransaction {
+ sql(s"""CREATE TABLE $newTableNameAsString
+ |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 1)
+ assert(table.version() === "1") // source table: read-only, version unchanged
+ assert(newTable.version() === "1") // target table: newly created and written
+
+ // the source table was scanned once through the transaction catalog with a dep='hr' filter
+ val sourceTxnTable = txnTables(tableNameAsString)
+ assert(sourceTxnTable.scanEvents.size === 1)
+ assert(sourceTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ checkAnswer(
+ sql(s"SELECT * FROM $newTableNameAsString"),
+ Seq(Row(1, 100, "hr")))
+ }
+
+ test("CTAS from literal source with transactional checks") {
+ // no source catalog table involved — the query is a pure literal SELECT
+ val (txn, txnTables) = executeTransaction {
+ sql(s"CREATE TABLE $newTableNameAsString AS SELECT 1 AS pk, 100 AS salary, 'hr' AS dep")
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+
+ // literal SELECT - no catalog tables were scanned
+ assert(txnTables.isEmpty)
+ assert(newTable.version() === "1") // target table: newly created and written
+
+ checkAnswer(
+ sql(s"SELECT * FROM $newTableNameAsString"),
+ Seq(Row(1, 100, "hr")))
+ }
+
+ test("CTAS with analysis failure and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val e = intercept[AnalysisException] {
+ sql(s"CREATE TABLE $newTableNameAsString AS SELECT nonexistent_col FROM $tableNameAsString")
+ }
+
+ assert(e.getMessage.contains("nonexistent_col"))
+ assert(catalog.lastTransaction.currentState === Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("RTAS with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // pre-create the target so REPLACE TABLE (not CREATE OR REPLACE) is valid
+ sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+
+ val (txn, txnTables) = executeTransaction {
+ sql(s"""REPLACE TABLE $newTableNameAsString
+ |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 1)
+ assert(table.version() === "1") // source table: read-only, version unchanged
+ assert(newTable.version() === "1") // target table: replaced and written
+
+ // the source table was scanned once through the transaction catalog with a dep='hr' filter
+ val sourceTxnTable = txnTables(tableNameAsString)
+ assert(sourceTxnTable.scanEvents.size === 1)
+ assert(sourceTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ checkAnswer(
+ sql(s"SELECT * FROM $newTableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(3, 300, "hr")))
+ }
+
+ test("RTAS self-reference with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // source and target are the same table: reads the old snapshot via TxnTable,
+ // replaces the table with a filtered version
+ val (txn, txnTables) = executeTransaction {
+ sql(s"""CREATE OR REPLACE TABLE $tableNameAsString
+ |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size === 1)
+ assert(table.version() === "1") // source/target table: replaced in place, version reset to 1
+
+ // the source/target table was scanned once with a dep='hr' filter
+ val sourceTxnTable = txnTables(tableNameAsString)
+ assert(sourceTxnTable.scanEvents.size === 1)
+ assert(sourceTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(3, 300, "hr")))
+ }
+
+ test("RTAS with analysis failure and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val e = intercept[AnalysisException] {
+ sql(s"""CREATE OR REPLACE TABLE $tableNameAsString
+ |AS SELECT nonexistent_col FROM $tableNameAsString
+ |""".stripMargin)
+ }
+
+ assert(e.getMessage.contains("nonexistent_col"))
+ assert(catalog.lastTransaction.currentState === Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("simple CREATE TABLE and DROP TABLE do not create transactions") {
+ sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ assert(catalog.transaction === null)
+ assert(catalog.lastTransaction === null)
+
+ sql(s"DROP TABLE $newTableNameAsString")
+ assert(catalog.transaction === null)
+ assert(catalog.lastTransaction === null)
+ }
+
+ test("EXPLAIN CTAS with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"""EXPLAIN CREATE TABLE $newTableNameAsString
+ |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+
+ // EXPLAIN should not start a transaction
+ assert(catalog.transaction === null)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala
index fbcfdfb20c6e..803dd35513f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala
@@ -109,7 +109,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
collected = df.queryExecution.executedPlan.collect {
case CommandResultExec(
- _, AppendDataExec(_, _, write, _),
+ _, AppendDataExec(_, _, write, _, _),
_) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append]
assert(append.info.options.get("write.split-size") === "10")
@@ -141,7 +141,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
assert (collected.size == 1)
collected = qe.executedPlan.collect {
- case AppendDataExec(_, _, write, _) =>
+ case AppendDataExec(_, _, write, _, _) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append]
assert(append.info.options.get("write.split-size") === "10")
}
@@ -168,7 +168,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
assert (collected.size == 1)
collected = qe.executedPlan.collect {
- case AppendDataExec(_, _, write, _) =>
+ case AppendDataExec(_, _, write, _, _) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append]
assert(append.info.options.get("write.split-size") === "10")
}
@@ -194,7 +194,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
collected = df.queryExecution.executedPlan.collect {
case CommandResultExec(
- _, OverwriteByExpressionExec(_, _, write, _),
+ _, OverwriteByExpressionExec(_, _, write, _, _),
_) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
assert(append.info.options.get("write.split-size") === "10")
@@ -227,7 +227,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
assert (collected.size == 1)
collected = qe.executedPlan.collect {
- case OverwritePartitionsDynamicExec(_, _, write, _) =>
+ case OverwritePartitionsDynamicExec(_, _, write, _, _) =>
val dynOverwrite = write.toBatch.asInstanceOf[InMemoryBaseTable#DynamicOverwrite]
assert(dynOverwrite.info.options.get("write.split-size") === "10")
}
@@ -254,7 +254,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
collected = df.queryExecution.executedPlan.collect {
case CommandResultExec(
- _, OverwriteByExpressionExec(_, _, write, _),
+ _, OverwriteByExpressionExec(_, _, write, _, _),
_) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
assert(append.info.options.get("write.split-size") === "10")
@@ -287,7 +287,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
assert (collected.size == 1)
collected = qe.executedPlan.collect {
- case OverwriteByExpressionExec(_, _, write, _) =>
+ case OverwriteByExpressionExec(_, _, write, _, _) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
assert(append.info.options.get("write.split-size") === "10")
}
@@ -317,7 +317,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
assert (collected.size == 1)
collected = qe.executedPlan.collect {
- case OverwriteByExpressionExec(_, _, write, _) =>
+ case OverwriteByExpressionExec(_, _, write, _, _) =>
val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
assert(append.info.options.get("write.split-size") === "10")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
index adc88f5a54a0..f8d81ee08691 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
@@ -17,12 +17,14 @@
package org.apache.spark.sql.connector
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.expressions.CheckInvariant
import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.apache.spark.sql.connector.catalog.{Aborted, Committed}
import org.apache.spark.sql.connector.catalog.InMemoryTable
import org.apache.spark.sql.connector.write.DeleteSummary
import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec, WriteDeltaExec}
+import org.apache.spark.sql.sources
abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase {
@@ -773,6 +775,190 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase {
}
}
+ test("delete with analysis failure and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val exception = intercept[AnalysisException] {
+ sql(s"DELETE FROM $tableNameAsString WHERE invalid_column = 1")
+ }
+
+ assert(exception.getMessage.contains("invalid_column"))
+ assert(catalog.lastTransaction.currentState == Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("delete with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // simple predicate delete: goes through SupportsDelete.deleteWhere (no Spark-side scan)
+ val (txn, _) = executeTransaction {
+ sql(s"DELETE FROM $tableNameAsString WHERE dep = 'hr'")
+ }
+
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(table.version() == "2")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(2, 200, "software")))
+ }
+
+ test("delete with subquery on source table and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""DELETE FROM $tableNameAsString
+ |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr')
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaDelete) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numSubquerySourceScans == expectedNumSourceScans)
+
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaDelete) 1 else 2
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result)
+ }
+
+ test("delete with CTE and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""WITH cte AS (
+ | SELECT pk FROM $sourceNameAsString WHERE dep = 'hr'
+ |)
+ |DELETE FROM $tableNameAsString
+ |WHERE pk IN (SELECT pk FROM cte)
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaDelete) 1 else 2
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaDelete) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numCteSourceScans == expectedNumSourceScans)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (pk 3 not in source)
+ }
+
+ test("delete using view with transactional checks") {
+ withView("temp_view") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ sql(
+ s"""CREATE VIEW temp_view AS
+ |SELECT pk FROM $sourceNameAsString WHERE dep = 'hr'
+ |""".stripMargin)
+
+ val (txn, txnTables) = executeTransaction {
+ sql(s"DELETE FROM $tableNameAsString WHERE pk IN (SELECT pk FROM temp_view)")
+ }
+
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaDelete) 1 else 2
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaDelete) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (pk 3 not in source)
+ }
+ }
+
+ test("EXPLAIN DELETE SQL with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"EXPLAIN DELETE FROM $tableNameAsString WHERE dep = 'hr'")
+
+ // EXPLAIN should not start a new transaction
+ assert(catalog.transaction === null)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software")))
+ }
+
private def executeDeleteWithFilters(query: String): Unit = {
val executedPlan = executeAndKeepPlan {
sql(query)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
index e1c574ec7ba6..c5d82ec3c1d6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
@@ -17,10 +17,10 @@
package org.apache.spark.sql.connector
+import org.apache.spark.sql.{sources, Column, Row}
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.Row
import org.apache.spark.sql.classic.MergeIntoWriter
-import org.apache.spark.sql.connector.catalog.Column
+import org.apache.spark.sql.connector.catalog.{Aborted, Committed}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.TableInfo
import org.apache.spark.sql.functions._
@@ -31,6 +31,134 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase {
import testImplicits._
+ private def targetTableCol(colName: String): Column = {
+ col(tableNameAsString + "." + colName)
+ }
+
+ test("self merge with transactional checks") {
+ // create table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create a source on top of itself that will be fully resolved and analyzed
+ val sourceDF = spark.table(tableNameAsString)
+ .where("salary == 100")
+ .as("source")
+ sourceDF.queryExecution.assertAnalyzed()
+
+ // merge into table using the source on top of itself
+ val (txn, txnTables) = executeTransaction {
+ sourceDF
+ .mergeInto(
+ tableNameAsString,
+ $"source.pk" === targetTableCol("pk") && targetTableCol("dep") === "hr")
+ .whenMatched()
+ .update(Map("salary" -> targetTableCol("salary").plus(1)))
+ .whenNotMatched()
+ .insertAll()
+ .merge()
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 1)
+ assert(table.version() == "2")
+
+ // check all table scans
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size == 4)
+
+ // check table scans as MERGE target
+ val numTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numTargetScans == 2)
+
+ // check table scans as MERGE source
+ val numSourceScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("salary", 100) => true
+ case _ => false
+ }
+ assert(numSourceScans == 2)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "hr"), // update
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged
+ }
+
+ for (alterClause <- Seq(
+ "ADD COLUMN new_col INT",
+ "DROP COLUMN salary",
+ "ALTER COLUMN salary TYPE BIGINT",
+ "ALTER COLUMN pk DROP NOT NULL"))
+ test(s"self merge fails when source schema changes after analysis - DDL: $alterClause" ) {
+ withTable(tableNameAsString) {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val sourceDF = spark.table(tableNameAsString).where("salary == 100").as("source")
+ sourceDF.queryExecution.assertAnalyzed()
+
+ sql(s"ALTER TABLE $tableNameAsString $alterClause")
+
+ val e = intercept[AnalysisException] {
+ sourceDF
+ .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk"))
+ .whenMatched()
+ .update(Map("salary" -> targetTableCol("salary").plus(1)))
+ .merge()
+ }
+
+ assert(
+ e.getCondition == "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH",
+ alterClause)
+ assert(catalog.lastTransaction.currentState == Aborted, alterClause)
+ assert(catalog.lastTransaction.isClosed, alterClause)
+ }
+ }
+
+ test("self merge fails when source table is dropped and recreated after analysis") {
+ withTable(tableNameAsString) {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ val sourceDF = spark.table(tableNameAsString).where("salary == 100").as("source")
+ sourceDF.queryExecution.assertAnalyzed()
+
+ val originalId = catalog.loadTable(ident).id
+
+ sql(s"DROP TABLE $tableNameAsString")
+ sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ val newId = catalog.loadTable(ident).id
+ assert(originalId != newId)
+
+ val e = intercept[AnalysisException] {
+ sourceDF
+ .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk"))
+ .whenMatched()
+ .update(Map("salary" -> targetTableCol("salary").plus(1)))
+ .merge()
+ }
+
+ assert(e.getCondition == "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.TABLE_ID_MISMATCH")
+ assert(catalog.lastTransaction.currentState == Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+ }
+
test("merge into empty table with NOT MATCHED clause") {
withTempView("source") {
createTable("pk INT NOT NULL, salary INT, dep STRING")
@@ -979,6 +1107,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase {
}
test("SPARK-54157: version is refreshed when source is V2 table") {
+ import org.apache.spark.sql.connector.catalog.Column
val sourceTable = "cat.ns1.source_table"
withTable(sourceTable) {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
@@ -1026,6 +1155,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase {
}
test("SPARK-54444: any schema changes after analysis are prohibited") {
+ import org.apache.spark.sql.connector.catalog.Column
val sourceTable = "cat.ns1.source_table"
withTable(sourceTable) {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index 069781e40d8c..91f20885beb4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, In, Not}
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
-import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, InMemoryTable, TableInfo}
+import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, TableInfo}
import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.connector.write.MergeSummary
import org.apache.spark.sql.execution.SparkPlan
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.MergeRowsExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{IntegerType, StringType}
abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
@@ -38,6 +39,309 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
protected def deltaMerge: Boolean = false
+ test("self merge with transactional checks") {
+ // create table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // merge into table using a subquery on top of itself
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING (SELECT * FROM $tableNameAsString WHERE salary = 100) s
+ |ON t.pk = s.pk AND t.dep = 'hr'
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = t.salary + 1
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin)
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 1)
+ assert(table.version() == "2")
+
+ // check all table scans
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumScans = if (deltaMerge) 2 else 4
+ assert(targetTxnTable.scanEvents.size == expectedNumScans)
+
+ // check table scans as MERGE target
+ val numTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ val expectedNumTargetScans = if (deltaMerge) 1 else 2
+ assert(numTargetScans == expectedNumTargetScans)
+
+ // check table scans as MERGE source
+ val numSourceScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("salary", 100) => true
+ case _ => false
+ }
+ val expectedNumSourceScans = if (deltaMerge) 1 else 2
+ assert(numSourceScans == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "hr"), // update
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged
+ }
+
+ test("merge into table with analysis failure and transactional checks") {
+ createAndInitTable(
+ "pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')")
+
+ val exception = intercept[AnalysisException] {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING $sourceNameAsString s
+ |ON t.pk = s.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = s.invalid_column
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending')
+ |""".stripMargin)
+ }
+
+ assert(exception.getMessage.contains("invalid_column"))
+ assert(catalog.lastTransaction.currentState == Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("merge into table using view with transactional checks") {
+ withView("temp_view") {
+ // create target table
+ createAndInitTable(
+ "pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)")
+ sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)")
+
+ // create view on top of source and target tables
+ sql(
+ s"""CREATE VIEW temp_view AS
+ |SELECT s.pk, s.salary, t.dep
+ |FROM $sourceNameAsString s
+ |LEFT JOIN (
+ | SELECT * FROM $tableNameAsString WHERE pk < 10
+ |) t ON s.pk = t.pk
+ |""".stripMargin)
+
+ // merge into target table using view
+ val (txn, txnTables) = executeTransaction {
+ sql(s"""MERGE INTO $tableNameAsString t
+ |USING temp_view s
+ |ON t.pk = s.pk AND t.dep = 'hr'
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = s.salary, dep = s.dep
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending')
+ |""".stripMargin)
+ }
+
+ // check txn covers both tables and was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaMerge) 2 else 4
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check target table scans as MERGE target (dep = 'hr')
+ val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2
+ assert(numMergeTargetScans == expectedNumMergeTargetScans)
+
+ // check target table scans in view as MERGE source (pk < 10)
+ val numViewTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.LessThan("pk", 10L) => true
+ case _ => false
+ }
+ val expectedNumViewTargetScans = if (deltaMerge) 1 else 2
+ assert(numViewTargetScans == expectedNumViewTargetScans)
+
+ // check source table scans in view as MERGE source (no predicate)
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaMerge) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 150, "hr"), // update
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"), // unchanged
+ Row(4, 400, "pending"))) // new
+ }
+ }
+
+ test("merge into table using nested view with transactional checks") {
+ withView("base_view", "nested_view") {
+ withTable(sourceNameAsString) {
+ // create target table
+ createAndInitTable(
+ "pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)")
+ sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)")
+
+ // create base view
+ sql(
+ s"""CREATE VIEW base_view AS
+ |SELECT s.pk, s.salary, t.dep
+ |FROM $sourceNameAsString s
+ |LEFT JOIN (
+ | SELECT * FROM $tableNameAsString WHERE pk < 10
+ |) t ON s.pk = t.pk
+ |""".stripMargin)
+
+ // create nested view on top of base view
+ sql(
+ s"""CREATE VIEW nested_view AS
+ |SELECT * FROM base_view
+ |""".stripMargin)
+
+ // merge into target table using nested view
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING nested_view s
+ |ON t.pk = s.pk AND t.dep = 'hr'
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = s.salary, dep = s.dep
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending')
+ |""".stripMargin)
+ }
+
+ // check txn covers both tables and was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaMerge) 2 else 4
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check target table scans as MERGE target (dep = 'hr')
+ val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2
+ assert(numMergeTargetScans == expectedNumMergeTargetScans)
+
+ // check target table scans in view as MERGE source (pk < 10)
+ val numViewTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.LessThan("pk", 10L) => true
+ case _ => false
+ }
+ val expectedNumViewTargetScans = if (deltaMerge) 1 else 2
+ assert(numViewTargetScans == expectedNumViewTargetScans)
+
+ // check source table scans in view as MERGE source (no predicate)
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaMerge) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 150, "hr"), // update
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"), // unchanged
+ Row(4, 400, "pending"))) // new
+ }
+ }
+ }
+
+ test("merge into table rewritten as INSERT with transactional checks") {
+ withTable(sourceNameAsString) {
+ // create target table
+ createAndInitTable(
+ "pk INT, value STRING, dep STRING",
+ """{ "pk": 1, "value": "a", "dep": "hr" }
+ |{ "pk": 2, "value": "b", "dep": "finance" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT, value STRING, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (3, 'c', 'hr'), (4, 'd', 'software')")
+
+ // merge into target with only WHEN NOT MATCHED clauses (rewritten as insert)
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING $sourceNameAsString s
+ |ON t.pk = s.pk
+ |WHEN NOT MATCHED AND s.pk < 4 THEN
+ | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_low'), s.dep)
+ |WHEN NOT MATCHED AND s.pk >= 4 THEN
+ | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_high'), s.dep)
+ |""".stripMargin)
+ }
+
+ // check txn covers both tables and was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.size == 1)
+
+ // check source table was scanned correctly
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ assert(sourceTxnTable.scanEvents.size == 1)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, "a", "hr"), // unchanged
+ Row(2, "b", "finance"), // unchanged
+ Row(3, "c_low", "hr"), // inserted via first NOT MATCHED clause
+ Row(4, "d_high", "software"))) // inserted via second NOT MATCHED clause
+ }
+ }
+
test("merge into table with expression-based default values") {
val columns = Array(
Column.create("pk", IntegerType),
@@ -770,6 +1074,131 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
}
}
+ test("merge with CTE with transactional checks") {
+ withTable(sourceNameAsString) {
+ // create target table
+ createAndInitTable(
+ "pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ // merge into target table using CTE
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""WITH cte AS (
+ | SELECT pk, salary + 50 AS salary, dep
+ | FROM $sourceNameAsString
+ | WHERE salary > 100
+ |)
+ |MERGE INTO $tableNameAsString t
+ |USING cte s
+ |ON t.pk = s.pk AND t.dep = 'hr'
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = s.salary
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending')
+ |""".stripMargin)
+ }
+
+ // check txn covers both tables and was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaMerge) 1 else 2
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check target table scans as MERGE target (dep = 'hr')
+ val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numMergeTargetScans == expectedNumTargetScans)
+
+ // check source table was scanned correctly
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaMerge) 1 else 2
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ // check source table scans in CTE (salary > 100)
+ val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count {
+ case sources.GreaterThan("salary", 100) => true
+ case _ => false
+ }
+ assert(numCteSourceScans == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 200, "hr"), // updated (150 + 50)
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"), // unchanged
+ Row(4, 450, "pending"))) // inserted (400 + 50)
+ }
+ }
+
+ test("merge with cached source and transactional checks") {
+ withTable(sourceNameAsString) {
+ // create target table
+ createAndInitTable(
+ "pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create and populate source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')")
+
+ // Cache source table before the transaction. Make sure when the transation is active the
+ // catalog still creates a transaction table.
+ spark.table(sourceNameAsString).cache()
+
+ try {
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING $sourceNameAsString s
+ |ON t.pk = s.pk AND t.dep = 'hr'
+ |WHEN MATCHED THEN
+ | UPDATE SET salary = s.salary, dep = s.dep
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending')
+ |""".stripMargin)
+ }
+
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ // both target and source must have been read through the transaction catalog
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+ assert(txnTables(sourceNameAsString).scanEvents.nonEmpty)
+ assert(txnTables(tableNameAsString).scanEvents.nonEmpty)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 150, "support"), // matched and updated
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"), // unchanged (no match in source)
+ Row(4, 400, "pending"))) // not matched, inserted
+ } finally {
+ spark.catalog.uncacheTable(sourceNameAsString)
+ }
+ }
+ }
+
test("merge with subquery as source") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
@@ -2322,6 +2751,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
sql(query)
}
assert(e.getMessage.contains("ON search condition of the MERGE statement"))
+ assert(catalog.lastTransaction.currentState == Aborted)
+ assert(catalog.lastTransaction.isClosed)
}
private def assertMetric(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala
new file mode 100644
index 000000000000..c6b2f33c25fe
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTableCatalog, InMemoryTableCatalog, SessionConfigSupport, SupportsCatalogOptions}
+import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+/**
+ * Tests for transactional writes to path-based tables, where the table is identified by a
+ * bare path with no catalog prefix (e.g. `/path/to/t`), or a connector-prefixed path
+ * (e.g. `pathformat.`/path/to/t``). The transactional catalog is registered as the session
+ * catalog (`spark_catalog`).
+ */
+class PathBasedTableTransactionSuite extends RowLevelOperationSuiteBase {
+
+ private val tablePath = "`/path/to/t`"
+ private val tablePathWithFormat = "pathformat.`/path/to/t`"
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ spark.conf.set(
+ V2_SESSION_CATALOG_IMPLEMENTATION.key,
+ classOf[InMemoryRowLevelOperationTableCatalog].getName)
+ }
+
+ override def afterEach(): Unit = {
+ spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
+ super.afterEach()
+ }
+
+ override protected def catalog: InMemoryRowLevelOperationTableCatalog = {
+ spark.sessionState.catalogManager.v2SessionCatalog
+ .asInstanceOf[InMemoryRowLevelOperationTableCatalog]
+ }
+
+ private def createPathTable(name: String): Unit = {
+ sql(s"CREATE TABLE $name (id INT, data STRING)")
+ }
+
+ test("SQL insert into bare path-based table participates in transaction") {
+ createPathTable(tablePath)
+ val (txn, _) = executeTransaction {
+ sql(s"INSERT INTO $tablePath VALUES (1, 'a'), (2, 'b')")
+ }
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ checkAnswer(spark.table(tablePath), Row(1, "a") :: Row(2, "b") :: Nil)
+ }
+
+ test("SQL insert with connector-prefixed path participates in transaction") {
+ createPathTable(tablePathWithFormat)
+ val (txn, _) = executeTransaction {
+ sql(s"INSERT INTO $tablePathWithFormat VALUES (1, 'a'), (2, 'b')")
+ }
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ checkAnswer(spark.table(tablePathWithFormat), Row(1, "a") :: Row(2, "b") :: Nil)
+ }
+
+ test("SQL insert with CTE into connector-prefixed path participates in transaction") {
+ createPathTable(tablePathWithFormat)
+ val (txn, _) = executeTransaction {
+ sql(s"""
+ |WITH cte AS (SELECT 1 AS id, 'a' AS data)
+ |INSERT INTO $tablePathWithFormat SELECT * FROM cte
+ |""".stripMargin)
+ }
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+ checkAnswer(spark.table(tablePathWithFormat), Row(1, "a") :: Nil)
+ }
+
+ test("session-config catalog controls which catalog is enrolled in transaction") {
+ withSQLConf(
+ "spark.sql.catalog.txncat" -> classOf[InMemoryRowLevelOperationTableCatalog].getName,
+ "spark.sql.catalog.nontxncat" -> classOf[InMemoryTableCatalog].getName) {
+ val txnCat = spark.sessionState.catalogManager.catalog("txncat")
+ .asInstanceOf[InMemoryRowLevelOperationTableCatalog]
+
+ // Non-transactional catalog configured.
+ withSQLConf("spark.datasource.pathformat2.catalog" -> "nontxncat") {
+ createPathTable("pathformat2.`/path/to/t1`")
+ sql("INSERT INTO pathformat2.`/path/to/t1` VALUES (1, 'a')")
+ // The transaction was not routed to any of the transactional catalogs.
+ assert(catalog.lastTransaction == null)
+ assert(txnCat.lastTransaction == null)
+ }
+
+ // Transactional catalog configured: pathBased resolves txncat as a
+ // TransactionalCatalogPlugin and opens the transaction there instead.
+ withSQLConf("spark.datasource.pathformat2.catalog" -> "txncat") {
+ createPathTable("pathformat2.`/path/to/t2`")
+ sql("INSERT INTO pathformat2.`/path/to/t2` VALUES (1, 'a')")
+ assert(txnCat.lastTransaction.currentState === Committed)
+ assert(txnCat.lastTransaction.isClosed)
+ }
+ }
+ }
+
+ test("SQL insert with unregistered format produces analysis error and aborts transaction") {
+ createPathTable(tablePathWithFormat)
+ // "Unregistered" is not a known catalog and not registered data source.
+ // So Spark falls back to treating it as a namespace in spark_catalog. The table
+ // does not exist, causing an AnalysisException. The transaction is started (because
+ // spark_catalog IS a TransactionalCatalogPlugin) and then aborted on failure.
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("INSERT INTO unregistered.`/path/to/t` VALUES (1, 'a'), (2, 'b')")
+ },
+ condition = "TABLE_OR_VIEW_NOT_FOUND",
+ parameters = Map("relationName" -> "`unregistered`.`/path/to/t`"),
+ context = ExpectedContext(
+ fragment = "unregistered.`/path/to/t`",
+ start = -1,
+ stop = -1))
+ val txn = catalog.lastTransaction
+ assert(txn.currentState === Aborted)
+ assert(txn.isClosed)
+ }
+}
+
+/**
+ * Simulates a path-based connector (e.g. Delta) that implements [[SupportsCatalogOptions]]
+ * to route `pathformat.\`/path/to/t\`` SQL identifiers to the session catalog. Returning
+ * null from [[extractCatalog]] signals that the session catalog (`spark_catalog`) owns the
+ * table, matching Delta's behavior where DeltaCatalog is registered as spark_catalog.
+ */
+class FakePathBasedSource
+ extends FakeV2ProviderWithCustomSchema
+ with SupportsCatalogOptions
+ with DataSourceRegister {
+
+ override def shortName(): String = "pathformat"
+
+ // Use the session catalog.
+ override def extractCatalog(options: CaseInsensitiveStringMap): String = null
+
+ // Not used in the transactional path.
+ override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = null
+}
+
+/**
+ * Like [[FakePathBasedSource]] but resolves the owning catalog from the session config
+ * `spark.datasource.pathformat2.catalog` instead of always returning null. This simulates
+ * a connector that lets users configure the target catalog.
+ */
+class FakePathBasedSourceWithSessionConfig
+ extends FakeV2ProviderWithCustomSchema
+ with SupportsCatalogOptions
+ with SessionConfigSupport
+ with DataSourceRegister {
+
+ override def shortName(): String = "pathformat2"
+
+ override def keyPrefix: String = "pathformat2"
+
+ override def extractCatalog(options: CaseInsensitiveStringMap): String = options.get("catalog")
+
+ // Not used in the transactional path.
+ override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = null
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
index 79387821bf08..d0209c97cf93 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
@@ -28,16 +28,16 @@ import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expr
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReplaceData, WriteDelta}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, TableInfo, Update, Write}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, Table, TableInfo, Txn, TxnTable, Update, Write}
import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference}
import org.apache.spark.sql.connector.expressions.Transform
-import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan}
+import org.apache.spark.sql.connector.write.RowLevelOperationTable
+import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, StructField, StructType}
-import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
@@ -82,6 +82,7 @@ abstract class RowLevelOperationSuiteBase
protected val namespace: Array[String] = Array("ns1")
protected val ident: Identifier = Identifier.of(namespace, "test_table")
protected val tableNameAsString: String = "cat." + ident.toString
+ protected val sourceNameAsString: String = "cat.ns1.source_table"
protected def extraTableProps: java.util.Map[String, String] = {
Collections.emptyMap[String, String]
@@ -135,22 +136,28 @@ abstract class RowLevelOperationSuiteBase
// executes an operation and keeps the executed plan
protected def executeAndKeepPlan(func: => Unit): SparkPlan = {
- var executedPlan: SparkPlan = null
+ withQueryExecutionsCaptured(spark)(func) match {
+ case Seq(qe) => stripAQEPlan(qe.executedPlan)
+ case other => fail(s"expected only one query execution, but got ${other.size}")
+ }
+ }
- val listener = new QueryExecutionListener {
- override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
- executedPlan = qe.executedPlan
- }
- override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
+ protected def executeTransaction(func: => Unit): (Txn, Map[String, TxnTable]) = {
+ val tables = withQueryExecutionsCaptured(spark)(func).flatMap { qe =>
+ collectWithSubqueries(qe.executedPlan) {
+ case BatchScanExec(_, _, _, _, table: TxnTable, _) => table
+ case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => table
}
}
- spark.listenerManager.register(listener)
-
- func
-
- sparkContext.listenerBus.waitUntilEmpty()
+ (catalog.lastTransaction, indexByName(tables))
+ }
- stripAQEPlan(executedPlan)
+ protected def indexByName[T <: Table](tables: Seq[T]): Map[String, T] = {
+ tables.groupBy(_.name).map {
+ case (name, sameNameTables) =>
+ val Seq(table) = sameNameTables.distinct
+ name -> table
+ }
}
// executes an operation and extracts conditions from ReplaceData or WriteDelta
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala
new file mode 100644
index 000000000000..13b6267a28ff
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import java.util.Collections
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.connector.catalog.{Aborted, CatalogV2Util, Committed, Identifier, InMemoryBaseTable, InMemoryRowLevelOperationTableCatalog, InMemoryTableCatalog, SharedTablesInMemoryRowLevelOperationTableCatalog, TableInfo}
+import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryWrapper}
+import org.apache.spark.sql.streaming.StreamingQuery
+import org.apache.spark.sql.types.StructType
+
+class StreamingTransactionSuite extends RowLevelOperationSuiteBase {
+
+ import testImplicits._
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ spark.conf.set(
+ "spark.sql.catalog.cat",
+ classOf[SharedTablesInMemoryRowLevelOperationTableCatalog].getName)
+ }
+
+ override def afterEach(): Unit = {
+ SharedTablesInMemoryRowLevelOperationTableCatalog.reset()
+ super.afterEach()
+ }
+
+ private def createSimpleTable(schemaString: String): Unit = {
+ val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString))
+ val tableInfo = new TableInfo.Builder().withColumns(columns).build()
+ catalog.createTable(ident, tableInfo)
+ }
+
+ private def streamCatalog(query: StreamingQuery): InMemoryRowLevelOperationTableCatalog = {
+ val session = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sparkSessionForStream
+ session.sessionState.catalogManager.catalog("cat")
+ .asInstanceOf[InMemoryRowLevelOperationTableCatalog]
+ }
+
+ test("streaming write commits a transaction") {
+ createSimpleTable("value INT")
+
+ withTempDir { checkpointDir =>
+ val inputData = MemoryStream[Int]
+
+ val query = inputData.toDF()
+ .writeStream
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .toTable(tableNameAsString)
+
+ assert(table.version() === "0")
+
+ inputData.addData(1, 2, 3)
+ query.processAllAvailable()
+ query.stop()
+
+ val txn = streamCatalog(query).lastTransaction
+ assert(txn != null, "expected a transaction to have been committed")
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+
+ // Pure streaming append: the write target is not read, source is not a TxnTable.
+ val targetTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(tableNameAsString)
+ assert(txn.catalog.txnTables.size === 1)
+ assert(targetTxnTable.scanEvents.isEmpty)
+ assert(table.version() === "1")
+
+ // Transaction must be scoped to the streaming session; main session catalog is untouched.
+ assert(catalog.observedTransactions.isEmpty)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(1), Row(2), Row(3)))
+ }
+ }
+
+ test("each micro-batch is an independent transaction") {
+ createSimpleTable("value INT")
+
+ withTempDir { checkpointDir =>
+ val inputData = MemoryStream[Int]
+
+ val query = inputData.toDF()
+ .writeStream
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .toTable(tableNameAsString)
+
+ assert(table.version() === "0")
+
+ inputData.addData(1, 2, 3)
+ query.processAllAvailable()
+
+ inputData.addData(4, 5, 6)
+ query.processAllAvailable()
+
+ query.stop()
+
+ val sc = streamCatalog(query)
+ assert(sc.observedTransactions.size === 2)
+ assert(sc.observedTransactions.forall(_.currentState === Committed))
+ // Pure streaming append: write target is not read in any micro-batch.
+ assert(sc.observedTransactions.forall { t =>
+ indexByName(t.catalog.txnTables.values.toSeq)(tableNameAsString).scanEvents.isEmpty
+ })
+ // Each committed micro-batch increments the delegate version exactly once.
+ assert(table.version() === "2")
+
+ // Transaction must be scoped to the streaming session; main session catalog is untouched.
+ assert(catalog.observedTransactions.isEmpty)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(1), Row(2), Row(3), Row(4), Row(5), Row(6)))
+ }
+ }
+
+ test("batch read from catalog-backed table inside streaming query is tracked as a scan event") {
+ // Target table for the stream.
+ createSimpleTable("value INT")
+
+ // Catalog-backed static table used as a batch (non-streaming) source.
+ val sourceIdent = Identifier.of(namespace, "source_table")
+ val srcColumns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT"))
+ catalog.createTable(sourceIdent, new TableInfo.Builder().withColumns(srcColumns).build())
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1), (2), (3)")
+ // The INSERT above runs a transaction on the main session catalog; capture the count now
+ // so we can assert the streaming query does not add more.
+ val mainTxnsBefore = catalog.observedTransactions.size
+
+ withTempDir { checkpointDir =>
+ val inputData = MemoryStream[Int]
+
+ // spark.read produces a DataSourceV2Relation (batch), not a streaming source.
+ // UnresolveRelationsInTransaction converts it to V2TableReference each micro-batch so
+ // the transaction-aware catalog can record the scan event.
+ val staticData = spark.read.table(sourceNameAsString)
+
+ val query = inputData.toDF()
+ .join(staticData, "value")
+ .writeStream
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .toTable(tableNameAsString)
+
+ inputData.addData(1, 2, 3)
+ query.processAllAvailable()
+ query.stop()
+
+ val txn = streamCatalog(query).lastTransaction
+ assert(txn != null, "expected a transaction to have been committed")
+ assert(txn.currentState === Committed)
+ assert(txn.isClosed)
+
+ // Both the write target and the batch source participate in the transaction.
+ assert(txn.catalog.txnTables.size === 2)
+
+ val targetTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(tableNameAsString)
+ assert(targetTxnTable.scanEvents.isEmpty)
+
+ // The static source was read exactly once and its scan event was captured.
+ val sourceTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(sourceNameAsString)
+ assert(sourceTxnTable.scanEvents.size === 1)
+
+ // Streaming must not add transactions to the main session catalog beyond the pre-existing
+ // INSERT transaction.
+ assert(catalog.observedTransactions.size === mainTxnsBefore)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(1), Row(2), Row(3)))
+ }
+ }
+
+ test("transaction is aborted when micro-batch write fails and no data is written") {
+ val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT"))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(Collections.singletonMap(
+ InMemoryBaseTable.SIMULATE_FAILED_WRITE_OPTION, "true"))
+ .build()
+ catalog.createTable(ident, tableInfo)
+
+ withTempDir { checkpointDir =>
+ val inputData = MemoryStream[Int]
+ val query = inputData.toDF()
+ .writeStream
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .toTable(tableNameAsString)
+
+ inputData.addData(1, 2, 3)
+ intercept[Exception] { query.processAllAvailable() }
+ query.stop()
+
+ val txn = streamCatalog(query).lastTransaction
+ assert(txn != null, "expected a transaction to have been recorded")
+ assert(txn.currentState === Aborted)
+ assert(txn.isClosed)
+ // Aborted transaction must not advance the delegate version.
+ assert(table.version() === "0")
+
+ // Transaction must be scoped to the streaming session; main session catalog is untouched.
+ assert(catalog.observedTransactions.isEmpty)
+
+ // Writes must not be visible after an aborted transaction.
+ checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Seq.empty)
+ }
+ }
+
+ test("streaming write to non-transactional catalog does not start a transaction") {
+ withSQLConf("spark.sql.catalog.nonTxnCat" -> classOf[InMemoryTableCatalog].getName) {
+ val nonTxnCat = spark
+ .sessionState
+ .catalogManager
+ .catalog("nonTxnCat")
+ .asInstanceOf[InMemoryTableCatalog]
+ val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT"))
+ nonTxnCat.createTable(
+ Identifier.of(Array("ns"), "tbl"),
+ new TableInfo.Builder().withColumns(columns).build())
+
+ withTempDir { checkpointDir =>
+ val inputData = MemoryStream[Int]
+ val query = inputData.toDF()
+ .writeStream
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .toTable("nonTxnCat.ns.tbl")
+
+ inputData.addData(1, 2, 3)
+ query.processAllAvailable()
+ query.stop()
+
+ assert(catalog.observedTransactions.isEmpty,
+ "no transaction expected for non-transactional catalog")
+ checkAnswer(spark.table("nonTxnCat.ns.tbl"), Seq(Row(1), Row(2), Row(3)))
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
index d32a1e5c7f56..65c3b68fa8cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.connector
import org.apache.spark.SparkRuntimeException
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, InMemoryTable, TableChange, TableInfo}
+import org.apache.spark.sql.{sources, AnalysisException, Row}
+import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, TableChange, TableInfo}
import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.connector.write.UpdateSummary
import org.apache.spark.sql.internal.SQLConf
@@ -867,4 +867,325 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase {
Row(5)))
}
}
+
+ test("update with analysis failure and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ val exception = intercept[AnalysisException] {
+ sql(s"UPDATE $tableNameAsString SET invalid_column = -1")
+ }
+
+ assert(exception.getMessage.contains("invalid_column"))
+ assert(catalog.lastTransaction.currentState == Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("update with CTE and transactional checks") {
+ // create table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ // update using CTE
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""WITH cte AS (
+ | SELECT pk, salary + 50 AS adjusted_salary, dep
+ | FROM $sourceNameAsString
+ | WHERE salary > 100
+ |)
+ |UPDATE $tableNameAsString t
+ |SET salary = -1
+ |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM cte WHERE cte.pk = t.pk)
+ |""".stripMargin)
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaUpdate) 1 else 3
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check target table scans for UPDATE condition (dep = 'hr')
+ val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numUpdateTargetScans == expectedNumTargetScans)
+
+ // check source table was scanned correctly
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaUpdate) 1 else 4
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ // check source table scans in CTE (salary > 100)
+ val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count {
+ case sources.GreaterThan("salary", 100) => true
+ case _ => false
+ }
+ assert(numCteSourceScans == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, -1, "hr"), // updated
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (no matching pk in source)
+ }
+
+ test("update with subquery on source table and transactional checks") {
+ // create target table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ // update using an uncorrelated IN subquery that reads from a transactional catalog table
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""UPDATE $tableNameAsString
+ |SET salary = -1
+ |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr')
+ |""".stripMargin)
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check source table was scanned correctly (dep = 'hr' filter in the subquery)
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaUpdate) 1 else 4
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ assert(numSubquerySourceScans == expectedNumSourceScans)
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaUpdate) 1 else 3
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, -1, "hr"), // updated (pk 1 is in subquery result)
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result)
+ }
+
+ test("update with uncorrelated scalar subquery on source table and transactional checks") {
+ // create target table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)")
+ sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')")
+
+ // update using an uncorrelated scalar subquery in the SET clause that reads from a
+ // transactional catalog table; scalar subqueries are executed as SubqueryExec at runtime
+ // and cannot be rewritten as joins
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""UPDATE $tableNameAsString
+ |SET salary = (SELECT max(salary) FROM $sourceNameAsString WHERE dep = 'hr')
+ |WHERE dep = 'hr'
+ |""".stripMargin)
+ }
+
+ // check txn was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check source table was scanned via the transaction catalog
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ assert(sourceTxnTable.scanEvents.nonEmpty)
+ assert(sourceTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+
+ // check target table was scanned via the transaction catalog
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.nonEmpty)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 150, "hr"), // updated (max salary in source for 'hr' is 150)
+ Row(2, 200, "software"), // unchanged
+ Row(3, 150, "hr"))) // updated
+ }
+
+ test("update with constraint violation and transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ val exception = intercept[SparkRuntimeException] {
+ executeTransaction {
+ sql(
+ s"""UPDATE $tableNameAsString
+ |SET pk = NULL
+ |WHERE dep = 'hr'
+ |""".stripMargin) // NULL violates NOT NULL constraint
+ }
+ }
+
+ assert(exception.getMessage.contains("NOT_NULL_ASSERT_VIOLATION"))
+ assert(catalog.lastTransaction.currentState == Aborted)
+ assert(catalog.lastTransaction.isClosed)
+ }
+
+ test("update using view with transactional checks") {
+ withView("temp_view") {
+ // create target table
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // create source table
+ sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)")
+ sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)")
+
+ // create view on top of source and target tables
+ sql(
+ s"""CREATE VIEW temp_view AS
+ |SELECT s.pk, s.salary, t.dep
+ |FROM $sourceNameAsString s
+ |LEFT JOIN (
+ | SELECT * FROM $tableNameAsString WHERE pk < 10
+ |) t ON s.pk = t.pk
+ |""".stripMargin)
+
+ // update target table using view
+ val (txn, txnTables) = executeTransaction {
+ sql(
+ s"""UPDATE $tableNameAsString t
+ |SET salary = -1
+ |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM temp_view v WHERE v.pk = t.pk)
+ |""".stripMargin)
+ }
+
+ // check txn covers both tables and was properly committed and closed
+ assert(txn.currentState == Committed)
+ assert(txn.isClosed)
+ assert(txnTables.size == 2)
+ assert(table.version() == "2")
+
+ // check target table was scanned correctly
+ val targetTxnTable = txnTables(tableNameAsString)
+ val expectedNumTargetScans = if (deltaUpdate) 2 else 7
+ assert(targetTxnTable.scanEvents.size == expectedNumTargetScans)
+
+ // check target table scans as UPDATE target (dep = 'hr')
+ val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ }
+ val expectedNumUpdateTargetScans = if (deltaUpdate) 1 else 3
+ assert(numUpdateTargetScans == expectedNumUpdateTargetScans)
+
+ // check target table scans in view as source (pk < 10)
+ val numViewTargetScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.LessThan("pk", 10L) => true
+ case _ => false
+ }
+ val expectedNumViewTargetScans = if (deltaUpdate) 1 else 4
+ assert(numViewTargetScans == expectedNumViewTargetScans)
+
+ // check source table scans in view
+ val sourceTxnTable = txnTables(sourceNameAsString)
+ val expectedNumSourceScans = if (deltaUpdate) 1 else 4
+ assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans)
+
+ // check txn state was propagated correctly
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, -1, "hr"), // updated from view
+ Row(2, 200, "software"), // unchanged
+ Row(3, 300, "hr"))) // unchanged (no matching pk in source)
+ }
+ }
+
+ test("df.explain() on update with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ // sql() is lazy, but explain() forces executedPlan.
+ sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'").explain()
+
+ assert(catalog.lastTransaction != null)
+ assert(catalog.lastTransaction.currentState == Committed)
+ assert(catalog.lastTransaction.isClosed)
+ assert(table.version() == "2")
+
+ // The UPDATE was actually executed, not just planned.
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, -1, "hr"), // updated
+ Row(2, 200, "software"))) // unchanged
+ }
+
+ test("EXPLAIN UPDATE SQL with transactional checks") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ // EXPLAIN UPDATE only plans the command, it does not execute the write.
+ sql(s"EXPLAIN UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'")
+
+ // A transaction should not have started at all.
+ assert(catalog.transaction === null)
+
+ // The UPDATE was not executed. Data is unchanged.
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 100, "hr"),
+ Row(2, 200, "software")))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala
new file mode 100644
index 000000000000..141d5966b4b6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import scala.concurrent.duration._
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.classic
+import org.apache.spark.sql.execution.QueryExecution
+
+/**
+ * Benchmark to measure the overhead of cloning the analyzer for transactional query execution.
+ * Each transactional query creates a new [[Analyzer]] instance via
+ * [[Analyzer.withCatalogManager]], which shares all rules with the original but carries a
+ * per-query [[org.apache.spark.sql.connector.catalog.CatalogManager]]. This benchmark checks
+ * whether the cloning introduces measurable overhead.
+ *
+ * To run this benchmark:
+ * {{{
+ * build/sbt "sql/Test/runMain "
+ * }}}
+ */
+object AnalyzerBenchmark extends SqlBasedBenchmark {
+
+ private val numRows = 100
+ private val queries = Seq(
+ "simple select" -> "SELECT id, val FROM t1",
+ "join" -> "SELECT t1.id, t2.val FROM t1 JOIN t2 ON t1.id = t2.id",
+ "wide schema" -> s"SELECT ${(1 to 100).map(i => s"col_$i").mkString(", ")} FROM wide_t"
+ )
+
+ private def setupTables(): Unit = {
+ spark.range(numRows).selectExpr("id", "id * 2 as val").createOrReplaceTempView("t1")
+ spark.range(numRows).selectExpr("id", "id * 3 as val").createOrReplaceTempView("t2")
+ spark.range(numRows)
+ .selectExpr((1 to numRows).map(i => s"id as col_$i"): _*)
+ .createOrReplaceTempView("wide_t")
+ }
+
+ /**
+ * Measures analysis time for a pre-parsed plan, comparing the session analyzer against a
+ * cloned analyzer created via [[Analyzer.withCatalogManager]].
+ *
+ * Two cases:
+ * - "session analyzer" : baseline, uses the session analyzer directly.
+ * - "cloned analyzer (per query)": analyzer cloned every iteration; reflects the full
+ * per-transactional-query cost (clone + analysis).
+ */
+ def analysisBenchmark(): Unit = {
+ for ((name, sql) <- queries) {
+ runBenchmark(s"analysis overhead $name") {
+ val benchmark = new Benchmark(
+ name = s"analysis overhead $name",
+ // Per row measurements are not meaningful here.
+ valuesPerIteration = numRows,
+ minTime = 10.seconds,
+ output = output)
+ val catalogManager = spark.sessionState.catalogManager
+
+ benchmark.addCase("session analyzer") { _ =>
+ val plan = spark.sessionState.sqlParser.parsePlan(sql)
+ new QueryExecution(spark.asInstanceOf[classic.SparkSession], plan).analyzed
+ }
+
+ benchmark.addCase("cloned analyzer (per query)") { _ =>
+ val cloned = spark.sessionState.analyzer.withCatalogManager(catalogManager)
+ val plan = spark.sessionState.sqlParser.parsePlan(sql)
+ new QueryExecution(spark.asInstanceOf[classic.SparkSession],
+ plan, analyzerOpt = Some(cloned)).analyzed
+ }
+
+ benchmark.run()
+ }
+ }
+ }
+
+ /**
+ * Micro-benchmark for [[Analyzer.withCatalogManager]] in isolation: measures the cost of
+ * instantiating the anonymous [[Analyzer]] subclass, independent of analysis work.
+ */
+ def cloneCostBenchmark(): Unit = {
+ runBenchmark("analyzer clone cost") {
+ val numRows = 1 // Per row measurements are not meaningful here.
+ val benchmark = new Benchmark(
+ name = "analyzer clone cost",
+ valuesPerIteration = numRows,
+ output = output)
+ val catalogManager = spark.sessionState.catalogManager
+
+ benchmark.addCase("withCatalogManager") { _ =>
+ spark.sessionState.analyzer.withCatalogManager(catalogManager)
+ }
+
+ benchmark.run()
+ }
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ setupTables()
+ cloneCostBenchmark()
+ analysisBenchmark()
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala
index a36570467a9d..da697847874d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala
@@ -21,8 +21,10 @@ import org.apache.spark.SparkConf
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.plans.logical.CompoundBody
import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLConf
+import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTableCatalog, Txn, TxnTable, TxnTableCatalog}
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
@@ -47,6 +49,27 @@ class SqlScriptingE2eSuite extends SharedSparkSession {
}
// Helpers
+ private def withCatalog(
+ name: String)(
+ f: InMemoryRowLevelOperationTableCatalog => Unit): Unit = {
+ withSQLConf(s"spark.sql.catalog.$name" ->
+ classOf[InMemoryRowLevelOperationTableCatalog].getName) {
+ val catalog = spark.sessionState.catalogManager
+ .catalog(name)
+ .asInstanceOf[InMemoryRowLevelOperationTableCatalog]
+ try f(catalog) finally spark.sessionState.catalogManager.reset()
+ }
+ }
+
+ private def loadTxnTable(
+ txn: Txn,
+ tableName: String,
+ namespace: Array[String] = Array("ns1")): TxnTable =
+ txn.catalog
+ .asInstanceOf[TxnTableCatalog]
+ .loadTable(Identifier.of(namespace, tableName))
+ .asInstanceOf[TxnTable]
+
private def verifySqlScriptResult(
sqlText: String,
expected: Seq[Row],
@@ -174,6 +197,188 @@ class SqlScriptingE2eSuite extends SharedSparkSession {
}
}
+ test("multi statement with transactional checks - insert then delete") {
+ withCatalog("cat") { catalog =>
+ withTable("cat.ns1.t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING)
+ | PARTITIONED BY (dep);
+ | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software');
+ | DELETE FROM cat.ns1.t
+ | WHERE pk IN (SELECT pk FROM cat.ns1.t WHERE dep = 'hr');
+ | SELECT * FROM cat.ns1.t;
+ |END
+ |""".stripMargin
+
+ verifySqlScriptResult(sqlScript, Seq(Row(2, 200, "software")))
+
+ // Each DML statement in a script runs in its own independent QE and transaction.
+ assert(catalog.observedTransactions.size === 2)
+ assert(catalog.observedTransactions.forall(t =>
+ t.currentState === Committed && t.isClosed))
+
+ // The DELETE subquery scans the table with a dep='hr' predicate; verify it was tracked.
+ val deleteTxnTable = loadTxnTable(catalog.observedTransactions(1), "t")
+ assert(deleteTxnTable.scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "hr") => true
+ case _ => false
+ })
+ }
+ }
+ }
+
+ test("multi statement with transactional checks - second statement fails") {
+ withCatalog("cat") { catalog =>
+ withTable("cat.ns1.t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING)
+ | PARTITIONED BY (dep);
+ | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software');
+ | DELETE FROM cat.ns1.t WHERE nonexistent_column = 1;
+ |END
+ |""".stripMargin
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.sql(sqlScript).collect()
+ },
+ condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ parameters = Map(
+ "objectName" -> "`nonexistent_column`",
+ "proposal" -> ".*"),
+ matchPVals = true,
+ queryContext = Array(ExpectedContext("nonexistent_column")))
+
+ // INSERT committed; DELETE was aborted because analysis failed on the bad column.
+ assert(catalog.observedTransactions.size === 2)
+ assert(catalog.observedTransactions(0).currentState === Committed)
+ assert(catalog.observedTransactions(0).isClosed)
+ assert(catalog.observedTransactions(1).currentState === Aborted)
+ assert(catalog.observedTransactions(1).isClosed)
+ assert(catalog.lastTransaction.currentState === Aborted)
+ }
+ }
+ }
+
+ test("multi statement with transactional checks - insert, merge, update") {
+ withCatalog("cat") { catalog =>
+ withTable("cat.ns1.t", "cat.ns1.src") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING)
+ | PARTITIONED BY (dep);
+ | CREATE TABLE cat.ns1.src (pk INT NOT NULL, salary INT, dep STRING)
+ | PARTITIONED BY (dep);
+ | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'), (3, 300, 'hr');
+ | INSERT INTO cat.ns1.src VALUES (1, 150, 'hr'), (4, 400, 'finance');
+ | MERGE INTO cat.ns1.t AS t
+ | USING cat.ns1.src AS s
+ | ON t.pk = s.pk
+ | WHEN MATCHED THEN UPDATE SET salary = s.salary
+ | WHEN NOT MATCHED THEN INSERT (pk, salary, dep)
+ | VALUES (s.pk, s.salary, s.dep);
+ | UPDATE cat.ns1.t SET salary = salary + 50 WHERE dep = 'software';
+ | SELECT * FROM cat.ns1.t ORDER BY pk;
+ |END
+ |""".stripMargin
+
+ verifySqlScriptResult(
+ sqlScript,
+ Seq(
+ Row(1, 150, "hr"),
+ Row(2, 250, "software"),
+ Row(3, 300, "hr"),
+ Row(4, 400, "finance")))
+
+ // INSERT (x2), MERGE, and UPDATE each run in their own independent QE and transaction.
+ assert(catalog.observedTransactions.size === 4)
+ assert(catalog.observedTransactions.forall(t => t.currentState === Committed && t.isClosed))
+
+ def txnTable(txnIdx: Int): TxnTable =
+ loadTxnTable(catalog.observedTransactions(txnIdx), "t")
+
+ // Both inserts are pure writes - no scan.
+ assert(txnTable(0).scanEvents.isEmpty)
+ assert(txnTable(1).scanEvents.isEmpty)
+
+ // MERGE scans the full target table. The join is on pk (not the partition column).
+ assert(txnTable(2).scanEvents.nonEmpty)
+ assert(txnTable(2).scanEvents.flatten.isEmpty)
+
+ // UPDATE with WHERE dep='software' pushes an equality predicate on the partition column.
+ assert(txnTable(3).scanEvents.flatten.exists {
+ case sources.EqualTo("dep", "software") => true
+ case _ => false
+ })
+ }
+ }
+ }
+
+ test("loop with transactional checks - each iteration runs in its own transaction") {
+ withCatalog("cat") { catalog =>
+ withTable("cat.ns1.t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE i INT = 1;
+ | CREATE TABLE
+ | cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING)
+ | PARTITIONED BY (dep);
+ | WHILE i <= 3 DO
+ | INSERT INTO cat.ns1.t VALUES (i, i * 100, 'hr');
+ | SET i = i + 1;
+ | END WHILE;
+ | SELECT * FROM cat.ns1.t ORDER BY pk;
+ |END
+ |""".stripMargin
+
+ verifySqlScriptResult(
+ sqlScript,
+ Seq(Row(1, 100, "hr"), Row(2, 200, "hr"), Row(3, 300, "hr")))
+
+ // Each loop iteration's INSERT runs in its own independent transaction.
+ assert(catalog.observedTransactions.size === 3)
+ assert(catalog.observedTransactions.forall(t => t.currentState === Committed && t.isClosed))
+ }
+ }
+ }
+
+ test("continue handler with transactional checks - handler DML runs in its own transaction") {
+ withCatalog("cat") { catalog =>
+ withTable("cat.ns1.t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE CONTINUE HANDLER FOR DIVIDE_BY_ZERO
+ | BEGIN
+ | INSERT INTO cat.ns1.t VALUES (-1, -1, 'error');
+ | END;
+ | CREATE TABLE
+ | cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING)
+ | PARTITIONED BY (dep);
+ | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr');
+ | SELECT 1/0;
+ | INSERT INTO cat.ns1.t VALUES (2, 200, 'software');
+ | SELECT * FROM cat.ns1.t ORDER BY pk;
+ |END
+ |""".stripMargin
+
+ verifySqlScriptResult(
+ sqlScript,
+ Seq(Row(-1, -1, "error"), Row(1, 100, "hr"), Row(2, 200, "software")))
+
+ // INSERT(1), handler INSERT(-1), INSERT(2) - each in its own transaction.
+ assert(catalog.observedTransactions.size === 3)
+ assert(catalog.observedTransactions.forall(t => t.currentState === Committed && t.isClosed))
+ }
+ }
+ }
+
test("script without result statement") {
val sqlScript =
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
index 89f655622952..dab667731019 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.{FakeV2Provider, FakeV2ProviderWithCustomSchema, InMemoryTableSessionCatalog}
-import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTableCatalog, MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability, TableInfo, V2TableWithV1Fallback}
+import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTable, InMemoryTableCatalog, MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability, TableInfo, V2TableWithV1Fallback}
import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, Transform}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, MemoryStreamScanBuilder, StreamingQueryWrapper}
@@ -109,20 +109,23 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter {
}
test("read: read table without streaming capability support") {
- val tableIdentifier = "testcat.table_name"
+ withSQLConf("spark.sql.catalog.testcat" ->
+ classOf[DataStreamTableAPISuite.NonStreamingInMemoryTableCatalog].getName) {
+ val tableIdentifier = "testcat.table_name"
- spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING foo")
+ spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING foo")
- checkError(
- exception = intercept[AnalysisException] {
- spark.readStream.table(tableIdentifier)
- },
- condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION",
- parameters = Map(
- "tableName" -> "`testcat`.`table_name`",
- "operation" -> "either micro-batch or continuous scan"
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.readStream.table(tableIdentifier)
+ },
+ condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ parameters = Map(
+ "tableName" -> "`testcat`.`table_name`",
+ "operation" -> "either micro-batch or continuous scan"
+ )
)
- )
+ }
}
test("read: read table with custom catalog") {
@@ -638,6 +641,25 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter {
object DataStreamTableAPISuite {
val V1FallbackTestTableName = "fallbackV1Test"
+
+ class NonStreamingInMemoryTableCatalog extends InMemoryTableCatalog {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
+ if (tables.containsKey(ident)) {
+ throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
+ }
+ val tableName = s"$name.${ident.quoted}"
+ val table = new InMemoryTable(tableName, tableInfo.columns(), tableInfo.partitions(),
+ tableInfo.properties, tableInfo.constraints()) {
+ override def baseCapabiilities: Set[TableCapability] =
+ super.baseCapabiilities - TableCapability.MICRO_BATCH_READ
+ }
+ tables.put(ident, table)
+ namespaces.putIfAbsent(ident.namespace.toList, Map())
+ table
+ }
+ }
}
class InMemoryStreamTable(override val name: String)