diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java new file mode 100644 index 000000000000..daa3176dcbba --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TransactionalCatalogPlugin.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.transactions.Transaction; +import org.apache.spark.sql.connector.catalog.transactions.TransactionInfo; + +/** + * A {@link CatalogPlugin} that supports transactions. + *

+ * Catalogs that implement this interface opt in to transactional query execution. A catalog + * implementing this interface is responsible for starting transactions. + * + * @since 4.2.0 + */ +@Evolving +public interface TransactionalCatalogPlugin extends CatalogPlugin { + + /** + * Begins a new transaction and returns a {@link Transaction} representing it. + * + * @param info metadata about the transaction being started. + */ + Transaction beginTransaction(TransactionInfo info); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java new file mode 100644 index 000000000000..77044c6202fb --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin; + +import java.io.Closeable; + +/** + * Represents a transaction. + *

+ * Spark begins a transaction with {@link TransactionalCatalogPlugin#beginTransaction} and + * executes read/write operations against the transaction's catalog. On success, Spark + * calls {@link #commit()}; on failure, Spark calls {@link #abort()}. In both cases Spark + * subsequently calls {@link #close()} to release resources. + * + * @since 4.2.0 + */ +@Evolving +public interface Transaction extends Closeable { + + /** + * Returns the catalog associated with this transaction. This catalog is responsible for tracking + * read/write operations that occur within the boundaries of a transaction. This allows + * connectors to perform conflict resolution at commit time. + */ + CatalogPlugin catalog(); + + /** + * Commits the transaction. All writes performed under it become visible to other readers. + *

+ * The connector is responsible for detecting and resolving conflicting commits or throwing + * an exception if resolution is not possible. + *

+ * This method will be called exactly once per transaction. Spark calls {@link #close()} + * immediately after this method returns. + * + * @throws IllegalStateException if the transaction has already been committed or aborted. + */ + void commit(); + + /** + * Aborts the transaction, discarding any staged changes. + *

+ * This method must be idempotent. If the transaction has already been committed or aborted, + * invoking it must have no effect. + *

+ * Spark calls {@link #close()} immediately after this method returns. + */ + void abort(); + + /** + * Releases any resources held by this transaction. + *

+ * Spark always calls this method after {@link #commit()} or {@link #abort()}, regardless of + * whether those methods succeed or not. + *

+ * This method must be idempotent. If the transaction has already been closed, + * invoking it must have no effect. + */ + @Override + void close(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java new file mode 100644 index 000000000000..3e6979cec469 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/TransactionInfo.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions; + +import org.apache.spark.annotation.Evolving; + +/** + * Metadata about a transaction. + * + * @since 4.2.0 + */ +@Evolving +public interface TransactionInfo { + /** + * Returns a unique identifier for this transaction. + */ + String id(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java index 44fc5f9d794b..75816349af38 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java @@ -85,6 +85,12 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple * tasks may have committed successfully and one successful commit message per task will be * passed to this commit method. The remaining commit messages are ignored by Spark. + *

+ * Note: this method signals that all data for this write operation has been successfully written. + * It is NOT a transactional commit. When this write is part of a + * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction}, the transaction is + * committed separately via + * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction#commit()}. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java index ab98bc01b3ae..764ed0a35a3f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java @@ -80,6 +80,12 @@ default boolean useCommitCoordinator() { * The execution engine may call {@code commit} multiple times for the same epoch in some * circumstances. To support exactly-once data semantics, implementations must ensure that * multiple commits for the same epoch are idempotent. + *

+ * Note: this method signals that all data for this write operation has been successfully written. + * It is NOT a transactional commit. When this write is part of a + * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction}, the transaction is + * committed separately via + * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction#commit()}. */ void commit(long epochId, WriterCommitMessage[] messages); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 323a7db9c7ad..50dc055a98a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -351,6 +351,33 @@ class Analyzer( } } + /** + * Returns a copy of this analyzer that uses the given [[CatalogManager]] for all catalog + * lookups. All other configuration (extended rules, checks, etc.) is preserved. Used by + * [[QueryExecution]] to create a per-query analyzer for transactional operations for + * transaction-aware catalog resolution. + * + * IMPORTANT: any new extension point added to Analyzer must also be copied here, otherwise + * transaction-aware analyzer clones (created by QueryExecution) will silently miss those rules. + */ + def withCatalogManager(newCatalogManager: CatalogManager): Analyzer = { + val self = this + new Analyzer(newCatalogManager, sharedRelationCache) { + override val hintResolutionRules: Seq[Rule[LogicalPlan]] = self.hintResolutionRules + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = self.extendedResolutionRules + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = self.postHocResolutionRules + override val extendedCheckRules: Seq[LogicalPlan => Unit] = self.extendedCheckRules + override val singlePassResolverExtensions: Seq[ResolverExtension] = + self.singlePassResolverExtensions + override val singlePassMetadataResolverExtensions: Seq[ResolverExtension] = + self.singlePassMetadataResolverExtensions + override val singlePassPostHocResolutionRules: Seq[Rule[LogicalPlan]] = + self.singlePassPostHocResolutionRules + override val singlePassExtendedResolutionChecks: Seq[LogicalPlan => Unit] = + self.singlePassExtendedResolutionChecks + } + } + override def execute(plan: LogicalPlan): LogicalPlan = { AnalysisContext.withNewAnalysisContext { executeSameContext(plan) @@ -458,7 +485,9 @@ class Analyzer( Batch("Simple Sanity Check", Once, LookupFunctions), Batch("Keep Legacy Outputs", Once, - KeepLegacyOutputs) + KeepLegacyOutputs), + Batch("Unresolve Relations", Once, + new UnresolveRelationsInTransaction(catalogManager)) ) override def batches: Seq[Batch] = earlyBatches ++ Seq( @@ -1015,7 +1044,7 @@ class Analyzer( // DataSourceV2Relation on each view access. Only dataframe temp view may contain it // as it stores resolved plans directly. case view: View if view.isTempViewStoringAnalyzedPlan => - view.copy(child = resolveTableReferences(view.child)) + view.copy(child = resolveTableReferencesInTempView(view.child)) case p @ SubqueryAlias(_, view: View) => p.copy(child = resolveViews(view, options)) case _ => plan @@ -1024,17 +1053,43 @@ class Analyzer( // Unwrap temp views storing analyzed plans and resolve V2TableReference nodes in the child. private def unwrapRelationPlan(plan: LogicalPlan): LogicalPlan = { EliminateSubqueryAliases(plan) match { - case v: View if v.isTempViewStoringAnalyzedPlan => resolveTableReferences(v.child) + case v: View if v.isTempViewStoringAnalyzedPlan => resolveTableReferencesInTempView(v.child) case other => other } } - // Resolve V2TableReference nodes in a plan. V2TableReference is only created for temp views - // (via V2TableReference.createForTempView), so we only need to resolve it when returning + // Resolve the write target of a V2 write command (batch or streaming). + private def resolveWriteTarget( + write: LogicalPlan, + table: NamedRelation, + withNewTable: NamedRelation => LogicalPlan): LogicalPlan = { + table match { + case u: UnresolvedRelation if !u.isStreaming => + resolveRelation(u).map(unwrapRelationPlan).map { + case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError( + v.desc.identifier, write) + case u: UnresolvedCatalogRelation => + throw QueryCompilationErrors.writeIntoV1TableNotAllowedError( + u.tableMeta.identifier, write) + case r: DataSourceV2Relation => withNewTable(r) + case _ => + throw QueryCompilationErrors.writeIntoTempViewNotAllowedError( + u.multipartIdentifier.quoted) + }.getOrElse(write) + case _ => write + } + } + + // Resolve V2TableReference nodes inside temp view plans. These are created by + // V2TableReference.createForTempView. We only need to resolve it when returning // the plan of temp views (in resolveViews and unwrapRelationPlan). - private def resolveTableReferences(plan: LogicalPlan): LogicalPlan = { + private def resolveTableReferencesInTempView(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsUp { - case r: V2TableReference => relationResolution.resolveReference(r) + case r: V2TableReference => + assert(r.context.isInstanceOf[V2TableReference.TemporaryViewContext], + s"""Expected TemporaryViewContext in temp view but got + |${r.context.getClass.getSimpleName}""".stripMargin) + relationResolution.resolveReference(r) } } @@ -1057,23 +1112,11 @@ class Analyzer( case other => i.copy(table = other) } - // TODO (SPARK-27484): handle streaming write commands when we have them. + case write: StreamingV2WriteCommand => + resolveWriteTarget(write, write.table, write.withNewTable) + case write: V2WriteCommand => - write.table match { - case u: UnresolvedRelation if !u.isStreaming => - resolveRelation(u).map(unwrapRelationPlan).map { - case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError( - v.desc.identifier, write) - case u: UnresolvedCatalogRelation => - throw QueryCompilationErrors.writeIntoV1TableNotAllowedError( - u.tableMeta.identifier, write) - case r: DataSourceV2Relation => write.withNewTable(r) - case _ => - throw QueryCompilationErrors.writeIntoTempViewNotAllowedError( - u.multipartIdentifier.quoted) - }.getOrElse(write) - case _ => write - } + resolveWriteTarget(write, write.table, write.withNewTable) case u: UnresolvedRelation => resolveRelation(u).map(resolveViews(_, u.options)).getOrElse(u) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index 7a5077a8a3e1..4546c1dad001 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -496,10 +496,25 @@ class RelationResolution( } } + /** + * Loads the table for a [[V2TableReference]] and returns a resolved [[DataSourceV2Relation]]. + * + * The catalog is re-resolved by name through the [[CatalogManager]] rather than reusing + * [[V2TableReference#catalog]] directly. When a transaction is active, the + * [[TransactionAwareCatalogManager]] redirects catalog lookups to the transaction's catalog + * instance, so the [[TableCatalog#loadTable]] call is intercepted by the transaction catalog, + * which uses it to track which tables are read as part of the transaction. + */ private def loadRelation(ref: V2TableReference): LogicalPlan = { - val table = ref.catalog.loadTable(ref.identifier) + val resolvedCatalog = catalogManager.catalog(ref.catalog.name).asTableCatalog + val table = resolvedCatalog.loadTable(ref.identifier) V2TableReferenceUtils.validateLoadedTable(table, ref) - ref.toRelation(table) + DataSourceV2Relation( + table = table, + output = ref.output, + catalog = Some(resolvedCatalog), + identifier = Some(ref.identifier), + options = ref.options) } private def adaptCachedRelation(cached: LogicalPlan, ref: V2TableReference): LogicalPlan = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala new file mode 100644 index 000000000000..8ee64e32376f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolveRelationsInTransaction.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TransactionalWrite} +import org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.allowInvokingTransformsInAnalyzer +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +/** + * When a transaction is active, converts resolved [[DataSourceV2Relation]] nodes back to + * [[V2TableReference]] placeholders for all relations loaded by a catalog with the same + * name as the transaction catalog. + * + * This forces re-resolution of those relations against the transaction's catalog, which + * intercepts [[TableCatalog#loadTable]] calls to track which tables are read as part of + * the transaction. + */ +class UnresolveRelationsInTransaction(val catalogManager: CatalogManager) + extends Rule[LogicalPlan] with LookupCatalog { + + override def apply(plan: LogicalPlan): LogicalPlan = + catalogManager.transaction match { + case Some(transaction) => + // We use plain transform rather than resolveOperators* because the latter skips subtrees + // that have already been analyzed. Furthermore, allowInvokingTransformsInAnalyzer + // allows to suppress the assertNotAnalysisRule safety check, which forbids calling + // transform directly inside the analyzer when not within a resolveOperators call. + allowInvokingTransformsInAnalyzer { + plan.transform { + case tw: TransactionalWrite => + unresolveRelations(tw, transaction.catalog) + } + } + case _ => plan + } + + private def unresolveRelations( + plan: LogicalPlan, + catalog: CatalogPlugin): LogicalPlan = { + plan.transform { + case r: DataSourceV2Relation if isLoadedFromCatalog(r, catalog) => + V2TableReference.createForTransaction(r) + } + } + + private def isLoadedFromCatalog( + relation: DataSourceV2Relation, + catalog: CatalogPlugin): Boolean = { + relation.catalog.exists(_.name == catalog.name) && relation.identifier.isDefined + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index 85c36d452b30..5545141640a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.V2TableReference.Context import org.apache.spark.sql.catalyst.analysis.V2TableReference.TableInfo import org.apache.spark.sql.catalyst.analysis.V2TableReference.TemporaryViewContext +import org.apache.spark.sql.catalyst.analysis.V2TableReference.TransactionContext import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics @@ -37,7 +38,7 @@ import org.apache.spark.sql.connector.catalog.V2TableUtil import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.sql.util.SchemaValidationMode.ALLOW_NEW_TOP_LEVEL_FIELDS +import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_TOP_LEVEL_FIELDS, PROHIBIT_CHANGES} import org.apache.spark.util.ArrayImplicits._ /** @@ -79,22 +80,33 @@ private[sql] case class V2TableReference private( private[sql] object V2TableReference { case class TableInfo( + tableId: Option[String], columns: Seq[Column], metadataColumns: Seq[MetadataColumn]) sealed trait Context + /** Context for relations that are re-resolved on access of a dataframe temp view. */ case class TemporaryViewContext(viewName: Seq[String]) extends Context + /** Context for relations that are re-resolved through a transaction catalog. */ + case object TransactionContext extends Context def createForTempView(relation: DataSourceV2Relation, viewName: Seq[String]): V2TableReference = { create(relation, TemporaryViewContext(viewName)) } + // V2TableReference nodes in the transaction context are produced by + // UnresolveRelationsInTransaction which unresolves already resolved relations. + def createForTransaction(relation: DataSourceV2Relation): V2TableReference = { + create(relation, TransactionContext) + } + private def create(relation: DataSourceV2Relation, context: Context): V2TableReference = { val ref = V2TableReference( relation.catalog.get.asTableCatalog, relation.identifier.get, relation.options, TableInfo( + tableId = Option(relation.table.id()), columns = relation.table.columns.toImmutableArraySeq, metadataColumns = V2TableUtil.extractMetadataColumns(relation)), relation.output, @@ -110,11 +122,35 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { ref.context match { case ctx: TemporaryViewContext => validateLoadedTableInTempView(table, ref, ctx) + case TransactionContext => + validateLoadedTableInTransaction(table, ref) case ctx => throw SparkException.internalError(s"Unknown table ref context: ${ctx.getClass.getName}") } } + private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = { + // Make sure the table was not dropped and recreated. + ref.info.tableId.foreach(V2TableUtil.validateTableId(ref.name, _, table)) + + // Do not allow schema evolution to pre-analysed dataframes that are later used in + // transactional writes. This is because the entire plans was built based on the original schema + // and any schema change would make the plan structurally invalid. This is inline with the + // semantics of SPARK-54444. + val dataErrors = V2TableUtil.validateCapturedColumns( + table = table, + originCols = ref.info.columns, + mode = PROHIBIT_CHANGES) + if (dataErrors.nonEmpty) { + throw QueryCompilationErrors.columnsChangedAfterAnalysis(ref.name, dataErrors) + } + + val metaErrors = V2TableUtil.validateCapturedMetadataColumns(table, ref.info.metadataColumns) + if (metaErrors.nonEmpty) { + throw QueryCompilationErrors.metadataColumnsChangedAfterAnalysis(ref.name, metaErrors) + } + } + private def validateLoadedTableInTempView( table: Table, ref: V2TableReference, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index c38377582c15..774c783ecf8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -188,7 +188,11 @@ case class InsertIntoStatement( byName: Boolean = false, replaceCriteriaOpt: Option[InsertReplaceCriteria] = None, withSchemaEvolution: Boolean = false) - extends UnaryParsedStatement { + // Extends TransactionalWrite so that QueryExecution can detect a potential transaction on the + // unresolved logical plan before analysis runs. InsertIntoStatement is shared between V1 and V2 + // inserts, but the LookupCatalog.TransactionalWrite extractor only matches when the target + // catalog implements TransactionalCatalogPlugin, so V1 inserts are never assigned a transaction. + extends UnaryParsedStatement with TransactionalWrite { require(overwrite || !ifPartitionNotExists, "IF NOT EXISTS is only valid in INSERT OVERWRITE") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 0eded2d9dbdf..15573b157d5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -157,7 +157,7 @@ case class AppendData( isByName: Boolean, withSchemaEvolution: Boolean, write: Option[Write] = None, - analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand { + analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT) override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery) override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable) @@ -205,7 +205,7 @@ case class OverwriteByExpression( isByName: Boolean, withSchemaEvolution: Boolean, write: Option[Write] = None, - analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand { + analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE) override lazy val resolved: Boolean = { @@ -265,7 +265,7 @@ case class OverwritePartitionsDynamic( writeOptions: Map[String, String], isByName: Boolean, withSchemaEvolution: Boolean, - write: Option[Write] = None) extends V2WriteCommand { + write: Option[Write] = None) extends V2WriteCommand with TransactionalWrite { override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE) override def withNewQuery(newQuery: LogicalPlan): OverwritePartitionsDynamic = { @@ -521,8 +521,10 @@ case class WriteDelta( trait V2CreateTableAsSelectPlan extends V2CreateTablePlan with AnalysisOnlyCommand - with CTEInChildren { + with CTEInChildren + with TransactionalWrite { def query: LogicalPlan + override def table: LogicalPlan = name override def withCTEDefs(cteDefs: Seq[CTERelationDef]): LogicalPlan = { withNameAndQuery(newName = name, newQuery = WithCTE(query, cteDefs)) @@ -956,7 +958,8 @@ object DescribeColumn { */ case class DeleteFromTable( table: LogicalPlan, - condition: Expression) extends UnaryCommand with SupportsSubquery { + condition: Expression) + extends UnaryCommand with TransactionalWrite with SupportsSubquery { override def child: LogicalPlan = table override protected def withNewChildInternal(newChild: LogicalPlan): DeleteFromTable = copy(table = newChild) @@ -978,7 +981,8 @@ case class DeleteFromTableWithFilters( case class UpdateTable( table: LogicalPlan, assignments: Seq[Assignment], - condition: Option[Expression]) extends UnaryCommand with SupportsSubquery { + condition: Option[Expression]) + extends UnaryCommand with TransactionalWrite with SupportsSubquery { lazy val aligned: Boolean = AssignmentUtils.aligned(table.output, assignments) @@ -1011,8 +1015,12 @@ case class MergeIntoTable( notMatchedActions: Seq[MergeAction], notMatchedBySourceActions: Seq[MergeAction], withSchemaEvolution: Boolean) - extends BinaryCommand with WriteWithSchemaEvolution with SupportsSubquery { + extends BinaryCommand + with WriteWithSchemaEvolution + with SupportsSubquery + with TransactionalWrite { + // Implements WriteWithSchemaEvolution.table and TransactionalWrite.table. override val table: LogicalPlan = EliminateSubqueryAliases(targetTable) override def withNewTable(newTable: NamedRelation): MergeIntoTable = { @@ -1272,6 +1280,22 @@ case class Assignment(key: Expression, value: Expression) extends Expression newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight) } +/** + * Marker trait for write operations that participate in a DSv2 transaction. + * + * Implementations are expected to target a DSv2 catalog backed by a + * [[org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin]]. + */ +trait TransactionalWrite extends LogicalPlan { + def table: LogicalPlan +} + +/** Trait for streaming write commands that participate in DSv2 transactions. */ +trait StreamingV2WriteCommand extends TransactionalWrite { + override def table: NamedRelation + def withNewTable(newTable: NamedRelation): StreamingV2WriteCommand +} + /** * The logical plan of the DROP TABLE command. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala new file mode 100644 index 000000000000..a5f8afddf01c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.transactions + +import java.util.UUID + +import org.apache.spark.SparkException +import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfoImpl} +import org.apache.spark.util.Utils + +object TransactionUtils { + def commit(txn: Transaction): Unit = { + Utils.tryWithSafeFinally { + txn.commit() + } { + txn.close() + } + } + + def abort(txn: Transaction): Unit = { + Utils.tryWithSafeFinally { + txn.abort() + } { + txn.close() + } + } + + def beginTransaction(catalog: TransactionalCatalogPlugin): Transaction = { + val info = TransactionInfoImpl(id = UUID.randomUUID.toString) + val txn = catalog.beginTransaction(info) + if (txn.catalog.name != catalog.name) { + abort(txn) + throw SparkException.internalError( + s"""Transaction catalog name (${txn.catalog.name}) + |must match original catalog name (${catalog.name}).""".stripMargin) + } + txn + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index 3f5afd9ce0de..c851e931aad4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, TempVariableManager} import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -39,7 +40,7 @@ import org.apache.spark.sql.internal.SQLConf // need to track current database at all. private[sql] class CatalogManager( - defaultSessionCatalog: CatalogPlugin, + val defaultSessionCatalog: CatalogPlugin, val v1SessionCatalog: SessionCatalog) extends SQLConfHelper with Logging { import CatalogManager.SESSION_CATALOG_NAME import CatalogV2Util._ @@ -57,6 +58,11 @@ class CatalogManager( } } + def transaction: Option[Transaction] = None + + def withTransaction(transaction: Transaction): CatalogManager = + new TransactionAwareCatalogManager(this, transaction) + def isCatalogRegistered(name: String): Boolean = { try { catalog(name) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index 203cfc23452a..dd5be45bfc5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedIdentifier, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.TransactionalWrite import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -163,4 +165,17 @@ private[sql] trait LookupCatalog extends Logging { } } } + + object TransactionalWrite { + def unapply(write: TransactionalWrite): Option[TransactionalCatalogPlugin] = { + EliminateSubqueryAliases(write.table) match { + case UnresolvedRelation(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _, _) => + Some(c) + case UnresolvedIdentifier(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _) => + Some(c) + case _ => + None + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala new file mode 100644 index 000000000000..70079357b6dd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/TransactionAwareCatalogManager.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.catalog.TempVariableManager +import org.apache.spark.sql.connector.catalog.transactions.Transaction + +/** + * A [[CatalogManager]] decorator that redirects catalog lookups to the transaction's catalog + * instance when names match, ensuring table loads during analysis are scoped to the transaction. + * All mutable state (current catalog, current namespace, loaded catalogs) is delegated to the + * wrapped [[CatalogManager]]. + */ +// TODO: Extracting a CatalogManager trait (so this class can implement it instead of extending +// CatalogManager) would eliminate the inherited mutable state that this decorator doesn't use. +private[sql] class TransactionAwareCatalogManager( + delegate: CatalogManager, + txn: Transaction) + extends CatalogManager(delegate.defaultSessionCatalog, delegate.v1SessionCatalog) { + + override val tempVariableManager: TempVariableManager = delegate.tempVariableManager + + override def transaction: Option[Transaction] = Some(txn) + + override def withTransaction(newTxn: Transaction): CatalogManager = + throw SparkException.internalError("Cannot nest transactions: a transaction is already active.") + + override def catalog(name: String): CatalogPlugin = { + val resolved = delegate.catalog(name) + if (txn.catalog.name() == resolved.name()) txn.catalog else resolved + } + + override def currentCatalog: CatalogPlugin = { + val c = delegate.currentCatalog + if (txn.catalog.name() == c.name()) txn.catalog else c + } + + override def currentNamespace: Array[String] = delegate.currentNamespace + + override def setCurrentNamespace(namespace: Array[String]): Unit = + delegate.setCurrentNamespace(namespace) + + override def setCurrentCatalog(catalogName: String): Unit = + delegate.setCurrentCatalog(catalogName) + + override def listCatalogs(pattern: Option[String]): Seq[String] = + delegate.listCatalogs(pattern) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala index c7f7b17a5843..af7edce47427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, MetadataColumnHelper} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.sql.util.SchemaValidationMode @@ -131,6 +132,20 @@ private[sql] object V2TableUtil extends SQLConfHelper { case _ => Seq.empty } + /** + * Validates that the identity of a loaded table matches a previously captured table id. + * Throws if the table was dropped and recreated under the same name (which changes the id). + * No-op if the connector does not support table ids (capturedId is null). + */ + def validateTableId(name: String, capturedId: String, currentTable: Table): Unit = { + if (capturedId != null && capturedId != currentTable.id) { + throw QueryCompilationErrors.tableIdChangedAfterAnalysis( + name, + capturedTableId = capturedId, + currentTableId = currentTable.id) + } + } + private def normalize(name: String): String = { if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala new file mode 100644 index 000000000000..4cb53da0a59e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/transactions/TransactionInfoImpl.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.transactions + +case class TransactionInfoImpl(id: String) extends TransactionInfo diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala new file mode 100644 index 000000000000..02cfe6b4eb7e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.resolver.{ResolverExtension, TreeNodeResolver} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogManager + +/** + * Verifies that [[Analyzer.withCatalogManager]] propagates all extension points. + * + * If this suite fails with an unexpected method count, a new extension point was added to + * [[Analyzer.withCatalogManager]] without being verified here. Add the corresponding assertion + * and update the expected count. + * + * If [[Analyzer]] gains a new extension point that is NOT yet in [[Analyzer.withCatalogManager]], + * add it there first, then update this suite. + */ +class AnalyzerExtensionPropagationSuite extends SparkFunSuite { + + private val dummyRule: Rule[LogicalPlan] = new Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan + } + + private val dummyCheck: LogicalPlan => Unit = (_: LogicalPlan) => () + + private val dummyExtension: ResolverExtension = new ResolverExtension { + override def resolveOperator( + operator: LogicalPlan, + resolver: TreeNodeResolver[LogicalPlan, LogicalPlan]): Option[LogicalPlan] = None + } + + private def newCatalogManager(): CatalogManager = + new CatalogManager( + FakeV2SessionCatalog, + new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry)) + + test("withCatalogManager propagates all extension points") { + val analyzer = new Analyzer(newCatalogManager()) { + override val hintResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule) + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule) + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule) + override val extendedCheckRules: Seq[LogicalPlan => Unit] = Seq(dummyCheck) + override val singlePassResolverExtensions: Seq[ResolverExtension] = Seq(dummyExtension) + override val singlePassMetadataResolverExtensions: Seq[ResolverExtension] = + Seq(dummyExtension) + override val singlePassPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule) + override val singlePassExtendedResolutionChecks: Seq[LogicalPlan => Unit] = Seq(dummyCheck) + } + + val clone = analyzer.withCatalogManager(newCatalogManager()) + + assert(clone.hintResolutionRules eq analyzer.hintResolutionRules) + assert(clone.extendedResolutionRules eq analyzer.extendedResolutionRules) + assert(clone.postHocResolutionRules eq analyzer.postHocResolutionRules) + assert(clone.extendedCheckRules eq analyzer.extendedCheckRules) + assert(clone.singlePassResolverExtensions eq analyzer.singlePassResolverExtensions) + assert(clone.singlePassMetadataResolverExtensions eq + analyzer.singlePassMetadataResolverExtensions) + assert(clone.singlePassPostHocResolutionRules eq analyzer.singlePassPostHocResolutionRules) + assert(clone.singlePassExtendedResolutionChecks eq analyzer.singlePassExtendedResolutionChecks) + + // Verify the clone's anonymous class overrides exactly the expected extension points. + // If this assertion fails, withCatalogManager was updated but this test was not. + // Add the corresponding assert above and update the expected set. + val overriddenMethods = clone.getClass.getDeclaredMethods + .filterNot(m => m.isSynthetic || m.isBridge || m.getName.contains("$")) + .map(_.getName) + .toSet + + val expectedExtensions = Set( + "hintResolutionRules", + "extendedResolutionRules", + "postHocResolutionRules", + "extendedCheckRules", + "singlePassResolverExtensions", + "singlePassMetadataResolverExtensions", + "singlePassPostHocResolutionRules", + "singlePassExtendedResolutionChecks" + ) + + assert(overriddenMethods == expectedExtensions, + s"withCatalogManager does not copy the expected set of extension points. " + + s"Missing from withCatalogManager: ${expectedExtensions -- overriddenMethods}. " + + s"Unexpected overrides: ${overriddenMethods -- expectedExtensions}.") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala new file mode 100644 index 000000000000..ee771fe3f246 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtilsSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.transactions + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, TransactionalCatalogPlugin} +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class TransactionUtilsSuite extends SparkFunSuite { + val testCatalogName = "test_catalog" + + // --- Helpers --------------------------------------------------------------- + private def mockCatalog(catalogName: String): CatalogPlugin = new CatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = catalogName + } + + private val emptyFunction = () => () + private class TestTransaction( + catalogName: String, + onCommit: () => Unit = emptyFunction, + onAbort: () => Unit = emptyFunction, + onClose: () => Unit = emptyFunction) extends Transaction { + var committed = false + var aborted = false + var closed = false + + override def catalog(): CatalogPlugin = mockCatalog(catalogName) + override def commit(): Unit = { committed = true; onCommit() } + override def abort(): Unit = { aborted = true; onAbort() } + override def close(): Unit = { closed = true; onClose() } + } + + private def mockTransactionalCatalog( + catalogName: String, + txnCatalogName: String = null): TransactionalCatalogPlugin = { + val resolvedTxnCatalogName = Option(txnCatalogName).getOrElse(catalogName) + new TransactionalCatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = catalogName + override def beginTransaction(info: TransactionInfo): Transaction = + new TestTransaction(resolvedTxnCatalogName) + } + } + + // --- Commit ---------------------------------------------------------------- + test("commit: calls commit then close") { + val txn = new TestTransaction(testCatalogName) + TransactionUtils.commit(txn) + assert(txn.committed) + assert(txn.closed) + } + + test("commit: close is called even if commit fails") { + val txn = new TestTransaction( + testCatalogName, onCommit = () => throw new RuntimeException("commit failed")) + intercept[RuntimeException] { TransactionUtils.commit(txn) } + assert(txn.closed) + } + + // --- Abort ----------------------------------------------------------------- + test("abort: calls abort then close") { + val txn = new TestTransaction(testCatalogName) + TransactionUtils.abort(txn) + assert(txn.aborted) + assert(txn.closed) + } + + test("abort: close is called even if abort fails") { + val txn = new TestTransaction(testCatalogName, + onAbort = () => throw new RuntimeException("abort failed")) + intercept[RuntimeException] { TransactionUtils.abort(txn) } + assert(txn.closed) + } + + // --- Begin Transaction ----------------------------------------------------- + test("beginTransaction: returns transaction when catalog names match") { + val catalog = mockTransactionalCatalog(testCatalogName) + val txn = TransactionUtils.beginTransaction(catalog) + assert(txn.catalog().name() == testCatalogName) + } + + test("beginTransaction: fails when transaction catalog name does not match") { + val catalog = mockTransactionalCatalog(catalogName = testCatalogName, txnCatalogName = "other") + val e = intercept[SparkException] { + TransactionUtils.beginTransaction(catalog) + } + assert(e.getMessage.contains("other")) + assert(e.getMessage.contains(testCatalogName)) + } + + test("beginTransaction: aborts and closes transaction on catalog name mismatch") { + var aborted = false + var closed = false + val catalog = new TransactionalCatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = testCatalogName + override def beginTransaction(info: TransactionInfo): Transaction = + new TestTransaction( + "other", + onAbort = () => { aborted = true }, + onClose = () => { closed = true }) + } + intercept[SparkException] { TransactionUtils.beginTransaction(catalog) } + assert(aborted) + assert(closed) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index fd2c0f6e9c2e..5a1e3a30b150 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, Cus import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram, HistogramBin} import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.internal.SQLConf @@ -76,6 +77,10 @@ abstract class InMemoryBaseTable( override def columns(): Array[Column] = tableColumns + private[catalog] def updateColumns(newColumns: Array[Column]): Unit = { + tableColumns = newColumns + } + override def version(): String = tableVersion.toString def setVersion(version: String): Unit = { @@ -94,6 +99,8 @@ abstract class InMemoryBaseTable( validatedTableVersion = version } + protected def recordScanEvent(filters: Array[Filter]): Unit = {} + protected object PartitionKeyColumn extends MetadataColumn { override def name: String = "_partition" override def dataType: DataType = StringType @@ -406,6 +413,7 @@ abstract class InMemoryBaseTable( def baseCapabiilities: Set[TableCapability] = Set( TableCapability.BATCH_READ, + TableCapability.MICRO_BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.STREAMING_WRITE, TableCapability.OVERWRITE_BY_FILTER, @@ -455,6 +463,7 @@ abstract class InMemoryBaseTable( if (evaluableFilters.nonEmpty) { scan.filter(evaluableFilters) } + recordScanEvent(_pushedFilters) scan } @@ -494,6 +503,33 @@ abstract class InMemoryBaseTable( case class InMemoryHistogram(height: Double, bins: Array[HistogramBin]) extends Histogram + private class InMemoryTableOffset(val rowCount: Long) extends Offset { + override def json(): String = rowCount.toString + } + + class InMemoryMicroBatchStream extends MicroBatchStream { + override def initialOffset(): Offset = new InMemoryTableOffset(0) + override def latestOffset(): Offset = + new InMemoryTableOffset(InMemoryBaseTable.this.rows.size.toLong) + override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = { + val s = start.asInstanceOf[InMemoryTableOffset].rowCount.toInt + val e = end.asInstanceOf[InMemoryTableOffset].rowCount.toInt + Array(InMemoryMicroBatchPartition(InMemoryBaseTable.this.rows.slice(s, e))) + } + override def createReaderFactory(): PartitionReaderFactory = { partition => + val rows = partition.asInstanceOf[InMemoryMicroBatchPartition].rows + new PartitionReader[InternalRow] { + private var idx = -1 + override def next(): Boolean = { idx += 1; idx < rows.size } + override def get(): InternalRow = rows(idx) + override def close(): Unit = {} + } + } + override def deserializeOffset(json: String): Offset = new InMemoryTableOffset(json.toLong) + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} + } + abstract class BatchScanBaseClass( var data: Seq[InputPartition], readSchema: StructType, @@ -579,6 +615,9 @@ abstract class InMemoryBaseTable( override def supportedCustomMetrics(): Array[CustomMetric] = { Array(new RowsReadCustomMetric) } + + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = + new InMemoryMicroBatchStream } case class InMemoryBatchScan( @@ -799,6 +838,11 @@ object InMemoryBaseTable { } } +/** + * A partition for [[InMemoryBaseTable]] micro-batch streaming reads, holding a slice of rows. + */ +case class InMemoryMicroBatchPartition(rows: Seq[InternalRow]) extends InputPartition + /** * Represent a set of rows buffered in memory for a given partition key. * @param key partition key diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index 91e899bc1169..406d83aa86ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -38,13 +38,15 @@ class InMemoryRowLevelOperationTable( schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String], - constraints: Array[Constraint] = Array.empty) + constraints: Array[Constraint] = Array.empty, + tableId: String = java.util.UUID.randomUUID().toString) extends InMemoryTable( name, CatalogV2Util.structTypeToV2Columns(schema), partitioning, properties, - constraints) + constraints, + id = tableId) with SupportsRowLevelOperations { private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index bbb9041bab37..95d5975d269f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -17,11 +17,31 @@ package org.apache.spark.sql.connector.catalog +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.catalog.transactions.{Transaction, TransactionInfo} +import org.apache.spark.sql.types.StructType -class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog { +class InMemoryRowLevelOperationTableCatalog + extends InMemoryTableCatalog + with TransactionalCatalogPlugin { import CatalogV2Implicits._ + // The current active transaction. + var transaction: Txn = _ + // The last completed transaction. + var lastTransaction: Txn = _ + // All transactions in order (committed and aborted), allowing per-statement + // validation in SQL scripting tests. + val observedTransactions: ArrayBuffer[Txn] = new ArrayBuffer[Txn]() + + override def beginTransaction(info: TransactionInfo): Transaction = { + assert(transaction == null || transaction.currentState != Active) + this.transaction = new Txn(new TxnTableCatalog(this)) + transaction + } + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { if (tables.containsKey(ident)) { throw new TableAlreadyExistsException(ident.asMultipartIdentifier) @@ -41,11 +61,7 @@ class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog { override def alterTable(ident: Identifier, changes: TableChange*): Table = { val table = loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) - val schema = CatalogV2Util.applySchemaChanges( - table.schema, - changes, - tableProvider = Some("in-memory"), - statementType = "ALTER TABLE") + val schema = computeAlterTableSchema(table.schema, changes.toSeq) val partitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes) val constraints = CatalogV2Util.collectConstraintChanges(table, changes) @@ -59,13 +75,24 @@ class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog { schema = schema, partitioning = partitioning, properties = properties, - constraints = constraints) + constraints = constraints, + tableId = table.id) newTable.alterTableWithData(table.data, schema) tables.put(ident, newTable) newTable } + + /** + * Computes the schema that would result from applying `changes` to `currentSchema`. + * Can be overridden by subclasses to simulate catalogs that selectively ignore changes + * (e.g. [[PartialSchemaEvolutionCatalog]]). + */ + def computeAlterTableSchema(currentSchema: StructType, changes: Seq[TableChange]): StructType = { + CatalogV2Util.applySchemaChanges( + currentSchema, changes, tableProvider = Some("in-memory"), statementType = "ALTER TABLE") + } } /** @@ -84,9 +111,10 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo case _ => false } val properties = CatalogV2Util.applyPropertiesChanges(table.properties, propertyChanges) + val schema = computeAlterTableSchema(table.schema, changes.toSeq) val newTable = new InMemoryRowLevelOperationTable( name = table.name, - schema = table.schema, + schema = schema, partitioning = table.partitioning, properties = properties, constraints = table.constraints) @@ -94,4 +122,9 @@ class PartialSchemaEvolutionCatalog extends InMemoryRowLevelOperationTableCatalo tables.put(ident, newTable) newTable } + + // Ignores all schema changes and returns the current schema unchanged. + override def computeAlterTableSchema( + currentSchema: StructType, + changes: Seq[TableChange]): StructType = currentSchema } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index d5738475031d..15ed4136dbda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ /** @@ -215,6 +216,15 @@ class InMemoryTable( object InMemoryTable { + // Convert UTF8String to string to make sure equality checks between filters and partitions + // work correctly. + private def valuesEqual(filterValue: Any, partitionValue: Any): Boolean = + (filterValue, partitionValue) match { + case (s: String, u: UTF8String) => u.toString == s + case (u: UTF8String, s: String) => u.toString == s + case _ => filterValue == partitionValue + } + def filtersToKeys( keys: Iterable[Seq[Any]], partitionNames: Seq[String], @@ -222,7 +232,7 @@ object InMemoryTable { keys.filter { partValues => filters.flatMap(splitAnd).forall { case EqualTo(attr, value) => - value == InMemoryBaseTable.extractValue(attr, partitionNames, partValues) + valuesEqual(value, InMemoryBaseTable.extractValue(attr, partitionNames, partValues)) case EqualNullSafe(attr, value) => val attrVal = InMemoryBaseTable.extractValue(attr, partitionNames, partValues) if (attrVal == null && value == null) { @@ -230,7 +240,7 @@ object InMemoryTable { } else if (attrVal == null || value == null) { false } else { - value == attrVal + valuesEqual(value, attrVal) } case IsNull(attr) => null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index ff7995ad6697..5ce1804c11a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -40,7 +40,7 @@ class BasicInMemoryTableCatalog extends TableCatalog { protected val namespaces: util.Map[List[String], Map[String, String]] = new ConcurrentHashMap[List[String], Map[String, String]]() - protected val tables: util.Map[Identifier, Table] = + protected var tables: util.Map[Identifier, Table] = new ConcurrentHashMap[Identifier, Table]() private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala new file mode 100644 index 000000000000..bd7d6d689f51 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.connector.catalog.transactions.Transaction +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, RowLevelOperationBuilder, RowLevelOperationInfo, WriteBuilder} +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +sealed trait TransactionState +case object Active extends TransactionState +case object Committed extends TransactionState +case object Aborted extends TransactionState + +class Txn(override val catalog: TxnTableCatalog) extends Transaction { + + private[this] var state: TransactionState = Active + private[this] var closed: Boolean = false + + def currentState: TransactionState = state + + def isClosed: Boolean = closed + + override def commit(): Unit = { + if (closed) throw new IllegalStateException("Can't commit, already closed") + if (state == Aborted) throw new IllegalStateException("Can't commit, already aborted") + catalog.commit() + this.state = Committed + } + + // This is idempotent since nested QEs can cause multiple aborts. + override def abort(): Unit = { + if (state == Committed || state == Aborted) return + this.state = Aborted + } + + // This is idempotent since nested QEs can cause multiple aborts. + override def close(): Unit = { + if (!closed) { + catalog.clearActiveTransaction() + this.closed = true + } + } +} + +// A special table used in row-level operation transactions. It inherits data +// from the base table upon construction and propagates staged transaction state +// back after an explicit commit. +// Note, the in-memory data store does not handle concurrency at the moment. The assumes that the +// underlying delegate table cannot change from concurrent transactions. Data sources need to +// implement isolation semantics and make sure they are enforced. +class TxnTable( + val delegate: InMemoryRowLevelOperationTable, + schema: StructType, + catalog: TxnTableCatalog) + extends InMemoryRowLevelOperationTable( + delegate.name, + schema, + delegate.partitioning, + delegate.properties, + delegate.constraints) { + + // Expose the same id as the delegate so that identity checks during transaction re-resolution + // don't false-positive on the TxnTable wrapper having a different UUID. + override val id: String = delegate.id + + alterTableWithData(delegate.data, schema) + + // A tracker of filters used in each scan. + val scanEvents = new ArrayBuffer[Array[Filter]]() + + // Record scan events. This is invoked when building a scan for the particular table. + override protected def recordScanEvent(filters: Array[Filter]): Unit = { + scanEvents += filters + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + catalog.writeTarget = this + super.newWriteBuilder(info) + } + + override def newRowLevelOperationBuilder( + info: RowLevelOperationInfo): RowLevelOperationBuilder = { + catalog.writeTarget = this + super.newRowLevelOperationBuilder(info) + } + + override def deleteWhere(filters: Array[Filter]): Unit = { + catalog.writeTarget = this + super.deleteWhere(filters) + } + + // Propagates staged data and metadata changes to the delegate table. + def commit(): Unit = { + delegate.dataMap.clear() + delegate.updateColumns(columns()) // Evolve schema if needed. + delegate.alterTableWithData(data, schema) + delegate.replacedPartitions = replacedPartitions + delegate.lastWriteInfo = lastWriteInfo + delegate.lastWriteLog = lastWriteLog + delegate.commits ++= commits + delegate.increaseVersion() + } +} + +// A special table catalog used in row-level operation transactions. The lifecycle of this catalog +// is tied to the transaction. A new catalog instance is created at the beginning of a transaction +// and discarded at the end. The catalog is responsible for pinning all tables involved in the +// transaction. Table changes are initially staged in memory and propagated only after an explicit +// commit. +class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends TableCatalog { + + private val tables: util.Map[Identifier, TxnTable] = new ConcurrentHashMap[Identifier, TxnTable]() + + var writeTarget: TxnTable = _ + + override def name: String = delegate.name + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} + + override def listTables(namespace: Array[String]): Array[Identifier] = { + throw new UnsupportedOperationException() + } + + // This is where the table pinning logic should occur. In this implementation, a tables is loaded + // (pinned) the first time is accessed. All subsequent accesses should return the same pinned + // table. + override def loadTable(ident: Identifier): Table = { + tables.computeIfAbsent(ident, _ => { + val table = delegate.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] + new TxnTable(table, table.schema(), this) + }) + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + // AlterTable may be called by ResolveSchemaEvolution when schema evolution is enabled. Thus, + // it needs to be transactional. The schema changes are only propagated to the delegate at + // commit time. + // + // We delegate schema computation to the underlying catalog so that catalogs with special + // handling (e.g. PartialSchemaEvolutionCatalog) have the same behaviour inside a + // transaction. + val txnTable = tables.get(ident) + val schema = delegate.computeAlterTableSchema(txnTable.schema, changes.toSeq) + + if (schema.fields.isEmpty) { + throw new IllegalArgumentException(s"Cannot drop all fields") + } + + // TODO: We need to pass all tracked predicates to the new TXN table. + val newTxnTable = new TxnTable(txnTable.delegate, schema, this) + tables.put(ident, newTxnTable) + newTxnTable + } + + // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented. + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { + delegate.createTable(ident, tableInfo) + loadTable(ident) + } + + // TODO: Currently not transactional. Should be revised when Atomic CTAS/RTAS is implemented. + override def dropTable(ident: Identifier): Boolean = { + tables.remove(ident) + delegate.dropTable(ident) + } + + override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { + throw new UnsupportedOperationException() + } + + // Returns all tables that participated in this transaction, keyed by identifier. + def txnTables: scala.collection.Map[Identifier, TxnTable] = tables.asScala + + // Commit the write target table, propagating staged changes to the delegate. + def commit(): Unit = { + if (writeTarget != null) writeTarget.commit() + } + + // Clear transaction context. + def clearActiveTransaction(): Unit = { + val txn = delegate.transaction + delegate.lastTransaction = txn + delegate.observedTransactions += txn + delegate.transaction = null + } + + override def equals(obj: Any): Boolean = { + obj match { + case that: CatalogPlugin => this.name == that.name + case _ => false + } + } + + override def hashCode(): Int = name.hashCode() +} + +/** + * An InMemoryRowLevelOperationTableCatalog that utilizes tables backed by a shared map. This + * simulates the behavior of real catalogs (Delta, Iceberg, etc.) where multiple instances + * of the catalog share the same underlying persistent storage, thus, they see the same tables. + * + * This is needed for testing execution that spans multiple Spark sessions. In particular, + * streaming queries execute micro-batches in cloned Spark sessions. Without this, the cloned + * spark session catalog will not see any tables created in the original session. + * + * Tests that use this catalog must call + * [[SharedTablesInMemoryRowLevelOperationTableCatalog.reset()]] in `afterEach` to clear the + * shared state between test cases. + */ +class SharedTablesInMemoryRowLevelOperationTableCatalog + extends InMemoryRowLevelOperationTableCatalog { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { + super.initialize(name, options) + tables = SharedTablesInMemoryRowLevelOperationTableCatalog.sharedTables + } +} + +object SharedTablesInMemoryRowLevelOperationTableCatalog { + private[catalog] val sharedTables: ConcurrentHashMap[Identifier, Table] = + new ConcurrentHashMap[Identifier, Table]() + + def reset(): Unit = sharedTables.clear() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index c0ab906de484..5bfb73e6b887 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -22,28 +22,32 @@ import java.util.UUID import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import javax.annotation.concurrent.GuardedBy +import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} -import org.apache.spark.sql.catalyst.analysis.{LazyExpression, NameParameterizedQuery, UnsupportedOperationChecker} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubqueryAliases, LazyExpression, NameParameterizedQuery, UnresolvedRelation, UnsupportedOperationChecker} import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, TransactionalWrite => TransactionalWritePlan, Union, UnresolvedWith, WithCTE} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, LookupCatalog, SupportsCatalogOptions, TableCatalog, TransactionalCatalogPlugin} +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.execution.SQLExecution.EXECUTION_ROOT_ID_KEY import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan} import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} -import org.apache.spark.sql.execution.datasources.v2.V2TableRefreshUtil +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, TransactionalExec, V2TableRefreshUtil} import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery @@ -52,6 +56,7 @@ import org.apache.spark.sql.execution.streaming.runtime.{IncrementalExecution, W import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.scripting.SqlScriptingExecution import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LazyTry, Utils, UUIDv7Generator} import org.apache.spark.util.ArrayImplicits._ @@ -69,7 +74,12 @@ class QueryExecution( val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL, val shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None, val refreshPhaseEnabled: Boolean = true, - val queryId: UUID = UUIDv7Generator.generate()) extends Logging { + val queryId: UUID = UUIDv7Generator.generate(), + // When a transaction is active, callers creating nested QueryExecution instances MUST pass + // the enclosing QueryExecution's analyzer here to propagate the transaction context. + // Omitting it causes the nested QE to use sessionState.analyzer, which has no knowledge + // of the transaction and will load tables outside the transaction's catalog scope. + val analyzerOpt: Option[Analyzer] = None) extends LookupCatalog { val id: Long = QueryExecution.nextExecutionId @@ -79,6 +89,8 @@ class QueryExecution( // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner + protected val catalogManager = sparkSession.sessionState.catalogManager + /** * Check whether the query represented by this QueryExecution is a SQL script. * @return True if the query is a SQL script, False otherwise. @@ -90,6 +102,82 @@ class QueryExecution( logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyExpression]))) } + + // 1. At the pre-Analyzed plan we look for nodes that implement the TransactionalWrite trait. + // When a plan contains such a node we initiate a transaction. Note, we should never start + // a transaction for operations that are not executed, e.g. EXPLAIN. + // 2. Create an analyzer clone with a transaction aware Catalog Manager. The latter is the + // narrow waist of all catalog accesses, and it is also the transaction context carrier. + // This is then passed to all rules during analysis that need to check the catalog. Rules + // that are specifically interested in transactionality can access the transaction directly + // from the Catalog Manager. The transaction catalog, is potentially the place where connectors + // should keep state about the reads (tables+predicates) that occurred during the transaction. + // 3. The analyzer instance is passed to nested Query Execution instances. These need to respect + // the open transaction instead of creating their own. + private lazy val transactionOpt: Option[Transaction] = + // Always inherit an active transaction from the outer analyzer, regardless of mode. + analyzerOpt.flatMap(_.catalogManager.transaction).orElse { + // Only begin a new transaction for outer QEs that lead to execution. + if (mode != CommandExecutionMode.SKIP) { + def resolve(w: TransactionalWritePlan): Option[TransactionalCatalogPlugin] = + pathBased(w) match { + case Some(c: TransactionalCatalogPlugin) => Some(c) + case Some(_) => None + // If the path is not data source based, fallback to catalog based resolution. + case None => TransactionalWrite.unapply(w) + } + val catalog = logical match { + case UnresolvedWith(w: TransactionalWritePlan, _, _) => resolve(w) + case w: TransactionalWritePlan => resolve(w) + case _ => None + } + catalog.map(TransactionUtils.beginTransaction) + } else { + None + } + } + + // For path-based tables (e.g. `format.`/path/to/table``) the first identifier part is a + // connector name. SupportsCatalogOptions on the connector tells us which catalog actually + // owns the table. Returns Some(catalog) if parts.head is a recognized SupportsCatalogOptions + // data source (caller decides whether the catalog is transactional), or None to fall through + // to the catalog-based extractor. + private def pathBased(write: TransactionalWritePlan): Option[TableCatalog] = + EliminateSubqueryAliases(write.table) match { + case UnresolvedRelation(parts, _, _) if parts.length > 1 => + try { + DataSource.lookupDataSourceV2(parts.head, sparkSession.sessionState.conf) + .collect { case sco: SupportsCatalogOptions => sco } + .map { sco => + val sessionConfigs = DataSourceV2Utils.extractSessionConfigs( + sco, sparkSession.sessionState.conf) + // Pass the entire identifier as option. The connector can decide how to parse it + // if needed. + val options = sessionConfigs + ("identifier" -> parts.mkString(".")) + CatalogV2Util.getTableProviderCatalog( + sco, catalogManager, new CaseInsensitiveStringMap(options.asJava)) + } + } catch { + // The head of the multipart identifier is not a registered data source. + // Fallback to catalog-based detection. + case _: ClassNotFoundException => None + } + case _ => None + } + + // Per-query analyzer: uses a transaction-aware CatalogManager when a transaction is active, + // so that all catalog lookups and rule applications during analysis see the correct state + // without relying on thread-local context. Any nested QueryExecution that is created during + // analysis or execution of a transactional plan must receive this analyzer via analyzerOpt. + private lazy val analyzer: Analyzer = analyzerOpt.getOrElse { + transactionOpt match { + case Some(txn) => + sparkSession.sessionState.analyzer.withCatalogManager(catalogManager.withTransaction(txn)) + case None => + sparkSession.sessionState.analyzer + } + } + def assertAnalyzed(): Unit = { try { analyzed @@ -102,7 +190,7 @@ class QueryExecution( } } - def assertSupported(): Unit = { + def assertSupported(): Unit = withAbortTransactionOnFailure { if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForBatch(analyzed) } @@ -141,7 +229,7 @@ class QueryExecution( try { val plan = executePhase(QueryPlanningTracker.ANALYSIS) { // We can't clone `logical` here, which will reset the `_analyzed` flag. - sparkSession.sessionState.analyzer.executeAndCheck(sqlScriptExecuted, tracker) + analyzer.executeAndCheck(sqlScriptExecuted, tracker) } tracker.setAnalyzed(plan) plan @@ -152,7 +240,9 @@ class QueryExecution( } } - def analyzed: LogicalPlan = lazyAnalyzed.get + def analyzed: LogicalPlan = withAbortTransactionOnFailure { + lazyAnalyzed.get + } private val lazyCommandExecuted = LazyTry { mode match { @@ -162,7 +252,9 @@ class QueryExecution( } } - def commandExecuted: LogicalPlan = lazyCommandExecuted.get + def commandExecuted: LogicalPlan = withAbortTransactionOnFailure { + lazyCommandExecuted.get + } private def commandExecutionName(command: Command): String = command match { case _: CreateTableAsSelect => "create" @@ -184,7 +276,7 @@ class QueryExecution( // for eagerly executed commands we mark this place as beginning of execution. tracker.setReadyForExecution() val (qe, result) = QueryExecution.runCommand( - sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode)) + sparkSession, p, name, refreshPhaseEnabled, mode, Some(shuffleCleanupMode), Some(analyzer)) CommandResult( qe.analyzed.output, qe.commandExecuted, @@ -222,19 +314,31 @@ class QueryExecution( } // The plan that has been normalized by custom rules, so that it's more likely to hit cache. - def normalized: LogicalPlan = lazyNormalized.get + def normalized: LogicalPlan = withAbortTransactionOnFailure { + lazyNormalized.get + } private val lazyWithCachedData = LazyTry { sparkSession.withActive { assertAnalyzed() assertSupported() - // clone the plan to avoid sharing the plan instance between different stages like analyzing, - // optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + + // During a transaction, skip cache substitution. This is to avoid replacing relations + // loaded by the transactional catalog with potentially stale relations cached before + // the transaction was active. + if (transactionOpt.isDefined) { + normalized + } else { + // Clone the plan to avoid sharing the plan instance between different stages like + // analyzing, optimizing and planning. + sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + } } } - def withCachedData: LogicalPlan = lazyWithCachedData.get + def withCachedData: LogicalPlan = withAbortTransactionOnFailure { + lazyWithCachedData.get + } def assertCommandExecuted(): Unit = commandExecuted @@ -256,7 +360,9 @@ class QueryExecution( } } - def optimizedPlan: LogicalPlan = lazyOptimizedPlan.get + def optimizedPlan: LogicalPlan = withAbortTransactionOnFailure { + lazyOptimizedPlan.get + } def assertOptimized(): Unit = optimizedPlan @@ -264,14 +370,17 @@ class QueryExecution( // We need to materialize the optimizedPlan here because sparkPlan is also tracked under // the planning phase assertOptimized() - executePhase(QueryPlanningTracker.PLANNING) { + val plan = executePhase(QueryPlanningTracker.PLANNING) { // Clone the logical plan here, in case the planner rules change the states of the logical // plan. QueryExecution.createSparkPlan(planner, optimizedPlan.clone()) } + attachTransaction(plan) } - def sparkPlan: SparkPlan = lazySparkPlan.get + def sparkPlan: SparkPlan = withAbortTransactionOnFailure { + lazySparkPlan.get + } def assertSparkPlanPrepared(): Unit = sparkPlan @@ -292,7 +401,9 @@ class QueryExecution( // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - def executedPlan: SparkPlan = lazyExecutedPlan.get + def executedPlan: SparkPlan = withAbortTransactionOnFailure { + lazyExecutedPlan.get + } def assertExecutedPlanPrepared(): Unit = executedPlan @@ -310,7 +421,9 @@ class QueryExecution( * Given QueryExecution is not a public class, end users are discouraged to use this: please * use `Dataset.rdd` instead where conversion will be applied. */ - def toRdd: RDD[InternalRow] = lazyToRdd.get + def toRdd: RDD[InternalRow] = withAbortTransactionOnFailure { + lazyToRdd.get + } private val observedMetricsLock = new Object @@ -390,17 +503,24 @@ class QueryExecution( } } + /** + * Returns the QueryExecution to use when generating an explain string. + * Overridden by IncrementalExecution to reuse `this` so that the already-open transaction and + * cached executedPlan are not duplicated. + */ + protected def queryExecutionForExplain: QueryExecution = if (logical.isStreaming) { + // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the + // output mode does not matter since there is no `Sink`. + new IncrementalExecution( + sparkSession, logical, OutputMode.Append(), "", + UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0), + WatermarkPropagator.noop(), false, mode = this.mode) + } else { + this + } + private def explainString(mode: ExplainMode, maxFields: Int, append: String => Unit): Unit = { - val queryExecution = if (logical.isStreaming) { - // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the - // output mode does not matter since there is no `Sink`. - new IncrementalExecution( - sparkSession, logical, OutputMode.Append(), "", - UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0), - WatermarkPropagator.noop(), false, mode = this.mode) - } else { - this - } + val queryExecution = queryExecutionForExplain mode match { case SimpleMode => @@ -535,6 +655,26 @@ class QueryExecution( } } + /** + * Runs the given block, aborting the active transaction if an exception is thrown. + * If no transaction is active, the block is executed as-is. + */ + private def withAbortTransactionOnFailure[T](block: => T): T = transactionOpt match { + case Some(transaction) => + try block + catch { case e: Throwable => TransactionUtils.abort(transaction); throw e } + case None => block + } + + + /** Attaches a transaction to the given SparkPlan to the transactional execution nodes. */ + private def attachTransaction(plan: SparkPlan): SparkPlan = transactionOpt match { + case Some(txn) => plan.transformDown { + case w: TransactionalExec => w.withTransaction(Some(txn)) + } + case None => plan + } + /** A special namespace for commands that can be used to debug query execution. */ // scalastyle:off object debug { @@ -819,14 +959,16 @@ object QueryExecution { name: String, refreshPhaseEnabled: Boolean = true, mode: CommandExecutionMode.Value = CommandExecutionMode.SKIP, - shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None) + shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None, + analyzerOpt: Option[Analyzer] = None) : (QueryExecution, Array[InternalRow]) = { val qe = new QueryExecution( sparkSession, command, mode = mode, shuffleCleanupModeOpt = shuffleCleanupModeOpt, - refreshPhaseEnabled = refreshPhaseEnabled) + refreshPhaseEnabled = refreshPhaseEnabled, + analyzerOpt = analyzerOpt) val result = QueryExecution.withInternalError(s"Executed $name failed.") { SQLExecution.withNewExecutionId(qe, Some(name)) { qe.executedPlan.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala index 8d5ee6038e80..c6b1bae89b15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DeleteFromTableExec.scala @@ -19,16 +19,23 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.connector.catalog.SupportsDeleteV2 +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.connector.expressions.filter.Predicate case class DeleteFromTableExec( table: SupportsDeleteV2, condition: Array[Predicate], - refreshCache: () => Unit) extends LeafV2CommandExec { + refreshCache: () => Unit, + transaction: Option[Transaction] = None) extends LeafV2CommandExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): DeleteFromTableExec = + copy(transaction = txn) override protected def run(): Seq[InternalRow] = { table.deleteWhere(condition) + transaction.foreach(TransactionUtils.commit) refreshCache() Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala index 151329de9e6f..60965453d9ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala @@ -119,12 +119,7 @@ private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging { } private def validateTableIdentity(currentTable: Table, relation: DataSourceV2Relation): Unit = { - if (relation.table.id != null && relation.table.id != currentTable.id) { - throw QueryCompilationErrors.tableIdChangedAfterAnalysis( - relation.name, - capturedTableId = relation.table.id, - currentTableId = currentTable.id) - } + V2TableUtil.validateTableId(relation.name, relation.table.id, currentTable) } private def validateDataColumns( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index d8e871bcf482..581027f95193 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -91,17 +91,18 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { o.copy(write = Some(write), query = newQuery) case WriteToMicroBatchDataSource( - relationOpt, table, query, queryId, options, outputMode, Some(batchId)) => - val writeOptions = mergeOptions( - options, - relationOpt.map(r => r.options.asCaseSensitiveMap.asScala.toMap).getOrElse(Map.empty)) - val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId = queryId) - val write = buildWriteForMicroBatch(table, writeBuilder, outputMode) + relation, query, queryId, options, outputMode, Some(batchId)) => + val v2Relation = relation.asInstanceOf[DataSourceV2Relation] + val writeOptions = mergeOptions(options, v2Relation.options.asCaseSensitiveMap.asScala.toMap) + // Guaranteed to support writes since it is a strict requirement to construct + // WriteToMicroBatchDataSource. + val writeTable = v2Relation.table.asInstanceOf[SupportsWrite] + val writeBuilder = newWriteBuilder(writeTable, writeOptions, query.schema, queryId = queryId) + val write = buildWriteForMicroBatch(writeTable, writeBuilder, outputMode) val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming) val customMetrics = write.supportedCustomMetrics.toImmutableArraySeq - val funCatalogOpt = relationOpt.flatMap(_.funCatalog) - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt) - WriteToDataSourceV2(relationOpt, microBatchWrite, newQuery, customMetrics) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, v2Relation.funCatalog) + WriteToDataSourceV2(Some(v2Relation), microBatchWrite, newQuery, customMetrics) case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, projections, _, None) => val rowSchema = projections.rowProjection.schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 2071024c5b7e..9e579ae779f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,9 +26,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, ProjectingInternalRow} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, TableSpec, UnaryNode} +import org.apache.spark.sql.catalyst.transactions.TransactionUtils import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, ReplaceDataProjections, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{COPY_OPERATION, DELETE_OPERATION, INSERT_OPERATION, REINSERT_OPERATION, UPDATE_OPERATION} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier, StagedTable, StagingTableCatalog, Table, TableCatalog, TableInfo, TableWritePrivilege} +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeleteSummaryImpl, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperation, RowLevelOperationTable, UpdateSummaryImpl, Write, WriterCommitMessage, WriteSummary} @@ -74,7 +76,12 @@ case class CreateTableAsSelectExec( query: LogicalPlan, tableSpec: TableSpec, writeOptions: Map[String, String], - ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec { + ifNotExists: Boolean, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): CreateTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -92,7 +99,9 @@ case class CreateTableAsSelectExec( .build() val table = Option(catalog.createTable(ident, tableInfo)) .getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable(catalog, table, writeOptions, ident, query, overwrite = false) + val result = writeToTable(catalog, table, writeOptions, ident, query, overwrite = false) + transaction.foreach(TransactionUtils.commit) + result } } @@ -112,7 +121,8 @@ case class AtomicCreateTableAsSelectExec( query: LogicalPlan, tableSpec: TableSpec, writeOptions: Map[String, String], - ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec { + ifNotExists: Boolean) + extends V2CreateTableAsSelectBaseExec { val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -155,8 +165,12 @@ case class ReplaceTableAsSelectExec( tableSpec: TableSpec, writeOptions: Map[String, String], orCreate: Boolean, - invalidateCache: (TableCatalog, Identifier) => Unit) - extends V2CreateTableAsSelectBaseExec { + invalidateCache: (TableCatalog, Identifier) => Unit, + transaction: Option[Transaction] = None) + extends V2CreateTableAsSelectBaseExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): ReplaceTableAsSelectExec = + copy(transaction = txn) val properties = CatalogV2Util.convertTableProperties(tableSpec) @@ -192,9 +206,11 @@ case class ReplaceTableAsSelectExec( .build() val table = Option(catalog.createTable(ident, tableInfo)) .getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable( + val result = writeToTable( catalog, table, writeOptions, ident, refreshedQuery, overwrite = true, refreshPhaseEnabled = false) + transaction.foreach(TransactionUtils.commit) + result } } @@ -273,7 +289,9 @@ case class AppendDataExec( query: SparkPlan, refreshCache: () => Unit, write: Write, - tableName: String) extends V2ExistingTableWriteExec { + tableName: String, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): AppendDataExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): AppendDataExec = copy(query = newChild) } @@ -292,7 +310,10 @@ case class OverwriteByExpressionExec( query: SparkPlan, refreshCache: () => Unit, write: Write, - tableName: String) extends V2ExistingTableWriteExec { + tableName: String, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): OverwriteByExpressionExec = + copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): OverwriteByExpressionExec = copy(query = newChild) } @@ -310,7 +331,10 @@ case class OverwritePartitionsDynamicExec( query: SparkPlan, refreshCache: () => Unit, write: Write, - tableName: String) extends V2ExistingTableWriteExec { + tableName: String, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): OverwritePartitionsDynamicExec = + copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): OverwritePartitionsDynamicExec = copy(query = newChild) } @@ -324,7 +348,8 @@ case class ReplaceDataExec( projections: ReplaceDataProjections, write: Write, rowLevelCommand: RowLevelOperation.Command, - tableName: String) extends RowLevelWriteExec { + tableName: String, + transaction: Option[Transaction] = None) extends RowLevelWriteExec { override def writingTask: WritingSparkTask[_] = { projections.metadataProjection match { @@ -335,6 +360,7 @@ case class ReplaceDataExec( } } + override def withTransaction(txn: Option[Transaction]): ReplaceDataExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): ReplaceDataExec = { copy(query = newChild) } @@ -369,7 +395,8 @@ case class WriteDeltaExec( projections: WriteDeltaProjections, write: DeltaWrite, rowLevelCommand: RowLevelOperation.Command, - tableName: String) extends RowLevelWriteExec { + tableName: String, + transaction: Option[Transaction] = None) extends RowLevelWriteExec { override lazy val writingTask: WritingSparkTask[_] = { if (projections.metadataProjection.isDefined) { @@ -379,6 +406,7 @@ case class WriteDeltaExec( } } + override def withTransaction(txn: Option[Transaction]): WriteDeltaExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): WriteDeltaExec = { copy(query = newChild) } @@ -388,7 +416,11 @@ case class WriteToDataSourceV2Exec( batchWrite: BatchWrite, refreshCache: () => Unit, query: SparkPlan, - writeMetrics: Seq[CustomMetric]) extends V2TableWriteExec { + writeMetrics: Seq[CustomMetric], + transaction: Option[Transaction] = None) extends V2TableWriteExec with TransactionalExec { + + override def withTransaction(txn: Option[Transaction]): WriteToDataSourceV2Exec = + copy(transaction = txn) override def stringArgs: Iterator[Any] = Iterator(batchWrite, query) @@ -398,6 +430,7 @@ case class WriteToDataSourceV2Exec( override protected def run(): Seq[InternalRow] = { val writtenRows = writeWithV2(batchWrite) + transaction.foreach(TransactionUtils.commit) refreshCache() writtenRows } @@ -406,7 +439,16 @@ case class WriteToDataSourceV2Exec( copy(query = newChild) } -trait V2ExistingTableWriteExec extends V2TableWriteExec { +/** + * Trait for physical plan nodes that write to an existing table as part of a transaction. + * The [[transaction]] is injected post-planning by [[QueryExecution]]. + */ +trait TransactionalExec extends SparkPlan { + def transaction: Option[Transaction] + def withTransaction(txn: Option[Transaction]): SparkPlan +} + +trait V2ExistingTableWriteExec extends V2TableWriteExec with TransactionalExec { def refreshCache: () => Unit def write: Write def tableName: String @@ -426,6 +468,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec { } finally { postDriverMetrics() } + transaction.foreach(TransactionUtils.commit) refreshCache() writtenRows } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala index 1587fd4786a3..9fc72241e83b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala @@ -143,6 +143,9 @@ class IncrementalExecution( } } + // Use `this` for explain so the already-open transaction and executedPlan are reused. + override protected def queryExecutionForExplain: QueryExecution = this + private val allowMultipleStatefulOperators: Boolean = sparkSession.sessionState.conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala index 973af04e0430..e6d0666aca25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException} import org.apache.spark.internal.LogKeys import org.apache.spark.internal.LogKeys._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Deduplicate, DeduplicateWithinWatermark, Distinct, FlatMapGroupsInPandasWithState, FlatMapGroupsWithState, GlobalLimit, Join, LeafNode, LocalRelation, LogicalPlan, Project, StreamSourceAwareLogicalPlan, TransformWithState, TransformWithStateInPySpark} @@ -37,7 +38,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.classic.{Dataset, SparkSession} import org.apache.spark.sql.classic.ClassicConversions.castToImpl -import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability} +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability, TransactionalCatalogPlugin} import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset => OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, SupportsRealTimeMode, SupportsTriggerAvailableNow} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} @@ -46,7 +47,6 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, Real import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, Offset, OneTimeTrigger, ProcessingTimeTrigger, RealTimeModeAllowlist, RealTimeTrigger, Sink, Source, StreamingQueryPlanTraverseHelper} import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, CommitMetadata, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataV2} import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, StateStoreWriter} -import org.apache.spark.sql.execution.streaming.runtime.AcceptsLatestSeenOffsetHandler import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchSink, WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1} import org.apache.spark.sql.execution.streaming.state.{OfflineStateRepartitionUtils, StateSchemaBroadcast, StateStoreErrors} @@ -346,15 +346,20 @@ class MicroBatchExecution( ) } - // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. sink match { case s: SupportsWrite => - val relationOpt = plan.catalogAndIdent.map { - case (catalog, ident) => DataSourceV2Relation.create(s, Some(catalog), Some(ident)) + val relation = plan.catalogAndIdent match { + // When the catalog is transactional, instead of eagerly creating the relation, we + // delegate resolution to ResolveRelations. This allows to resolve the relation against + // a transactional catalo which keeps track of all tables loaded within the transaction. + case Some((catalog: TransactionalCatalogPlugin, ident)) => + UnresolvedRelation(catalog.name +: ident.namespace().toSeq :+ ident.name()) + case Some((catalog, ident)) => + DataSourceV2Relation.create(s, Some(catalog), Some(ident)) + case None => DataSourceV2Relation.create(s, None, None) } WriteToMicroBatchDataSource( - relationOpt, - table = s, + relation, query = _logicalPlan, queryId = id.toString, extraOptions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala index 0a33093dcbce..7aa7a31bb085 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.execution.streaming.sources +import org.apache.spark.sql.catalyst.analysis.NamedRelation import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} -import org.apache.spark.sql.connector.catalog.SupportsWrite -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, StreamingV2WriteCommand, UnaryNode} import org.apache.spark.sql.streaming.OutputMode /** @@ -29,19 +28,36 @@ import org.apache.spark.sql.streaming.OutputMode * Note that this logical plan does not have a corresponding physical plan, as it will be converted * to [[org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 WriteToDataSourceV2]] * with [[MicroBatchWrite]] before execution. + * + * [[relation]] starts as [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation]] when the + * sink has a catalog+identifier (transactional catalogs), or as a resolved + * [[org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation]] for non-transactional + * catalog-backed sinks and format-based sinks. + * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveRelations]] + * resolves it to [[org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation]] during + * each micro-batch analysis, going through the transaction-aware catalog when a transaction is + * active. */ case class WriteToMicroBatchDataSource( - relation: Option[DataSourceV2Relation], - table: SupportsWrite, + relation: NamedRelation, query: LogicalPlan, queryId: String, writeOptions: Map[String, String], outputMode: OutputMode, batchId: Option[Long] = None) - extends UnaryNode { + extends UnaryNode with StreamingV2WriteCommand { + override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil + override def simpleString(maxFields: Int): String = + s"WriteToMicroBatchDataSource ${relation.name}" + + override def table: NamedRelation = relation + + override def withNewTable(newTable: NamedRelation): WriteToMicroBatchDataSource = + copy(relation = newTable) + def withNewBatchId(batchId: Long): WriteToMicroBatchDataSource = { copy(batchId = Some(batchId)) } diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index c1fc7234d7c1..0354e545aa90 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -18,6 +18,8 @@ org.apache.spark.sql.sources.FakeSourceOne org.apache.spark.sql.sources.FakeSourceTwo org.apache.spark.sql.sources.FakeSourceThree +org.apache.spark.sql.connector.FakePathBasedSource +org.apache.spark.sql.connector.FakePathBasedSourceWithSessionConfig org.apache.spark.sql.sources.FakeSourceFour org.apache.fakesource.FakeExternalSourceOne org.apache.fakesource.FakeExternalSourceTwo diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala new file mode 100644 index 000000000000..aef9c65550fc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AppendDataTransactionSuite.scala @@ -0,0 +1,500 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.sources + +class AppendDataTransactionSuite extends RowLevelOperationSuiteBase { + + test("writeTo append with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // create a source on top of itself that will be fully resolved and analyzed + val sourceDF = spark.table(tableNameAsString) + .where("pk == 1") + .select(col("pk") + 10 as "pk", col("salary"), col("dep")) + sourceDF.queryExecution.assertAnalyzed() + + // append data using the DataFrame API + val (txn, txnTables) = executeTransaction { + sourceDF.writeTo(tableNameAsString).append() + } + + // check txn was properly committed and closed + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 1) + assert(table.version() === "2") + + // check the source scan was tracked via the transaction catalog + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("pk", 1) => true + case _ => false + }) + + // check data was appended correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 100, "hr"))) // appended + } + + test("SQL INSERT INTO with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // SQL INSERT INTO using VALUES + val (txn, txnTables) = executeTransaction { + sql(s"INSERT INTO $tableNameAsString VALUES (3, 300, 'hr'), (4, 400, 'finance')") + } + + // check txn was properly committed and closed + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(table.version() === "2") + + // VALUES literal - No catalog tables were scanned + assert(txnTables.isEmpty) + + // check data was inserted correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(4, 400, "finance"))) + } + + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE with transactional checks - isDynamic: $isDynamic") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val insertOverwrite = if (isDynamic) { + // OverwritePartitionsDynamic + s"""INSERT OVERWRITE $tableNameAsString + |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } else { + // OverwriteByExpression + s"""INSERT OVERWRITE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT pk + 10, salary FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val (txn, txnTables) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + executeTransaction { sql(insertOverwrite) } + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(table.version() === "2") + + // the SELECT reads from the target table once with a dep='hr' filter + assert(txnTables.size == 1) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(11, 100, "hr"), // overwritten + Row(13, 300, "hr"))) // overwritten + } + + test("writeTo overwrite with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // overwrite using a condition that covers the hr partition -> OverwriteByExpression + val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). + toDF("pk", "salary", "dep") + + val (txn, txnTables) = executeTransaction { + sourceDF.writeTo(tableNameAsString).overwrite(col("dep") === "hr") + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(table.version() === "2") + + // literal DataFrame source - no catalog tables were scanned + assert(txnTables.isEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(11, 999, "hr"), // overwrote hr partition + Row(12, 888, "hr"))) // overwrote hr partition + } + + test("writeTo overwritePartitions with transactional checks") { + // create table with initial data; table is partitioned by dep + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // overwrite partitions dynamically -> OverwritePartitionsDynamic + val sourceDF = spark.createDataFrame(Seq((11, 999, "hr"), (12, 888, "hr"))). + toDF("pk", "salary", "dep") + + val (txn, txnTables) = executeTransaction { + sourceDF.writeTo(tableNameAsString).overwritePartitions() + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(table.version() === "2") + + // literal DataFrame source - no catalog tables were scanned + assert(txnTables.isEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(11, 999, "hr"), // overwrote hr partition + Row(12, 888, "hr"))) // overwrote hr partition + } + + test("SQL INSERT INTO SELECT with transactional checks") { + // create table with initial data + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // SQL INSERT INTO using SELECT from the same table (self-insert) + val (txn, txnTables) = executeTransaction { + sql(s"""INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(table.version() === "2") + + // the SELECT reads from the target table once with a dep='hr' filter + assert(txnTables.size === 1) + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // check data was inserted correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(3, 300, "hr"), + Row(11, 100, "hr"), // inserted from pk=1 + Row(13, 300, "hr"))) // inserted from pk=3 + } + + test("SQL INSERT INTO SELECT with subquery on source table and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')") + + // INSERT using a subquery that reads from the target to filter source rows + // both tables are scanned through the transaction catalog + val (txn, txnTables) = executeTransaction { + sql( + s"""INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $sourceNameAsString + |WHERE pk IN (SELECT pk FROM $tableNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 2) + assert(table.version() === "2") + + // target was scanned via the transaction catalog (IN subquery) once with dep='hr' filter + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // source was scanned via the transaction catalog exactly once (no filter) + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row + } + + test("SQL INSERT INTO SELECT with CTE and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 500, 'hr'), (3, 600, 'software')") + + // CTE reads from target; INSERT selects from source filtered by the CTE result + // both tables are scanned through the transaction catalog + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH hr_pks AS (SELECT pk FROM $tableNameAsString WHERE dep = 'hr') + |INSERT INTO $tableNameAsString + |SELECT pk + 10, salary, dep FROM $sourceNameAsString + |WHERE pk IN (SELECT pk FROM hr_pks) + |""".stripMargin) + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 2) + assert(table.version() === "2") + + // target was scanned via the transaction catalog (CTE) once with dep='hr' filter + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // source was scanned via the transaction catalog exactly once (no filter) + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"), + Row(11, 500, "hr"))) // inserted: source pk=1 matched target hr row via CTE + } + + test("SQL INSERT with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val e = intercept[AnalysisException] { + sql(s"INSERT INTO $tableNameAsString SELECT nonexistent_col FROM $tableNameAsString") + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + } + + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE with analysis failure and transactional checks" + + s"isDynamic: $isDynamic") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val insertOverwrite = if (isDynamic) { + s"""INSERT OVERWRITE $tableNameAsString + |SELECT nonexistent_col, salary, dep FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } else { + s"""INSERT OVERWRITE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT nonexistent_col FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val e = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + intercept[AnalysisException] { sql(insertOverwrite) } + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("EXPLAIN INSERT SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"EXPLAIN INSERT INTO $tableNameAsString VALUES (3, 300, 'hr')") + + // EXPLAIN should not start a transaction + assert(catalog.transaction === null) + + // INSERT was not executed; data is unchanged + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } + + test("SQL INSERT WITH SCHEMA EVOLUTION adds new column with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + sql(s"INSERT INTO $sourceNameAsString VALUES (3, 300, 'hr', true), (4, 400, 'software', false)") + + val (txn, txnTables) = executeTransaction { + sql(s"INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString SELECT * FROM $sourceNameAsString") + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // the new column must be visible in the committed delegate's schema + assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep", "active")) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr", null), // pre-existing rows: active is null + Row(2, 200, "software", null), + Row(3, 300, "hr", true), // inserted with active + Row(4, 400, "software", false))) + } + + for (isDynamic <- Seq(false, true)) + test(s"SQL INSERT OVERWRITE WITH SCHEMA EVOLUTION adds new column with transactional checks " + + s"isDynamic: $isDynamic") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + sql(s"INSERT INTO $sourceNameAsString VALUES (11, 999, 'hr', true), (12, 888, 'hr', false)") + + val insertOverwrite = if (isDynamic) { + s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString + |SELECT * FROM $sourceNameAsString + |""".stripMargin + } else { + s"""INSERT WITH SCHEMA EVOLUTION OVERWRITE TABLE $tableNameAsString + |PARTITION (dep = 'hr') + |SELECT pk, salary, active FROM $sourceNameAsString + |""".stripMargin + } + + val confValue = if (isDynamic) PartitionOverwriteMode.DYNAMIC else PartitionOverwriteMode.STATIC + val (txn, _) = withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> confValue.toString) { + executeTransaction { sql(insertOverwrite) } + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // the new column must be visible in the committed delegate's schema + assert(table.schema.fieldNames.contains("active")) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software", null), // unchanged (different partition) + Row(11, 999, "hr", true), // overwrote hr partition + Row(12, 888, "hr", false))) + } + + test("SQL INSERT WITH SCHEMA EVOLUTION analysis failure aborts transaction") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql( + s"""CREATE TABLE $sourceNameAsString + |(pk INT NOT NULL, salary INT, dep STRING, active BOOLEAN)""".stripMargin) + + val e = intercept[AnalysisException] { + sql( + s"""INSERT WITH SCHEMA EVOLUTION INTO $tableNameAsString + |SELECT nonexistent_col FROM $sourceNameAsString + |""".stripMargin) + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + // schema must be unchanged after the aborted transaction + assert(table.schema.fieldNames.toSeq === Seq("pk", "salary", "dep")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala new file mode 100644 index 000000000000..8acdd8242ef1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/CTASRTASTransactionSuite.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTable} +import org.apache.spark.sql.sources + +class CTASRTASTransactionSuite extends RowLevelOperationSuiteBase { + + private val newTableNameAsString = "cat.ns1.new_table" + + private def newTable: InMemoryRowLevelOperationTable = + catalog.loadTable(Identifier.of(Array("ns1"), "new_table")) + .asInstanceOf[InMemoryRowLevelOperationTable] + + test("CTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val (txn, txnTables) = executeTransaction { + sql(s"""CREATE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 1) + assert(table.version() === "1") // source table: read-only, version unchanged + assert(newTable.version() === "1") // target table: newly created and written + + // the source table was scanned once through the transaction catalog with a dep='hr' filter + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq(Row(1, 100, "hr"))) + } + + test("CTAS from literal source with transactional checks") { + // no source catalog table involved — the query is a pure literal SELECT + val (txn, txnTables) = executeTransaction { + sql(s"CREATE TABLE $newTableNameAsString AS SELECT 1 AS pk, 100 AS salary, 'hr' AS dep") + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // literal SELECT - no catalog tables were scanned + assert(txnTables.isEmpty) + assert(newTable.version() === "1") // target table: newly created and written + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq(Row(1, 100, "hr"))) + } + + test("CTAS with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val e = intercept[AnalysisException] { + sql(s"CREATE TABLE $newTableNameAsString AS SELECT nonexistent_col FROM $tableNameAsString") + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("RTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // pre-create the target so REPLACE TABLE (not CREATE OR REPLACE) is valid + sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + + val (txn, txnTables) = executeTransaction { + sql(s"""REPLACE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 1) + assert(table.version() === "1") // source table: read-only, version unchanged + assert(newTable.version() === "1") // target table: replaced and written + + // the source table was scanned once through the transaction catalog with a dep='hr' filter + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + checkAnswer( + sql(s"SELECT * FROM $newTableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(3, 300, "hr"))) + } + + test("RTAS self-reference with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // source and target are the same table: reads the old snapshot via TxnTable, + // replaces the table with a filtered version + val (txn, txnTables) = executeTransaction { + sql(s"""CREATE OR REPLACE TABLE $tableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + } + + assert(txn.currentState === Committed) + assert(txn.isClosed) + assert(txnTables.size === 1) + assert(table.version() === "1") // source/target table: replaced in place, version reset to 1 + + // the source/target table was scanned once with a dep='hr' filter + val sourceTxnTable = txnTables(tableNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(3, 300, "hr"))) + } + + test("RTAS with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val e = intercept[AnalysisException] { + sql(s"""CREATE OR REPLACE TABLE $tableNameAsString + |AS SELECT nonexistent_col FROM $tableNameAsString + |""".stripMargin) + } + + assert(e.getMessage.contains("nonexistent_col")) + assert(catalog.lastTransaction.currentState === Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("simple CREATE TABLE and DROP TABLE do not create transactions") { + sql(s"CREATE TABLE $newTableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + assert(catalog.transaction === null) + assert(catalog.lastTransaction === null) + + sql(s"DROP TABLE $newTableNameAsString") + assert(catalog.transaction === null) + assert(catalog.lastTransaction === null) + } + + test("EXPLAIN CTAS with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"""EXPLAIN CREATE TABLE $newTableNameAsString + |AS SELECT * FROM $tableNameAsString WHERE dep = 'hr' + |""".stripMargin) + + // EXPLAIN should not start a transaction + assert(catalog.transaction === null) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala index fbcfdfb20c6e..803dd35513f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala @@ -109,7 +109,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, AppendDataExec(_, _, write, _), + _, AppendDataExec(_, _, write, _, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") @@ -141,7 +141,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case AppendDataExec(_, _, write, _) => + case AppendDataExec(_, _, write, _, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") } @@ -168,7 +168,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case AppendDataExec(_, _, write, _) => + case AppendDataExec(_, _, write, _, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] assert(append.info.options.get("write.split-size") === "10") } @@ -194,7 +194,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, OverwriteByExpressionExec(_, _, write, _), + _, OverwriteByExpressionExec(_, _, write, _, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") @@ -227,7 +227,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwritePartitionsDynamicExec(_, _, write, _) => + case OverwritePartitionsDynamicExec(_, _, write, _, _) => val dynOverwrite = write.toBatch.asInstanceOf[InMemoryBaseTable#DynamicOverwrite] assert(dynOverwrite.info.options.get("write.split-size") === "10") } @@ -254,7 +254,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { collected = df.queryExecution.executedPlan.collect { case CommandResultExec( - _, OverwriteByExpressionExec(_, _, write, _), + _, OverwriteByExpressionExec(_, _, write, _, _), _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") @@ -287,7 +287,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwriteByExpressionExec(_, _, write, _) => + case OverwriteByExpressionExec(_, _, write, _, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") } @@ -317,7 +317,7 @@ class DataSourceV2OptionSuite extends DatasourceV2SQLBase { assert (collected.size == 1) collected = qe.executedPlan.collect { - case OverwriteByExpressionExec(_, _, write, _) => + case OverwriteByExpressionExec(_, _, write, _, _) => val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] assert(append.info.options.get("write.split-size") === "10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index adc88f5a54a0..f8d81ee08691 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.connector -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.CheckInvariant import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} import org.apache.spark.sql.connector.catalog.InMemoryTable import org.apache.spark.sql.connector.write.DeleteSummary import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec, WriteDeltaExec} +import org.apache.spark.sql.sources abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { @@ -773,6 +775,190 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { } } + test("delete with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val exception = intercept[AnalysisException] { + sql(s"DELETE FROM $tableNameAsString WHERE invalid_column = 1") + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("delete with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // simple predicate delete: goes through SupportsDelete.deleteWhere (no Spark-side scan) + val (txn, _) = executeTransaction { + sql(s"DELETE FROM $tableNameAsString WHERE dep = 'hr'") + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(table.version() == "2") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(2, 200, "software"))) + } + + test("delete with subquery on source table and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + val (txn, txnTables) = executeTransaction { + sql( + s"""DELETE FROM $tableNameAsString + |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numSubquerySourceScans == expectedNumSourceScans) + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 2 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result) + } + + test("delete with CTE and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk FROM $sourceNameAsString WHERE dep = 'hr' + |) + |DELETE FROM $tableNameAsString + |WHERE pk IN (SELECT pk FROM cte) + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 2 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in source) + } + + test("delete using view with transactional checks") { + withView("temp_view") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + sql( + s"""CREATE VIEW temp_view AS + |SELECT pk FROM $sourceNameAsString WHERE dep = 'hr' + |""".stripMargin) + + val (txn, txnTables) = executeTransaction { + sql(s"DELETE FROM $tableNameAsString WHERE pk IN (SELECT pk FROM temp_view)") + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaDelete) 1 else 2 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaDelete) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in source) + } + } + + test("EXPLAIN DELETE SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"EXPLAIN DELETE FROM $tableNameAsString WHERE dep = 'hr'") + + // EXPLAIN should not start a new transaction + assert(catalog.transaction === null) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } + private def executeDeleteWithFilters(query: String): Unit = { val executedPlan = executeAndKeepPlan { sql(query) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala index e1c574ec7ba6..c5d82ec3c1d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.connector +import org.apache.spark.sql.{sources, Column, Row} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.Row import org.apache.spark.sql.classic.MergeIntoWriter -import org.apache.spark.sql.connector.catalog.Column +import org.apache.spark.sql.connector.catalog.{Aborted, Committed} import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableInfo import org.apache.spark.sql.functions._ @@ -31,6 +31,134 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { import testImplicits._ + private def targetTableCol(colName: String): Column = { + col(tableNameAsString + "." + colName) + } + + test("self merge with transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create a source on top of itself that will be fully resolved and analyzed + val sourceDF = spark.table(tableNameAsString) + .where("salary == 100") + .as("source") + sourceDF.queryExecution.assertAnalyzed() + + // merge into table using the source on top of itself + val (txn, txnTables) = executeTransaction { + sourceDF + .mergeInto( + tableNameAsString, + $"source.pk" === targetTableCol("pk") && targetTableCol("dep") === "hr") + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .whenNotMatched() + .insertAll() + .merge() + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + assert(table.version() == "2") + + // check all table scans + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 4) + + // check table scans as MERGE target + val numTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numTargetScans == 2) + + // check table scans as MERGE source + val numSourceScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("salary", 100) => true + case _ => false + } + assert(numSourceScans == 2) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged + } + + for (alterClause <- Seq( + "ADD COLUMN new_col INT", + "DROP COLUMN salary", + "ALTER COLUMN salary TYPE BIGINT", + "ALTER COLUMN pk DROP NOT NULL")) + test(s"self merge fails when source schema changes after analysis - DDL: $alterClause" ) { + withTable(tableNameAsString) { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = spark.table(tableNameAsString).where("salary == 100").as("source") + sourceDF.queryExecution.assertAnalyzed() + + sql(s"ALTER TABLE $tableNameAsString $alterClause") + + val e = intercept[AnalysisException] { + sourceDF + .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk")) + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .merge() + } + + assert( + e.getCondition == "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", + alterClause) + assert(catalog.lastTransaction.currentState == Aborted, alterClause) + assert(catalog.lastTransaction.isClosed, alterClause) + } + } + + test("self merge fails when source table is dropped and recreated after analysis") { + withTable(tableNameAsString) { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + val sourceDF = spark.table(tableNameAsString).where("salary == 100").as("source") + sourceDF.queryExecution.assertAnalyzed() + + val originalId = catalog.loadTable(ident).id + + sql(s"DROP TABLE $tableNameAsString") + sql(s"CREATE TABLE $tableNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + val newId = catalog.loadTable(ident).id + assert(originalId != newId) + + val e = intercept[AnalysisException] { + sourceDF + .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk")) + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .merge() + } + + assert(e.getCondition == "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.TABLE_ID_MISMATCH") + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + } + test("merge into empty table with NOT MATCHED clause") { withTempView("source") { createTable("pk INT NOT NULL, salary INT, dep STRING") @@ -979,6 +1107,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { } test("SPARK-54157: version is refreshed when source is V2 table") { + import org.apache.spark.sql.connector.catalog.Column val sourceTable = "cat.ns1.source_table" withTable(sourceTable) { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -1026,6 +1155,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { } test("SPARK-54444: any schema changes after analysis are prohibited") { + import org.apache.spark.sql.connector.catalog.Column val sourceTable = "cat.ns1.source_table" withTable(sourceTable) { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 069781e40d8c..91f20885beb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, In, Not} import org.apache.spark.sql.catalyst.optimizer.BuildLeft -import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, InMemoryTable, TableInfo} +import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, TableInfo} import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.write.MergeSummary import org.apache.spark.sql.execution.SparkPlan @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.MergeRowsExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources import org.apache.spark.sql.types.{IntegerType, StringType} abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase @@ -38,6 +39,309 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase protected def deltaMerge: Boolean = false + test("self merge with transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // merge into table using a subquery on top of itself + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING (SELECT * FROM $tableNameAsString WHERE salary = 100) s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = t.salary + 1 + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 1) + assert(table.version() == "2") + + // check all table scans + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumScans) + + // check table scans as MERGE target + val numTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumTargetScans = if (deltaMerge) 1 else 2 + assert(numTargetScans == expectedNumTargetScans) + + // check table scans as MERGE source + val numSourceScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("salary", 100) => true + case _ => false + } + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(numSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged + } + + test("merge into table with analysis failure and transactional checks") { + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')") + + val exception = intercept[AnalysisException] { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = s.invalid_column + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("merge into table using view with transactional checks") { + withView("temp_view") { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create view on top of source and target tables + sql( + s"""CREATE VIEW temp_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // merge into target table using view + val (txn, txnTables) = executeTransaction { + sql(s"""MERGE INTO $tableNameAsString t + |USING temp_view s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2 + assert(numMergeTargetScans == expectedNumMergeTargetScans) + + // check target table scans in view as MERGE source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaMerge) 1 else 2 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view as MERGE source (no predicate) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 400, "pending"))) // new + } + } + + test("merge into table using nested view with transactional checks") { + withView("base_view", "nested_view") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create base view + sql( + s"""CREATE VIEW base_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // create nested view on top of base view + sql( + s"""CREATE VIEW nested_view AS + |SELECT * FROM base_view + |""".stripMargin) + + // merge into target table using nested view + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING nested_view s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 2 else 4 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumMergeTargetScans = if (deltaMerge) 1 else 2 + assert(numMergeTargetScans == expectedNumMergeTargetScans) + + // check target table scans in view as MERGE source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaMerge) 1 else 2 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view as MERGE source (no predicate) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // update + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 400, "pending"))) // new + } + } + } + + test("merge into table rewritten as INSERT with transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT, value STRING, dep STRING", + """{ "pk": 1, "value": "a", "dep": "hr" } + |{ "pk": 2, "value": "b", "dep": "finance" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT, value STRING, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (3, 'c', 'hr'), (4, 'd', 'software')") + + // merge into target with only WHEN NOT MATCHED clauses (rewritten as insert) + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk + |WHEN NOT MATCHED AND s.pk < 4 THEN + | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_low'), s.dep) + |WHEN NOT MATCHED AND s.pk >= 4 THEN + | INSERT (pk, value, dep) VALUES (s.pk, concat(s.value, '_high'), s.dep) + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.size == 1) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size == 1) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, "a", "hr"), // unchanged + Row(2, "b", "finance"), // unchanged + Row(3, "c_low", "hr"), // inserted via first NOT MATCHED clause + Row(4, "d_high", "software"))) // inserted via second NOT MATCHED clause + } + } + test("merge into table with expression-based default values") { val columns = Array( Column.create("pk", IntegerType), @@ -770,6 +1074,131 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("merge with CTE with transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // merge into target table using CTE + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk, salary + 50 AS salary, dep + | FROM $sourceNameAsString + | WHERE salary > 100 + |) + |MERGE INTO $tableNameAsString t + |USING cte s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaMerge) 1 else 2 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as MERGE target (dep = 'hr') + val numMergeTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numMergeTargetScans == expectedNumTargetScans) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaMerge) 1 else 2 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check source table scans in CTE (salary > 100) + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.GreaterThan("salary", 100) => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 200, "hr"), // updated (150 + 50) + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged + Row(4, 450, "pending"))) // inserted (400 + 50) + } + } + + test("merge with cached source and transactional checks") { + withTable(sourceNameAsString) { + // create target table + createAndInitTable( + "pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create and populate source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'support'), (4, 400, 'finance')") + + // Cache source table before the transaction. Make sure when the transation is active the + // catalog still creates a transaction table. + spark.table(sourceNameAsString).cache() + + try { + val (txn, txnTables) = executeTransaction { + sql( + s"""MERGE INTO $tableNameAsString t + |USING $sourceNameAsString s + |ON t.pk = s.pk AND t.dep = 'hr' + |WHEN MATCHED THEN + | UPDATE SET salary = s.salary, dep = s.dep + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, s.salary, 'pending') + |""".stripMargin) + } + + assert(txn.currentState == Committed) + assert(txn.isClosed) + // both target and source must have been read through the transaction catalog + assert(txnTables.size == 2) + assert(table.version() == "2") + assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + assert(txnTables(tableNameAsString).scanEvents.nonEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "support"), // matched and updated + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"), // unchanged (no match in source) + Row(4, 400, "pending"))) // not matched, inserted + } finally { + spark.catalog.uncacheTable(sourceNameAsString) + } + } + } + test("merge with subquery as source") { withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -2322,6 +2751,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(query) } assert(e.getMessage.contains("ON search condition of the MERGE statement")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) } private def assertMetric( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala new file mode 100644 index 000000000000..c6b2f33c25fe --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTableCatalog, InMemoryTableCatalog, SessionConfigSupport, SupportsCatalogOptions} +import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Tests for transactional writes to path-based tables, where the table is identified by a + * bare path with no catalog prefix (e.g. `/path/to/t`), or a connector-prefixed path + * (e.g. `pathformat.`/path/to/t``). The transactional catalog is registered as the session + * catalog (`spark_catalog`). + */ +class PathBasedTableTransactionSuite extends RowLevelOperationSuiteBase { + + private val tablePath = "`/path/to/t`" + private val tablePathWithFormat = "pathformat.`/path/to/t`" + + override def beforeEach(): Unit = { + super.beforeEach() + spark.conf.set( + V2_SESSION_CATALOG_IMPLEMENTATION.key, + classOf[InMemoryRowLevelOperationTableCatalog].getName) + } + + override def afterEach(): Unit = { + spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) + super.afterEach() + } + + override protected def catalog: InMemoryRowLevelOperationTableCatalog = { + spark.sessionState.catalogManager.v2SessionCatalog + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + } + + private def createPathTable(name: String): Unit = { + sql(s"CREATE TABLE $name (id INT, data STRING)") + } + + test("SQL insert into bare path-based table participates in transaction") { + createPathTable(tablePath) + val (txn, _) = executeTransaction { + sql(s"INSERT INTO $tablePath VALUES (1, 'a'), (2, 'b')") + } + assert(txn.currentState === Committed) + assert(txn.isClosed) + checkAnswer(spark.table(tablePath), Row(1, "a") :: Row(2, "b") :: Nil) + } + + test("SQL insert with connector-prefixed path participates in transaction") { + createPathTable(tablePathWithFormat) + val (txn, _) = executeTransaction { + sql(s"INSERT INTO $tablePathWithFormat VALUES (1, 'a'), (2, 'b')") + } + assert(txn.currentState === Committed) + assert(txn.isClosed) + checkAnswer(spark.table(tablePathWithFormat), Row(1, "a") :: Row(2, "b") :: Nil) + } + + test("SQL insert with CTE into connector-prefixed path participates in transaction") { + createPathTable(tablePathWithFormat) + val (txn, _) = executeTransaction { + sql(s""" + |WITH cte AS (SELECT 1 AS id, 'a' AS data) + |INSERT INTO $tablePathWithFormat SELECT * FROM cte + |""".stripMargin) + } + assert(txn.currentState === Committed) + assert(txn.isClosed) + checkAnswer(spark.table(tablePathWithFormat), Row(1, "a") :: Nil) + } + + test("session-config catalog controls which catalog is enrolled in transaction") { + withSQLConf( + "spark.sql.catalog.txncat" -> classOf[InMemoryRowLevelOperationTableCatalog].getName, + "spark.sql.catalog.nontxncat" -> classOf[InMemoryTableCatalog].getName) { + val txnCat = spark.sessionState.catalogManager.catalog("txncat") + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + + // Non-transactional catalog configured. + withSQLConf("spark.datasource.pathformat2.catalog" -> "nontxncat") { + createPathTable("pathformat2.`/path/to/t1`") + sql("INSERT INTO pathformat2.`/path/to/t1` VALUES (1, 'a')") + // The transaction was not routed to any of the transactional catalogs. + assert(catalog.lastTransaction == null) + assert(txnCat.lastTransaction == null) + } + + // Transactional catalog configured: pathBased resolves txncat as a + // TransactionalCatalogPlugin and opens the transaction there instead. + withSQLConf("spark.datasource.pathformat2.catalog" -> "txncat") { + createPathTable("pathformat2.`/path/to/t2`") + sql("INSERT INTO pathformat2.`/path/to/t2` VALUES (1, 'a')") + assert(txnCat.lastTransaction.currentState === Committed) + assert(txnCat.lastTransaction.isClosed) + } + } + } + + test("SQL insert with unregistered format produces analysis error and aborts transaction") { + createPathTable(tablePathWithFormat) + // "Unregistered" is not a known catalog and not registered data source. + // So Spark falls back to treating it as a namespace in spark_catalog. The table + // does not exist, causing an AnalysisException. The transaction is started (because + // spark_catalog IS a TransactionalCatalogPlugin) and then aborted on failure. + checkError( + exception = intercept[AnalysisException] { + sql("INSERT INTO unregistered.`/path/to/t` VALUES (1, 'a'), (2, 'b')") + }, + condition = "TABLE_OR_VIEW_NOT_FOUND", + parameters = Map("relationName" -> "`unregistered`.`/path/to/t`"), + context = ExpectedContext( + fragment = "unregistered.`/path/to/t`", + start = -1, + stop = -1)) + val txn = catalog.lastTransaction + assert(txn.currentState === Aborted) + assert(txn.isClosed) + } +} + +/** + * Simulates a path-based connector (e.g. Delta) that implements [[SupportsCatalogOptions]] + * to route `pathformat.\`/path/to/t\`` SQL identifiers to the session catalog. Returning + * null from [[extractCatalog]] signals that the session catalog (`spark_catalog`) owns the + * table, matching Delta's behavior where DeltaCatalog is registered as spark_catalog. + */ +class FakePathBasedSource + extends FakeV2ProviderWithCustomSchema + with SupportsCatalogOptions + with DataSourceRegister { + + override def shortName(): String = "pathformat" + + // Use the session catalog. + override def extractCatalog(options: CaseInsensitiveStringMap): String = null + + // Not used in the transactional path. + override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = null +} + +/** + * Like [[FakePathBasedSource]] but resolves the owning catalog from the session config + * `spark.datasource.pathformat2.catalog` instead of always returning null. This simulates + * a connector that lets users configure the target catalog. + */ +class FakePathBasedSourceWithSessionConfig + extends FakeV2ProviderWithCustomSchema + with SupportsCatalogOptions + with SessionConfigSupport + with DataSourceRegister { + + override def shortName(): String = "pathformat2" + + override def keyPrefix: String = "pathformat2" + + override def extractCatalog(options: CaseInsensitiveStringMap): String = options.get("catalog") + + // Not used in the transactional path. + override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = null +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 79387821bf08..d0209c97cf93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -28,16 +28,16 @@ import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expr import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReplaceData, WriteDelta} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, TableInfo, Update, Write} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, Table, TableInfo, Txn, TxnTable, Update, Write} import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan} +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, StructField, StructType} -import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -82,6 +82,7 @@ abstract class RowLevelOperationSuiteBase protected val namespace: Array[String] = Array("ns1") protected val ident: Identifier = Identifier.of(namespace, "test_table") protected val tableNameAsString: String = "cat." + ident.toString + protected val sourceNameAsString: String = "cat.ns1.source_table" protected def extraTableProps: java.util.Map[String, String] = { Collections.emptyMap[String, String] @@ -135,22 +136,28 @@ abstract class RowLevelOperationSuiteBase // executes an operation and keeps the executed plan protected def executeAndKeepPlan(func: => Unit): SparkPlan = { - var executedPlan: SparkPlan = null + withQueryExecutionsCaptured(spark)(func) match { + case Seq(qe) => stripAQEPlan(qe.executedPlan) + case other => fail(s"expected only one query execution, but got ${other.size}") + } + } - val listener = new QueryExecutionListener { - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - executedPlan = qe.executedPlan - } - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + protected def executeTransaction(func: => Unit): (Txn, Map[String, TxnTable]) = { + val tables = withQueryExecutionsCaptured(spark)(func).flatMap { qe => + collectWithSubqueries(qe.executedPlan) { + case BatchScanExec(_, _, _, _, table: TxnTable, _) => table + case BatchScanExec(_, _, _, _, RowLevelOperationTable(table: TxnTable, _), _) => table } } - spark.listenerManager.register(listener) - - func - - sparkContext.listenerBus.waitUntilEmpty() + (catalog.lastTransaction, indexByName(tables)) + } - stripAQEPlan(executedPlan) + protected def indexByName[T <: Table](tables: Seq[T]): Map[String, T] = { + tables.groupBy(_.name).map { + case (name, sameNameTables) => + val Seq(table) = sameNameTables.distinct + name -> table + } } // executes an operation and extracts conditions from ReplaceData or WriteDelta diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala new file mode 100644 index 000000000000..13b6267a28ff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util.Collections + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{Aborted, CatalogV2Util, Committed, Identifier, InMemoryBaseTable, InMemoryRowLevelOperationTableCatalog, InMemoryTableCatalog, SharedTablesInMemoryRowLevelOperationTableCatalog, TableInfo} +import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.types.StructType + +class StreamingTransactionSuite extends RowLevelOperationSuiteBase { + + import testImplicits._ + + override def beforeEach(): Unit = { + super.beforeEach() + spark.conf.set( + "spark.sql.catalog.cat", + classOf[SharedTablesInMemoryRowLevelOperationTableCatalog].getName) + } + + override def afterEach(): Unit = { + SharedTablesInMemoryRowLevelOperationTableCatalog.reset() + super.afterEach() + } + + private def createSimpleTable(schemaString: String): Unit = { + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) + val tableInfo = new TableInfo.Builder().withColumns(columns).build() + catalog.createTable(ident, tableInfo) + } + + private def streamCatalog(query: StreamingQuery): InMemoryRowLevelOperationTableCatalog = { + val session = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sparkSessionForStream + session.sessionState.catalogManager.catalog("cat") + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + } + + test("streaming write commits a transaction") { + createSimpleTable("value INT") + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + + val query = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tableNameAsString) + + assert(table.version() === "0") + + inputData.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + val txn = streamCatalog(query).lastTransaction + assert(txn != null, "expected a transaction to have been committed") + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // Pure streaming append: the write target is not read, source is not a TxnTable. + val targetTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(tableNameAsString) + assert(txn.catalog.txnTables.size === 1) + assert(targetTxnTable.scanEvents.isEmpty) + assert(table.version() === "1") + + // Transaction must be scoped to the streaming session; main session catalog is untouched. + assert(catalog.observedTransactions.isEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1), Row(2), Row(3))) + } + } + + test("each micro-batch is an independent transaction") { + createSimpleTable("value INT") + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + + val query = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tableNameAsString) + + assert(table.version() === "0") + + inputData.addData(1, 2, 3) + query.processAllAvailable() + + inputData.addData(4, 5, 6) + query.processAllAvailable() + + query.stop() + + val sc = streamCatalog(query) + assert(sc.observedTransactions.size === 2) + assert(sc.observedTransactions.forall(_.currentState === Committed)) + // Pure streaming append: write target is not read in any micro-batch. + assert(sc.observedTransactions.forall { t => + indexByName(t.catalog.txnTables.values.toSeq)(tableNameAsString).scanEvents.isEmpty + }) + // Each committed micro-batch increments the delegate version exactly once. + assert(table.version() === "2") + + // Transaction must be scoped to the streaming session; main session catalog is untouched. + assert(catalog.observedTransactions.isEmpty) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1), Row(2), Row(3), Row(4), Row(5), Row(6))) + } + } + + test("batch read from catalog-backed table inside streaming query is tracked as a scan event") { + // Target table for the stream. + createSimpleTable("value INT") + + // Catalog-backed static table used as a batch (non-streaming) source. + val sourceIdent = Identifier.of(namespace, "source_table") + val srcColumns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT")) + catalog.createTable(sourceIdent, new TableInfo.Builder().withColumns(srcColumns).build()) + sql(s"INSERT INTO $sourceNameAsString VALUES (1), (2), (3)") + // The INSERT above runs a transaction on the main session catalog; capture the count now + // so we can assert the streaming query does not add more. + val mainTxnsBefore = catalog.observedTransactions.size + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + + // spark.read produces a DataSourceV2Relation (batch), not a streaming source. + // UnresolveRelationsInTransaction converts it to V2TableReference each micro-batch so + // the transaction-aware catalog can record the scan event. + val staticData = spark.read.table(sourceNameAsString) + + val query = inputData.toDF() + .join(staticData, "value") + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tableNameAsString) + + inputData.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + val txn = streamCatalog(query).lastTransaction + assert(txn != null, "expected a transaction to have been committed") + assert(txn.currentState === Committed) + assert(txn.isClosed) + + // Both the write target and the batch source participate in the transaction. + assert(txn.catalog.txnTables.size === 2) + + val targetTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(tableNameAsString) + assert(targetTxnTable.scanEvents.isEmpty) + + // The static source was read exactly once and its scan event was captured. + val sourceTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + + // Streaming must not add transactions to the main session catalog beyond the pre-existing + // INSERT transaction. + assert(catalog.observedTransactions.size === mainTxnsBefore) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(1), Row(2), Row(3))) + } + } + + test("transaction is aborted when micro-batch write fails and no data is written") { + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT")) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(Collections.singletonMap( + InMemoryBaseTable.SIMULATE_FAILED_WRITE_OPTION, "true")) + .build() + catalog.createTable(ident, tableInfo) + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tableNameAsString) + + inputData.addData(1, 2, 3) + intercept[Exception] { query.processAllAvailable() } + query.stop() + + val txn = streamCatalog(query).lastTransaction + assert(txn != null, "expected a transaction to have been recorded") + assert(txn.currentState === Aborted) + assert(txn.isClosed) + // Aborted transaction must not advance the delegate version. + assert(table.version() === "0") + + // Transaction must be scoped to the streaming session; main session catalog is untouched. + assert(catalog.observedTransactions.isEmpty) + + // Writes must not be visible after an aborted transaction. + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Seq.empty) + } + } + + test("streaming write to non-transactional catalog does not start a transaction") { + withSQLConf("spark.sql.catalog.nonTxnCat" -> classOf[InMemoryTableCatalog].getName) { + val nonTxnCat = spark + .sessionState + .catalogManager + .catalog("nonTxnCat") + .asInstanceOf[InMemoryTableCatalog] + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT")) + nonTxnCat.createTable( + Identifier.of(Array("ns"), "tbl"), + new TableInfo.Builder().withColumns(columns).build()) + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable("nonTxnCat.ns.tbl") + + inputData.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + assert(catalog.observedTransactions.isEmpty, + "no transaction expected for non-transactional catalog") + checkAnswer(spark.table("nonTxnCat.ns.tbl"), Seq(Row(1), Row(2), Row(3))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index d32a1e5c7f56..65c3b68fa8cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.connector import org.apache.spark.SparkRuntimeException -import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, InMemoryTable, TableChange, TableInfo} +import org.apache.spark.sql.{sources, AnalysisException, Row} +import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, TableChange, TableInfo} import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.write.UpdateSummary import org.apache.spark.sql.internal.SQLConf @@ -867,4 +867,325 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { Row(5))) } } + + test("update with analysis failure and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val exception = intercept[AnalysisException] { + sql(s"UPDATE $tableNameAsString SET invalid_column = -1") + } + + assert(exception.getMessage.contains("invalid_column")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("update with CTE and transactional checks") { + // create table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using CTE + val (txn, txnTables) = executeTransaction { + sql( + s"""WITH cte AS ( + | SELECT pk, salary + 50 AS adjusted_salary, dep + | FROM $sourceNameAsString + | WHERE salary > 100 + |) + |UPDATE $tableNameAsString t + |SET salary = -1 + |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM cte WHERE cte.pk = t.pk) + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans for UPDATE condition (dep = 'hr') + val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numUpdateTargetScans == expectedNumTargetScans) + + // check source table was scanned correctly + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check source table scans in CTE (salary > 100) + val numCteSourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.GreaterThan("salary", 100) => true + case _ => false + } + assert(numCteSourceScans == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (no matching pk in source) + } + + test("update with subquery on source table and transactional checks") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using an uncorrelated IN subquery that reads from a transactional catalog table + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET salary = -1 + |WHERE pk IN (SELECT pk FROM $sourceNameAsString WHERE dep = 'hr') + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check source table was scanned correctly (dep = 'hr' filter in the subquery) + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + val numSubquerySourceScans = sourceTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + assert(numSubquerySourceScans == expectedNumSourceScans) + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 1 else 3 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated (pk 1 is in subquery result) + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (pk 3 not in subquery result) + } + + test("update with uncorrelated scalar subquery on source table and transactional checks") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT, dep STRING)") + sql(s"INSERT INTO $sourceNameAsString VALUES (1, 150, 'hr'), (4, 400, 'finance')") + + // update using an uncorrelated scalar subquery in the SET clause that reads from a + // transactional catalog table; scalar subqueries are executed as SubqueryExec at runtime + // and cannot be rewritten as joins + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET salary = (SELECT max(salary) FROM $sourceNameAsString WHERE dep = 'hr') + |WHERE dep = 'hr' + |""".stripMargin) + } + + // check txn was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check source table was scanned via the transaction catalog + val sourceTxnTable = txnTables(sourceNameAsString) + assert(sourceTxnTable.scanEvents.nonEmpty) + assert(sourceTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + + // check target table was scanned via the transaction catalog + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.nonEmpty) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 150, "hr"), // updated (max salary in source for 'hr' is 150) + Row(2, 200, "software"), // unchanged + Row(3, 150, "hr"))) // updated + } + + test("update with constraint violation and transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val exception = intercept[SparkRuntimeException] { + executeTransaction { + sql( + s"""UPDATE $tableNameAsString + |SET pk = NULL + |WHERE dep = 'hr' + |""".stripMargin) // NULL violates NOT NULL constraint + } + } + + assert(exception.getMessage.contains("NOT_NULL_ASSERT_VIOLATION")) + assert(catalog.lastTransaction.currentState == Aborted) + assert(catalog.lastTransaction.isClosed) + } + + test("update using view with transactional checks") { + withView("temp_view") { + // create target table + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // create source table + sql(s"CREATE TABLE $sourceNameAsString (pk INT NOT NULL, salary INT)") + sql(s"INSERT INTO $sourceNameAsString (pk, salary) VALUES (1, 150), (4, 400)") + + // create view on top of source and target tables + sql( + s"""CREATE VIEW temp_view AS + |SELECT s.pk, s.salary, t.dep + |FROM $sourceNameAsString s + |LEFT JOIN ( + | SELECT * FROM $tableNameAsString WHERE pk < 10 + |) t ON s.pk = t.pk + |""".stripMargin) + + // update target table using view + val (txn, txnTables) = executeTransaction { + sql( + s"""UPDATE $tableNameAsString t + |SET salary = -1 + |WHERE t.dep = 'hr' AND EXISTS (SELECT 1 FROM temp_view v WHERE v.pk = t.pk) + |""".stripMargin) + } + + // check txn covers both tables and was properly committed and closed + assert(txn.currentState == Committed) + assert(txn.isClosed) + assert(txnTables.size == 2) + assert(table.version() == "2") + + // check target table was scanned correctly + val targetTxnTable = txnTables(tableNameAsString) + val expectedNumTargetScans = if (deltaUpdate) 2 else 7 + assert(targetTxnTable.scanEvents.size == expectedNumTargetScans) + + // check target table scans as UPDATE target (dep = 'hr') + val numUpdateTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.EqualTo("dep", "hr") => true + case _ => false + } + val expectedNumUpdateTargetScans = if (deltaUpdate) 1 else 3 + assert(numUpdateTargetScans == expectedNumUpdateTargetScans) + + // check target table scans in view as source (pk < 10) + val numViewTargetScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("pk", 10L) => true + case _ => false + } + val expectedNumViewTargetScans = if (deltaUpdate) 1 else 4 + assert(numViewTargetScans == expectedNumViewTargetScans) + + // check source table scans in view + val sourceTxnTable = txnTables(sourceNameAsString) + val expectedNumSourceScans = if (deltaUpdate) 1 else 4 + assert(sourceTxnTable.scanEvents.size == expectedNumSourceScans) + + // check txn state was propagated correctly + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated from view + Row(2, 200, "software"), // unchanged + Row(3, 300, "hr"))) // unchanged (no matching pk in source) + } + } + + test("df.explain() on update with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // sql() is lazy, but explain() forces executedPlan. + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'").explain() + + assert(catalog.lastTransaction != null) + assert(catalog.lastTransaction.currentState == Committed) + assert(catalog.lastTransaction.isClosed) + assert(table.version() == "2") + + // The UPDATE was actually executed, not just planned. + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, -1, "hr"), // updated + Row(2, 200, "software"))) // unchanged + } + + test("EXPLAIN UPDATE SQL with transactional checks") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // EXPLAIN UPDATE only plans the command, it does not execute the write. + sql(s"EXPLAIN UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + // A transaction should not have started at all. + assert(catalog.transaction === null) + + // The UPDATE was not executed. Data is unchanged. + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 100, "hr"), + Row(2, 200, "software"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala new file mode 100644 index 000000000000..141d5966b4b6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AnalyzerBenchmark.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.concurrent.duration._ + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.classic +import org.apache.spark.sql.execution.QueryExecution + +/** + * Benchmark to measure the overhead of cloning the analyzer for transactional query execution. + * Each transactional query creates a new [[Analyzer]] instance via + * [[Analyzer.withCatalogManager]], which shares all rules with the original but carries a + * per-query [[org.apache.spark.sql.connector.catalog.CatalogManager]]. This benchmark checks + * whether the cloning introduces measurable overhead. + * + * To run this benchmark: + * {{{ + * build/sbt "sql/Test/runMain " + * }}} + */ +object AnalyzerBenchmark extends SqlBasedBenchmark { + + private val numRows = 100 + private val queries = Seq( + "simple select" -> "SELECT id, val FROM t1", + "join" -> "SELECT t1.id, t2.val FROM t1 JOIN t2 ON t1.id = t2.id", + "wide schema" -> s"SELECT ${(1 to 100).map(i => s"col_$i").mkString(", ")} FROM wide_t" + ) + + private def setupTables(): Unit = { + spark.range(numRows).selectExpr("id", "id * 2 as val").createOrReplaceTempView("t1") + spark.range(numRows).selectExpr("id", "id * 3 as val").createOrReplaceTempView("t2") + spark.range(numRows) + .selectExpr((1 to numRows).map(i => s"id as col_$i"): _*) + .createOrReplaceTempView("wide_t") + } + + /** + * Measures analysis time for a pre-parsed plan, comparing the session analyzer against a + * cloned analyzer created via [[Analyzer.withCatalogManager]]. + * + * Two cases: + * - "session analyzer" : baseline, uses the session analyzer directly. + * - "cloned analyzer (per query)": analyzer cloned every iteration; reflects the full + * per-transactional-query cost (clone + analysis). + */ + def analysisBenchmark(): Unit = { + for ((name, sql) <- queries) { + runBenchmark(s"analysis overhead $name") { + val benchmark = new Benchmark( + name = s"analysis overhead $name", + // Per row measurements are not meaningful here. + valuesPerIteration = numRows, + minTime = 10.seconds, + output = output) + val catalogManager = spark.sessionState.catalogManager + + benchmark.addCase("session analyzer") { _ => + val plan = spark.sessionState.sqlParser.parsePlan(sql) + new QueryExecution(spark.asInstanceOf[classic.SparkSession], plan).analyzed + } + + benchmark.addCase("cloned analyzer (per query)") { _ => + val cloned = spark.sessionState.analyzer.withCatalogManager(catalogManager) + val plan = spark.sessionState.sqlParser.parsePlan(sql) + new QueryExecution(spark.asInstanceOf[classic.SparkSession], + plan, analyzerOpt = Some(cloned)).analyzed + } + + benchmark.run() + } + } + } + + /** + * Micro-benchmark for [[Analyzer.withCatalogManager]] in isolation: measures the cost of + * instantiating the anonymous [[Analyzer]] subclass, independent of analysis work. + */ + def cloneCostBenchmark(): Unit = { + runBenchmark("analyzer clone cost") { + val numRows = 1 // Per row measurements are not meaningful here. + val benchmark = new Benchmark( + name = "analyzer clone cost", + valuesPerIteration = numRows, + output = output) + val catalogManager = spark.sessionState.catalogManager + + benchmark.addCase("withCatalogManager") { _ => + spark.sessionState.analyzer.withCatalogManager(catalogManager) + } + + benchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + setupTables() + cloneCostBenchmark() + analysisBenchmark() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala index a36570467a9d..da697847874d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala @@ -21,8 +21,10 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.plans.logical.CompoundBody import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLConf +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTableCatalog, Txn, TxnTable, TxnTableCatalog} import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -47,6 +49,27 @@ class SqlScriptingE2eSuite extends SharedSparkSession { } // Helpers + private def withCatalog( + name: String)( + f: InMemoryRowLevelOperationTableCatalog => Unit): Unit = { + withSQLConf(s"spark.sql.catalog.$name" -> + classOf[InMemoryRowLevelOperationTableCatalog].getName) { + val catalog = spark.sessionState.catalogManager + .catalog(name) + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + try f(catalog) finally spark.sessionState.catalogManager.reset() + } + } + + private def loadTxnTable( + txn: Txn, + tableName: String, + namespace: Array[String] = Array("ns1")): TxnTable = + txn.catalog + .asInstanceOf[TxnTableCatalog] + .loadTable(Identifier.of(namespace, tableName)) + .asInstanceOf[TxnTable] + private def verifySqlScriptResult( sqlText: String, expected: Seq[Row], @@ -174,6 +197,188 @@ class SqlScriptingE2eSuite extends SharedSparkSession { } } + test("multi statement with transactional checks - insert then delete") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'); + | DELETE FROM cat.ns1.t + | WHERE pk IN (SELECT pk FROM cat.ns1.t WHERE dep = 'hr'); + | SELECT * FROM cat.ns1.t; + |END + |""".stripMargin + + verifySqlScriptResult(sqlScript, Seq(Row(2, 200, "software"))) + + // Each DML statement in a script runs in its own independent QE and transaction. + assert(catalog.observedTransactions.size === 2) + assert(catalog.observedTransactions.forall(t => + t.currentState === Committed && t.isClosed)) + + // The DELETE subquery scans the table with a dep='hr' predicate; verify it was tracked. + val deleteTxnTable = loadTxnTable(catalog.observedTransactions(1), "t") + assert(deleteTxnTable.scanEvents.flatten.exists { + case sources.EqualTo("dep", "hr") => true + case _ => false + }) + } + } + } + + test("multi statement with transactional checks - second statement fails") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'); + | DELETE FROM cat.ns1.t WHERE nonexistent_column = 1; + |END + |""".stripMargin + + checkError( + exception = intercept[AnalysisException] { + spark.sql(sqlScript).collect() + }, + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map( + "objectName" -> "`nonexistent_column`", + "proposal" -> ".*"), + matchPVals = true, + queryContext = Array(ExpectedContext("nonexistent_column"))) + + // INSERT committed; DELETE was aborted because analysis failed on the bad column. + assert(catalog.observedTransactions.size === 2) + assert(catalog.observedTransactions(0).currentState === Committed) + assert(catalog.observedTransactions(0).isClosed) + assert(catalog.observedTransactions(1).currentState === Aborted) + assert(catalog.observedTransactions(1).isClosed) + assert(catalog.lastTransaction.currentState === Aborted) + } + } + } + + test("multi statement with transactional checks - insert, merge, update") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t", "cat.ns1.src") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | CREATE TABLE cat.ns1.src (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'), (2, 200, 'software'), (3, 300, 'hr'); + | INSERT INTO cat.ns1.src VALUES (1, 150, 'hr'), (4, 400, 'finance'); + | MERGE INTO cat.ns1.t AS t + | USING cat.ns1.src AS s + | ON t.pk = s.pk + | WHEN MATCHED THEN UPDATE SET salary = s.salary + | WHEN NOT MATCHED THEN INSERT (pk, salary, dep) + | VALUES (s.pk, s.salary, s.dep); + | UPDATE cat.ns1.t SET salary = salary + 50 WHERE dep = 'software'; + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq( + Row(1, 150, "hr"), + Row(2, 250, "software"), + Row(3, 300, "hr"), + Row(4, 400, "finance"))) + + // INSERT (x2), MERGE, and UPDATE each run in their own independent QE and transaction. + assert(catalog.observedTransactions.size === 4) + assert(catalog.observedTransactions.forall(t => t.currentState === Committed && t.isClosed)) + + def txnTable(txnIdx: Int): TxnTable = + loadTxnTable(catalog.observedTransactions(txnIdx), "t") + + // Both inserts are pure writes - no scan. + assert(txnTable(0).scanEvents.isEmpty) + assert(txnTable(1).scanEvents.isEmpty) + + // MERGE scans the full target table. The join is on pk (not the partition column). + assert(txnTable(2).scanEvents.nonEmpty) + assert(txnTable(2).scanEvents.flatten.isEmpty) + + // UPDATE with WHERE dep='software' pushes an equality predicate on the partition column. + assert(txnTable(3).scanEvents.flatten.exists { + case sources.EqualTo("dep", "software") => true + case _ => false + }) + } + } + } + + test("loop with transactional checks - each iteration runs in its own transaction") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | DECLARE i INT = 1; + | CREATE TABLE + | cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | WHILE i <= 3 DO + | INSERT INTO cat.ns1.t VALUES (i, i * 100, 'hr'); + | SET i = i + 1; + | END WHILE; + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq(Row(1, 100, "hr"), Row(2, 200, "hr"), Row(3, 300, "hr"))) + + // Each loop iteration's INSERT runs in its own independent transaction. + assert(catalog.observedTransactions.size === 3) + assert(catalog.observedTransactions.forall(t => t.currentState === Committed && t.isClosed)) + } + } + } + + test("continue handler with transactional checks - handler DML runs in its own transaction") { + withCatalog("cat") { catalog => + withTable("cat.ns1.t") { + val sqlScript = + """ + |BEGIN + | DECLARE CONTINUE HANDLER FOR DIVIDE_BY_ZERO + | BEGIN + | INSERT INTO cat.ns1.t VALUES (-1, -1, 'error'); + | END; + | CREATE TABLE + | cat.ns1.t (pk INT NOT NULL, salary INT, dep STRING) + | PARTITIONED BY (dep); + | INSERT INTO cat.ns1.t VALUES (1, 100, 'hr'); + | SELECT 1/0; + | INSERT INTO cat.ns1.t VALUES (2, 200, 'software'); + | SELECT * FROM cat.ns1.t ORDER BY pk; + |END + |""".stripMargin + + verifySqlScriptResult( + sqlScript, + Seq(Row(-1, -1, "error"), Row(1, 100, "hr"), Row(2, 200, "software"))) + + // INSERT(1), handler INSERT(-1), INSERT(2) - each in its own transaction. + assert(catalog.observedTransactions.size === 3) + assert(catalog.observedTransactions.forall(t => t.currentState === Committed && t.isClosed)) + } + } + } + test("script without result statement") { val sqlScript = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala index 89f655622952..dab667731019 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.{FakeV2Provider, FakeV2ProviderWithCustomSchema, InMemoryTableSessionCatalog} -import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTableCatalog, MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability, TableInfo, V2TableWithV1Fallback} +import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTable, InMemoryTableCatalog, MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability, TableInfo, V2TableWithV1Fallback} import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, Transform} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, MemoryStreamScanBuilder, StreamingQueryWrapper} @@ -109,20 +109,23 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { } test("read: read table without streaming capability support") { - val tableIdentifier = "testcat.table_name" + withSQLConf("spark.sql.catalog.testcat" -> + classOf[DataStreamTableAPISuite.NonStreamingInMemoryTableCatalog].getName) { + val tableIdentifier = "testcat.table_name" - spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING foo") + spark.sql(s"CREATE TABLE $tableIdentifier (id bigint, data string) USING foo") - checkError( - exception = intercept[AnalysisException] { - spark.readStream.table(tableIdentifier) - }, - condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", - parameters = Map( - "tableName" -> "`testcat`.`table_name`", - "operation" -> "either micro-batch or continuous scan" + checkError( + exception = intercept[AnalysisException] { + spark.readStream.table(tableIdentifier) + }, + condition = "UNSUPPORTED_FEATURE.TABLE_OPERATION", + parameters = Map( + "tableName" -> "`testcat`.`table_name`", + "operation" -> "either micro-batch or continuous scan" + ) ) - ) + } } test("read: read table with custom catalog") { @@ -638,6 +641,25 @@ class DataStreamTableAPISuite extends StreamTest with BeforeAndAfter { object DataStreamTableAPISuite { val V1FallbackTestTableName = "fallbackV1Test" + + class NonStreamingInMemoryTableCatalog extends InMemoryTableCatalog { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident.asMultipartIdentifier) + } + val tableName = s"$name.${ident.quoted}" + val table = new InMemoryTable(tableName, tableInfo.columns(), tableInfo.partitions(), + tableInfo.properties, tableInfo.constraints()) { + override def baseCapabiilities: Set[TableCapability] = + super.baseCapabiilities - TableCapability.MICRO_BATCH_READ + } + tables.put(ident, table) + namespaces.putIfAbsent(ident.namespace.toList, Map()) + table + } + } } class InMemoryStreamTable(override val name: String)