diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 225ad6d45afb1..16f40396490c7 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -477,7 +477,7 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": name=j["name"], timestamp=j["timestamp"], batchId=j["batchId"], - batchDuration=j["batchDuration"], + batchDuration=j.get("batchDuration", None), durationMs=dict(j["durationMs"]) if "durationMs" in j else {}, eventTime=dict(j["eventTime"]) if "eventTime" in j else {}, stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]], diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 4bf58bf7807b3..5069a76cfdb73 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -18,39 +18,31 @@ import unittest import time +import pyspark.cloudpickle from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin -from pyspark.sql.streaming.listener import StreamingQueryListener, QueryStartedEvent -from pyspark.sql.types import StructType, StructField, StringType +from pyspark.sql.streaming.listener import StreamingQueryListener +from pyspark.sql.functions import count, lit from pyspark.testing.connectutils import ReusedConnectTestCase -def get_start_event_schema(): - return StructType( - [ - StructField("id", StringType(), True), - StructField("runId", StringType(), True), - StructField("name", StringType(), True), - StructField("timestamp", StringType(), True), - ] - ) - - class TestListener(StreamingQueryListener): def onQueryStarted(self, event): - df = self.spark.createDataFrame( - data=[(str(event.id), str(event.runId), event.name, event.timestamp)], - schema=get_start_event_schema(), - ) - df.write.saveAsTable("listener_start_events") + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_start_events") def onQueryProgress(self, event): - pass + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_progress_events") def onQueryIdle(self, event): pass def onQueryTerminated(self, event): - pass + e = pyspark.cloudpickle.dumps(event) + df = self.spark.createDataFrame(data=[(e,)]) + df.write.mode("append").saveAsTable("listener_terminated_events") class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): @@ -65,17 +57,36 @@ def test_listener_events(self): time.sleep(30) df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - q = df.writeStream.format("noop").queryName("test").start() + df_observe = df.observe("my_event", count(lit(1)).alias("rc")) + df_stateful = df_observe.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) self.assertTrue(q.isActive) time.sleep(10) + self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran q.stop() + self.assertFalse(q.isActive) + + start_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_start_events").collect()[0][0] + ) + + progress_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_progress_events").collect()[0][0] + ) - start_event = QueryStartedEvent.fromJson( - self.spark.read.table("listener_start_events").collect()[0].asDict() + terminated_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_terminated_events").collect()[0][0] ) self.check_start_event(start_event) + self.check_progress_event(progress_event) + self.check_terminated_event(terminated_event) finally: self.spark.streams.removeListener(test_listener) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index cbbdc2955e59f..87d0dae00d8bd 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -88,7 +88,7 @@ def check_streaming_query_progress(self, progress): except Exception: self.fail("'%s' is not in ISO 8601 format.") self.assertTrue(isinstance(progress.batchId, int)) - self.assertTrue(isinstance(progress.batchDuration, int)) + self.assertTrue(progress.batchDuration is None or isinstance(progress.batchDuration, int)) self.assertTrue(isinstance(progress.durationMs, dict)) self.assertTrue( set(progress.durationMs.keys()).issubset(