diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala index bf67ed670ec81..c7556ed478599 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala @@ -43,36 +43,51 @@ import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap -object MemoryStream extends LowPriorityMemoryStreamImplicits { +object MemoryStream { protected val currentBlockId = new AtomicInteger(0) protected val memoryStreamId = new AtomicInteger(0) - def apply[A : Encoder](implicit sparkSession: SparkSession): MemoryStream[A] = - new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) - - def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: SparkSession): MemoryStream[A] = - new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, Some(numPartitions)) -} - -/** - * Provides lower-priority implicits for MemoryStream to prevent ambiguity when both - * SparkSession and SQLContext are in scope. The implicits in the companion object, - * which use SparkSession, take higher precedence. - */ -trait LowPriorityMemoryStreamImplicits { - this: MemoryStream.type => - - // Deprecated: Used when an implicit SQLContext is in scope - @deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0") - def apply[A: Encoder]()(implicit sqlContext: SQLContext): MemoryStream[A] = + /** + * Creates a MemoryStream with an implicit SQLContext (backward compatible). + * Usage: `MemoryStream[Int]` + */ + def apply[A: Encoder](implicit sqlContext: SQLContext): MemoryStream[A] = new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession) - @deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0") - def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] = + /** + * Creates a MemoryStream with specified partitions using implicit SQLContext. + * Usage: `MemoryStream[Int](numPartitions)` + */ + def apply[A: Encoder](numPartitions: Int)( + implicit sqlContext: SQLContext): MemoryStream[A] = new MemoryStream[A]( memoryStreamId.getAndIncrement(), sqlContext.sparkSession, Some(numPartitions)) + + /** + * Creates a MemoryStream with explicit SparkSession. + * Usage: `MemoryStream[Int](spark)` + */ + def apply[A: Encoder](sparkSession: SparkSession): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) + + /** + * Creates a MemoryStream with specified partitions using explicit SparkSession. + * Usage: `MemoryStream[Int](spark, numPartitions)` + */ + def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): MemoryStream[A] = + new MemoryStream[A]( + memoryStreamId.getAndIncrement(), + sparkSession, + Some(numPartitions)) + + /** + * Creates a MemoryStream with explicit encoder and SparkSession. + * Usage: `MemoryStream(Encoders.scalaInt, spark)` + */ + def apply[A](encoder: Encoder[A], sparkSession: SparkSession): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)(encoder) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 8042cacf1374b..885f9ada22c9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -112,47 +112,36 @@ class ContinuousMemoryStream[A : Encoder]( override def commit(end: Offset): Unit = {} } -object ContinuousMemoryStream extends LowPriorityContinuousMemoryStreamImplicits { +object ContinuousMemoryStream { protected val memoryStreamId = new AtomicInteger(0) - def apply[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] = - new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) - - def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: SparkSession): - ContinuousMemoryStream[A] = - new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, numPartitions) - - def singlePartition[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] = - new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1) -} - -/** - * Provides lower-priority implicits for ContinuousMemoryStream to prevent ambiguity when both - * SparkSession and SQLContext are in scope. The implicits in the companion object, - * which use SparkSession, take higher precedence. - */ -trait LowPriorityContinuousMemoryStreamImplicits { - this: ContinuousMemoryStream.type => - - // Deprecated: Used when an implicit SQLContext is in scope - @deprecated("Use ContinuousMemoryStream with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def apply[A: Encoder]()(implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + /** Creates a ContinuousMemoryStream with an implicit SQLContext (backward compatible). */ + def apply[A: Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession) - @deprecated("Use ContinuousMemoryStream with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): - ContinuousMemoryStream[A] = + /** Creates a ContinuousMemoryStream with specified partitions (SQLContext). */ + def apply[A: Encoder](numPartitions: Int)( + implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A]( memoryStreamId.getAndIncrement(), sqlContext.sparkSession, numPartitions) - @deprecated("Use ContinuousMemoryStream.singlePartition with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def singlePartition[A: Encoder]()(implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + /** Creates a ContinuousMemoryStream with explicit SparkSession. */ + def apply[A: Encoder](sparkSession: SparkSession): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) + + /** Creates a ContinuousMemoryStream with specified partitions (SparkSession). */ + def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, numPartitions) + + /** Creates a single partition ContinuousMemoryStream (SQLContext). */ + def singlePartition[A: Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession, 1) + + /** Creates a single partition ContinuousMemoryStream (SparkSession). */ + def singlePartition[A: Encoder](sparkSession: SparkSession): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala index 65b202894ec46..0e34b46a56c58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala @@ -183,53 +183,39 @@ class LowLatencyMemoryStream[A: Encoder]( } } -object LowLatencyMemoryStream extends LowPriorityLowLatencyMemoryStreamImplicits { +object LowLatencyMemoryStream { protected val memoryStreamId = new AtomicInteger(0) - def apply[A: Encoder](implicit sparkSession: SparkSession): LowLatencyMemoryStream[A] = - new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) + /** Creates a LowLatencyMemoryStream with an implicit SQLContext (backward compatible). */ + def apply[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = + new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession) + /** Creates a LowLatencyMemoryStream with specified partitions (SQLContext). */ def apply[A: Encoder](numPartitions: Int)( - implicit - sparkSession: SparkSession): LowLatencyMemoryStream[A] = + implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = new LowLatencyMemoryStream[A]( memoryStreamId.getAndIncrement(), - sparkSession, - numPartitions = numPartitions - ) - - def singlePartition[A: Encoder](implicit sparkSession: SparkSession): LowLatencyMemoryStream[A] = - new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1) -} - -/** - * Provides lower-priority implicits for LowLatencyMemoryStream to prevent ambiguity when both - * SparkSession and SQLContext are in scope. The implicits in the companion object, - * which use SparkSession, take higher precedence. - */ -trait LowPriorityLowLatencyMemoryStreamImplicits { - this: LowLatencyMemoryStream.type => + sqlContext.sparkSession, + numPartitions = numPartitions) - // Deprecated: Used when an implicit SQLContext is in scope - @deprecated("Use LowLatencyMemoryStream with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def apply[A: Encoder]()(implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = - new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession) + /** Creates a LowLatencyMemoryStream with explicit SparkSession. */ + def apply[A: Encoder](sparkSession: SparkSession): LowLatencyMemoryStream[A] = + new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) - @deprecated("Use LowLatencyMemoryStream with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): - LowLatencyMemoryStream[A] = + /** Creates a LowLatencyMemoryStream with specified partitions (SparkSession). */ + def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): LowLatencyMemoryStream[A] = new LowLatencyMemoryStream[A]( memoryStreamId.getAndIncrement(), - sqlContext.sparkSession, - numPartitions = numPartitions - ) + sparkSession, + numPartitions = numPartitions) - @deprecated("Use LowLatencyMemoryStream.singlePartition with an implicit SparkSession " + - "instead of SQLContext", "4.1.0") - def singlePartition[A: Encoder]()(implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = + /** Creates a single partition LowLatencyMemoryStream (SQLContext). */ + def singlePartition[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] = new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession, 1) + + /** Creates a single partition LowLatencyMemoryStream (SparkSession). */ + def singlePartition[A: Encoder](sparkSession: SparkSession): LowLatencyMemoryStream[A] = + new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala index 0e33b6e55a432..a5bc52beffaa7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala @@ -914,7 +914,7 @@ class PythonStreamingDataSourceWriteSuite extends PythonDataSourceSuiteBase { val dataSource = createUserDefinedPythonDataSource(dataSourceName, simpleDataStreamWriterScript) spark.dataSource.registerPython(dataSourceName, dataSource) - val inputData = MemoryStream[Int](numPartitions = 3) + val inputData = MemoryStream[Int](spark, numPartitions = 3) val df = inputData.toDF() withTempDir { dir => val path = dir.getAbsolutePath @@ -998,7 +998,7 @@ class PythonStreamingDataSourceWriteSuite extends PythonDataSourceSuiteBase { |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) spark.dataSource.registerPython(dataSourceName, dataSource) - val inputData = MemoryStream[Int](numPartitions = 3) + val inputData = MemoryStream[Int](spark, numPartitions = 3) withTempDir { dir => val path = dir.getAbsolutePath val checkpointDir = new File(path, "checkpoint") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index e0ec3fd1b907b..4ec44eac22e36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -343,84 +343,6 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { intsToDF(expected)(schema)) } - test("LowPriorityMemoryStreamImplicits works with implicit sqlContext") { - // Test that MemoryStream can be created using implicit sqlContext - implicit val sqlContext: SQLContext = spark.sqlContext - - // Test MemoryStream[A]() with implicit sqlContext - val stream1 = MemoryStream[Int]() - assert(stream1 != null) - - // Test MemoryStream[A](numPartitions) with implicit sqlContext - val stream2 = MemoryStream[String](3) - assert(stream2 != null) - - // Verify the streams work correctly - stream1.addData(1, 2, 3) - val df1 = stream1.toDF() - assert(df1.schema.fieldNames.contains("value")) - - stream2.addData("a", "b", "c") - val df2 = stream2.toDF() - assert(df2.schema.fieldNames.contains("value")) - } - - test("LowPriorityContinuousMemoryStreamImplicits works with implicit sqlContext") { - import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream - // Test that ContinuousMemoryStream can be created using implicit sqlContext - implicit val sqlContext: SQLContext = spark.sqlContext - - // Test ContinuousMemoryStream[A]() with implicit sqlContext - val stream1 = ContinuousMemoryStream[Int]() - assert(stream1 != null) - - // Test ContinuousMemoryStream[A](numPartitions) with implicit sqlContext - val stream2 = ContinuousMemoryStream[String](3) - assert(stream2 != null) - - // Test ContinuousMemoryStream.singlePartition with implicit sqlContext - val stream3 = ContinuousMemoryStream.singlePartition[Int]() - assert(stream3 != null) - - // Verify the streams work correctly - stream1.addData(Seq(1, 2, 3)) - stream2.addData(Seq("a", "b", "c")) - stream3.addData(Seq(10, 20)) - - // Basic verification that streams are functional - assert(stream1.initialOffset() != null) - assert(stream2.initialOffset() != null) - assert(stream3.initialOffset() != null) - } - - test("LowPriorityLowLatencyMemoryStreamImplicits works with implicit sqlContext") { - import org.apache.spark.sql.execution.streaming.LowLatencyMemoryStream - // Test that LowLatencyMemoryStream can be created using implicit sqlContext - implicit val sqlContext: SQLContext = spark.sqlContext - - // Test LowLatencyMemoryStream[A]() with implicit sqlContext - val stream1 = LowLatencyMemoryStream[Int]() - assert(stream1 != null) - - // Test LowLatencyMemoryStream[A](numPartitions) with implicit sqlContext - val stream2 = LowLatencyMemoryStream[String](3) - assert(stream2 != null) - - // Test LowLatencyMemoryStream.singlePartition with implicit sqlContext - val stream3 = LowLatencyMemoryStream.singlePartition[Int]() - assert(stream3 != null) - - // Verify the streams work correctly - stream1.addData(Seq(1, 2, 3)) - stream2.addData(Seq("a", "b", "c")) - stream3.addData(Seq(10, 20)) - - // Basic verification that streams are functional - assert(stream1.initialOffset() != null) - assert(stream2.initialOffset() != null) - assert(stream3.initialOffset() != null) - } - private implicit def intsToDF(seq: Seq[Int])(implicit schema: StructType): DataFrame = { require(schema.fields.length === 1) sqlContext.createDataset(seq).toDF(schema.fieldNames.head) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 7446390e8d068..03257d3da373b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -156,14 +156,14 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("query stop deactivates related store providers") { var coordRef: StateStoreCoordinatorRef = null try { - implicit val spark: SparkSession = SparkSession.builder().sparkContext(sc).getOrCreate() + val spark: SparkSession = SparkSession.builder().sparkContext(sc).getOrCreate() SparkSession.setActiveSession(spark) import spark.implicits._ coordRef = spark.streams.stateStoreCoordinator spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") // Start a query and run a batch to load state stores - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // stateful query val checkpointLocation = Utils.createTempDir().getAbsoluteFile val query = aggregated.writeStream @@ -279,10 +279,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark - val inputData = MemoryStream[Int] + + val inputData = MemoryStream[Int](spark) val query = setUpStatefulQuery(inputData, "query") - val inputData2 = MemoryStream[Int] + val inputData2 = MemoryStream[Int](spark) val query2 = setUpStatefulQuery(inputData2, "query2") // Add, commit, and wait multiple times to force snapshot versions and time difference // we will detect state store with partition 0 and 1 to be lagged on version 5 @@ -378,8 +378,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark - val inputData = MemoryStream[Int] + + val inputData = MemoryStream[Int](spark) val aggregated = inputData.toDF().groupBy("value").agg(count("*")) val checkpointLocation = Utils.createTempDir().getAbsoluteFile val query = aggregated.writeStream @@ -444,8 +444,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark - val inputData = MemoryStream[Int] + + val inputData = MemoryStream[Int](spark) val query = setUpStatefulQuery(inputData, "query") // Add, commit, and wait multiple times to force snapshot versions and time difference (0 until 6).foreach { _ => @@ -481,10 +481,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark + // Start a join query and run some data to force snapshot uploads - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] + val input1 = MemoryStream[Int](spark) + val input2 = MemoryStream[Int](spark) val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") val joined = df1.join(df2, expr("leftKey = rightKey")) @@ -525,10 +525,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark + // Start and run two queries together with some data to force snapshot uploads - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] + val input1 = MemoryStream[Int](spark) + val input2 = MemoryStream[Int](spark) val query1 = setUpStatefulQuery(input1, "query1") val query2 = setUpStatefulQuery(input2, "query2") @@ -593,9 +593,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark + // Start a query and run some data to force snapshot uploads - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val query = setUpStatefulQuery(inputData, "query") // Go through two batches to force two snapshot uploads. @@ -638,9 +638,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sparkSession: SparkSession = spark + // Start a query and run some data to force snapshot uploads - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val query = setUpStatefulQuery(inputData, "query") // Go through several rounds of input to force snapshot uploads @@ -682,7 +682,7 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { withTempDir { srcDir => - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val query = inputData.toDF().dropDuplicates() val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) // Keep track of state checkpoint directory for the second run @@ -805,7 +805,7 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { withTempDir { srcDir => - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val query = inputData.toDF().dropDuplicates() testStream(query)( @@ -884,7 +884,7 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { withTempDir { srcDir => - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val query = inputData.toDF().dropDuplicates() // Populate state stores with an initial snapshot, so that timestamp isn't marked diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index b13998708b615..d35263b655d2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1234,11 +1234,11 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("SPARK-21145: Restarted queries create new provider instances") { try { val checkpointLocation = Utils.createTempDir().getAbsoluteFile - implicit val spark: SparkSession = SparkSession.builder().master("local[2]").getOrCreate() + val spark: SparkSession = SparkSession.builder().master("local[2]").getOrCreate() SparkSession.setActiveSession(spark) spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") import spark.implicits._ - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = { val aggregated = inputData.toDF().groupBy("value").agg(count("*")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 5269edc682210..8dbe8e334ad69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -832,7 +832,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { def constructUnionDf(desiredPartitionsForInput1: Int) : (MemoryStream[String], MemoryStream[String], DataFrame) = { - val input1 = MemoryStream[String](desiredPartitionsForInput1) + val input1 = MemoryStream[String](spark, desiredPartitionsForInput1) val input2 = MemoryStream[String] val df1 = input1.toDF() .select($"value", $"value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 7825730d901da..f065f1de5cdc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -347,7 +347,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { " shifted partition IDs") { def constructUnionDf(desiredPartitionsForInput1: Int) : (MemoryStream[Int], MemoryStream[Int], DataFrame) = { - val input1 = MemoryStream[Int](desiredPartitionsForInput1) + val input1 = MemoryStream[Int](spark, desiredPartitionsForInput1) val input2 = MemoryStream[Int] val df1 = input1.toDF() .select($"value", $"value" + 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index dbc2b767b0f9a..7aec3353cd4dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -343,7 +343,7 @@ class StreamingDeduplicationSuite extends StateStoreMetricsTest " shifted partition IDs") { def constructUnionDf(desiredPartitionsForInput1: Int) : (MemoryStream[Int], MemoryStream[Int], DataFrame) = { - val input1 = MemoryStream[Int](desiredPartitionsForInput1) + val input1 = MemoryStream[Int](spark, desiredPartitionsForInput1) val input2 = MemoryStream[Int] val df1 = input1.toDF().select($"value") val df2 = input2.toDF().dropDuplicates("value") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 22028a585e229..6cdca9fb5309f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -1609,7 +1609,7 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { test("SPARK-29438: ensure UNION doesn't lead stream-stream join to use shifted partition IDs") { def constructUnionDf(desiredPartitionsForInput1: Int) : (MemoryStream[Int], MemoryStream[Int], MemoryStream[Int], DataFrame) = { - val input1 = MemoryStream[Int](desiredPartitionsForInput1) + val input1 = MemoryStream[Int](spark, desiredPartitionsForInput1) val df1 = input1.toDF() .select( $"value" as "key", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 5c23192b1d3fd..43f15a12cad19 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -2653,7 +2653,7 @@ class HiveDDLSuite |SELECT word, number from t1 """.stripMargin) - val inputData = MemoryStream[Int] + val inputData = MemoryStream[Int](spark) val joined = inputData.toDS().toDF() .join(spark.table("smallTable"), $"value" === $"number") diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala index 7c8181b5b72a5..f37716b4a24d3 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.pipelines.graph -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.streaming.runtime.MemoryStream import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext} import org.apache.spark.sql.test.SharedSparkSession @@ -423,7 +423,6 @@ class ConnectInvalidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ val p = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem = MemoryStream[Int] mem.addData(1) registerPersistedView("a", query = dfFlowFunc(mem.toDF())) @@ -467,7 +466,6 @@ class ConnectInvalidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark registerMaterializedView("a", query = dfFlowFunc(MemoryStream[Int].toDF())) }.resolveToDataflowGraph() @@ -491,7 +489,6 @@ class ConnectInvalidPipelineSuite extends PipelineTest with SharedSparkSession { val graph = new TestGraphRegistrationContext(spark) { registerTable("a") - implicit val sparkSession: SparkSession = spark registerFlow( destinationName = "a", name = "once_flow", diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala index a4bb7c067d875..3ac3c09017506 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.pipelines.graph -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.Union @@ -159,7 +158,6 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ class P extends TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val ints = MemoryStream[Int] ints.addData(1, 2, 3, 4) registerPersistedView("a", query = dfFlowFunc(ints.toDF())) @@ -201,7 +199,6 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ class P extends TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val ints1 = MemoryStream[Int] ints1.addData(1, 2, 3, 4) val ints2 = MemoryStream[Int] @@ -362,7 +359,6 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ class P extends TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem = MemoryStream[Int] registerPersistedView("a", query = dfFlowFunc(mem.toDF())) registerTable("b") @@ -406,7 +402,6 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem = MemoryStream[Int] mem.addData(1, 2) registerPersistedView("complete-view", query = dfFlowFunc(Seq(1, 2).toDF("x"))) @@ -499,7 +494,6 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ val P = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem = MemoryStream[Int] mem.addData(1, 2) registerTemporaryView("a", query = dfFlowFunc(mem.toDF().select($"value" as "x"))) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala index ba8419eb6e9c8..72cc644e57684 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.pipelines.graph import scala.jdk.CollectionConverters._ import org.apache.spark.SparkThrowable -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} import org.apache.spark.sql.connector.expressions.{ClusterByTransform, Expressions, FieldReference} import org.apache.spark.sql.execution.streaming.runtime.MemoryStream @@ -269,7 +268,7 @@ abstract class MaterializeTablesSuite extends BaseCoreExecutionTest { test("invalid schema merge") { val session = spark - implicit val sparkSession: SparkSession = spark + implicit val sqlCtx: SQLContext = spark.sqlContext import session.implicits._ val streamInts = MemoryStream[Int] @@ -353,7 +352,6 @@ abstract class MaterializeTablesSuite extends BaseCoreExecutionTest { val ex = intercept[TableMaterializationException] { materializeGraph(new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val source: MemoryStream[Int] = MemoryStream[Int] source.addData(1, 2) registerTable( @@ -646,7 +644,7 @@ abstract class MaterializeTablesSuite extends BaseCoreExecutionTest { s"Streaming tables should evolve schema only if not full refresh = $isFullRefresh" ) { val session = spark - implicit val sparkSession: SparkSession = spark + implicit val sqlCtx: SQLContext = spark.sqlContext import session.implicits._ val streamInts = MemoryStream[Int] diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala index c37a6fb52f95d..71301c34c14ef 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.pipelines.graph import org.apache.hadoop.fs.Path -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryWrapper} import org.apache.spark.sql.pipelines.utils.{ExecutionTest, TestGraphRegistrationContext} @@ -39,7 +38,6 @@ class SystemMetadataSuite // create a pipeline with only a single ST val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem: MemoryStream[Int] = MemoryStream[Int] mem.addData(1, 2, 3) registerView("a", query = dfFlowFunc(mem.toDF())) @@ -107,7 +105,6 @@ class SystemMetadataSuite import session.implicits._ val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem: MemoryStream[Int] = MemoryStream[Int] mem.addData(1, 2, 3) registerView("a", query = dfFlowFunc(mem.toDF())) @@ -172,7 +169,6 @@ class SystemMetadataSuite import session.implicits._ val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem: MemoryStream[Int] = MemoryStream[Int] mem.addData(1, 2, 3) registerView("a", query = dfFlowFunc(mem.toDF())) @@ -234,7 +230,6 @@ class SystemMetadataSuite // create a pipeline with only a single ST val graph = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark val mem: MemoryStream[Int] = MemoryStream[Int] mem.addData(1, 2, 3) registerView("a", query = dfFlowFunc(mem.toDF())) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala index 36b749cc84d9f..57baf4c2d5b11 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.pipelines.graph import org.scalatest.time.{Seconds, Span} -import org.apache.spark.sql.{functions, Row, SparkSession} +import org.apache.spark.sql.{functions, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} @@ -183,7 +183,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession // Construct pipeline val pipelineDef = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark private val ints = MemoryStream[Int] ints.addData(1 until 10: _*) registerView("input", query = dfFlowFunc(ints.toDF())) @@ -260,7 +259,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession // Construct pipeline val pipelineDef = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark private val ints = MemoryStream[Int] registerView("input", query = dfFlowFunc(ints.toDF())) registerTable( @@ -311,7 +309,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession }) val pipelineDef = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark private val memoryStream = MemoryStream[Int] memoryStream.addData(1, 2) registerView("input_view", query = dfFlowFunc(memoryStream.toDF())) @@ -551,7 +548,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession // Construct pipeline val pipelineDef = new TestGraphRegistrationContext(spark) { - implicit val sparkSession: SparkSession = spark private val memoryStream = MemoryStream[Int] memoryStream.addData(1, 2) registerView("input_view", query = dfFlowFunc(memoryStream.toDF())) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala index e7c0956385135..9ff92ee895b1d 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.pipelines.utils +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{LocalTempView, PersistedView => PersistedViewType, UnresolvedRelation, ViewType} import org.apache.spark.sql.classic.{DataFrame, SparkSession} @@ -28,7 +29,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * A test class to simplify the creation of pipelines and datasets for unit testing. */ class TestGraphRegistrationContext( - val spark: SparkSession, + val _spark: SparkSession, val sqlConf: Map[String, String] = Map.empty) extends GraphRegistrationContext( defaultCatalog = TestGraphRegistrationContext.DEFAULT_CATALOG, @@ -36,6 +37,10 @@ class TestGraphRegistrationContext( defaultSqlConf = sqlConf ) { + /** Re-expose as implicit so nested anonymous classes can use it without shadowing issues */ + implicit def spark: SparkSession = _spark + implicit def sqlContext: SQLContext = _spark.sqlContext + // scalastyle:off // Disable scalastyle to ignore argument count. /** Registers a streaming table in this [[TestGraphRegistrationContext]] */ @@ -145,7 +150,7 @@ class TestGraphRegistrationContext( val qualifiedIdentifier = GraphIdentifierManager .parseAndQualifyTableIdentifier( rawTableIdentifier = GraphIdentifierManager - .parseTableIdentifier(name, spark), + .parseTableIdentifier(name, _spark), currentCatalog = catalog.orElse(Some(defaultCatalog)), currentDatabase = database.orElse(Some(defaultDatabase))) .identifier @@ -304,9 +309,9 @@ class TestGraphRegistrationContext( catalog: Option[String] = None, database: Option[String] = None ): Unit = { - val rawFlowIdentifier = GraphIdentifierManager.parseTableIdentifier(name, spark) + val rawFlowIdentifier = GraphIdentifierManager.parseTableIdentifier(name, _spark) val rawDestinationIdentifier = - GraphIdentifierManager.parseTableIdentifier(destinationName, spark) + GraphIdentifierManager.parseTableIdentifier(destinationName, _spark) val flowWritesToView = getViews .filter(_.isInstanceOf[TemporaryView])