From 8c6d34660e22dc7d64928073e2847c4628390a7b Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 16 Aug 2023 14:08:34 -0700 Subject: [PATCH 01/11] add tests for listener --- python/pyspark/sql/streaming/listener.py | 81 ++++++++- .../connect/streaming/test_parity_listener.py | 154 ++++++++++++++++-- .../streaming/test_streaming_foreachBatch.py | 99 ++++++++++- 3 files changed, 319 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 225ad6d45afb1..5f56d956162b9 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -22,6 +22,15 @@ from py4j.java_gateway import JavaObject from pyspark.sql import Row +from pyspark.sql.types import ( + ArrayType, + StructType, + StructField, + StringType, + IntegerType, + FloatType, + MapType, +) from pyspark import cloudpickle __all__ = ["StreamingQueryListener"] @@ -197,6 +206,26 @@ def fromJson(cls, j: Dict[str, Any]) -> "QueryStartedEvent": timestamp=j["timestamp"], ) + def asDict(self) -> Dict[str, Any]: + def conv(obj: Any) -> Any: + if isinstance(obj, uuid.UUID): + return str(obj) + else: + return obj + + return {k[1:]: conv(v) for k, v in self.__dict__.items()} + + @staticmethod + def schema(): + return StructType( + [ + StructField("id", StringType(), False), + StructField("runId", StringType(), False), + StructField("name", StringType(), True), + StructField("timestamp", StringType(), False), + ] + ) + @property def id(self) -> uuid.UUID: """ @@ -257,6 +286,9 @@ def progress(self) -> "StreamingQueryProgress": """ return self._progress + def asDict(self) -> Dict[str, Any]: + return {"progress": self.progress.asDict()} + class QueryIdleEvent: """ @@ -286,6 +318,15 @@ def fromJObject(cls, jevent: JavaObject) -> "QueryIdleEvent": def fromJson(cls, j: Dict[str, Any]) -> "QueryIdleEvent": return cls(id=uuid.UUID(j["id"]), runId=uuid.UUID(j["runId"]), timestamp=j["timestamp"]) + def asDict(self) -> Dict[str, Any]: + def conv(obj: Any) -> Any: + if isinstance(obj, uuid.UUID): + return str(obj) + else: + return obj + + return {k[1:]: conv(v) for k, v in self.__dict__.items()} + @property def id(self) -> uuid.UUID: """ @@ -353,6 +394,15 @@ def fromJson(cls, j: Dict[str, Any]) -> "QueryTerminatedEvent": errorClassOnException=j["errorClassOnException"], ) + def asDict(self) -> Dict[str, Any]: + def conv(obj: Any) -> Any: + if isinstance(obj, uuid.UUID): + return str(obj) + else: + return obj + + return {k[1:]: conv(v) for k, v in self.__dict__.items()} + @property def id(self) -> uuid.UUID: """ @@ -487,13 +537,31 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": inputRowsPerSecond=j["inputRowsPerSecond"], processedRowsPerSecond=j["processedRowsPerSecond"], observedMetrics={ - k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows - for k, row_dict in j["observedMetrics"].items() + k: Row(row) if isinstance(row, str) else + Row(*row.keys())(*row.values()) # Assume no nested rows + for k, row in j["observedMetrics"].items() } if "observedMetrics" in j else {}, ) + def asDict(self) -> Dict[str, Any]: + def conv(obj: Any) -> Any: + if isinstance(obj, uuid.UUID): + return str(obj) + elif isinstance(obj, (SourceProgress, SinkProgress, StateOperatorProgress)): + return obj.asDict() + elif isinstance(obj, Row): + return json.dumps(obj.asDict()) # Assume no nested row in observed metrics + elif isinstance(obj, list): + return [conv(o) for o in obj] + elif isinstance(obj, dict): + return dict((k, conv(v)) for k, v in obj.items()) + else: + return obj + + return {k[1:]: conv(v) for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]} + @property def id(self) -> uuid.UUID: """ @@ -716,6 +784,9 @@ def fromJson(cls, j: Dict[str, Any]) -> "StateOperatorProgress": customMetrics=dict(j["customMetrics"]) if "customMetrics" in j else {}, ) + def asDict(self) -> Dict[str, Any]: + return {k[1:]: v for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]} + @property def operatorName(self) -> str: return self._operatorName @@ -851,6 +922,9 @@ def fromJson(cls, j: Dict[str, Any]) -> "SourceProgress": metrics=dict(j["metrics"]) if "metrics" in j else {}, ) + def asDict(self) -> Dict[str, Any]: + return {k[1:]: v for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]} + @property def description(self) -> str: """ @@ -962,6 +1036,9 @@ def fromJObject(cls, jprogress: JavaObject) -> "SinkProgress": metrics=dict(jprogress.metrics()), ) + def asDict(self) -> Dict[str, Any]: + return {k[1:]: v for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]} + @classmethod def fromJson(cls, j: Dict[str, Any]) -> "SinkProgress": return cls( diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 4bf58bf7807b3..7f6211b8ee465 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -19,38 +19,151 @@ import time from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin -from pyspark.sql.streaming.listener import StreamingQueryListener, QueryStartedEvent -from pyspark.sql.types import StructType, StructField, StringType +from pyspark.sql.streaming.listener import ( + StreamingQueryListener, + QueryStartedEvent, + QueryProgressEvent, + QueryIdleEvent, + QueryTerminatedEvent, +) +from pyspark.sql.types import ( + ArrayType, + StructType, + StructField, + StringType, + IntegerType, + FloatType, + MapType, +) +from pyspark.sql.functions import count, lit from pyspark.testing.connectutils import ReusedConnectTestCase def get_start_event_schema(): return StructType( [ - StructField("id", StringType(), True), - StructField("runId", StringType(), True), + StructField("id", StringType(), False), + StructField("runId", StringType(), False), StructField("name", StringType(), True), - StructField("timestamp", StringType(), True), + StructField("timestamp", StringType(), False), ] ) +def get_idle_event_schema(): + return StructType( + [ + StructField("id", StringType(), False), + StructField("runId", StringType(), False), + StructField("timestamp", StringType(), False), + ] + ) + +def get_terminated_event_schema(): + return StructType( + [ + StructField("id", StringType(), False), + StructField("runId", StringType(), False), + StructField("exception", StringType(), True), + StructField("errorClassOnException", StringType(), True), + ] + ) + +def get_state_operators_progress_schema(): + return StructType( + [ + StructField("operatorName", StringType(), False), + StructField("numRowsTotal", IntegerType(), False), + StructField("numRowsUpdated", IntegerType(), False), + StructField("numRowsRemoved", IntegerType(), False), + StructField("allUpdatesTimeMs", IntegerType(), False), + StructField("allRemovalsTimeMs", IntegerType(), False), + StructField("commitTimeMs", IntegerType(), False), + StructField("memoryUsedBytes", IntegerType(), False), + StructField("numRowsDroppedByWatermark", IntegerType(), False), + StructField("numShufflePartitions", IntegerType(), False), + StructField("numStateStoreInstances", IntegerType(), False), + StructField("customMetrics", MapType(StringType(), IntegerType(), True), True), + ] + ) + + +def get_source_progress_schema(): + return StructType( + [ + StructField("description", StringType(), False), + StructField("startOffset", StringType(), False), + StructField("endOffset", StringType(), False), + StructField("latestOffset", StringType(), False), + StructField("numInputRows", IntegerType(), False), + StructField("inputRowsPerSecond", FloatType(), False), + StructField("processedRowsPerSecond", FloatType(), False), + StructField("metrics", MapType(StringType(), StringType(), True), True), + ] + ) + + +def get_sink_progress_schema(): + return StructType( + [ + StructField("description", StringType(), False), + StructField("numOutputRows", IntegerType(), False), + StructField("metrics", MapType(StringType(), StringType(), True), True), + ] + ) + + +def get_streaming_query_progress_schema(): + return StructType( + [ + StructField("id", StringType(), False), + StructField("runId", StringType(), False), + StructField("name", StringType(), True), + StructField("timestamp", StringType(), False), + StructField("batchId", IntegerType(), False), + StructField("batchDuration", IntegerType(), False), + StructField("durationMs", MapType(StringType(), IntegerType(), True), True), + StructField("eventTime", MapType(StringType(), StringType(), True), True), + StructField("stateOperators", ArrayType(get_state_operators_progress_schema()), True), + StructField("sources", ArrayType(get_source_progress_schema()), True), + StructField("sink", get_sink_progress_schema(), True), # TODO: false? + StructField("numInputRows", IntegerType(), False), + StructField("inputRowsPerSecond", FloatType(), False), + StructField("processedRowsPerSecond", FloatType(), False), + StructField("observedMetrics", MapType(StringType(), StringType()), False), + ] + ) + + +def get_progress_event_schema(): + return StructType([StructField("progress", get_streaming_query_progress_schema(), False)]) + class TestListener(StreamingQueryListener): + def onQueryStarted(self, event): df = self.spark.createDataFrame( - data=[(str(event.id), str(event.runId), event.name, event.timestamp)], - schema=get_start_event_schema(), + data=[(event.asDict())], + schema=event.schema(), ) df.write.saveAsTable("listener_start_events") def onQueryProgress(self, event): - pass + print(event.asDict()) + df = self.spark.createDataFrame( + data=[event.asDict()], + schema=get_progress_event_schema(), + ) + df.write.mode("append").saveAsTable("listener_progress_events") def onQueryIdle(self, event): pass def onQueryTerminated(self, event): - pass + df = self.spark.createDataFrame( + data=[event.asDict()], + schema=get_terminated_event_schema(), + ) + df.write.saveAsTable("listener_terminated_events") class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): @@ -62,20 +175,39 @@ def test_listener_events(self): # This ensures the read socket on the server won't crash (i.e. because of timeout) # when there hasn't been a new event for a long time - time.sleep(30) + # time.sleep(30) df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - q = df.writeStream.format("noop").queryName("test").start() + df_observe = df.observe("my_event", count(lit(1)).alias("rc")) + df_stateful = df_observe.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) self.assertTrue(q.isActive) time.sleep(10) + self.assertTrue(q.lastProgress["batchId"] > 0) # ensure at least one batch is ran q.stop() + self.assertFalse(q.isActive) start_event = QueryStartedEvent.fromJson( self.spark.read.table("listener_start_events").collect()[0].asDict() ) + progress_event = QueryProgressEvent.fromJson( + self.spark.read.table("listener_progress_events").collect()[0].asDict() + ) + + terminated_event = QueryTerminatedEvent.fromJson( + self.spark.read.table("listener_terminated_events").collect()[0].asDict() + ) + self.check_start_event(start_event) + self.check_progress_event(progress_event) + self.check_terminated_event(terminated_event) finally: self.spark.streams.removeListener(test_listener) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py index d4e185c3d856d..d43eb26c5b6f3 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py @@ -15,11 +15,14 @@ # limitations under the License. # -import time - +from pyspark.sql.dataframe import DataFrame from pyspark.testing.sqlutils import ReusedSQLTestCase +def my_test_function_1(): + return 1 + + class StreamingTestsForeachBatchMixin: def test_streaming_foreachBatch(self): q = None @@ -88,6 +91,98 @@ def func(batch_df, _): q.stop() self.assertIsNone(q.exception(), "No exception has to be propagated.") + def test_streaming_foreachBatch_spark_session(self): + table_name = "testTable-foreachBatch" + + def func(df: DataFrame, _): + spark = df.sparkSession + df1 = spark.createDataFrame([("structured",), ("streaming",)]) + df1.union(df).write.saveAsTable(table_name) + + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() + + actual = self.spark.read.table(table_name) + df = self.spark.read.format("text").load(path="python/test_support/sql/streaming/").union( + self.spark.createDataFrame([("structured",), ("streaming",)]) + ) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + def test_streaming_foreachBatch_path_access(self): + table_name = "testTable-foreachBatch-path" + + def func(df: DataFrame, _): + spark = df.sparkSession + df1 = spark.read.format("text").load("python/test_support/sql/streaming") + df1.union(df).write.saveAsTable(table_name) + + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() + + actual = self.spark.read.table(table_name) + df = self.spark.read.format("text").load(path="python/test_support/sql/streaming/") + df = df.union(df) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + # write to delta table? + + @staticmethod + def my_test_function_2(): + return 2 + + def test_streaming_foreachBatch_fuction_calling(self): + def my_test_function_3(): + return 3 + + table_name = "testTable-foreachBatch-function" + + def func(df: DataFrame, _): + spark = df.sparkSession + df1 = spark.createDataFrame([ + (my_test_function_1(),), + (StreamingTestsForeachBatchMixin.my_test_function_2(),), + (my_test_function_3(),), + ]) + df1.write.saveAsTable(table_name) + + df = self.spark.readStream.format("rate").load() + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() + + actual = self.spark.read.table(table_name) + df = self.spark.createDataFrame([ + (my_test_function_1(),), + (StreamingTestsForeachBatchMixin.my_test_function_2(),), + (my_test_function_3(),), + ]) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + def test_streaming_foreachBatch_import(self): + import time # not imported in foreachBatch_worker + table_name = "testTable-foreachBatch-import" + + def func(df: DataFrame, _): + time.sleep(1) + spark = df.sparkSession + df1 = spark.read.format("text").load("python/test_support/sql/streaming") + df1.write.saveAsTable(table_name) + + df = self.spark.readStream.format("rate").load() + q = df.writeStream.foreachBatch(func).start() + q.processAllAvailable() + q.stop() + + actual = self.spark.read.table(table_name) + df = self.spark.read.format("text").load("python/test_support/sql/streaming") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + + class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, ReusedSQLTestCase): pass From d2a9d887c7f32312191ae30eb0ce694b13cf4b43 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 16 Aug 2023 14:24:56 -0700 Subject: [PATCH 02/11] minor update foreachBatch, fmt --- python/pyspark/sql/streaming/listener.py | 9 ++- .../connect/streaming/test_parity_listener.py | 6 +- .../streaming/test_streaming_foreachBatch.py | 62 ++++++++++++------- 3 files changed, 48 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 5f56d956162b9..76639b355a78d 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -537,8 +537,9 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": inputRowsPerSecond=j["inputRowsPerSecond"], processedRowsPerSecond=j["processedRowsPerSecond"], observedMetrics={ - k: Row(row) if isinstance(row, str) else - Row(*row.keys())(*row.values()) # Assume no nested rows + k: Row(row) + if isinstance(row, str) + else Row(*row.keys())(*row.values()) # Assume no nested rows for k, row in j["observedMetrics"].items() } if "observedMetrics" in j @@ -560,7 +561,9 @@ def conv(obj: Any) -> Any: else: return obj - return {k[1:]: conv(v) for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]} + return { + k[1:]: conv(v) for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"] + } @property def id(self) -> uuid.UUID: diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 7f6211b8ee465..4146773926083 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -49,6 +49,7 @@ def get_start_event_schema(): ] ) + def get_idle_event_schema(): return StructType( [ @@ -58,6 +59,7 @@ def get_idle_event_schema(): ] ) + def get_terminated_event_schema(): return StructType( [ @@ -68,6 +70,7 @@ def get_terminated_event_schema(): ] ) + def get_state_operators_progress_schema(): return StructType( [ @@ -125,7 +128,7 @@ def get_streaming_query_progress_schema(): StructField("eventTime", MapType(StringType(), StringType(), True), True), StructField("stateOperators", ArrayType(get_state_operators_progress_schema()), True), StructField("sources", ArrayType(get_source_progress_schema()), True), - StructField("sink", get_sink_progress_schema(), True), # TODO: false? + StructField("sink", get_sink_progress_schema(), True), # TODO: false? StructField("numInputRows", IntegerType(), False), StructField("inputRowsPerSecond", FloatType(), False), StructField("processedRowsPerSecond", FloatType(), False), @@ -139,7 +142,6 @@ def get_progress_event_schema(): class TestListener(StreamingQueryListener): - def onQueryStarted(self, event): df = self.spark.createDataFrame( data=[(event.asDict())], diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py index d43eb26c5b6f3..65a0f6279fb08 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py @@ -15,6 +15,7 @@ # limitations under the License. # +import time from pyspark.sql.dataframe import DataFrame from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -92,12 +93,14 @@ def func(batch_df, _): self.assertIsNone(q.exception(), "No exception has to be propagated.") def test_streaming_foreachBatch_spark_session(self): - table_name = "testTable-foreachBatch" + table_name = "testTable_foreachBatch" - def func(df: DataFrame, _): + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return spark = df.sparkSession df1 = spark.createDataFrame([("structured",), ("streaming",)]) - df1.union(df).write.saveAsTable(table_name) + df1.union(df).write.mode("append").saveAsTable(table_name) df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") q = df.writeStream.foreachBatch(func).start() @@ -105,18 +108,22 @@ def func(df: DataFrame, _): q.stop() actual = self.spark.read.table(table_name) - df = self.spark.read.format("text").load(path="python/test_support/sql/streaming/").union( - self.spark.createDataFrame([("structured",), ("streaming",)]) + df = ( + self.spark.read.format("text") + .load(path="python/test_support/sql/streaming/") + .union(self.spark.createDataFrame([("structured",), ("streaming",)])) ) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) def test_streaming_foreachBatch_path_access(self): - table_name = "testTable-foreachBatch-path" + table_name = "testTable_foreachBatch_path" - def func(df: DataFrame, _): + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return spark = df.sparkSession df1 = spark.read.format("text").load("python/test_support/sql/streaming") - df1.union(df).write.saveAsTable(table_name) + df1.union(df).write.mode("append").saveAsTable(table_name) df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") q = df.writeStream.foreachBatch(func).start() @@ -138,16 +145,20 @@ def test_streaming_foreachBatch_fuction_calling(self): def my_test_function_3(): return 3 - table_name = "testTable-foreachBatch-function" + table_name = "testTable_foreachBatch_function" - def func(df: DataFrame, _): + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return spark = df.sparkSession - df1 = spark.createDataFrame([ - (my_test_function_1(),), - (StreamingTestsForeachBatchMixin.my_test_function_2(),), - (my_test_function_3(),), - ]) - df1.write.saveAsTable(table_name) + df1 = spark.createDataFrame( + [ + (my_test_function_1(),), + (StreamingTestsForeachBatchMixin.my_test_function_2(),), + (my_test_function_3(),), + ] + ) + df1.write.mode("append").saveAsTable(table_name) df = self.spark.readStream.format("rate").load() q = df.writeStream.foreachBatch(func).start() @@ -155,22 +166,27 @@ def func(df: DataFrame, _): q.stop() actual = self.spark.read.table(table_name) - df = self.spark.createDataFrame([ + df = self.spark.createDataFrame( + [ (my_test_function_1(),), (StreamingTestsForeachBatchMixin.my_test_function_2(),), (my_test_function_3(),), - ]) + ] + ) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) def test_streaming_foreachBatch_import(self): - import time # not imported in foreachBatch_worker - table_name = "testTable-foreachBatch-import" + import time # not imported in foreachBatch_worker - def func(df: DataFrame, _): + table_name = "testTable_foreachBatch_import" + + def func(df: DataFrame, batch_id: int): + if batch_id > 0: # only process once + return time.sleep(1) spark = df.sparkSession df1 = spark.read.format("text").load("python/test_support/sql/streaming") - df1.write.saveAsTable(table_name) + df1.write.mode("append").saveAsTable(table_name) df = self.spark.readStream.format("rate").load() q = df.writeStream.foreachBatch(func).start() @@ -182,8 +198,6 @@ def func(df: DataFrame, _): self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, ReusedSQLTestCase): pass From b3529ebb0c01c82d55933737b101fcac5716c4f2 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 16 Aug 2023 14:33:41 -0700 Subject: [PATCH 03/11] minor --- python/pyspark/sql/streaming/listener.py | 11 ----------- .../tests/connect/streaming/test_parity_listener.py | 4 ++-- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 76639b355a78d..dd2d290fed542 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -215,17 +215,6 @@ def conv(obj: Any) -> Any: return {k[1:]: conv(v) for k, v in self.__dict__.items()} - @staticmethod - def schema(): - return StructType( - [ - StructField("id", StringType(), False), - StructField("runId", StringType(), False), - StructField("name", StringType(), True), - StructField("timestamp", StringType(), False), - ] - ) - @property def id(self) -> uuid.UUID: """ diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 4146773926083..d9e780a4f2735 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -145,7 +145,7 @@ class TestListener(StreamingQueryListener): def onQueryStarted(self, event): df = self.spark.createDataFrame( data=[(event.asDict())], - schema=event.schema(), + schema=get_start_event_schema(), ) df.write.saveAsTable("listener_start_events") @@ -177,7 +177,7 @@ def test_listener_events(self): # This ensures the read socket on the server won't crash (i.e. because of timeout) # when there hasn't been a new event for a long time - # time.sleep(30) + time.sleep(30) df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() df_observe = df.observe("my_event", count(lit(1)).alias("rc")) From 600623ccfaa2a6795bae33c68ca5d5f2b64fd971 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 16 Aug 2023 14:35:09 -0700 Subject: [PATCH 04/11] minor --- .../sql/tests/connect/streaming/test_parity_listener.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index d9e780a4f2735..3d509148cccfa 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -128,11 +128,11 @@ def get_streaming_query_progress_schema(): StructField("eventTime", MapType(StringType(), StringType(), True), True), StructField("stateOperators", ArrayType(get_state_operators_progress_schema()), True), StructField("sources", ArrayType(get_source_progress_schema()), True), - StructField("sink", get_sink_progress_schema(), True), # TODO: false? + StructField("sink", get_sink_progress_schema(), True), StructField("numInputRows", IntegerType(), False), StructField("inputRowsPerSecond", FloatType(), False), StructField("processedRowsPerSecond", FloatType(), False), - StructField("observedMetrics", MapType(StringType(), StringType()), False), + StructField("observedMetrics", MapType(StringType(), StringType()), True), ] ) From a6304c9683767c95436a8264b98f837a3b274641 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 16 Aug 2023 14:42:21 -0700 Subject: [PATCH 05/11] add comment --- .../sql/tests/connect/streaming/test_parity_listener.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 3d509148cccfa..a7381e935e54f 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -132,6 +132,9 @@ def get_streaming_query_progress_schema(): StructField("numInputRows", IntegerType(), False), StructField("inputRowsPerSecond", FloatType(), False), StructField("processedRowsPerSecond", FloatType(), False), + # it's difficult to get the schema of observed metrics. + # Just serialize the row to string for now. + # Things would be easier if there is a method to get the schema of Row in PySpark StructField("observedMetrics", MapType(StringType(), StringType()), True), ] ) From ffac7ec49c3e67abe62b821ac072d0c992655108 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 16 Aug 2023 14:42:41 -0700 Subject: [PATCH 06/11] minor --- .../pyspark/sql/tests/connect/streaming/test_parity_listener.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index a7381e935e54f..f96ff65cedf32 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -153,7 +153,6 @@ def onQueryStarted(self, event): df.write.saveAsTable("listener_start_events") def onQueryProgress(self, event): - print(event.asDict()) df = self.spark.createDataFrame( data=[event.asDict()], schema=get_progress_event_schema(), From fa8be5cb7a4105c0668de389cdd622fcfbfbf28b Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 21 Aug 2023 11:17:42 -0700 Subject: [PATCH 07/11] move asDict methods to test suite --- python/pyspark/sql/streaming/listener.py | 67 ------------------- .../connect/streaming/test_parity_listener.py | 58 ++++++++++++++-- 2 files changed, 53 insertions(+), 72 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index dd2d290fed542..f14f4b26be861 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -22,15 +22,6 @@ from py4j.java_gateway import JavaObject from pyspark.sql import Row -from pyspark.sql.types import ( - ArrayType, - StructType, - StructField, - StringType, - IntegerType, - FloatType, - MapType, -) from pyspark import cloudpickle __all__ = ["StreamingQueryListener"] @@ -206,15 +197,6 @@ def fromJson(cls, j: Dict[str, Any]) -> "QueryStartedEvent": timestamp=j["timestamp"], ) - def asDict(self) -> Dict[str, Any]: - def conv(obj: Any) -> Any: - if isinstance(obj, uuid.UUID): - return str(obj) - else: - return obj - - return {k[1:]: conv(v) for k, v in self.__dict__.items()} - @property def id(self) -> uuid.UUID: """ @@ -275,9 +257,6 @@ def progress(self) -> "StreamingQueryProgress": """ return self._progress - def asDict(self) -> Dict[str, Any]: - return {"progress": self.progress.asDict()} - class QueryIdleEvent: """ @@ -307,15 +286,6 @@ def fromJObject(cls, jevent: JavaObject) -> "QueryIdleEvent": def fromJson(cls, j: Dict[str, Any]) -> "QueryIdleEvent": return cls(id=uuid.UUID(j["id"]), runId=uuid.UUID(j["runId"]), timestamp=j["timestamp"]) - def asDict(self) -> Dict[str, Any]: - def conv(obj: Any) -> Any: - if isinstance(obj, uuid.UUID): - return str(obj) - else: - return obj - - return {k[1:]: conv(v) for k, v in self.__dict__.items()} - @property def id(self) -> uuid.UUID: """ @@ -383,15 +353,6 @@ def fromJson(cls, j: Dict[str, Any]) -> "QueryTerminatedEvent": errorClassOnException=j["errorClassOnException"], ) - def asDict(self) -> Dict[str, Any]: - def conv(obj: Any) -> Any: - if isinstance(obj, uuid.UUID): - return str(obj) - else: - return obj - - return {k[1:]: conv(v) for k, v in self.__dict__.items()} - @property def id(self) -> uuid.UUID: """ @@ -535,25 +496,6 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": else {}, ) - def asDict(self) -> Dict[str, Any]: - def conv(obj: Any) -> Any: - if isinstance(obj, uuid.UUID): - return str(obj) - elif isinstance(obj, (SourceProgress, SinkProgress, StateOperatorProgress)): - return obj.asDict() - elif isinstance(obj, Row): - return json.dumps(obj.asDict()) # Assume no nested row in observed metrics - elif isinstance(obj, list): - return [conv(o) for o in obj] - elif isinstance(obj, dict): - return dict((k, conv(v)) for k, v in obj.items()) - else: - return obj - - return { - k[1:]: conv(v) for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"] - } - @property def id(self) -> uuid.UUID: """ @@ -776,9 +718,6 @@ def fromJson(cls, j: Dict[str, Any]) -> "StateOperatorProgress": customMetrics=dict(j["customMetrics"]) if "customMetrics" in j else {}, ) - def asDict(self) -> Dict[str, Any]: - return {k[1:]: v for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]} - @property def operatorName(self) -> str: return self._operatorName @@ -914,9 +853,6 @@ def fromJson(cls, j: Dict[str, Any]) -> "SourceProgress": metrics=dict(j["metrics"]) if "metrics" in j else {}, ) - def asDict(self) -> Dict[str, Any]: - return {k[1:]: v for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]} - @property def description(self) -> str: """ @@ -1028,9 +964,6 @@ def fromJObject(cls, jprogress: JavaObject) -> "SinkProgress": metrics=dict(jprogress.metrics()), ) - def asDict(self) -> Dict[str, Any]: - return {k[1:]: v for k, v in self.__dict__.items() if k not in ["_jprogress", "_jdict"]} - @classmethod def fromJson(cls, j: Dict[str, Any]) -> "SinkProgress": return cls( diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index f96ff65cedf32..e7e74fbe0ce64 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -17,6 +17,9 @@ import unittest import time +import uuid +import json +from typing import Any, Dict, Union from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin from pyspark.sql.streaming.listener import ( @@ -25,6 +28,10 @@ QueryProgressEvent, QueryIdleEvent, QueryTerminatedEvent, + StateOperatorProgress, + StreamingQueryProgress, + SourceProgress, + SinkProgress, ) from pyspark.sql.types import ( ArrayType, @@ -35,10 +42,51 @@ FloatType, MapType, ) +from pyspark.sql import Row from pyspark.sql.functions import count, lit from pyspark.testing.connectutils import ReusedConnectTestCase +def listener_event_as_dict( + e: Union[QueryStartedEvent, QueryProgressEvent, QueryIdleEvent, QueryTerminatedEvent] +) -> Dict[str, Any]: + if isinstance(e, QueryProgressEvent): + return {"progress": streaming_query_progress_as_dict(e.progress)} + else: + + def conv(obj: Any) -> Any: + if isinstance(obj, uuid.UUID): + return str(obj) + else: + return obj + + return {k[1:]: conv(v) for k, v in e.__dict__.items()} + + +def streaming_query_progress_as_dict(e: StreamingQueryProgress) -> Dict[str, Any]: + def conv(obj: Any) -> Any: + if isinstance(obj, uuid.UUID): + return str(obj) + elif isinstance(obj, (SourceProgress, SinkProgress, StateOperatorProgress)): + return other_progress_as_dict(obj) + elif isinstance(obj, Row): + return json.dumps(obj.asDict()) # Assume no nested row in observed metrics + elif isinstance(obj, list): + return [conv(o) for o in obj] + elif isinstance(obj, dict): + return dict((k, conv(v)) for k, v in obj.items()) + else: + return obj + + return {k[1:]: conv(v) for k, v in e.__dict__.items() if k not in ["_jprogress", "_jdict"]} + + +def other_progress_as_dict( + e: Union[StateOperatorProgress, SourceProgress, SinkProgress] +) -> Dict[str, Any]: + return {k[1:]: v for k, v in e.__dict__.items() if k not in ["_jprogress", "_jdict"]} + + def get_start_event_schema(): return StructType( [ @@ -147,14 +195,14 @@ def get_progress_event_schema(): class TestListener(StreamingQueryListener): def onQueryStarted(self, event): df = self.spark.createDataFrame( - data=[(event.asDict())], + data=[listener_event_as_dict(event)], schema=get_start_event_schema(), ) - df.write.saveAsTable("listener_start_events") + df.write.mode("append").saveAsTable("listener_start_events") def onQueryProgress(self, event): df = self.spark.createDataFrame( - data=[event.asDict()], + data=[listener_event_as_dict(event)], schema=get_progress_event_schema(), ) df.write.mode("append").saveAsTable("listener_progress_events") @@ -164,10 +212,10 @@ def onQueryIdle(self, event): def onQueryTerminated(self, event): df = self.spark.createDataFrame( - data=[event.asDict()], + data=[listener_event_as_dict(event)], schema=get_terminated_event_schema(), ) - df.write.saveAsTable("listener_terminated_events") + df.write.mode("append").saveAsTable("listener_terminated_events") class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): From 132b2c86da330802b3d45d5c5e2bac355658b29a Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 21 Aug 2023 11:25:10 -0700 Subject: [PATCH 08/11] add comments for observed metrics --- python/pyspark/sql/streaming/listener.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index f14f4b26be861..1f675c2d7560b 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -487,8 +487,10 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": inputRowsPerSecond=j["inputRowsPerSecond"], processedRowsPerSecond=j["processedRowsPerSecond"], observedMetrics={ - k: Row(row) - if isinstance(row, str) + # in test_parity_listener, observed metrics is serialized into string, + # this won't happen in production. + k: Row(row) # for test only, + if isinstance(row, str) # for test only, else Row(*row.keys())(*row.values()) # Assume no nested rows for k, row in j["observedMetrics"].items() } From 51b55805a12618cfc955f3186ba733122c3756d2 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 21 Aug 2023 14:18:39 -0700 Subject: [PATCH 09/11] wip --- .../spark/sql/connect/planner/StreamingForeachBatchHelper.scala | 1 - .../pyspark/sql/connect/streaming/worker/foreachBatch_worker.py | 1 - python/pyspark/sql/connect/streaming/worker/listener_worker.py | 1 - 3 files changed, 3 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 21e4adb9896b6..ef7195439f9cd 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -113,7 +113,6 @@ object StreamingForeachBatchHelper extends Logging { val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => { - // TODO(SPARK-44460): Support Auth credentials // TODO(SPARK-44462): A new session id pointing to args.df.sparkSession needs to be created. // This is because MicroBatch execution clones the session during start. // The session attached to the foreachBatch dataframe is different from the one the one diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py index cf61463cd6870..72037f1263dbf 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py @@ -51,7 +51,6 @@ def main(infile: IO, outfile: IO) -> None: spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() spark_connect_session._client._session_id = session_id # type: ignore[attr-defined] - # TODO(SPARK-44460): Pass credentials. # TODO(SPARK-44461): Enable Process Isolation func = worker.read_command(pickle_ser, infile) diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py index e1f4678e42f16..c026945767d94 100644 --- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -59,7 +59,6 @@ def main(infile: IO, outfile: IO) -> None: spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() spark_connect_session._client._session_id = session_id # type: ignore[attr-defined] - # TODO(SPARK-44460): Pass credentials. # TODO(SPARK-44461): Enable Process Isolation listener = worker.read_command(pickle_ser, infile) From 85adb41730b45dfa4a851f9a9a5684834445caf4 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 23 Aug 2023 13:41:13 -0700 Subject: [PATCH 10/11] address Hyukjin's comment, use cloudPickle for events --- python/pyspark/sql/streaming/listener.py | 6 +- .../connect/streaming/test_parity_listener.py | 173 ++---------------- 2 files changed, 17 insertions(+), 162 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 1f675c2d7560b..4f94a45cdcb44 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -487,11 +487,7 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": inputRowsPerSecond=j["inputRowsPerSecond"], processedRowsPerSecond=j["processedRowsPerSecond"], observedMetrics={ - # in test_parity_listener, observed metrics is serialized into string, - # this won't happen in production. - k: Row(row) # for test only, - if isinstance(row, str) # for test only, - else Row(*row.keys())(*row.values()) # Assume no nested rows + Row(*row.keys())(*row.values()) # Assume no nested rows for k, row in j["observedMetrics"].items() } if "observedMetrics" in j diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index e7e74fbe0ce64..059c064e712bf 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -21,6 +21,7 @@ import json from typing import Any, Dict, Union +import pyspark.cloudpickle from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin from pyspark.sql.streaming.listener import ( StreamingQueryListener, @@ -47,163 +48,20 @@ from pyspark.testing.connectutils import ReusedConnectTestCase -def listener_event_as_dict( - e: Union[QueryStartedEvent, QueryProgressEvent, QueryIdleEvent, QueryTerminatedEvent] -) -> Dict[str, Any]: - if isinstance(e, QueryProgressEvent): - return {"progress": streaming_query_progress_as_dict(e.progress)} - else: - - def conv(obj: Any) -> Any: - if isinstance(obj, uuid.UUID): - return str(obj) - else: - return obj - - return {k[1:]: conv(v) for k, v in e.__dict__.items()} - - -def streaming_query_progress_as_dict(e: StreamingQueryProgress) -> Dict[str, Any]: - def conv(obj: Any) -> Any: - if isinstance(obj, uuid.UUID): - return str(obj) - elif isinstance(obj, (SourceProgress, SinkProgress, StateOperatorProgress)): - return other_progress_as_dict(obj) - elif isinstance(obj, Row): - return json.dumps(obj.asDict()) # Assume no nested row in observed metrics - elif isinstance(obj, list): - return [conv(o) for o in obj] - elif isinstance(obj, dict): - return dict((k, conv(v)) for k, v in obj.items()) - else: - return obj - - return {k[1:]: conv(v) for k, v in e.__dict__.items() if k not in ["_jprogress", "_jdict"]} - - -def other_progress_as_dict( - e: Union[StateOperatorProgress, SourceProgress, SinkProgress] -) -> Dict[str, Any]: - return {k[1:]: v for k, v in e.__dict__.items() if k not in ["_jprogress", "_jdict"]} - - -def get_start_event_schema(): - return StructType( - [ - StructField("id", StringType(), False), - StructField("runId", StringType(), False), - StructField("name", StringType(), True), - StructField("timestamp", StringType(), False), - ] - ) - - -def get_idle_event_schema(): - return StructType( - [ - StructField("id", StringType(), False), - StructField("runId", StringType(), False), - StructField("timestamp", StringType(), False), - ] - ) - - -def get_terminated_event_schema(): - return StructType( - [ - StructField("id", StringType(), False), - StructField("runId", StringType(), False), - StructField("exception", StringType(), True), - StructField("errorClassOnException", StringType(), True), - ] - ) - - -def get_state_operators_progress_schema(): - return StructType( - [ - StructField("operatorName", StringType(), False), - StructField("numRowsTotal", IntegerType(), False), - StructField("numRowsUpdated", IntegerType(), False), - StructField("numRowsRemoved", IntegerType(), False), - StructField("allUpdatesTimeMs", IntegerType(), False), - StructField("allRemovalsTimeMs", IntegerType(), False), - StructField("commitTimeMs", IntegerType(), False), - StructField("memoryUsedBytes", IntegerType(), False), - StructField("numRowsDroppedByWatermark", IntegerType(), False), - StructField("numShufflePartitions", IntegerType(), False), - StructField("numStateStoreInstances", IntegerType(), False), - StructField("customMetrics", MapType(StringType(), IntegerType(), True), True), - ] - ) - - -def get_source_progress_schema(): - return StructType( - [ - StructField("description", StringType(), False), - StructField("startOffset", StringType(), False), - StructField("endOffset", StringType(), False), - StructField("latestOffset", StringType(), False), - StructField("numInputRows", IntegerType(), False), - StructField("inputRowsPerSecond", FloatType(), False), - StructField("processedRowsPerSecond", FloatType(), False), - StructField("metrics", MapType(StringType(), StringType(), True), True), - ] - ) - - -def get_sink_progress_schema(): - return StructType( - [ - StructField("description", StringType(), False), - StructField("numOutputRows", IntegerType(), False), - StructField("metrics", MapType(StringType(), StringType(), True), True), - ] - ) - - -def get_streaming_query_progress_schema(): - return StructType( - [ - StructField("id", StringType(), False), - StructField("runId", StringType(), False), - StructField("name", StringType(), True), - StructField("timestamp", StringType(), False), - StructField("batchId", IntegerType(), False), - StructField("batchDuration", IntegerType(), False), - StructField("durationMs", MapType(StringType(), IntegerType(), True), True), - StructField("eventTime", MapType(StringType(), StringType(), True), True), - StructField("stateOperators", ArrayType(get_state_operators_progress_schema()), True), - StructField("sources", ArrayType(get_source_progress_schema()), True), - StructField("sink", get_sink_progress_schema(), True), - StructField("numInputRows", IntegerType(), False), - StructField("inputRowsPerSecond", FloatType(), False), - StructField("processedRowsPerSecond", FloatType(), False), - # it's difficult to get the schema of observed metrics. - # Just serialize the row to string for now. - # Things would be easier if there is a method to get the schema of Row in PySpark - StructField("observedMetrics", MapType(StringType(), StringType()), True), - ] - ) - - -def get_progress_event_schema(): - return StructType([StructField("progress", get_streaming_query_progress_schema(), False)]) - - class TestListener(StreamingQueryListener): def onQueryStarted(self, event): + e = pyspark.cloudpickle.dumps(event) df = self.spark.createDataFrame( - data=[listener_event_as_dict(event)], - schema=get_start_event_schema(), + data=[(e,)], + schema=StructField("event", StringType(), False), ) df.write.mode("append").saveAsTable("listener_start_events") def onQueryProgress(self, event): + e = pyspark.cloudpickle.dumps(event) df = self.spark.createDataFrame( - data=[listener_event_as_dict(event)], - schema=get_progress_event_schema(), + data=[(e,)], + schema=StructField("event", StringType(), False), ) df.write.mode("append").saveAsTable("listener_progress_events") @@ -211,9 +69,10 @@ def onQueryIdle(self, event): pass def onQueryTerminated(self, event): + e = pyspark.cloudpickle.dumps(event) df = self.spark.createDataFrame( - data=[listener_event_as_dict(event)], - schema=get_terminated_event_schema(), + data=[(e,)], + schema=StructField("event", StringType(), False), ) df.write.mode("append").saveAsTable("listener_terminated_events") @@ -245,16 +104,16 @@ def test_listener_events(self): q.stop() self.assertFalse(q.isActive) - start_event = QueryStartedEvent.fromJson( - self.spark.read.table("listener_start_events").collect()[0].asDict() + start_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_start_events").collect()[0][0] ) - progress_event = QueryProgressEvent.fromJson( - self.spark.read.table("listener_progress_events").collect()[0].asDict() + progress_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_progress_events").collect()[0][0] ) - terminated_event = QueryTerminatedEvent.fromJson( - self.spark.read.table("listener_terminated_events").collect()[0].asDict() + terminated_event = pyspark.cloudpickle.loads( + self.spark.read.table("listener_terminated_events").collect()[0][0] ) self.check_start_event(start_event) From 8f656377b833a6127bfdb91a6c4221a8693fa54b Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 23 Aug 2023 13:49:17 -0700 Subject: [PATCH 11/11] remove unused import, simplify code --- python/pyspark/sql/streaming/listener.py | 4 +- .../connect/streaming/test_parity_listener.py | 40 ++----------------- 2 files changed, 6 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 4f94a45cdcb44..225ad6d45afb1 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -487,8 +487,8 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": inputRowsPerSecond=j["inputRowsPerSecond"], processedRowsPerSecond=j["processedRowsPerSecond"], observedMetrics={ - Row(*row.keys())(*row.values()) # Assume no nested rows - for k, row in j["observedMetrics"].items() + k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows + for k, row_dict in j["observedMetrics"].items() } if "observedMetrics" in j else {}, diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 059c064e712bf..5069a76cfdb73 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -17,33 +17,10 @@ import unittest import time -import uuid -import json -from typing import Any, Dict, Union import pyspark.cloudpickle from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin -from pyspark.sql.streaming.listener import ( - StreamingQueryListener, - QueryStartedEvent, - QueryProgressEvent, - QueryIdleEvent, - QueryTerminatedEvent, - StateOperatorProgress, - StreamingQueryProgress, - SourceProgress, - SinkProgress, -) -from pyspark.sql.types import ( - ArrayType, - StructType, - StructField, - StringType, - IntegerType, - FloatType, - MapType, -) -from pyspark.sql import Row +from pyspark.sql.streaming.listener import StreamingQueryListener from pyspark.sql.functions import count, lit from pyspark.testing.connectutils import ReusedConnectTestCase @@ -51,18 +28,12 @@ class TestListener(StreamingQueryListener): def onQueryStarted(self, event): e = pyspark.cloudpickle.dumps(event) - df = self.spark.createDataFrame( - data=[(e,)], - schema=StructField("event", StringType(), False), - ) + df = self.spark.createDataFrame(data=[(e,)]) df.write.mode("append").saveAsTable("listener_start_events") def onQueryProgress(self, event): e = pyspark.cloudpickle.dumps(event) - df = self.spark.createDataFrame( - data=[(e,)], - schema=StructField("event", StringType(), False), - ) + df = self.spark.createDataFrame(data=[(e,)]) df.write.mode("append").saveAsTable("listener_progress_events") def onQueryIdle(self, event): @@ -70,10 +41,7 @@ def onQueryIdle(self, event): def onQueryTerminated(self, event): e = pyspark.cloudpickle.dumps(event) - df = self.spark.createDataFrame( - data=[(e,)], - schema=StructField("event", StringType(), False), - ) + df = self.spark.createDataFrame(data=[(e,)]) df.write.mode("append").saveAsTable("listener_terminated_events")