diff --git a/packages/reflex-base/src/reflex_base/event/processor/event_processor.py b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py index 864d3f40d79..7d9296fe4dd 100644 --- a/packages/reflex-base/src/reflex_base/event/processor/event_processor.py +++ b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py @@ -10,9 +10,9 @@ import sys import time import traceback -from collections.abc import AsyncGenerator, Callable, Mapping, Sequence +from collections.abc import AsyncGenerator, Callable, Coroutine, Mapping, Sequence from contextvars import Token, copy_context -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar import rich.markup from typing_extensions import Self @@ -41,6 +41,41 @@ class QueueShutDown(Exception): # noqa: N818 """Exception raised when trying to put an item into a shut down queue.""" +_StreamItemT = TypeVar("_StreamItemT") + + +async def _stream_queue_until_done( + queue: asyncio.Queue[_StreamItemT], + done_when: Coroutine[Any, Any, Any], +) -> AsyncGenerator[_StreamItemT]: + """Yield items from ``queue`` until ``done_when`` completes. + + Items are yielded in the order they were enqueued. Completion of + ``done_when`` is signalled through the same queue via a private sentinel, + so items enqueued before completion are never lost to a race between + "watcher done" and "item available". + + Args: + queue: The queue to drain. + done_when: Coroutine whose completion marks end-of-stream. Wrapped in + a task owned by this helper and cancelled on exit. + + Yields: + Each item pulled from the queue, in order. + """ + end_of_stream: Any = object() + watcher = asyncio.create_task(done_when) + watcher.add_done_callback(lambda _: queue.put_nowait(end_of_stream)) + try: + while True: + item = await queue.get() + if item is end_of_stream: + return + yield item + finally: + watcher.cancel() + + @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) class EventQueueEntry: """An entry in the event queue.""" @@ -429,23 +464,13 @@ async def _emit_delta_impl( emit_delta_impl=_emit_delta_impl, ), ) - all_task_futures = asyncio.create_task(task_future.wait_all()) - waiting_for = {all_task_futures, asyncio.create_task(deltas.get())} + try: - while not all_task_futures.done() or not deltas.empty(): - with contextlib.suppress(asyncio.TimeoutError): - async for result in as_completed( - waiting_for, - timeout=1, - ): - waiting_for.remove(result) - if result is not all_task_futures: - yield await result - waiting_for.add(asyncio.create_task(deltas.get())) - break + async for delta in _stream_queue_until_done( + queue=deltas, done_when=task_future.wait_all() + ): + yield delta finally: - for future in waiting_for: - future.cancel() # Cancel the event chain if the streaming consumer exits early. if not task_future.done(): task_future.cancel() diff --git a/tests/units/reflex_base/event/processor/test_event_processor.py b/tests/units/reflex_base/event/processor/test_event_processor.py index de1ea4dcb23..bcc1108be98 100644 --- a/tests/units/reflex_base/event/processor/test_event_processor.py +++ b/tests/units/reflex_base/event/processor/test_event_processor.py @@ -6,7 +6,11 @@ import pytest from reflex_base.event.context import EventContext -from reflex_base.event.processor.event_processor import EventProcessor, QueueShutDown +from reflex_base.event.processor.event_processor import ( + EventProcessor, + QueueShutDown, + _stream_queue_until_done, +) from reflex_base.registry import RegistrationContext from reflex.event import Event, EventHandler @@ -64,6 +68,19 @@ async def _multi_delta_handler(): await asyncio.sleep(0.01) +async def _rapid_multi_delta_handler(): + """A handler that emits multiple deltas back-to-back with no intervening awaits. + + This is the pattern an ``@rx.event`` async-generator handler produces when + it yields once early (emitting an intermediate state delta) and then runs + to completion synchronously before the framework emits the final state + delta on its behalf — exactly what ``rx.upload``-driven handlers do. + """ + ctx = EventContext.get() + for i in range(2): + await ctx.emit_delta({"state": {"i": i}}) + + async def _slow_logging_handler(value: str = "default"): """A slow logging handler that pauses before recording. @@ -117,6 +134,7 @@ async def _background_slow_logging_handler(value: str = "default"): chaining_event = EventHandler(fn=_chaining_handler) delta_event = EventHandler(fn=_delta_handler) multi_delta_event = EventHandler(fn=_multi_delta_handler) +rapid_multi_delta_event = EventHandler(fn=_rapid_multi_delta_handler) slow_logging_event = EventHandler(fn=_slow_logging_handler) multi_chaining_event = EventHandler(fn=_multi_chaining_handler) background_slow_logging_event = EventHandler(fn=_background_slow_logging_handler) @@ -140,6 +158,7 @@ def _register_handlers(forked_registration_context: RegistrationContext): chaining_event, delta_event, multi_delta_event, + rapid_multi_delta_event, slow_logging_event, multi_chaining_event, background_slow_logging_event, @@ -496,6 +515,31 @@ async def test_stream_delta_yields_multiple_deltas(token: str): ] +async def test_stream_delta_yields_rapid_back_to_back_deltas(token: str): + """Regression: back-to-back deltas must not be dropped. + + When a handler emits multiple deltas without an intervening await that + yields to the event loop, the final delta could race with the handler's + own completion: during the ``as_completed`` tick that yields + ``all_task_futures``, a pending ``deltas.get()`` task in ``waiting_for`` + could silently consume the last queued delta. The outer loop would then + exit (``all_task_futures.done()`` and queue empty) and the ``finally`` + cancel would drop the delta held in that get-task's result. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + event = Event.from_event_type(rapid_multi_delta_event())[0] + collected = [d async for d in ep.enqueue_stream_delta(token, event)] + assert collected == [ + {"state": {"i": 0}}, + {"state": {"i": 1}}, + ] + + async def test_stream_delta_noop_handler_yields_nothing(token: str): """enqueue_stream_delta with a handler that emits no deltas yields nothing. @@ -590,3 +634,104 @@ async def test_sequential_chain_continues_after_error(token: str): with contextlib.suppress(Exception): await future.wait_all() assert _CALL_LOG == [{"value": "after_chain_error"}] + + +async def test_stream_queue_until_done_yields_in_order(): + """Items are yielded in enqueue order until the watcher completes.""" + queue: asyncio.Queue[int] = asyncio.Queue() + done_event = asyncio.Event() + + async def _watcher(): + await done_event.wait() + + async def _producer(): + for i in range(3): + await queue.put(i) + await asyncio.sleep(0) + done_event.set() + + producer = asyncio.create_task(_producer()) + collected = [v async for v in _stream_queue_until_done(queue, _watcher())] + await producer + assert collected == [0, 1, 2] + + +async def test_stream_queue_until_done_drains_items_put_before_completion(): + """Items enqueued before the watcher completes are never lost.""" + queue: asyncio.Queue[int] = asyncio.Queue() + + async def _watcher(): # noqa: RUF029 + # Fill the queue synchronously then return immediately so the + # watcher's done-callback fires in the same tick as the last put. + for i in range(5): + queue.put_nowait(i) + + collected = [v async for v in _stream_queue_until_done(queue, _watcher())] + assert collected == [0, 1, 2, 3, 4] + + +async def test_stream_queue_until_done_empty_when_watcher_completes_immediately(): + """Watcher completing with no items produces an empty stream.""" + queue: asyncio.Queue[int] = asyncio.Queue() + + async def _watcher(): # noqa: RUF029 + return + + collected = [v async for v in _stream_queue_until_done(queue, _watcher())] + assert collected == [] + + +async def test_stream_queue_until_done_watcher_exception_still_terminates(): + """A watcher that raises still signals end-of-stream cleanly.""" + queue: asyncio.Queue[int] = asyncio.Queue() + queue.put_nowait(42) + + async def _watcher(): # noqa: RUF029 + msg = "boom" + raise RuntimeError(msg) + + collected = [v async for v in _stream_queue_until_done(queue, _watcher())] + assert collected == [42] + + +async def test_stream_queue_until_done_cancels_watcher_on_early_exit(): + """Stopping iteration early cancels the in-flight watcher task.""" + queue: asyncio.Queue[int] = asyncio.Queue() + watcher_started = asyncio.Event() + watcher_cancelled = False + + async def _watcher(): + nonlocal watcher_cancelled + watcher_started.set() + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + watcher_cancelled = True + raise + + queue.put_nowait(1) + gen = _stream_queue_until_done(queue, _watcher()) + first = await gen.__anext__() + await watcher_started.wait() + await gen.aclose() + # Give the cancelled watcher a tick to observe the cancellation. + await asyncio.sleep(0) + assert first == 1 + assert watcher_cancelled + + +async def test_stream_queue_until_done_handles_concurrent_put_and_completion(): + """Race regression: item put and watcher completion in the same tick. + + This is the exact scenario that motivated the helper — an item arriving + at the queue in the same event-loop tick that the watcher resolves must + still be yielded, not silently dropped. + """ + queue: asyncio.Queue[int] = asyncio.Queue() + + async def _watcher(): # noqa: RUF029 + # No await — completes immediately after posting the final item. + queue.put_nowait(99) + + collected = [v async for v in _stream_queue_until_done(queue, _watcher())] + assert collected == [99] diff --git a/uv.lock b/uv.lock index 5365963e7e3..819f50d25b4 100644 --- a/uv.lock +++ b/uv.lock @@ -3868,14 +3868,22 @@ wheels = [ name = "reflex-site-shared" source = { editable = "packages/reflex-site-shared" } dependencies = [ + { name = "email-validator" }, + { name = "httpx" }, + { name = "pyyaml" }, { name = "reflex" }, { name = "reflex-components-internal" }, + { name = "ruff-format" }, ] [package.metadata] requires-dist = [ + { name = "email-validator" }, + { name = "httpx" }, + { name = "pyyaml" }, { name = "reflex", editable = "." }, { name = "reflex-components-internal", editable = "packages/reflex-components-internal" }, + { name = "ruff-format" }, ] [[package]]