diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 2227a7391..e54696b28 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -507,7 +507,7 @@ def on_partition_get_start_offset(self, event): assert msg.seqno == current_offset + 1 current_offset += 2 - reader._reconnector._stream_reader._set_first_error(ydb.Unavailable("some retriable error")) + reader._reconnector._conn._set_first_error(ydb.Unavailable("some retriable error")) await asyncio.sleep(0) @@ -629,7 +629,7 @@ def on_partition_get_start_offset(self, event): assert msg.seqno == current_offset + 1 current_offset += 2 - reader._async_reader._reconnector._stream_reader._set_first_error(ydb.Unavailable("some retriable error")) + reader._async_reader._reconnector._conn._set_first_error(ydb.Unavailable("some retriable error")) msg = reader.receive_message() diff --git a/tests/topics/test_topic_transactions.py b/tests/topics/test_topic_transactions.py index 1263e7d49..03a451ec1 100644 --- a/tests/topics/test_topic_transactions.py +++ b/tests/topics/test_topic_transactions.py @@ -151,7 +151,7 @@ async def test_tx_commit_after_reconnect_does_not_commit_stale_offsets( assert batch.messages[0].data.decode() == "123" reconnector = reader._reconnector - old_stream = reconnector._stream_reader + old_stream = reconnector._conn with mock.patch.object( reconnector, @@ -163,10 +163,10 @@ async def test_tx_commit_after_reconnect_does_not_commit_stale_offsets( old_stream._set_first_error(ydb.issues.ConnectionLost("forced reconnect")) for _ in range(100): await asyncio.sleep(0.05) - current = reconnector._stream_reader + current = reconnector._conn if current is not None and current is not old_stream and current._started: break - assert reconnector._stream_reader is not old_stream + assert reconnector._conn is not old_stream # Committing the stale batch must fail loudly instead of silently # sending a gapped UpdateOffsetsInTransaction for the dead session. diff --git a/ydb/_topic_common/STREAM_DESIGN.md b/ydb/_topic_common/STREAM_DESIGN.md new file mode 100644 index 000000000..696b17320 --- /dev/null +++ b/ydb/_topic_common/STREAM_DESIGN.md @@ -0,0 +1,89 @@ +# Topic stream stack + +The topic reader and writer share their reconnect + bidi-stream lifecycle through two base +classes in this package. The concrete reader/writer reconnectors and stream objects are thin +subclasses. + +``` +StreamReconnector (abc, Generic[ConnT]) — WHEN to (re)connect + • the single reconnect loop, backoff, the reconnector-level fatal signal, close ordering + • _new_connection() (sync) + _handshake(conn) + _run(conn) + │ creates / drives + ▼ +StreamConnection (abc) — WHAT a bidi stream is + • owns the gRPC wrapper (built SYNC), connect() = start + init handshake, + the per-connection death signal (wait_error), the update-token loop, close() + ├── ReaderStream ("fat": read/decode loops, partition sessions, commits) + └── WriterAsyncIOStream ("thin": receive()/write(); send/read loops in the reconnector) + +ReaderReconnector(StreamReconnector["ReaderStream"]) +WriterAsyncIOReconnector(StreamReconnector["WriterAsyncIOStream"]) +``` + +## StreamReconnector — the reconnect loop + +```python +attempt = 0 +while not closed: + conn = self._new_connection() # SYNC: builds the connection, which owns its gRPC stream + try: + await self._handshake(conn) # open call + init handshake + self._conn = conn # publish only after a successful handshake + attempt = 0 # reset ONLY on a successful connect + await self._on_connected(conn) + self._state_changed.set() + await self._run(conn) # block until this connection dies + except BaseException as err: + if CancelledError and not closed: err = ConnectionLost(...) + info = self._classify_error(err, attempt) + if not info.is_retriable: + self._signal_fatal(err); return + await asyncio.sleep(info.sleep); attempt += 1 + finally: + await self._close_connection(conn, flush=False) # local `conn`: an interrupted + # handshake is still closed (no zombie) +``` + +Hooks: `_new_connection` (sync), `_handshake`, `_run`, `_close_connection`, `_terminal_error` +(required); `_on_connected`, `_on_fatal`, `_classify_error` (defaults). `_run(conn)` is +`await conn.wait_error()` for **both** reader and writer. + +## StreamConnection — the bidi stream + +Constructed synchronously (no network in `__init__`; it builds its `GrpcWrapperAsyncIO` +immediately). `connect(driver)` = `stream.start(...)` + `_init_and_spawn()`. Owns the +per-connection death signal (`_first_error` / `wait_error` / `_set_first_error`) and the shared +update-token loop. Hooks: `_init_and_spawn`, `_make_update_token_request`, `_on_first_error`. + +## Two error signals + +- **connection-level** `StreamConnection._first_error` (`wait_error`): *this stream* died → + `_run` returns → reconnect. +- **reconnector-level** `StreamReconnector._fatal` (`_signal_fatal`): the reconnector is + terminally done (non-retriable / `close()`) → surfaces to the public API. + +## Invariants + +1. **Single live stream, no zombie — structural.** Only the one loop assigns `_conn`, and the + `finally` closes the current connection before the next is created. Because `_new_connection()` + is synchronous and the connection owns its gRPC stream, the reconnector holds a closeable + reference *before* the first cancellable network await — so a cancel during the handshake + (e.g. `close()` mid-reconnect) always closes the stream instead of leaking it. No per-`create` + cleanup guard, and it covers the writer too. (`test_reconnect_handshake_cancel_closes_stream`.) +2. **Backoff grows.** `attempt` lives in the base loop and resets only on a successful connect. +3. **close ordering.** Mark closed, flush + close the live connection, wake waiters, cancel the + loop. The writer flushes *before* `super().close()` while `_closed` is still False (flushing + after would deadlock — the loop wouldn't bring up the connection the buffered writes need). + +## Reader/writer asymmetry (intentional) + +`_run` is symmetric, but the run-loops live in different places, dictated by **data ownership +across reconnects**: + +- the reader's per-stream state (partition sessions, read-ahead) *must die* on reconnect → it + lives on `ReaderStream`, with the read/decode loops; +- the writer's loops drive the unacked outbox (`_messages`, `_messages_future`, seqno dedup) that + *must survive* reconnect → it lives on the reconnector, with the send/read loops next to it. + +Moving the writer's loops onto the stream would require extracting that outbox into its own +object (the reconnector keeps it; each stream pumps it). Not done. diff --git a/ydb/_topic_common/_stream_connection.py b/ydb/_topic_common/_stream_connection.py new file mode 100644 index 000000000..8db0472c4 --- /dev/null +++ b/ydb/_topic_common/_stream_connection.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import abc +import asyncio +import logging +from typing import Any, Awaitable, Callable, Optional, Set, Union + +from .._grpc.grpcwrapper.common_utils import GrpcWrapperAsyncIO, IGrpcWrapperAsyncIO, SupportedDriverType + +logger = logging.getLogger(__name__) + + +class StreamConnection(abc.ABC): + """The thin bidi-stream lifecycle shared by the topic reader and writer streams. + + It owns the gRPC wrapper, performs ``connect`` = open call + init handshake, runs the + update-token loop, and is constructed SYNCHRONOUSLY (no network in ``__init__``). The + latter is what makes the reconnector's single-stream / no-zombie guarantee structural: + the connection object exists and owns its gRPC stream before the first cancellable + network await, so the reconnector can always close it on cancel. + + Subclasses add their protocol-specific message handling and implement the hooks below. + """ + + def __init__( + self, + *, + from_proto: Callable[[Any], Any], + stub: Any, + method: Any, + update_token_interval: Optional[Union[int, float]] = None, + get_token_function: Optional[Callable[[], Union[str, Awaitable[str]]]] = None, + ): + self._loop = asyncio.get_running_loop() + self._stub = stub + self._method = method + # Built (not started) here so the connection owns its transport before connect()'s first + # network await — that is what makes the no-zombie guarantee structural. The legacy + # _start(stream, ...) injection path overrides this with an externally provided stream. + self._stream: IGrpcWrapperAsyncIO = GrpcWrapperAsyncIO(from_proto) + self._background_tasks: Set[asyncio.Task] = set() + self._closed = False + self._first_error: asyncio.Future = self._loop.create_future() + self._update_token_interval = update_token_interval + self._get_token_function = get_token_function + self._update_token_event = asyncio.Event() + + async def connect(self, driver: SupportedDriverType) -> None: + """Open the gRPC call and run the init handshake on the already-owned stream.""" + await self._stream.start(driver, self._stub, self._method) # type: ignore[attr-defined] + await self._init_and_spawn() + + # ------------------------------------------------------------------ death signal + + def _set_first_error(self, err: BaseException) -> None: + """Record the first error that ended this connection; later errors are ignored.""" + if self._first_error.done(): + return + self._first_error.set_result(err) + self._on_first_error() + + def _get_first_error(self) -> Optional[BaseException]: + return self._first_error.result() if self._first_error.done() else None + + async def wait_error(self) -> None: + """Block until this connection fails, then raise that error (the reconnect signal).""" + raise await self._first_error + + def _on_first_error(self) -> None: + """Hook: wake local subscribers (e.g. the reader's wait_messages). Default: nothing.""" + + # ------------------------------------------------------------------ shared update-token loop + + async def _update_token_loop(self) -> None: + if self._update_token_interval is None: + return # nothing to refresh on a cadence; avoids a hot loop + while True: + await asyncio.sleep(self._update_token_interval) + if self._get_token_function is None: + return + token = self._get_token_function() + if not isinstance(token, str): + token = await token # async token providers are supported + await self._update_token(token) + + async def _update_token(self, token: str) -> None: + await self._update_token_event.wait() + try: + if self._stream is not None: + self._stream.write(self._make_update_token_request(token)) + finally: + self._update_token_event.clear() + + # ------------------------------------------------------------------ hooks + + @abc.abstractmethod + async def _init_and_spawn(self, init_message: Any = None) -> None: + """Send the init request on ``self._stream``, validate the response, and spawn the + protocol's background workers. ``init_message`` is supplied by the legacy ``_start`` + injection path; when ``None`` (the ``connect`` path) the subclass derives it itself.""" + + @abc.abstractmethod + def _make_update_token_request(self, token: str) -> Any: + """Wrap ``token`` in the protocol's FromClient(UpdateTokenRequest) message.""" diff --git a/ydb/_topic_common/_stream_reconnector.py b/ydb/_topic_common/_stream_reconnector.py new file mode 100644 index 000000000..6e6e3a2d0 --- /dev/null +++ b/ydb/_topic_common/_stream_reconnector.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import abc +import asyncio +import logging +from typing import Generic, Optional, Set, TypeVar + +from .. import issues +from .._errors import check_retriable_error, ErrorRetryInfo +from .._utilities import AtomicCounter +from ..retries import RetrySettings + +logger = logging.getLogger(__name__) + +ConnT = TypeVar("ConnT") + + +class StreamReconnector(abc.ABC, Generic[ConnT]): + """Connection lifecycle shared by the topic reader and writer. + + It owns exactly one background task — the reconnect loop — plus the fatal signal + and the state-change event consumers wait on. Everything about *when* to reconnect + (connect, run-until-error, classify, backoff, reconnect, teardown ordering) lives + here once; subclasses plug in the per-protocol parts through the hooks below. + + Subclasses must set up their own attributes BEFORE calling ``super().__init__()``, + because that call schedules the connection loop, which immediately uses ``_new_connection`` + and ``_handshake``. + """ + + _static_counter = AtomicCounter() + + def __init__( + self, + *, + retry_settings: RetrySettings, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + self._id = StreamReconnector._static_counter.inc_and_get() + self._retry_settings = retry_settings + self._loop = loop if loop is not None else asyncio.get_running_loop() + self._closed = False + self._conn: Optional[ConnT] = None + self._state_changed = asyncio.Event() + self._fatal: asyncio.Future = self._loop.create_future() + self._background_tasks: Set[asyncio.Task] = set() + self._background_tasks.add(asyncio.create_task(self._connection_loop())) + logger.debug("%s init id=%s", type(self).__name__, self._id) + + # ------------------------------------------------------------------ helpers for subclasses + + @property + def connection(self) -> Optional[ConnT]: + return self._conn + + def _signal_fatal(self, err: BaseException) -> None: + """Record the terminal error (first one wins), run teardown, wake every waiter.""" + if self._fatal.done(): + self._state_changed.set() + return + self._fatal.set_result(err) + self._on_fatal(err) + self._state_changed.set() + + def _fatal_error(self) -> Optional[BaseException]: + return self._fatal.result() if self._fatal.done() else None + + async def _wait_state_change(self) -> None: + await self._state_changed.wait() + self._state_changed.clear() + + # ------------------------------------------------------------------ the one reconnect loop + + async def _connection_loop(self) -> None: + attempt = 0 + while True: + if self._closed: + return + conn = None + try: + logger.debug("%s %s connect attempt %s", type(self).__name__, self._id, attempt) + # Construct synchronously and take ownership BEFORE the first cancellable network + # await, so a cancel during the handshake still reaches close() in the finally — + # this is what makes "one stream, no zombie" structural rather than a contract. + conn = self._new_connection() + await self._handshake(conn) + # Publish only after a successful handshake so consumers never observe a + # half-initialized connection. Teardown still uses the local `conn` in the + # finally, so an interrupted handshake is closed regardless (no zombie). + self._conn = conn + attempt = 0 # reset only on a successful connect — backoff grows across failures + await self._on_connected(conn) + self._state_changed.set() + await self._run(conn) + except BaseException as err: + if isinstance(err, asyncio.CancelledError): + if self._closed: + raise # let close() tear the loop down + # a cancelled gRPC call surfaces as CancelledError; treat it as a lost + # connection and reconnect instead of dying + err = issues.ConnectionLost("gRPC stream cancelled") + logger.debug("%s %s loop error: %s", type(self).__name__, self._id, err) + retry_info = self._classify_error(err, attempt) + if not retry_info.is_retriable: + self._signal_fatal(err) + return + await asyncio.sleep(retry_info.sleep_timeout_seconds or 0) + attempt += 1 + finally: + if conn is not None: + # noinspection PyBroadException + try: + await self._close_connection(conn, flush=False) + except asyncio.CancelledError: + # propagate cancellation (e.g. from close()) so the loop stops instead + # of swallowing it and reconnecting into a zombie connection + raise + except Exception: + pass # suppress any error while closing the dead connection + + async def close(self, flush: bool) -> None: + if self._closed: + return + logger.debug("%s %s close", type(self).__name__, self._id) + # Mark closed first so the loop won't bring up a new connection, then close the live + # one with the requested flush BEFORE cancelling the loop — cancelling first would let + # the finally close it with flush=False and skip the flush. + self._closed = True + if self._conn is not None: + await self._close_connection(self._conn, flush) + # Wake any pending waiter so it doesn't hang if the loop was mid-reconnect. + self._signal_fatal(self._terminal_error()) + for task in self._background_tasks: + task.cancel() + await asyncio.wait(self._background_tasks) + + # ------------------------------------------------------------------ hooks + + @abc.abstractmethod + def _new_connection(self) -> ConnT: + """Construct the connection object SYNCHRONOUSLY (no network). It must already own its + transport, so close() is safe even if the handshake is later cancelled.""" + + @abc.abstractmethod + async def _handshake(self, conn: ConnT) -> None: + """Open the call and run the init handshake on an already-constructed connection.""" + + async def _on_connected(self, conn: ConnT) -> None: + """Run right after a connection is established — restore state, capture init info, + re-send buffered work. Default: nothing.""" + + @abc.abstractmethod + async def _run(self, conn: ConnT) -> None: + """Block until this connection ends (it errors or is torn down).""" + + @abc.abstractmethod + async def _close_connection(self, conn: ConnT, flush: bool) -> None: + """Tear down a connection, optionally flushing its own per-stream state.""" + + @abc.abstractmethod + def _terminal_error(self) -> BaseException: + """Error used to wake waiters once close() is called.""" + + def _on_fatal(self, err: BaseException) -> None: + """Synchronous teardown run once, when the reconnector terminates (e.g. fail pending + writes). Default: nothing.""" + + def _classify_error(self, err: BaseException, attempt: int) -> ErrorRetryInfo: + """Decide whether ``err`` is retriable and how long to back off. Override to add + protocol-specific rules (e.g. the writer disables retries inside a transaction).""" + return check_retriable_error(err, self._retry_settings, attempt) diff --git a/ydb/_topic_common/common_test.py b/ydb/_topic_common/common_test.py index b31f9af9a..35eb37bd2 100644 --- a/ydb/_topic_common/common_test.py +++ b/ydb/_topic_common/common_test.py @@ -7,6 +7,10 @@ import pytest from .common import CallFromSyncToAsync +from ._stream_reconnector import StreamReconnector +from ._stream_connection import StreamConnection +from .test_helpers import wait_condition +from ..retries import RetrySettings from .._grpc.grpcwrapper.common_utils import ( GrpcWrapperAsyncIO, ServerStatus, @@ -288,3 +292,135 @@ def callback(): with pytest.raises(TestError): caller.call_sync(callback) assert callback_eventloop is separate_loop + + +class _FakeConn: + def __init__(self): + self.close_calls = 0 + self.close_raises = None + self._err = asyncio.Future() + + async def wait_error(self): + raise await self._err + + def fail(self, err): + if not self._err.done(): + self._err.set_result(err) + + async def close(self, flush=False): + self.close_calls += 1 + if self.close_raises is not None: + raise self.close_raises + + +class _FakeReconnector(StreamReconnector): + def __init__(self, conns): + self._conns = list(conns) + super().__init__(retry_settings=RetrySettings()) + + def _new_connection(self): + if not self._conns: + raise issues.Error("no more connections") + return self._conns.pop(0) + + async def _handshake(self, conn): + pass + + async def _run(self, conn): + await conn.wait_error() + + async def _close_connection(self, conn, flush): + await conn.close(flush) + + def _terminal_error(self): + return issues.Error("reconnector closed") + + +class _FakeConnection(StreamConnection): + async def _init_and_spawn(self, init_message=None): + pass + + def _make_update_token_request(self, token): + return ("update-token", token) + + +@pytest.mark.asyncio +class TestStreamReconnectorBase: + async def test_reconnect_then_fatal(self): + c1, c2 = _FakeConn(), _FakeConn() + r = _FakeReconnector([c1, c2]) + await wait_condition(lambda: r.connection is c1) + c1.fail(issues.Unavailable("retriable")) # reconnect + await wait_condition(lambda: r.connection is c2) + c2.fail(issues.Error("fatal")) # not retriable -> stop + await wait_condition(lambda: r._fatal_error() is not None) + assert isinstance(r._fatal_error(), issues.Error) + assert c1.close_calls >= 1 + await r.close(False) + + async def test_cancelled_error_reconnects(self): + c1, c2 = _FakeConn(), _FakeConn() + r = _FakeReconnector([c1, c2]) + await wait_condition(lambda: r.connection is c1) + c1.fail(asyncio.CancelledError()) # not closed -> ConnectionLost -> reconnect + await wait_condition(lambda: r.connection is c2) + await r.close(False) + + async def test_finally_swallows_close_error(self): + c1, c2 = _FakeConn(), _FakeConn() + c1.close_raises = RuntimeError("boom") + r = _FakeReconnector([c1, c2]) + await wait_condition(lambda: r.connection is c1) + c1.fail(issues.Unavailable("retriable")) + await wait_condition(lambda: r.connection is c2) + await r.close(False) + + async def test_close_without_connection_is_idempotent(self): + r = _FakeReconnector([]) # connect fails immediately, _conn stays None + await wait_condition(lambda: r._fatal_error() is not None) + await r.close(False) # _conn is None + await r.close(False) # already closed + + async def test_retriable_connect_failure_keeps_looping(self): + c = _FakeConn() + calls = {"n": 0} + + class R(_FakeReconnector): + def _new_connection(self): + calls["n"] += 1 + if calls["n"] == 1: + raise issues.Unavailable("retriable connect failure") # conn stays None, loop retries + return c + + r = R([]) + await wait_condition(lambda: r.connection is c) + await r.close(False) + + +@pytest.mark.asyncio +class TestStreamConnectionBase: + def _make(self, **kw): + return _FakeConnection(from_proto=lambda x: x, stub=None, method="m", **kw) + + async def test_update_token_loop_without_token_function(self): + # no interval -> guard returns immediately (no hot loop) + conn = self._make(update_token_interval=None, get_token_function=lambda: "t") + await conn._update_token_loop() + # interval set but no token function -> loop returns after the first tick + conn2 = self._make(update_token_interval=0, get_token_function=None) + await conn2._update_token_loop() + + async def test_update_token_coroutine_and_no_stream(self): + async def coro_token(): + return "T" + + conn = self._make(update_token_interval=0, get_token_function=coro_token) + conn._stream = None # exercise the "no stream" branch of _update_token + conn._update_token_event.set() + task = asyncio.create_task(conn._update_token_loop()) + await asyncio.sleep(0.02) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index ee988a1ec..f0a681290 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -5,9 +5,8 @@ import gzip import math import typing -from asyncio import Task from collections import defaultdict, OrderedDict -from typing import Optional, Set, Dict, Union, Callable +from typing import Optional, Dict, Callable import ydb from .. import _apis, issues @@ -18,6 +17,8 @@ from . import datatypes from . import events from . import topic_reader +from .._topic_common._stream_reconnector import StreamReconnector +from .._topic_common._stream_connection import StreamConnection from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, @@ -31,7 +32,6 @@ UpdateOffsetsInTransactionRequest, Codec, ) -from .._errors import check_retriable_error import logging from ..query.base import TxEvent @@ -205,19 +205,10 @@ def read_session_id(self) -> Optional[str]: return self._reconnector.read_session_id -class ReaderReconnector: - _static_reader_reconnector_counter = AtomicCounter() - - _id: int +class ReaderReconnector(StreamReconnector["ReaderStream"]): _settings: topic_reader.PublicReaderSettings _driver: Driver - _background_tasks: Set[Task] - - _state_changed: asyncio.Event - _stream_reader: Optional["ReaderStream"] - _first_error: asyncio.Future[YdbError] _tx_to_batches_map: Dict[str, typing.List[datatypes.PublicBatch]] - _closed: bool def __init__( self, @@ -225,87 +216,65 @@ def __init__( settings: topic_reader.PublicReaderSettings, loop: Optional[asyncio.AbstractEventLoop] = None, ): - self._id = ReaderReconnector._static_reader_reconnector_counter.inc_and_get() self._settings = settings self._driver = driver - self._loop = loop if loop is not None else asyncio.get_running_loop() - self._background_tasks = set() - logger.debug("init reader reconnector id=%s", self._id) + self._tx_to_batches_map = dict() + # StreamReconnector.__init__ schedules the reconnect loop, so our attributes + # (used by _new_connection) must already be set before this call. + super().__init__(retry_settings=settings._retry_settings(), loop=loop) - self._state_changed = asyncio.Event() - self._stream_reader = None - self._closed = False - self._background_tasks.add(asyncio.create_task(self._connection_loop())) - self._first_error = asyncio.get_running_loop().create_future() + # ---- reconnect hooks (StreamReconnector) ---- - self._tx_to_batches_map = dict() + def _new_connection(self) -> "ReaderStream": + creds = self._driver._credentials + return ReaderStream( + self._id, + self._settings, + get_token_function=creds.get_auth_token if creds else None, + ) - async def _connection_loop(self): - attempt = 0 - while True: - if self._closed: - return - try: - logger.debug("reader %s connect attempt %s", self._id, attempt) - self._stream_reader = await ReaderStream.create(self._id, self._driver, self._settings) - logger.debug("reader %s connected stream %s", self._id, self._stream_reader._id) - attempt = 0 - self._state_changed.set() - await self._stream_reader.wait_error() - except BaseException as err: - logger.debug("reader %s, attempt %s connection loop error %s", self._id, attempt, err) - retry_info = check_retriable_error(err, self._settings._retry_settings(), attempt) - if not retry_info.is_retriable: - logger.debug("reader %s stop connection loop due to %s", self._id, err) - self._set_first_error(err) - return + async def _handshake(self, conn: "ReaderStream") -> None: + await conn.connect(self._driver) - logger.debug("sleep before retry for %s seconds", retry_info.sleep_timeout_seconds) + async def _run(self, conn: "ReaderStream") -> None: + await conn.wait_error() - await asyncio.sleep(retry_info.sleep_timeout_seconds) + async def _close_connection(self, conn: "ReaderStream", flush: bool) -> None: + await conn.close(flush) - attempt += 1 - finally: - if self._stream_reader is not None: - # noinspection PyBroadException - try: - await self._stream_reader.close(flush=False) - except asyncio.CancelledError: - # propagate cancellation (e.g. from reader.close()) so the loop stops - # instead of swallowing it and reconnecting into a zombie stream - raise - except Exception: - # suppress any error on close stream reader - pass + def _terminal_error(self) -> BaseException: + return TopicReaderStreamClosedError() async def wait_message(self): while True: - if self._first_error.done(): - raise self._first_error.result() + if self._fatal.done(): + raise self._fatal.result() - if self._stream_reader: + if self._conn: try: - await self._stream_reader.wait_messages() + await self._conn.wait_messages() return except YdbError: pass # handle errors in reconnection loop - await self._state_changed.wait() - self._state_changed.clear() + await self._wait_state_change() def receive_batch_nowait(self, max_messages: Optional[int] = None): - if self._stream_reader is None: + if self._conn is None: return None - return self._stream_reader.receive_batch_nowait( + return self._conn.receive_batch_nowait( max_messages=max_messages, ) def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: Optional[int] = None): - if self._stream_reader is None: + if self._conn is None: return None - batch = self._stream_reader.receive_batch_nowait( + batch = self._conn.receive_batch_nowait( max_messages=max_messages, ) + if batch is None: + # the queue was drained concurrently between wait_message() and now; nothing to bind + return None self._init_tx(tx) @@ -319,7 +288,9 @@ def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: O return batch def receive_message_nowait(self): - return self._stream_reader.receive_message_nowait() + if self._conn is None: + return None + return self._conn.receive_message_nowait() def _init_tx(self, tx: "BaseQueryTxContext"): tx_id = tx.tx_id @@ -335,7 +306,7 @@ def _batch_partition_session_expired(self, batch: datatypes.PublicBatch) -> bool # A batch is expired if the reader reconnected after it was received: its partition # session no longer belongs to the current stream. Mirrors the guard in # ReaderStream.commit() for the non-transactional commit path. - stream = self._stream_reader + stream = self._conn partition_session = batch._partition_session return ( stream is None @@ -393,8 +364,8 @@ async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"): except BaseException: err = issues.ClientInternalError("Failed to update offsets in tx.") tx._set_external_error(err) - if self._stream_reader is not None: - self._stream_reader._set_first_error(err) + if self._conn is not None: + self._conn._set_first_error(err) finally: if tx_id in self._tx_to_batches_map: del self._tx_to_batches_map[tx_id] @@ -419,70 +390,40 @@ async def _handle_after_tx_rollback(self, tx: "BaseQueryTxContext", exc: Optiona if tx_id is not None and tx_id in self._tx_to_batches_map: del self._tx_to_batches_map[tx_id] err = issues.ClientInternalError("Reconnect due to transaction rollback") - if self._stream_reader is not None: - self._stream_reader._set_first_error(err) + if self._conn is not None: + self._conn._set_first_error(err) async def _handle_after_tx_commit(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None: tx_id = tx.tx_id if tx_id is not None and tx_id in self._tx_to_batches_map: del self._tx_to_batches_map[tx_id] - if exc is not None and self._stream_reader is not None: - self._stream_reader._set_first_error( - issues.ClientInternalError("Reconnect due to transaction commit failed") - ) + if exc is not None and self._conn is not None: + self._conn._set_first_error(issues.ClientInternalError("Reconnect due to transaction commit failed")) def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.CommitAckWaiter: - if self._stream_reader is None: + if self._conn is None: raise TopicReaderError("Stream reader is not connected") - return self._stream_reader.commit(batch) - - async def close(self, flush: bool): - logger.debug("reader reconnector %s close", self._id) - # Mark closed so the connection loop won't start a new stream, then close the - # current stream with the requested flush before cancelling the loop. On a normal - # close this flushes pending commits; cancelling the loop first would let it close - # the stream with flush=False instead and skip the flush. - self._closed = True - if self._stream_reader: - await self._stream_reader.close(flush) - # Wake any pending wait_message() waiter (e.g. a concurrent receive) so it doesn't - # hang if the loop was reconnecting when close() cancelled it. - self._set_first_error(TopicReaderStreamClosedError()) - for task in self._background_tasks: - task.cancel() - - await asyncio.wait(self._background_tasks) + return self._conn.commit(batch) async def flush(self): - if self._stream_reader: - await self._stream_reader.flush() - - def _set_first_error(self, err: issues.Error): - try: - self._first_error.set_result(err) - self._state_changed.set() - except asyncio.InvalidStateError: - # skip if already has result - pass + if self._conn: + await self._conn.flush() @property def read_session_id(self) -> Optional[str]: - if not self._stream_reader: + if not self._conn: return None - return self._stream_reader._session_id + return self._conn._session_id -class ReaderStream: +class ReaderStream(StreamConnection): _static_id_counter = AtomicCounter() - _loop: asyncio.AbstractEventLoop _id: int _reader_reconnector_id: int _session_id: str - _stream: Optional[IGrpcWrapperAsyncIO] _started: bool - _background_tasks: Set[asyncio.Task] _partition_sessions: Dict[int, datatypes.PartitionSession] _buffer_size_bytes: int # use for init request, then for debug purposes only _min_buffer_release_bytes: int @@ -496,13 +437,7 @@ class ReaderStream: _batches_to_decode: asyncio.Queue _state_changed: asyncio.Event - _closed: bool _message_batches: "OrderedDict[int, datatypes.PublicBatch]" # keys are partition session ID - _first_error: asyncio.Future[YdbError] - - _update_token_interval: Union[int, float] - _update_token_event: asyncio.Event - _get_token_function: Optional[Callable[[], str]] _settings: topic_reader.PublicReaderSettings def __init__( @@ -511,7 +446,15 @@ def __init__( settings: topic_reader.PublicReaderSettings, get_token_function: Optional[Callable[[], str]] = None, ): - self._loop = asyncio.get_running_loop() + # StreamConnection owns the gRPC wrapper (sync), _stream / _background_tasks / + # _closed / _update_token_* and _loop. + super().__init__( + from_proto=StreamReadMessage.FromServer.from_proto, + stub=_apis.TopicService.Stub, + method=_apis.TopicService.StreamRead, + update_token_interval=settings.update_token_interval, + get_token_function=get_token_function, + ) self._id = ReaderStream._static_id_counter.inc_and_get() self._reader_reconnector_id = reader_reconnector_id self._session_id = "not initialized" @@ -520,9 +463,7 @@ def __init__( self._id, self._session_id, ) - self._stream = None self._started = False - self._background_tasks = set() self._partition_sessions = dict() self._buffer_size_bytes = settings.buffer_size_bytes self._min_buffer_release_bytes = math.ceil(settings.buffer_size_bytes * settings.buffer_release_threshold) @@ -534,15 +475,9 @@ def __init__( self._decoders.update(settings.decoders) self._state_changed = asyncio.Event() - self._closed = False - self._first_error = asyncio.get_running_loop().create_future() self._batches_to_decode = asyncio.Queue() self._message_batches = OrderedDict() - self._update_token_interval = settings.update_token_interval - self._get_token_function = get_token_function - self._update_token_event = asyncio.Event() - self._settings = settings logger.debug("created ReaderStream id=%s reconnector=%s", self._id, self._reader_reconnector_id) @@ -579,16 +514,23 @@ async def create( return reader async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest): + # Legacy/injection entry point (kept for existing fixtures): adopt an externally + # provided stream, then run the shared connect path. + self._stream = stream + await self._init_and_spawn(init_message) + + async def _init_and_spawn(self, init_message: Optional[StreamReadMessage.InitRequest] = None): if self._started: raise TopicReaderError("Double start ReaderStream") self._started = True - self._stream = stream + if init_message is None: + init_message = self._settings._init_message() logger.debug("%s send init request", self._log_prefix) - stream.write(StreamReadMessage.FromClient(client_message=init_message)) + self._stream.write(StreamReadMessage.FromClient(client_message=init_message)) try: - init_response = await stream.receive( + init_response = await self._stream.receive( timeout=DEFAULT_INITIAL_RESPONSE_TIMEOUT ) # type: StreamReadMessage.FromServer except asyncio.TimeoutError: @@ -624,9 +566,6 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess errors_task.set_name("handle_background_errors") self._background_tasks.add(errors_task) - async def wait_error(self): - raise await self._first_error - async def wait_messages(self): while True: first_error = self._get_first_error() @@ -809,24 +748,8 @@ async def _read_messages_loop(self): self._set_first_error(e) return - async def _update_token_loop(self): - while True: - await asyncio.sleep(self._update_token_interval) - if self._get_token_function is None: - return - token = self._get_token_function() - if asyncio.iscoroutine(token): - token = await token - await self._update_token(token=token) - - async def _update_token(self, token: str): - await self._update_token_event.wait() - try: - msg = StreamReadMessage.FromClient(UpdateTokenRequest(token)) - if self._stream is not None: - self._stream.write(msg) - finally: - self._update_token_event.clear() + def _make_update_token_request(self, token: str) -> StreamReadMessage.FromClient: + return StreamReadMessage.FromClient(UpdateTokenRequest(token)) async def _on_start_partition_session(self, message: StreamReadMessage.StartPartitionSessionRequest): try: @@ -1020,18 +943,9 @@ async def _decode_batch_inplace(self, batch): batch._codec = Codec.CODEC_RAW - def _set_first_error(self, err: YdbError): - try: - self._first_error.set_result(err) - self._state_changed.set() - except asyncio.InvalidStateError: - # skip later set errors - pass - - def _get_first_error(self) -> Optional[YdbError]: - if self._first_error.done(): - return self._first_error.result() - return None + def _on_first_error(self) -> None: + # wake wait_messages() so it re-checks and raises the recorded error + self._state_changed.set() async def flush(self): futures = [] diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index ffc249479..e848d5f5e 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -12,10 +12,11 @@ from ydb import issues from . import datatypes, topic_reader_asyncio +from .._topic_common import _stream_connection from .datatypes import PublicBatch, PublicMessage from .topic_reader import PublicReaderSettings from .topic_reader_asyncio import ReaderStream, ReaderReconnector, TopicReaderError -from .._grpc.grpcwrapper.common_utils import SupportedDriverType, ServerStatus +from .._grpc.grpcwrapper.common_utils import ServerStatus from .._grpc.grpcwrapper.ydb_topic import ( StreamReadMessage, Codec, @@ -1542,11 +1543,9 @@ async def wait_forever(): stream_index = 0 - async def stream_create( - reader_reconnector_id: int, - driver: SupportedDriverType, - settings: PublicReaderSettings, - ): + def new_connection(self): + # the reconnect path is now _new_connection() (sync) + _handshake()=conn.connect(); + # conn.connect/wait_* are AsyncMocks via the ReaderStream spec. nonlocal stream_index stream_index += 1 if stream_index == 1: @@ -1556,9 +1555,13 @@ async def stream_create( else: raise Exception("unexpected create stream") - with mock.patch.object(ReaderStream, "create", stream_create): + with mock.patch.object(ReaderReconnector, "_new_connection", new_connection): reconnector = ReaderReconnector(mock.Mock(), PublicReaderSettings("", "")) - await wait_for_fast(reconnector.wait_message()) + try: + # generous window: the reconnect crosses one retry backoff sleep + await wait_for_fast(reconnector.wait_message(), timeout=5) + finally: + await reconnector.close(flush=False) reader_stream_mock_with_error.wait_error.assert_any_await() reader_stream_mock_with_error.wait_messages.assert_any_await() @@ -1594,12 +1597,12 @@ async def wait_forever(): create_calls = 0 - async def stream_create(reader_reconnector_id, driver, settings): + def new_connection(self): nonlocal create_calls create_calls += 1 return stream1 if create_calls == 1 else stream2 - with mock.patch.object(ReaderStream, "create", stream_create): + with mock.patch.object(ReaderReconnector, "_new_connection", new_connection): reconnector = ReaderReconnector(mock.Mock(), PublicReaderSettings("", "")) await asyncio.wait_for(finally_close_started.wait(), timeout=2) await asyncio.wait_for(reconnector.close(flush=False), timeout=5) @@ -1638,6 +1641,33 @@ async def start(self, driver, stub, method): assert built[0]._closed, "create() leaked the in-flight gRPC stream on cancel" + async def test_reconnect_handshake_cancel_closes_stream(self, default_reader_settings): + # Structural no-zombie guarantee provided by StreamReconnector: the connection (and its + # gRPC stream) is constructed synchronously and owned BEFORE the handshake's first network + # await, so cancelling the loop mid-handshake (reader.close()) closes the stream instead of + # leaking it — without any per-create cleanup guard. + built = [] + + class FakeStream(StreamMock): + def __init__(self, *args, **kwargs): + super().__init__() + built.append(self) + + async def start(self, driver, stub, method): + return None + + driver = mock.Mock() + driver._credentials = None + + with mock.patch.object(_stream_connection, "GrpcWrapperAsyncIO", FakeStream): + reconnector = ReaderReconnector(driver, default_reader_settings) + # the loop builds the connection (and its FakeStream) and parks in the init handshake + await wait_condition(lambda: bool(built) and not built[0].from_client.empty()) + assert not built[0]._closed + await asyncio.wait_for(reconnector.close(flush=False), timeout=5) + + assert built[0]._closed, "handshake cancel leaked the in-flight gRPC stream" + async def test_wait_error_returns_on_cancelled_error_from_receive(self, default_reader_settings): receive_call = 0 diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 5e55917d8..d71fd3707 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -6,7 +6,7 @@ import gzip import typing from collections import deque -from typing import Deque, AsyncIterator, Union, List, Optional, Dict, Callable +from typing import Deque, Union, List, Optional, Dict, Callable import logging @@ -26,13 +26,12 @@ PublicWriteResultTypes, Message, ) -from .. import ( - _apis, - issues, -) +from .. import _apis, issues from .._utilities import AtomicCounter -from .._errors import check_retriable_error +from .._errors import check_retriable_error, ErrorRetryInfo from ..retries import RetrySettings +from .._topic_common._stream_reconnector import StreamReconnector +from .._topic_common._stream_connection import StreamConnection from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._grpc.grpcwrapper.ydb_topic import ( UpdateTokenRequest, @@ -246,11 +245,7 @@ async def _on_before_rollback(self, tx: "BaseQueryTxContext"): await self.close(flush=False) -class WriterAsyncIOReconnector: - _static_id_counter = AtomicCounter() - - _closed: bool - _loop: asyncio.AbstractEventLoop +class WriterAsyncIOReconnector(StreamReconnector["WriterAsyncIOStream"]): _credentials: Union[ydb.credentials.Credentials, None] _driver: ydb.aio.Driver _init_message: StreamWriteMessage.InitRequest @@ -271,13 +266,7 @@ class WriterAsyncIOReconnector: _messages: Deque[InternalMessage] _messages_future: Deque[asyncio.Future] _new_messages: asyncio.Queue - _background_tasks: List[asyncio.Task] - _state_changed: asyncio.Event - if typing.TYPE_CHECKING: - _stop_reason: asyncio.Future[BaseException] - else: - _stop_reason: asyncio.Future _init_info: Optional[PublicWriterInitInfo] _buffer_bytes: int _buffer_messages: int @@ -286,9 +275,6 @@ class WriterAsyncIOReconnector: def __init__( self, driver: SupportedDriverType, settings: WriterSettings, tx: Optional["BaseQueryTxContext"] = None ): - self._closed = False - self._id = WriterAsyncIOReconnector._static_id_counter.inc_and_get() - self._loop = asyncio.get_running_loop() self._driver = driver # type: ignore[assignment] self._credentials = driver._credentials self._init_message = settings.create_init_request() @@ -328,46 +314,37 @@ def __init__( self._buffer_bytes = 0 self._buffer_messages = 0 self._buffer_updated = asyncio.Event() - self._stop_reason = self._loop.create_future() - connection_task = asyncio.create_task(self._connection_loop()) - connection_task.set_name("connection_loop") + + # StreamReconnector.__init__ schedules the reconnect loop and provides _loop / + # _state_changed / _fatal / _background_tasks; our attributes (used by the hooks) + # must already be set before this call. + super().__init__(retry_settings=RetrySettings(retry_cancelled=True), loop=None) + encode_task = asyncio.create_task(self._encode_loop()) encode_task.set_name("encode_loop") - self._background_tasks = [connection_task, encode_task] - - self._state_changed = asyncio.Event() + self._background_tasks.add(encode_task) logger.debug("init writer reconnector id=%s", self._id) async def close(self, flush: bool): if self._closed: return - self._closed = True - logger.debug("Close writer reconnector id=%s", self._id) - + # Flush BEFORE closing: while _closed is still False the reconnect loop is free to + # (re)connect and drain buffered messages. Flushing after _closed=True could deadlock, + # since the loop would refuse to bring up the connection the pending writes need. if flush: await self.flush() - - self._stop(TopicWriterStopped()) - - for task in self._background_tasks: - task.cancel() - await asyncio.wait(self._background_tasks) - - # if work was stopped before close by error - raise the error + await super().close(flush=False) + # if work was stopped before close by an error, surface it (but not the normal stop) try: self._check_stop() except TopicWriterStopped: pass - logger.debug("Writer reconnector id=%s was closed", self._id) async def wait_init(self) -> PublicWriterInitInfo: while True: - if self._stop_reason.done(): - exc = self._stop_reason.exception() - if exc is not None: - raise exc - raise TopicWriterError("Writer stopped without exception") + if self._fatal.done(): + raise self._fatal.result() if self._init_info: return self._init_info @@ -375,11 +352,9 @@ async def wait_init(self) -> PublicWriterInitInfo: await self._state_changed.wait() async def wait_stop(self) -> BaseException: - try: - await self._stop_reason - return TopicWriterError("Writer stopped without exception") - except BaseException as stop_reason: - return stop_reason + if not self._fatal.done(): + await self._fatal + return self._fatal.result() async def write_with_ack_future(self, messages: List[PublicMessage]) -> List[asyncio.Future]: self._check_stop() @@ -483,83 +458,58 @@ def _prepare_internal_messages(self, messages: List[PublicMessage]) -> List[Inte return res def _check_stop(self): - if self._stop_reason.done(): - raise self._stop_reason.exception() + if self._fatal.done(): + raise self._fatal.result() - async def _connection_loop(self): - retry_settings = RetrySettings(retry_cancelled=True) # todo + # ---- reconnect hooks (StreamReconnector) ---- - while True: - attempt = 0 # todo calc and reset - tasks = [] - - # noinspection PyBroadException - stream_writer = None - try: - logger.debug("writer reconnector %s connect attempt %s", self._id, attempt) - tx_identity = None if self._tx is None else self._tx._tx_identity() - stream_writer = await WriterAsyncIOStream.create( - self._driver, - self._init_message, - self._settings.update_token_interval, - tx_identity=tx_identity, - ) - logger.debug( - "writer reconnector %s connected stream %s", - self._id, - stream_writer._id, - ) - try: - if self._init_info is None: - self._last_known_seq_no = stream_writer.last_seqno - self._init_info = PublicWriterInitInfo( - last_seqno=stream_writer.last_seqno, - supported_codecs=stream_writer.supported_codecs, - ) - self._state_changed.set() - - except asyncio.InvalidStateError: - pass + def _new_connection(self) -> "WriterAsyncIOStream": + creds = self._driver._credentials + tx_identity = None if self._tx is None else self._tx._tx_identity() + return WriterAsyncIOStream( + update_token_interval=self._settings.update_token_interval, + get_token_function=creds.get_auth_token if creds else (lambda: ""), + tx_identity=tx_identity, + init_request=self._init_message, + ) - self._stream_connected.set() + async def _handshake(self, conn: "WriterAsyncIOStream") -> None: + await conn.connect(self._driver) - send_loop = asyncio.create_task(self._send_loop(stream_writer)) - send_loop.set_name("writer send loop") - receive_loop = asyncio.create_task(self._read_loop(stream_writer)) - receive_loop.set_name("writer receive loop") + async def _on_connected(self, conn: "WriterAsyncIOStream") -> None: + if self._init_info is None: + self._last_known_seq_no = conn.last_seqno + self._init_info = PublicWriterInitInfo( + last_seqno=conn.last_seqno, + supported_codecs=conn.supported_codecs or [], + ) + self._state_changed.set() + self._stream_connected.set() + + async def _run(self, conn: "WriterAsyncIOStream") -> None: + # Symmetric with the reader: the send/read loops funnel their terminal condition into + # conn._first_error, and we just wait for the connection to fail (or be torn down). + send_loop = asyncio.create_task(self._send_loop(conn)) + send_loop.set_name("writer send loop") + receive_loop = asyncio.create_task(self._read_loop(conn)) + receive_loop.set_name("writer receive loop") + try: + await conn.wait_error() + finally: + for task in (send_loop, receive_loop): + task.cancel() + await asyncio.wait([send_loop, receive_loop]) - tasks = [send_loop, receive_loop] - done, _ = await asyncio.wait([send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED) - done.pop().result() # need for raise exception - reason of stop task - except (asyncio.CancelledError, issues.Error) as err: - if isinstance(err, asyncio.CancelledError): - if self._closed: - return - err = issues.ConnectionLost("gRPC stream cancelled") + async def _close_connection(self, conn: "WriterAsyncIOStream", flush: bool) -> None: + await conn.close() - err_info = check_retriable_error(err, retry_settings, attempt) - if not err_info.is_retriable or self._tx is not None: # no retries in tx writer - logger.debug("writer reconnector %s stop connection loop due to %s", self._id, err) - self._stop(err) - return + def _terminal_error(self) -> BaseException: + return TopicWriterStopped() - logger.debug( - "writer reconnector %s retry in %s seconds", - self._id, - err_info.sleep_timeout_seconds, - ) - await asyncio.sleep(err_info.sleep_timeout_seconds) - - except Exception as err: - self._stop(err) - return - finally: - for task in tasks: - task.cancel() - if tasks: - await asyncio.wait(tasks) - if stream_writer: - await stream_writer.close() + def _classify_error(self, err: BaseException, attempt: int) -> ErrorRetryInfo: + if self._tx is not None: + return ErrorRetryInfo(False, None) # no retries in a tx writer + return check_retriable_error(err, self._retry_settings, attempt) async def _encode_loop(self): try: @@ -682,16 +632,22 @@ def select_codec() -> PublicCodec: codec = await loop.run_in_executor(self._encode_executor, select_codec) return codec - async def _read_loop(self, writer: "WriterAsyncIOStream"): - while True: - resp = await writer.receive() - - logger.debug("writer reconnector %s received %s acks", self._id, len(resp.acks)) - - for ack in resp.acks: - self._handle_receive_ack(ack) - - logger.debug("writer reconnector %s handled %s acks", self._id, len(resp.acks)) + async def _read_loop(self, conn: "WriterAsyncIOStream"): + try: + while True: + resp = await conn.receive() + logger.debug("writer reconnector %s received %s acks", self._id, len(resp.acks)) + for ack in resp.acks: + self._handle_receive_ack(ack) + logger.debug("writer reconnector %s handled %s acks", self._id, len(resp.acks)) + except asyncio.CancelledError: + # gRPC stream death surfaces here as CancelledError; record it as a lost connection + # so wait_error() fires and the reconnector reconnects. If this is our own teardown + # cancel, _first_error is already set and this is a no-op. + conn._set_first_error(issues.ConnectionLost("gRPC stream cancelled")) + raise + except BaseException as e: + conn._set_first_error(e) def _handle_receive_ack(self, ack): current_message = self._messages.popleft() @@ -755,28 +711,28 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"): batch[-1].seq_no, ) except asyncio.CancelledError: - # the loop task cancelled be parent code, for example for reconnection + # the loop task cancelled by parent code, for example for reconnection # no need to stop all work. raise except BaseException as e: - self._stop(e) - raise + # report the failure through the connection's death signal; the reconnector + # classifies it (retriable -> reconnect, otherwise -> stop). + writer._set_first_error(e) def _stop(self, reason: BaseException): if reason is None: raise Exception("writer stop reason can not be None") + self._signal_fatal(reason) - if self._stop_reason.done(): - return - - self._stop_reason.set_exception(reason) - + def _on_fatal(self, reason: BaseException) -> None: + # StreamReconnector calls this once, when the writer terminates. for f in self._messages_future: + if f.done(): + continue # already resolved or cancelled by the caller — don't raise InvalidStateError f.set_exception(reason) f.exception() # mark as retrieved so asyncio does not log "Future exception was never retrieved" self._buffer_updated.set() # wake any tasks blocked in _acquire_buffer_space - self._state_changed.set() logger.info("Stop topic writer %s: %s" % (self._id, reason)) async def flush(self): @@ -787,24 +743,13 @@ async def flush(self): await asyncio.wait(self._messages_future) -class WriterAsyncIOStream: +class WriterAsyncIOStream(StreamConnection): _static_id_counter = AtomicCounter() - # todo slots - _closed: bool - last_seqno: int supported_codecs: Optional[List[PublicCodec]] - _stream: IGrpcWrapperAsyncIO - _requests: asyncio.Queue - _responses: AsyncIterator - - _update_token_interval: Optional[Union[int, float]] _update_token_task: Optional[asyncio.Task] - _update_token_event: asyncio.Event - _get_token_function: Optional[Callable[[], str]] - _tx_identity: Optional[TransactionIdentity] def __init__( @@ -812,16 +757,20 @@ def __init__( update_token_interval: Optional[Union[int, float]] = None, get_token_function: Optional[Callable[[], str]] = None, tx_identity: Optional[TransactionIdentity] = None, + init_request: Optional[StreamWriteMessage.InitRequest] = None, ): - self._closed = False + # StreamConnection owns the gRPC wrapper (sync), _stream / _closed / _update_token_*. + super().__init__( + from_proto=StreamWriteMessage.FromServer.from_proto, + stub=_apis.TopicService.Stub, + method=_apis.TopicService.StreamWrite, + update_token_interval=update_token_interval, + get_token_function=get_token_function, + ) self._id = WriterAsyncIOStream._static_id_counter.inc_and_get() - - self._update_token_interval = update_token_interval - self._get_token_function = get_token_function - self._update_token_event = asyncio.Event() self._update_token_task = None - self._tx_identity = tx_identity + self._init_request = init_request async def close(self): if self._closed: @@ -875,11 +824,20 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: raise Exception("Unknown message while read writer answers: %s" % item) async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest): + # Legacy/injection entry point (kept for existing fixtures): adopt an externally + # provided stream, then run the shared connect path. + self._stream = stream + await self._init_and_spawn(init_message) + + async def _init_and_spawn(self, init_message: Optional[StreamWriteMessage.InitRequest] = None): + if init_message is None: + init_message = self._init_request + assert init_message is not None logger.debug("writer stream %s send init request", self._id) - stream.write(StreamWriteMessage.FromClient(init_message)) + self._stream.write(StreamWriteMessage.FromClient(init_message)) try: - resp = await stream.receive(timeout=DEFAULT_INITIAL_RESPONSE_TIMEOUT) + resp = await self._stream.receive(timeout=DEFAULT_INITIAL_RESPONSE_TIMEOUT) except asyncio.TimeoutError: raise TopicWriterError("Timeout waiting for init response") @@ -895,8 +853,6 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMes self.last_seqno, ) - self._stream = stream - if self._update_token_interval is not None: self._update_token_event.set() self._update_token_task = asyncio.create_task(self._update_token_loop()) @@ -916,20 +872,5 @@ def write(self, messages: List[InternalMessage]): for request in messages_to_proto_requests(messages, self._tx_identity): self._stream.write(request) - async def _update_token_loop(self): - while True: - await asyncio.sleep(self._update_token_interval) - token = self._get_token_function() - if asyncio.iscoroutine(token): - token = await token - logger.debug("writer stream %s update token", self._id) - await self._update_token(token=token) - - async def _update_token(self, token: str): - await self._update_token_event.wait() - try: - msg = StreamWriteMessage.FromClient(UpdateTokenRequest(token)) - self._stream.write(msg) - logger.debug("writer stream %s token sent", self._id) - finally: - self._update_token_event.clear() + def _make_update_token_request(self, token: str) -> StreamWriteMessage.FromClient: + return StreamWriteMessage.FromClient(UpdateTokenRequest(token)) diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 848632044..691e662e0 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -34,7 +34,8 @@ TopicWriterBufferFullError, ) from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec -from .._topic_common.test_helpers import StreamMock, wait_for_fast +from .._topic_common.test_helpers import StreamMock, wait_for_fast, wait_condition +from .._topic_common import _stream_connection from .topic_writer_asyncio import ( WriterAsyncIOStream, @@ -309,6 +310,14 @@ def __init__( self.from_client = asyncio.Queue() self._closed = False self.supported_codecs = [] + self._first_error = asyncio.Future() + + def _set_first_error(self, err): + if not self._first_error.done(): + self._first_error.set_result(err) + + async def wait_error(self): + raise await self._first_error def write(self, messages: typing.List[InternalMessage]): if self._closed: @@ -330,6 +339,11 @@ async def close(self): return self._closed = True + async def connect(self, driver=None): + # new connect path: _new_connection() returns this mock (already "connected"), + # _handshake() awaits connect(); nothing to do for the double. + pass + @pytest.fixture(autouse=True) async def stream_writer_double_queue(self, monkeypatch): class DoubleQueueWriters: @@ -362,10 +376,10 @@ def _create(self): res = DoubleQueueWriters() - async def async_create(driver, init_message, token_getter, tx_identity): + def new_connection(self): return res.get_first() - monkeypatch.setattr(WriterAsyncIOStream, "create", async_create) + monkeypatch.setattr(WriterAsyncIOReconnector, "_new_connection", new_connection) return res @pytest.fixture @@ -471,7 +485,7 @@ async def receive(self): raise asyncio.CancelledError() await asyncio.Future() # stream 2 stays alive - async def create_mock(*args, **kwargs): + def new_connection(self): nonlocal stream_creates stream_creates += 1 writer = StreamWriterCancelOnFirstReceive() @@ -480,7 +494,7 @@ async def create_mock(*args, **kwargs): stream_2_created.set() return writer - with mock.patch.object(WriterAsyncIOStream, "create", create_mock): + with mock.patch.object(WriterAsyncIOReconnector, "_new_connection", new_connection): reconnector = WriterAsyncIOReconnector(default_driver, default_settings) try: # Bug: stream 2 is never created — _stop(CancelledError) kills the writer permanently. @@ -937,6 +951,45 @@ async def test_custom_encoder(self, default_driver, default_settings, get_stream await reconnector.close(flush=False) +@pytest.mark.asyncio +class TestWriterStreamConnection: + async def test_reconnect_handshake_cancel_closes_stream(self): + # Structural no-zombie guarantee: the writer connection (and its gRPC stream) is owned + # before the handshake's first network await, so cancelling the loop mid-handshake + # (writer.close()) closes the stream instead of leaking it. Defined before TestWriterAsyncIO + # so it runs before that class's autouse WriterAsyncIOReconnector.__new__ patch. + built = [] + + class FakeStream(StreamMock): + def __init__(self, *args, **kwargs): + super().__init__() + built.append(self) + + async def start(self, driver, stub, method): + return None + + settings = WriterSettings( + PublicWriterSettings( + topic="/local/topic", + producer_id="test-producer", + auto_seqno=False, + auto_created_at=False, + codec=PublicCodec.RAW, + update_token_interval=3600, + ) + ) + driver = mock.Mock() + driver._credentials = None + + with mock.patch.object(_stream_connection, "GrpcWrapperAsyncIO", FakeStream): + reconnector = WriterAsyncIOReconnector(driver, settings) + await wait_condition(lambda: bool(built) and not built[0].from_client.empty()) + assert not built[0]._closed + await asyncio.wait_for(reconnector.close(False), timeout=5) + + assert built[0]._closed, "handshake cancel leaked the in-flight gRPC stream" + + @pytest.mark.asyncio class TestWriterAsyncIO: class ReconnectorMock: