From 823133a826882ada8d2ac0118cd39c55ff78368a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 3 Feb 2026 13:27:18 +0800 Subject: [PATCH 1/4] [SPARK-53656][SS][FOLLOWUP] Improve MemoryStream backward compatibility ### What changes were proposed in this pull request? This is a followup to #52402 that addresses backward compatibility concerns: 1. Keep the original `implicit SQLContext` factory methods for full backward compatibility 2. Add new overloads with explicit `SparkSession` parameter for new code 3. Fix `TestGraphRegistrationContext` to provide implicit `spark` and `sqlContext` to avoid name shadowing issues in nested classes 4. Remove redundant `implicit val sparkSession` declarations from pipeline tests that are no longer needed with the fix ### Why are the changes needed? PR #52402 changed the MemoryStream API to use `implicit SparkSession` which broke backward compatibility for code that only has `implicit SQLContext` available. This followup ensures: - Old code continues to work without modification - New code can use SparkSession with explicit parameters - Internal implementation uses SparkSession (modernization from #52402) ### Does this PR introduce _any_ user-facing change? No. This maintains full backward compatibility while adding new API options. ### How was this patch tested? Existing tests pass. The API changes are additive. ### Was this patch authored or co-authored using generative AI tooling? Yes Co-authored-by: Cursor --- .../execution/streaming/runtime/memory.scala | 54 ++++++++----- .../sources/ContinuousMemoryStream.scala | 51 +++++------- .../sources/LowLatencyMemoryStream.scala | 56 +++++-------- .../execution/streaming/MemorySinkSuite.scala | 78 ------------------- .../state/StateStoreCoordinatorSuite.scala | 42 +++++----- .../sql/hive/execution/HiveDDLSuite.scala | 2 +- .../graph/ConnectInvalidPipelineSuite.scala | 5 +- .../graph/ConnectValidPipelineSuite.scala | 6 -- .../graph/MaterializeTablesSuite.scala | 8 +- .../pipelines/graph/SystemMetadataSuite.scala | 5 -- .../graph/TriggeredGraphExecutionSuite.scala | 6 +- .../utils/TestGraphRegistrationContext.scala | 13 +++- 12 files changed, 110 insertions(+), 216 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala index bf67ed670ec81..306fd45db2bb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala @@ -43,36 +43,48 @@ import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap -object MemoryStream extends LowPriorityMemoryStreamImplicits { +object MemoryStream { protected val currentBlockId = new AtomicInteger(0) protected val memoryStreamId = new AtomicInteger(0) - def apply[A : Encoder](implicit sparkSession: SparkSession): MemoryStream[A] = - new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) - - def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: SparkSession): MemoryStream[A] = - new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, Some(numPartitions)) -} - -/** - * Provides lower-priority implicits for MemoryStream to prevent ambiguity when both - * SparkSession and SQLContext are in scope. The implicits in the companion object, - * which use SparkSession, take higher precedence. - */ -trait LowPriorityMemoryStreamImplicits { - this: MemoryStream.type => - - // Deprecated: Used when an implicit SQLContext is in scope - @deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0") - def apply[A: Encoder]()(implicit sqlContext: SQLContext): MemoryStream[A] = + /** + * Creates a MemoryStream with an implicit SQLContext (backward compatible). + * Usage: `MemoryStream[Int]` + */ + def apply[A: Encoder](implicit sqlContext: SQLContext): MemoryStream[A] = new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession) - @deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0") - def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] = + /** + * Creates a MemoryStream with specified partitions using implicit SQLContext. + * Usage: `MemoryStream[Int](numPartitions)` + */ + def apply[A: Encoder](numPartitions: Int)( + implicit sqlContext: SQLContext): MemoryStream[A] = new MemoryStream[A]( memoryStreamId.getAndIncrement(), sqlContext.sparkSession, Some(numPartitions)) + + /** + * Creates a MemoryStream with explicit SparkSession. + * Usage: `MemoryStream[Int](spark)` + */ + def apply[A: Encoder](sparkSession: SparkSession): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) + + /** + * Creates a MemoryStream with explicit encoder and SparkSession. + * Usage: `MemoryStream(Encoders.scalaInt, spark)` + */ + def apply[A](encoder: Encoder[A], sparkSession: SparkSession): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)(encoder) + + /** + * Creates a MemoryStream with specified partitions using explicit SparkSession. + * Usage: `MemoryStream[Int](numPartitions, spark)` + */ + def apply[A: Encoder](numPartitions: Int, sparkSession: SparkSession): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, Some(numPartitions)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 8042cacf1374b..8187563e178bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -112,47 +112,36 @@ class ContinuousMemoryStream[A : Encoder]( override def commit(end: Offset): Unit = {} } -object ContinuousMemoryStream extends LowPriorityContinuousMemoryStreamImplicits { +object ContinuousMemoryStream { protected val memoryStreamId = new AtomicInteger(0) - def apply[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] = - new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) - - def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: SparkSession): - ContinuousMemoryStream[A] = - new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, numPartitions) - - def singlePartition[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] = - new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1) -} - -/** - * Provides lower-priority implicits for ContinuousMemoryStream to prevent ambiguity when both - * SparkSession and SQLContext are in scope. The implicits in the companion object, - * which use SparkSession, take higher precedence. - */ -trait LowPriorityContinuousMemoryStreamImplicits { - this: ContinuousMemoryStream.type => - - // Deprecated: Used when an implicit SQLContext is in scope - @deprecated("Use ContinuousMemoryStream with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def apply[A: Encoder]()(implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + /** Creates a ContinuousMemoryStream with an implicit SQLContext (backward compatible). */ + def apply[A: Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession) - @deprecated("Use ContinuousMemoryStream with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): - ContinuousMemoryStream[A] = + /** Creates a ContinuousMemoryStream with specified partitions (SQLContext). */ + def apply[A: Encoder](numPartitions: Int)( + implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A]( memoryStreamId.getAndIncrement(), sqlContext.sparkSession, numPartitions) - @deprecated("Use ContinuousMemoryStream.singlePartition with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def singlePartition[A: Encoder]()(implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + /** Creates a ContinuousMemoryStream with explicit SparkSession. */ + def apply[A: Encoder](sparkSession: SparkSession): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) + + /** Creates a ContinuousMemoryStream with specified partitions (SparkSession). */ + def apply[A: Encoder](numPartitions: Int, sparkSession: SparkSession): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, numPartitions) + + /** Creates a single partition ContinuousMemoryStream (SQLContext). */ + def singlePartition[A: Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession, 1) + + /** Creates a single partition ContinuousMemoryStream (SparkSession). */ + def singlePartition[A: Encoder](sparkSession: SparkSession): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala index 65b202894ec46..97fb074a1190f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala @@ -183,53 +183,39 @@ class LowLatencyMemoryStream[A: Encoder]( } } -object LowLatencyMemoryStream extends LowPriorityLowLatencyMemoryStreamImplicits { +object LowLatencyMemoryStream { protected val memoryStreamId = new AtomicInteger(0) - def apply[A: Encoder](implicit sparkSession: SparkSession): LowLatencyMemoryStream[A] = - new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) + /** Creates a LowLatencyMemoryStream with an implicit SQLContext (backward compatible). */ + def apply[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = + new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession) + /** Creates a LowLatencyMemoryStream with specified partitions (SQLContext). */ def apply[A: Encoder](numPartitions: Int)( - implicit - sparkSession: SparkSession): LowLatencyMemoryStream[A] = + implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = new LowLatencyMemoryStream[A]( memoryStreamId.getAndIncrement(), - sparkSession, - numPartitions = numPartitions - ) - - def singlePartition[A: Encoder](implicit sparkSession: SparkSession): LowLatencyMemoryStream[A] = - new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1) -} - -/** - * Provides lower-priority implicits for LowLatencyMemoryStream to prevent ambiguity when both - * SparkSession and SQLContext are in scope. The implicits in the companion object, - * which use SparkSession, take higher precedence. - */ -trait LowPriorityLowLatencyMemoryStreamImplicits { - this: LowLatencyMemoryStream.type => + sqlContext.sparkSession, + numPartitions = numPartitions) - // Deprecated: Used when an implicit SQLContext is in scope - @deprecated("Use LowLatencyMemoryStream with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def apply[A: Encoder]()(implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = - new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession) + /** Creates a LowLatencyMemoryStream with explicit SparkSession. */ + def apply[A: Encoder](sparkSession: SparkSession): LowLatencyMemoryStream[A] = + new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) - @deprecated("Use LowLatencyMemoryStream with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): - LowLatencyMemoryStream[A] = + /** Creates a LowLatencyMemoryStream with specified partitions (SparkSession). */ + def apply[A: Encoder](numPartitions: Int, sparkSession: SparkSession): LowLatencyMemoryStream[A] = new LowLatencyMemoryStream[A]( memoryStreamId.getAndIncrement(), - sqlContext.sparkSession, - numPartitions = numPartitions - ) + sparkSession, + numPartitions = numPartitions) - @deprecated("Use LowLatencyMemoryStream.singlePartition with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def singlePartition[A: Encoder]()(implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = + /** Creates a single partition LowLatencyMemoryStream (SQLContext). */ + def singlePartition[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession, 1) + + /** Creates a single partition LowLatencyMemoryStream (SparkSession). */ + def singlePartition[A: Encoder](sparkSession: SparkSession): LowLatencyMemoryStream[A] = + new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index e0ec3fd1b907b..4ec44eac22e36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -343,84 +343,6 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { intsToDF(expected)(schema)) } - test("LowPriorityMemoryStreamImplicits works with implicit sqlContext") { - // Test that MemoryStream can be created using implicit sqlContext - implicit val sqlContext: SQLContext = spark.sqlContext - - // Test MemoryStream[A]() with implicit sqlContext - val stream1 = MemoryStream[Int]() - assert(stream1 != null) - - // Test MemoryStream[A](numPartitions) with implicit sqlContext - val stream2 = MemoryStream[String](3) - assert(stream2 != null) - - // Verify the streams work correctly - stream1.addData(1, 2, 3) - val df1 = stream1.toDF() - assert(df1.schema.fieldNames.contains("value")) - - stream2.addData("a", "b", "c") - val df2 = stream2.toDF() - assert(df2.schema.fieldNames.contains("value")) - } - - test("LowPriorityContinuousMemoryStreamImplicits works with implicit sqlContext") { - import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream - // Test that ContinuousMemoryStream can be created using implicit sqlContext - implicit val sqlContext: SQLContext = spark.sqlContext - - // Test ContinuousMemoryStream[A]() with implicit sqlContext - val stream1 = ContinuousMemoryStream[Int]() - assert(stream1 != null) - - // Test ContinuousMemoryStream[A](numPartitions) with implicit sqlContext - val stream2 = ContinuousMemoryStream[String](3) - assert(stream2 != null) - - // Test ContinuousMemoryStream.singlePartition with implicit sqlContext - val stream3 = ContinuousMemoryStream.singlePartition[Int]() - assert(stream3 != null) - - // Verify the streams work correctly - stream1.addData(Seq(1, 2, 3)) - stream2.addData(Seq("a", "b", "c")) - stream3.addData(Seq(10, 20)) - - // Basic verification that streams are functional - assert(stream1.initialOffset() != null) - assert(stream2.initialOffset() != null) - assert(stream3.initialOffset() != null) - } - - test("LowPriorityLowLatencyMemoryStreamImplicits works with implicit sqlContext") { - import org.apache.spark.sql.execution.streaming.LowLatencyMemoryStream - // Test that LowLatencyMemoryStream can be created using implicit sqlContext - implicit val sqlContext: SQLContext = spark.sqlContext - - // Test LowLatencyMemoryStream[A]() with implicit sqlContext - val stream1 = LowLatencyMemoryStream[Int]() - assert(stream1 != null) - - // Test LowLatencyMemoryStream[A](numPartitions) with implicit sqlContext - val stream2 = LowLatencyMemoryStream[String](3) - assert(stream2 != null) - - // Test LowLatencyMemoryStream.singlePartition with implicit sqlContext - val stream3 = LowLatencyMemoryStream.singlePartition[Int]() - assert(stream3 != null) - - // Verify the streams work correctly - stream1.addData(Seq(1, 2, 3)) - stream2.addData(Seq("a", "b", "c")) - stream3.addData(Seq(10, 20)) - - // Basic verification that streams are functional - assert(stream1.initialOffset() != null) - assert(stream2.initialOffset() != null) - assert(stream3.initialOffset() != null) - } - private implicit def intsToDF(seq: Seq[Int])(implicit schema: StructType): DataFrame = { require(schema.fields.length === 1) sqlContext.createDataset(seq).toDF(schema.fieldNames.head) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 7446390e8d068..a15521333614f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -163,7 +163,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") // Start a query and run a batch to load state stores - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // stateful query val checkpointLocation = Utils.createTempDir().getAbsoluteFile val query = aggregated.writeStream @@ -279,10 +279,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark - val inputData = MemoryStream[Int] + + val inputData = MemoryStream[Int](spark) val query = setUpStatefulQuery(inputData, "query") - val inputData2 = MemoryStream[Int] + val inputData2 = MemoryStream[Int](spark) val query2 = setUpStatefulQuery(inputData2, "query2") // Add, commit, and wait multiple times to force snapshot versions and time difference // we will detect state store with partition 0 and 1 to be lagged on version 5 @@ -378,8 +378,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark - val inputData = MemoryStream[Int] + + val inputData = MemoryStream[Int](spark) val aggregated = inputData.toDF().groupBy("value").agg(count("*")) val checkpointLocation = Utils.createTempDir().getAbsoluteFile val query = aggregated.writeStream @@ -444,8 +444,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark - val inputData = MemoryStream[Int] + + val inputData = MemoryStream[Int](spark) val query = setUpStatefulQuery(inputData, "query") // Add, commit, and wait multiple times to force snapshot versions and time difference (0 until 6).foreach { _ => @@ -481,10 +481,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark + // Start a join query and run some data to force snapshot uploads - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] + val input1 = MemoryStream[Int](spark) + val input2 = MemoryStream[Int](spark) val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") val joined = df1.join(df2, expr("leftKey = rightKey")) @@ -525,10 +525,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark + // Start and run two queries together with some data to force snapshot uploads - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] + val input1 = MemoryStream[Int](spark) + val input2 = MemoryStream[Int](spark) val query1 = setUpStatefulQuery(input1, "query1") val query2 = setUpStatefulQuery(input2, "query2") @@ -593,9 +593,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark + // Start a query and run some data to force snapshot uploads - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val query = setUpStatefulQuery(inputData, "query") // Go through two batches to force two snapshot uploads. @@ -638,9 +638,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark + // Start a query and run some data to force snapshot uploads - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val query = setUpStatefulQuery(inputData, "query") // Go through several rounds of input to force snapshot uploads @@ -682,7 +682,7 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { withTempDir { srcDir => - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val query = inputData.toDF().dropDuplicates() val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) // Keep track of state checkpoint directory for the second run @@ -805,7 +805,7 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { withTempDir { srcDir => - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val query = inputData.toDF().dropDuplicates() testStream(query)( @@ -884,7 +884,7 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { withTempDir { srcDir => - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val query = inputData.toDF().dropDuplicates() // Populate state stores with an initial snapshot, so that timestamp isn't marked diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 5c23192b1d3fd..43f15a12cad19 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -2653,7 +2653,7 @@ class HiveDDLSuite |SELECT word, number from t1 """.stripMargin) - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val joined = inputData.toDS().toDF() .join(spark.table("smallTable"), $"value" === $"number") diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala index 7c8181b5b72a5..f37716b4a24d3 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.pipelines.graph -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.streaming.runtime.MemoryStream import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext} import org.apache.spark.sql.test.SharedSparkSession @@ -423,7 +423,6 @@ class ConnectInvalidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ val p = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem = MemoryStream[Int] mem.addData(1) registerPersistedView("a", query = dfFlowFunc(mem.toDF())) @@ -467,7 +466,6 @@ class ConnectInvalidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark registerMaterializedView("a", query = dfFlowFunc(MemoryStream[Int].toDF())) }.resolveToDataflowGraph() @@ -491,7 +489,6 @@ class ConnectInvalidPipelineSuite extends PipelineTest with SharedSparkSession { val graph = new TestGraphRegistrationContext(spark) { registerTable("a") - implicit val sparkSession: SparkSession = spark registerFlow( destinationName = "a", name = "once_flow", diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala index a4bb7c067d875..3ac3c09017506 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.pipelines.graph -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.Union @@ -159,7 +158,6 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ class P extends TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val ints = MemoryStream[Int] ints.addData(1, 2, 3, 4) registerPersistedView("a", query = dfFlowFunc(ints.toDF())) @@ -201,7 +199,6 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ class P extends TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val ints1 = MemoryStream[Int] ints1.addData(1, 2, 3, 4) val ints2 = MemoryStream[Int] @@ -362,7 +359,6 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ class P extends TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem = MemoryStream[Int] registerPersistedView("a", query = dfFlowFunc(mem.toDF())) registerTable("b") @@ -406,7 +402,6 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem = MemoryStream[Int] mem.addData(1, 2) registerPersistedView("complete-view", query = dfFlowFunc(Seq(1, 2).toDF("x"))) @@ -499,7 +494,6 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ val P = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem = MemoryStream[Int] mem.addData(1, 2) registerTemporaryView("a", query = dfFlowFunc(mem.toDF().select($"value" as "x"))) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala index ba8419eb6e9c8..72cc644e57684 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.pipelines.graph import scala.jdk.CollectionConverters._ import org.apache.spark.SparkThrowable -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} import org.apache.spark.sql.connector.expressions.{ClusterByTransform, Expressions, FieldReference} import org.apache.spark.sql.execution.streaming.runtime.MemoryStream @@ -269,7 +268,7 @@ abstract class MaterializeTablesSuite extends BaseCoreExecutionTest { test("invalid schema merge") { val session = spark - implicit val sparkSession: SparkSession = spark + implicit val sqlCtx: SQLContext = spark.sqlContext import session.implicits._ val streamInts = MemoryStream[Int] @@ -353,7 +352,6 @@ abstract class MaterializeTablesSuite extends BaseCoreExecutionTest { val ex = intercept[TableMaterializationException] { materializeGraph(new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val source: MemoryStream[Int] = MemoryStream[Int] source.addData(1, 2) registerTable( @@ -646,7 +644,7 @@ abstract class MaterializeTablesSuite extends BaseCoreExecutionTest { s"Streaming tables should evolve schema only if not full refresh = $isFullRefresh" ) { val session = spark - implicit val sparkSession: SparkSession = spark + implicit val sqlCtx: SQLContext = spark.sqlContext import session.implicits._ val streamInts = MemoryStream[Int] diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala index c37a6fb52f95d..71301c34c14ef 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.pipelines.graph import org.apache.hadoop.fs.Path -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryWrapper} import org.apache.spark.sql.pipelines.utils.{ExecutionTest, TestGraphRegistrationContext} @@ -39,7 +38,6 @@ class SystemMetadataSuite // create a pipeline with only a single ST val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem: MemoryStream[Int] = MemoryStream[Int] mem.addData(1, 2, 3) registerView("a", query = dfFlowFunc(mem.toDF())) @@ -107,7 +105,6 @@ class SystemMetadataSuite import session.implicits._ val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem: MemoryStream[Int] = MemoryStream[Int] mem.addData(1, 2, 3) registerView("a", query = dfFlowFunc(mem.toDF())) @@ -172,7 +169,6 @@ class SystemMetadataSuite import session.implicits._ val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem: MemoryStream[Int] = MemoryStream[Int] mem.addData(1, 2, 3) registerView("a", query = dfFlowFunc(mem.toDF())) @@ -234,7 +230,6 @@ class SystemMetadataSuite // create a pipeline with only a single ST val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem: MemoryStream[Int] = MemoryStream[Int] mem.addData(1, 2, 3) registerView("a", query = dfFlowFunc(mem.toDF())) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala index 36b749cc84d9f..57baf4c2d5b11 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.pipelines.graph import org.scalatest.time.{Seconds, Span} -import org.apache.spark.sql.{functions, Row, SparkSession} +import org.apache.spark.sql.{functions, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} @@ -183,7 +183,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession // Construct pipeline val pipelineDef = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark private val ints = MemoryStream[Int] ints.addData(1 until 10: _*) registerView("input", query = dfFlowFunc(ints.toDF())) @@ -260,7 +259,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession // Construct pipeline val pipelineDef = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark private val ints = MemoryStream[Int] registerView("input", query = dfFlowFunc(ints.toDF())) registerTable( @@ -311,7 +309,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession }) val pipelineDef = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark private val memoryStream = MemoryStream[Int] memoryStream.addData(1, 2) registerView("input_view", query = dfFlowFunc(memoryStream.toDF())) @@ -551,7 +548,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession // Construct pipeline val pipelineDef = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark private val memoryStream = MemoryStream[Int] memoryStream.addData(1, 2) registerView("input_view", query = dfFlowFunc(memoryStream.toDF())) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala index e7c0956385135..9ff92ee895b1d 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.pipelines.utils +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{LocalTempView, PersistedView => PersistedViewType, UnresolvedRelation, ViewType} import org.apache.spark.sql.classic.{DataFrame, SparkSession} @@ -28,7 +29,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * A test class to simplify the creation of pipelines and datasets for unit testing. */ class TestGraphRegistrationContext( - val spark: SparkSession, + val _spark: SparkSession, val sqlConf: Map[String, String] = Map.empty) extends GraphRegistrationContext( defaultCatalog = TestGraphRegistrationContext.DEFAULT_CATALOG, @@ -36,6 +37,10 @@ class TestGraphRegistrationContext( defaultSqlConf = sqlConf ) { + /** Re-expose as implicit so nested anonymous classes can use it without shadowing issues */ + implicit def spark: SparkSession = _spark + implicit def sqlContext: SQLContext = _spark.sqlContext + // scalastyle:off // Disable scalastyle to ignore argument count. /** Registers a streaming table in this [[TestGraphRegistrationContext]] */ @@ -145,7 +150,7 @@ class TestGraphRegistrationContext( val qualifiedIdentifier = GraphIdentifierManager .parseAndQualifyTableIdentifier( rawTableIdentifier = GraphIdentifierManager - .parseTableIdentifier(name, spark), + .parseTableIdentifier(name, _spark), currentCatalog = catalog.orElse(Some(defaultCatalog)), currentDatabase = database.orElse(Some(defaultDatabase))) .identifier @@ -304,9 +309,9 @@ class TestGraphRegistrationContext( catalog: Option[String] = None, database: Option[String] = None ): Unit = { - val rawFlowIdentifier = GraphIdentifierManager.parseTableIdentifier(name, spark) + val rawFlowIdentifier = GraphIdentifierManager.parseTableIdentifier(name, _spark) val rawDestinationIdentifier = - GraphIdentifierManager.parseTableIdentifier(destinationName, spark) + GraphIdentifierManager.parseTableIdentifier(destinationName, _spark) val flowWritesToView = getViews .filter(_.isInstanceOf[TemporaryView]) From 3ca4eb6f5577d3f8c67ad0d19a9cff9a6a8e1e5f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 3 Feb 2026 18:54:00 +0800 Subject: [PATCH 2/4] [SPARK-53656][SS][FOLLOWUP] Remove confusing MemoryStream factory method Remove the `apply[A: Encoder](numPartitions: Int, sparkSession: SparkSession)` factory method that creates a semantic trap - it can accidentally match calls like `MemoryStream[T](0, spark)` interpreting the first argument as `numPartitions` instead of `id`, causing zero partitions to be created and no data to flow. Users who need both `numPartitions` and explicit `SparkSession` can use the case class constructor directly: `new MemoryStream[A](id, sparkSession, Some(numPartitions))`. Co-authored-by: Cursor --- .../spark/sql/execution/streaming/runtime/memory.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala index 306fd45db2bb9..e4487b03bd41b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala @@ -78,13 +78,6 @@ object MemoryStream { */ def apply[A](encoder: Encoder[A], sparkSession: SparkSession): MemoryStream[A] = new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)(encoder) - - /** - * Creates a MemoryStream with specified partitions using explicit SparkSession. - * Usage: `MemoryStream[Int](numPartitions, spark)` - */ - def apply[A: Encoder](numPartitions: Int, sparkSession: SparkSession): MemoryStream[A] = - new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, Some(numPartitions)) } /** From fdb8c9ead1156bdb0dbc5b8b7a42e83012a7b661 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 3 Feb 2026 21:34:09 +0800 Subject: [PATCH 3/4] Update StateStoreCoordinatorSuite.scala --- .../execution/streaming/state/StateStoreCoordinatorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index a15521333614f..03257d3da373b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -156,7 +156,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("query stop deactivates related store providers") { var coordRef: StateStoreCoordinatorRef = null try { - implicit val spark: SparkSession = SparkSession.builder().sparkContext(sc).getOrCreate() + val spark: SparkSession = SparkSession.builder().sparkContext(sc).getOrCreate() SparkSession.setActiveSession(spark) import spark.implicits._ coordRef = spark.streams.stateStoreCoordinator From e1c0b1eb0a4a9ddbbce5682792d1fc1fcfd85b91 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 5 Feb 2026 00:47:05 +0800 Subject: [PATCH 4/4] address comments --- .../spark/sql/execution/streaming/runtime/memory.scala | 10 ++++++++++ .../streaming/sources/ContinuousMemoryStream.scala | 2 +- .../streaming/sources/LowLatencyMemoryStream.scala | 2 +- .../streaming/PythonStreamingDataSourceSuite.scala | 4 ++-- .../execution/streaming/state/StateStoreSuite.scala | 4 ++-- .../sql/streaming/FlatMapGroupsWithStateSuite.scala | 2 +- .../sql/streaming/StreamingAggregationSuite.scala | 2 +- .../sql/streaming/StreamingDeduplicationSuite.scala | 2 +- .../spark/sql/streaming/StreamingJoinSuite.scala | 2 +- 9 files changed, 20 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala index e4487b03bd41b..c7556ed478599 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala @@ -72,6 +72,16 @@ object MemoryStream { def apply[A: Encoder](sparkSession: SparkSession): MemoryStream[A] = new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) + /** + * Creates a MemoryStream with specified partitions using explicit SparkSession. + * Usage: `MemoryStream[Int](spark, numPartitions)` + */ + def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): MemoryStream[A] = + new MemoryStream[A]( + memoryStreamId.getAndIncrement(), + sparkSession, + Some(numPartitions)) + /** * Creates a MemoryStream with explicit encoder and SparkSession. * Usage: `MemoryStream(Encoders.scalaInt, spark)` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 8187563e178bf..885f9ada22c9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -132,7 +132,7 @@ object ContinuousMemoryStream { new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) /** Creates a ContinuousMemoryStream with specified partitions (SparkSession). */ - def apply[A: Encoder](numPartitions: Int, sparkSession: SparkSession): ContinuousMemoryStream[A] = + def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, numPartitions) /** Creates a single partition ContinuousMemoryStream (SQLContext). */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala index 97fb074a1190f..0e34b46a56c58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala @@ -203,7 +203,7 @@ object LowLatencyMemoryStream { new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) /** Creates a LowLatencyMemoryStream with specified partitions (SparkSession). */ - def apply[A: Encoder](numPartitions: Int, sparkSession: SparkSession): LowLatencyMemoryStream[A] = + def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): LowLatencyMemoryStream[A] = new LowLatencyMemoryStream[A]( memoryStreamId.getAndIncrement(), sparkSession, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala index 0e33b6e55a432..a5bc52beffaa7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala @@ -914,7 +914,7 @@ class PythonStreamingDataSourceWriteSuite extends PythonDataSourceSuiteBase { val dataSource = createUserDefinedPythonDataSource(dataSourceName, simpleDataStreamWriterScript) spark.dataSource.registerPython(dataSourceName, dataSource) - val inputData = MemoryStream[Int](numPartitions = 3) + val inputData = MemoryStream[Int](spark, numPartitions = 3) val df = inputData.toDF() withTempDir { dir => val path = dir.getAbsolutePath @@ -998,7 +998,7 @@ class PythonStreamingDataSourceWriteSuite extends PythonDataSourceSuiteBase { |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) spark.dataSource.registerPython(dataSourceName, dataSource) - val inputData = MemoryStream[Int](numPartitions = 3) + val inputData = MemoryStream[Int](spark, numPartitions = 3) withTempDir { dir => val path = dir.getAbsolutePath val checkpointDir = new File(path, "checkpoint") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index b13998708b615..d35263b655d2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1234,11 +1234,11 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("SPARK-21145: Restarted queries create new provider instances") { try { val checkpointLocation = Utils.createTempDir().getAbsoluteFile - implicit val spark: SparkSession = SparkSession.builder().master("local[2]").getOrCreate() + val spark: SparkSession = SparkSession.builder().master("local[2]").getOrCreate() SparkSession.setActiveSession(spark) spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") import spark.implicits._ - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = { val aggregated = inputData.toDF().groupBy("value").agg(count("*")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 5269edc682210..8dbe8e334ad69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -832,7 +832,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { def constructUnionDf(desiredPartitionsForInput1: Int) : (MemoryStream[String], MemoryStream[String], DataFrame) = { - val input1 = MemoryStream[String](desiredPartitionsForInput1) + val input1 = MemoryStream[String](spark, desiredPartitionsForInput1) val input2 = MemoryStream[String] val df1 = input1.toDF() .select($"value", $"value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 7825730d901da..f065f1de5cdc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -347,7 +347,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { " shifted partition IDs") { def constructUnionDf(desiredPartitionsForInput1: Int) : (MemoryStream[Int], MemoryStream[Int], DataFrame) = { - val input1 = MemoryStream[Int](desiredPartitionsForInput1) + val input1 = MemoryStream[Int](spark, desiredPartitionsForInput1) val input2 = MemoryStream[Int] val df1 = input1.toDF() .select($"value", $"value" + 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index dbc2b767b0f9a..7aec3353cd4dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -343,7 +343,7 @@ class StreamingDeduplicationSuite extends StateStoreMetricsTest " shifted partition IDs") { def constructUnionDf(desiredPartitionsForInput1: Int) : (MemoryStream[Int], MemoryStream[Int], DataFrame) = { - val input1 = MemoryStream[Int](desiredPartitionsForInput1) + val input1 = MemoryStream[Int](spark, desiredPartitionsForInput1) val input2 = MemoryStream[Int] val df1 = input1.toDF().select($"value") val df2 = input2.toDF().dropDuplicates("value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 22028a585e229..6cdca9fb5309f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -1609,7 +1609,7 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { test("SPARK-29438: ensure UNION doesn't lead stream-stream join to use shifted partition IDs") { def constructUnionDf(desiredPartitionsForInput1: Int) : (MemoryStream[Int], MemoryStream[Int], MemoryStream[Int], DataFrame) = { - val input1 = MemoryStream[Int](desiredPartitionsForInput1) + val input1 = MemoryStream[Int](spark, desiredPartitionsForInput1) val df1 = input1.toDF() .select( $"value" as "key",