diff --git a/python/docs/source/tutorial/sql/python_data_source.rst b/python/docs/source/tutorial/sql/python_data_source.rst index b3267405ffdd7..07f35722e73ff 100644 --- a/python/docs/source/tutorial/sql/python_data_source.rst +++ b/python/docs/source/tutorial/sql/python_data_source.rst @@ -309,7 +309,13 @@ This is the same dummy streaming reader that generates 2 rows every batch implem def read(self, start: dict) -> Tuple[Iterator[Tuple], dict]: """ Takes start offset as an input, return an iterator of tuples and - the start offset of next read. + the end offset (start offset for the next read). The end offset must + advance past the start offset when returning data; otherwise Spark + raises a validation exception. + For example, returning 2 records from start_idx 0 means end should + be {"offset": 2} (i.e. start + 2). + When there is no data to read, you may return the same offset as end and + start, but you must provide an empty iterator. """ start_idx = start["offset"] it = iter([(i,) for i in range(start_idx, start_idx + 2)]) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index ee35e237b8983..bbc4d005b490a 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -1185,6 +1185,11 @@ "SparkContext or SparkSession should be created first." ] }, + "SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE": { + "message": [ + "SimpleDataSourceStreamReader.read() returned a non-empty batch but the end offset: did not advance past the start offset: . The end offset must represent the position after the last record returned." + ] + }, "SLICE_WITH_STEP": { "message": [ "Slice with step is not supported." diff --git a/python/pyspark/sql/datasource_internal.py b/python/pyspark/sql/datasource_internal.py index 92a968cf05723..2ac6c280e822e 100644 --- a/python/pyspark/sql/datasource_internal.py +++ b/python/pyspark/sql/datasource_internal.py @@ -93,6 +93,30 @@ def getDefaultReadLimit(self) -> ReadLimit: # We do not consider providing different read limit on simple stream reader. return ReadAllAvailable() + def add_result_to_cache(self, start: dict, end: dict, it: Iterator[Tuple]) -> None: + """ + Validates that read() did not return a non-empty batch with end equal to start, + which would cause the same batch to be processed repeatedly. When end != start, + appends the result to the cache; when end == start with empty iterator, does not + cache (avoids unbounded cache growth). + """ + start_str = json.dumps(start) + end_str = json.dumps(end) + if end_str != start_str: + self.cache.append(PrefetchedCacheEntry(start, end, it)) + return + try: + next(it) + except StopIteration: + return + raise PySparkException( + errorClass="SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE", + messageParameters={ + "start_offset": start_str, + "end_offset": end_str, + }, + ) + def latestOffset(self, start: dict, limit: ReadLimit) -> dict: assert start is not None, "start offset should not be None" assert isinstance( @@ -100,7 +124,7 @@ def latestOffset(self, start: dict, limit: ReadLimit) -> dict: ), "simple stream reader does not support read limit" (iter, end) = self.simple_reader.read(start) - self.cache.append(PrefetchedCacheEntry(start, end, iter)) + self.add_result_to_cache(start, end, iter) return end def commit(self, end: dict) -> None: diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py b/python/pyspark/sql/tests/test_python_streaming_datasource.py index bef85f7ba8457..5f6aaf10fe01a 100644 --- a/python/pyspark/sql/tests/test_python_streaming_datasource.py +++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py @@ -41,6 +41,7 @@ have_pyarrow, pyarrow_requirement_message, ) +from pyspark.errors import PySparkException from pyspark.testing import assertDataFrameEqual from pyspark.testing.utils import eventually from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -509,6 +510,60 @@ def check_batch(df, batch_id): q.awaitTermination(timeout=30) self.assertIsNone(q.exception(), "No exception has to be propagated.") + def test_simple_stream_reader_offset_did_not_advance_raises(self): + """Validate that returning end == start with non-empty data raises SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE.""" + from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper + + class BuggySimpleStreamReader(SimpleDataSourceStreamReader): + def initialOffset(self): + return {"offset": 0} + + def read(self, start: dict): + # Bug: return same offset as end despite returning data + start_idx = start["offset"] + it = iter([(i,) for i in range(start_idx, start_idx + 3)]) + return (it, start) + + def readBetweenOffsets(self, start: dict, end: dict): + return iter([]) + + def commit(self, end: dict): + pass + + reader = BuggySimpleStreamReader() + wrapper = _SimpleStreamReaderWrapper(reader) + with self.assertRaises(PySparkException) as cm: + wrapper.latestOffset({"offset": 0}, ReadAllAvailable()) + self.assertEqual( + cm.exception.getCondition(), + "SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE", + ) + + def test_simple_stream_reader_empty_iterator_start_equals_end_allowed(self): + """When read() returns end == start with an empty iterator, no exception and no cache entry.""" + from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper + + class EmptyBatchReader(SimpleDataSourceStreamReader): + def initialOffset(self): + return {"offset": 0} + + def read(self, start: dict): + # Valid: same offset as end but empty iterator (no data) + return (iter([]), start) + + def readBetweenOffsets(self, start: dict, end: dict): + return iter([]) + + def commit(self, end: dict): + pass + + reader = EmptyBatchReader() + wrapper = _SimpleStreamReaderWrapper(reader) + start = {"offset": 0} + end = wrapper.latestOffset(start, ReadAllAvailable()) + self.assertEqual(end, start) + self.assertEqual(len(wrapper.cache), 0) + def test_stream_writer(self): input_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_input") output_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_output")