Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/topics/test_topic_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def on_partition_get_start_offset(self, event):
assert msg.seqno == current_offset + 1

current_offset += 2
reader._reconnector._stream_reader._set_first_error(ydb.Unavailable("some retriable error"))
reader._reconnector._conn._set_first_error(ydb.Unavailable("some retriable error"))

await asyncio.sleep(0)

Expand Down Expand Up @@ -629,7 +629,7 @@ def on_partition_get_start_offset(self, event):
assert msg.seqno == current_offset + 1

current_offset += 2
reader._async_reader._reconnector._stream_reader._set_first_error(ydb.Unavailable("some retriable error"))
reader._async_reader._reconnector._conn._set_first_error(ydb.Unavailable("some retriable error"))

msg = reader.receive_message()

Expand Down
6 changes: 3 additions & 3 deletions tests/topics/test_topic_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def test_tx_commit_after_reconnect_does_not_commit_stale_offsets(
assert batch.messages[0].data.decode() == "123"

reconnector = reader._reconnector
old_stream = reconnector._stream_reader
old_stream = reconnector._conn

with mock.patch.object(
reconnector,
Expand All @@ -163,10 +163,10 @@ async def test_tx_commit_after_reconnect_does_not_commit_stale_offsets(
old_stream._set_first_error(ydb.issues.ConnectionLost("forced reconnect"))
for _ in range(100):
await asyncio.sleep(0.05)
current = reconnector._stream_reader
current = reconnector._conn
if current is not None and current is not old_stream and current._started:
break
assert reconnector._stream_reader is not old_stream
assert reconnector._conn is not old_stream

# Committing the stale batch must fail loudly instead of silently
# sending a gapped UpdateOffsetsInTransaction for the dead session.
Expand Down
89 changes: 89 additions & 0 deletions ydb/_topic_common/STREAM_DESIGN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Topic stream stack

The topic reader and writer share their reconnect + bidi-stream lifecycle through two base
classes in this package. The concrete reader/writer reconnectors and stream objects are thin
subclasses.

```
StreamReconnector (abc, Generic[ConnT]) — WHEN to (re)connect
• the single reconnect loop, backoff, the reconnector-level fatal signal, close ordering
• _new_connection() (sync) + _handshake(conn) + _run(conn)
│ creates / drives
StreamConnection (abc) — WHAT a bidi stream is
• owns the gRPC wrapper (built SYNC), connect() = start + init handshake,
the per-connection death signal (wait_error), the update-token loop, close()
├── ReaderStream ("fat": read/decode loops, partition sessions, commits)
└── WriterAsyncIOStream ("thin": receive()/write(); send/read loops in the reconnector)

ReaderReconnector(StreamReconnector["ReaderStream"])
WriterAsyncIOReconnector(StreamReconnector["WriterAsyncIOStream"])
```

## StreamReconnector — the reconnect loop

```python
attempt = 0
while not closed:
conn = self._new_connection() # SYNC: builds the connection, which owns its gRPC stream
try:
await self._handshake(conn) # open call + init handshake
self._conn = conn # publish only after a successful handshake
attempt = 0 # reset ONLY on a successful connect
await self._on_connected(conn)
Comment thread
Copilot marked this conversation as resolved.
self._state_changed.set()
await self._run(conn) # block until this connection dies
except BaseException as err:
if CancelledError and not closed: err = ConnectionLost(...)
info = self._classify_error(err, attempt)
if not info.is_retriable:
self._signal_fatal(err); return
await asyncio.sleep(info.sleep); attempt += 1
finally:
await self._close_connection(conn, flush=False) # local `conn`: an interrupted
# handshake is still closed (no zombie)
```

Hooks: `_new_connection` (sync), `_handshake`, `_run`, `_close_connection`, `_terminal_error`
(required); `_on_connected`, `_on_fatal`, `_classify_error` (defaults). `_run(conn)` is
`await conn.wait_error()` for **both** reader and writer.

## StreamConnection — the bidi stream

Constructed synchronously (no network in `__init__`; it builds its `GrpcWrapperAsyncIO`
immediately). `connect(driver)` = `stream.start(...)` + `_init_and_spawn()`. Owns the
per-connection death signal (`_first_error` / `wait_error` / `_set_first_error`) and the shared
update-token loop. Hooks: `_init_and_spawn`, `_make_update_token_request`, `_on_first_error`.

## Two error signals

- **connection-level** `StreamConnection._first_error` (`wait_error`): *this stream* died →
`_run` returns → reconnect.
- **reconnector-level** `StreamReconnector._fatal` (`_signal_fatal`): the reconnector is
terminally done (non-retriable / `close()`) → surfaces to the public API.

## Invariants

1. **Single live stream, no zombie — structural.** Only the one loop assigns `_conn`, and the
`finally` closes the current connection before the next is created. Because `_new_connection()`
is synchronous and the connection owns its gRPC stream, the reconnector holds a closeable
reference *before* the first cancellable network await — so a cancel during the handshake
(e.g. `close()` mid-reconnect) always closes the stream instead of leaking it. No per-`create`
cleanup guard, and it covers the writer too. (`test_reconnect_handshake_cancel_closes_stream`.)
2. **Backoff grows.** `attempt` lives in the base loop and resets only on a successful connect.
3. **close ordering.** Mark closed, flush + close the live connection, wake waiters, cancel the
loop. The writer flushes *before* `super().close()` while `_closed` is still False (flushing
after would deadlock — the loop wouldn't bring up the connection the buffered writes need).

## Reader/writer asymmetry (intentional)

`_run` is symmetric, but the run-loops live in different places, dictated by **data ownership
across reconnects**:

- the reader's per-stream state (partition sessions, read-ahead) *must die* on reconnect → it
lives on `ReaderStream`, with the read/decode loops;
- the writer's loops drive the unacked outbox (`_messages`, `_messages_future`, seqno dedup) that
*must survive* reconnect → it lives on the reconnector, with the send/read loops next to it.

Moving the writer's loops onto the stream would require extracting that outbox into its own
object (the reconnector keeps it; each stream pumps it). Not done.
104 changes: 104 additions & 0 deletions ydb/_topic_common/_stream_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

import abc
import asyncio
import logging
from typing import Any, Awaitable, Callable, Optional, Set, Union

from .._grpc.grpcwrapper.common_utils import GrpcWrapperAsyncIO, IGrpcWrapperAsyncIO, SupportedDriverType

logger = logging.getLogger(__name__)


class StreamConnection(abc.ABC):
"""The thin bidi-stream lifecycle shared by the topic reader and writer streams.

It owns the gRPC wrapper, performs ``connect`` = open call + init handshake, runs the
update-token loop, and is constructed SYNCHRONOUSLY (no network in ``__init__``). The
latter is what makes the reconnector's single-stream / no-zombie guarantee structural:
the connection object exists and owns its gRPC stream before the first cancellable
network await, so the reconnector can always close it on cancel.

Subclasses add their protocol-specific message handling and implement the hooks below.
"""

def __init__(
self,
*,
from_proto: Callable[[Any], Any],
stub: Any,
method: Any,
update_token_interval: Optional[Union[int, float]] = None,
get_token_function: Optional[Callable[[], Union[str, Awaitable[str]]]] = None,
):
self._loop = asyncio.get_running_loop()
self._stub = stub
self._method = method
# Built (not started) here so the connection owns its transport before connect()'s first
# network await — that is what makes the no-zombie guarantee structural. The legacy
# _start(stream, ...) injection path overrides this with an externally provided stream.
self._stream: IGrpcWrapperAsyncIO = GrpcWrapperAsyncIO(from_proto)
self._background_tasks: Set[asyncio.Task] = set()
Comment on lines +34 to +41
self._closed = False
self._first_error: asyncio.Future = self._loop.create_future()
self._update_token_interval = update_token_interval
self._get_token_function = get_token_function
self._update_token_event = asyncio.Event()

async def connect(self, driver: SupportedDriverType) -> None:
"""Open the gRPC call and run the init handshake on the already-owned stream."""
await self._stream.start(driver, self._stub, self._method) # type: ignore[attr-defined]
await self._init_and_spawn()

# ------------------------------------------------------------------ death signal

def _set_first_error(self, err: BaseException) -> None:
"""Record the first error that ended this connection; later errors are ignored."""
if self._first_error.done():
return
self._first_error.set_result(err)
self._on_first_error()

def _get_first_error(self) -> Optional[BaseException]:
return self._first_error.result() if self._first_error.done() else None

async def wait_error(self) -> None:
"""Block until this connection fails, then raise that error (the reconnect signal)."""
raise await self._first_error

def _on_first_error(self) -> None:
"""Hook: wake local subscribers (e.g. the reader's wait_messages). Default: nothing."""

# ------------------------------------------------------------------ shared update-token loop

async def _update_token_loop(self) -> None:
if self._update_token_interval is None:
return # nothing to refresh on a cadence; avoids a hot loop
while True:
await asyncio.sleep(self._update_token_interval)
if self._get_token_function is None:
return
token = self._get_token_function()
if not isinstance(token, str):
token = await token # async token providers are supported
await self._update_token(token)

async def _update_token(self, token: str) -> None:
await self._update_token_event.wait()
try:
if self._stream is not None:
self._stream.write(self._make_update_token_request(token))
finally:
self._update_token_event.clear()

# ------------------------------------------------------------------ hooks

@abc.abstractmethod
async def _init_and_spawn(self, init_message: Any = None) -> None:
"""Send the init request on ``self._stream``, validate the response, and spawn the
protocol's background workers. ``init_message`` is supplied by the legacy ``_start``
injection path; when ``None`` (the ``connect`` path) the subclass derives it itself."""

@abc.abstractmethod
def _make_update_token_request(self, token: str) -> Any:
"""Wrap ``token`` in the protocol's FromClient(UpdateTokenRequest) message."""
Loading
Loading