Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import sys
import time
import traceback
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from collections.abc import AsyncGenerator, Callable, Coroutine, Mapping, Sequence
from contextvars import Token, copy_context
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeVar

import rich.markup
from typing_extensions import Self
Expand Down Expand Up @@ -41,6 +41,41 @@ class QueueShutDown(Exception): # noqa: N818
"""Exception raised when trying to put an item into a shut down queue."""


_StreamItemT = TypeVar("_StreamItemT")


async def _stream_queue_until_done(
queue: asyncio.Queue[_StreamItemT],
done_when: Coroutine[Any, Any, Any],
) -> AsyncGenerator[_StreamItemT]:
"""Yield items from ``queue`` until ``done_when`` completes.

Items are yielded in the order they were enqueued. Completion of
``done_when`` is signalled through the same queue via a private sentinel,
so items enqueued before completion are never lost to a race between
"watcher done" and "item available".

Args:
queue: The queue to drain.
done_when: Coroutine whose completion marks end-of-stream. Wrapped in
a task owned by this helper and cancelled on exit.

Yields:
Each item pulled from the queue, in order.
"""
end_of_stream: Any = object()
watcher = asyncio.create_task(done_when)
watcher.add_done_callback(lambda _: queue.put_nowait(end_of_stream))
try:
while True:
item = await queue.get()
if item is end_of_stream:
return
yield item
finally:
watcher.cancel()


@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class EventQueueEntry:
"""An entry in the event queue."""
Expand Down Expand Up @@ -429,23 +464,13 @@ async def _emit_delta_impl(
emit_delta_impl=_emit_delta_impl,
),
)
all_task_futures = asyncio.create_task(task_future.wait_all())
waiting_for = {all_task_futures, asyncio.create_task(deltas.get())}

try:
while not all_task_futures.done() or not deltas.empty():
with contextlib.suppress(asyncio.TimeoutError):
async for result in as_completed(
waiting_for,
timeout=1,
):
waiting_for.remove(result)
if result is not all_task_futures:
yield await result
waiting_for.add(asyncio.create_task(deltas.get()))
break
async for delta in _stream_queue_until_done(
queue=deltas, done_when=task_future.wait_all()
):
yield delta
finally:
for future in waiting_for:
future.cancel()
# Cancel the event chain if the streaming consumer exits early.
if not task_future.done():
task_future.cancel()
Expand Down
147 changes: 146 additions & 1 deletion tests/units/reflex_base/event/processor/test_event_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import pytest
from reflex_base.event.context import EventContext
from reflex_base.event.processor.event_processor import EventProcessor, QueueShutDown
from reflex_base.event.processor.event_processor import (
EventProcessor,
QueueShutDown,
_stream_queue_until_done,
)
from reflex_base.registry import RegistrationContext

from reflex.event import Event, EventHandler
Expand Down Expand Up @@ -64,6 +68,19 @@ async def _multi_delta_handler():
await asyncio.sleep(0.01)


async def _rapid_multi_delta_handler():
"""A handler that emits multiple deltas back-to-back with no intervening awaits.

This is the pattern an ``@rx.event`` async-generator handler produces when
it yields once early (emitting an intermediate state delta) and then runs
to completion synchronously before the framework emits the final state
delta on its behalf — exactly what ``rx.upload``-driven handlers do.
"""
ctx = EventContext.get()
for i in range(2):
await ctx.emit_delta({"state": {"i": i}})


async def _slow_logging_handler(value: str = "default"):
"""A slow logging handler that pauses before recording.

Expand Down Expand Up @@ -117,6 +134,7 @@ async def _background_slow_logging_handler(value: str = "default"):
chaining_event = EventHandler(fn=_chaining_handler)
delta_event = EventHandler(fn=_delta_handler)
multi_delta_event = EventHandler(fn=_multi_delta_handler)
rapid_multi_delta_event = EventHandler(fn=_rapid_multi_delta_handler)
slow_logging_event = EventHandler(fn=_slow_logging_handler)
multi_chaining_event = EventHandler(fn=_multi_chaining_handler)
background_slow_logging_event = EventHandler(fn=_background_slow_logging_handler)
Expand All @@ -140,6 +158,7 @@ def _register_handlers(forked_registration_context: RegistrationContext):
chaining_event,
delta_event,
multi_delta_event,
rapid_multi_delta_event,
slow_logging_event,
multi_chaining_event,
background_slow_logging_event,
Expand Down Expand Up @@ -496,6 +515,31 @@ async def test_stream_delta_yields_multiple_deltas(token: str):
]


async def test_stream_delta_yields_rapid_back_to_back_deltas(token: str):
"""Regression: back-to-back deltas must not be dropped.

When a handler emits multiple deltas without an intervening await that
yields to the event loop, the final delta could race with the handler's
own completion: during the ``as_completed`` tick that yields
``all_task_futures``, a pending ``deltas.get()`` task in ``waiting_for``
could silently consume the last queued delta. The outer loop would then
exit (``all_task_futures.done()`` and queue empty) and the ``finally``
cancel would drop the delta held in that get-task's result.

Args:
token: The client token.
"""
ep = EventProcessor(graceful_shutdown_timeout=2)
ep.configure()
async with ep:
event = Event.from_event_type(rapid_multi_delta_event())[0]
collected = [d async for d in ep.enqueue_stream_delta(token, event)]
assert collected == [
{"state": {"i": 0}},
{"state": {"i": 1}},
]


async def test_stream_delta_noop_handler_yields_nothing(token: str):
"""enqueue_stream_delta with a handler that emits no deltas yields nothing.

Expand Down Expand Up @@ -590,3 +634,104 @@ async def test_sequential_chain_continues_after_error(token: str):
with contextlib.suppress(Exception):
await future.wait_all()
assert _CALL_LOG == [{"value": "after_chain_error"}]


async def test_stream_queue_until_done_yields_in_order():
"""Items are yielded in enqueue order until the watcher completes."""
queue: asyncio.Queue[int] = asyncio.Queue()
done_event = asyncio.Event()

async def _watcher():
await done_event.wait()

async def _producer():
for i in range(3):
await queue.put(i)
await asyncio.sleep(0)
done_event.set()

producer = asyncio.create_task(_producer())
collected = [v async for v in _stream_queue_until_done(queue, _watcher())]
await producer
assert collected == [0, 1, 2]


async def test_stream_queue_until_done_drains_items_put_before_completion():
"""Items enqueued before the watcher completes are never lost."""
queue: asyncio.Queue[int] = asyncio.Queue()

async def _watcher(): # noqa: RUF029
# Fill the queue synchronously then return immediately so the
# watcher's done-callback fires in the same tick as the last put.
for i in range(5):
queue.put_nowait(i)

collected = [v async for v in _stream_queue_until_done(queue, _watcher())]
assert collected == [0, 1, 2, 3, 4]


async def test_stream_queue_until_done_empty_when_watcher_completes_immediately():
"""Watcher completing with no items produces an empty stream."""
queue: asyncio.Queue[int] = asyncio.Queue()

async def _watcher(): # noqa: RUF029
return

collected = [v async for v in _stream_queue_until_done(queue, _watcher())]
assert collected == []


async def test_stream_queue_until_done_watcher_exception_still_terminates():
"""A watcher that raises still signals end-of-stream cleanly."""
queue: asyncio.Queue[int] = asyncio.Queue()
queue.put_nowait(42)

async def _watcher(): # noqa: RUF029
msg = "boom"
raise RuntimeError(msg)

collected = [v async for v in _stream_queue_until_done(queue, _watcher())]
assert collected == [42]


async def test_stream_queue_until_done_cancels_watcher_on_early_exit():
"""Stopping iteration early cancels the in-flight watcher task."""
queue: asyncio.Queue[int] = asyncio.Queue()
watcher_started = asyncio.Event()
watcher_cancelled = False

async def _watcher():
nonlocal watcher_cancelled
watcher_started.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
watcher_cancelled = True
raise

queue.put_nowait(1)
gen = _stream_queue_until_done(queue, _watcher())
first = await gen.__anext__()
await watcher_started.wait()
await gen.aclose()
# Give the cancelled watcher a tick to observe the cancellation.
await asyncio.sleep(0)
assert first == 1
assert watcher_cancelled


async def test_stream_queue_until_done_handles_concurrent_put_and_completion():
"""Race regression: item put and watcher completion in the same tick.

This is the exact scenario that motivated the helper — an item arriving
at the queue in the same event-loop tick that the watcher resolves must
still be yielded, not silently dropped.
"""
queue: asyncio.Queue[int] = asyncio.Queue()

async def _watcher(): # noqa: RUF029
# No await — completes immediately after posting the final item.
queue.put_nowait(99)

collected = [v async for v in _stream_queue_until_done(queue, _watcher())]
assert collected == [99]
8 changes: 8 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading