From b373a6617ffb9e396ccfc869c7c65450d900f761 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 30 Jun 2026 11:16:38 +0300 Subject: [PATCH 1/4] Fix gRPC thread leak on failed topic writer reconnect WriterAsyncIOStream.create() left the stream open when startup failed after stream.start(), leaking a gRPC consumer thread per reconnect attempt. Close it on failure (mirrors ReaderStream.create) and add a regression test. --- ydb/_topic_writer/topic_writer_asyncio.py | 28 +++-- .../topic_writer_asyncio_test.py | 103 ++++++++++++++++++ 2 files changed, 123 insertions(+), 8 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 5e55917d8..f93e8a7ad 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -845,15 +845,27 @@ async def create( ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) - await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) + writer = None + try: + await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) - creds = driver._credentials - writer = WriterAsyncIOStream( - update_token_interval=update_token_interval, - get_token_function=creds.get_auth_token if creds else lambda: "", - tx_identity=tx_identity, - ) - await writer._start(stream, init_request) + creds = driver._credentials + writer = WriterAsyncIOStream( + update_token_interval=update_token_interval, + get_token_function=creds.get_auth_token if creds else lambda: "", + tx_identity=tx_identity, + ) + await writer._start(stream, init_request) + except BaseException: + # If create() fails after stream.start() (e.g. the connection is lost while + # waiting for the init response), the stream is not yet assigned to the + # reconnector, so its finally cannot reach it. Close it here to avoid a + # stranded gRPC consumption thread that blocks forever on the request queue. + if writer is not None and getattr(writer, "_stream", None) is not None: + await writer.close() + else: + stream.close() + raise logger.debug( "writer stream %s started seqno=%s", writer._id, diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index a4ded3847..aa82c0588 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -4,13 +4,17 @@ import copy import dataclasses import datetime +import gc import gzip +import sys import typing +from concurrent.futures import ThreadPoolExecutor from queue import Queue, Empty from typing import List, Callable, Optional from unittest import mock import freezegun +import grpc import pytest from .. import aio @@ -1035,3 +1039,102 @@ async def ack_next_messages(): res = await writer.write_with_ack([PublicMessage(seqno=2, data="123"), PublicMessage(seqno=3, data="123")]) assert res == [PublicWriteResult.Written(offset=2), PublicWriteResult.Skipped()] + + +_STREAM_WRITE_METHOD = "/Ydb.Topic.V1.TopicService/StreamWrite" + + +def _count_stranded_consumer_threads() -> int: + """Number of threads parked in AsyncQueueToSyncIteratorAsyncIO.__next__ (the leak).""" + count = 0 + for frame in sys._current_frames().values(): + f: typing.Optional[typing.Any] = frame + while f is not None: + if f.f_code.co_name == "__next__" and f.f_code.co_filename.endswith("common_utils.py"): + count += 1 + break + f = f.f_back + return count + + +class _AbortingStreamServer: + """In-process gRPC server that accepts StreamWrite then immediately drops the stream.""" + + def __init__(self): + def handler(request_iterator, context): + try: + next(request_iterator) # consume the client's init request + except Exception: + pass + context.abort(grpc.StatusCode.UNAVAILABLE, "simulated node down") + + rpc = grpc.stream_stream_rpc_method_handler( + handler, + request_deserializer=lambda b: b, + response_serializer=lambda b: b, + ) + + class _Generic(grpc.GenericRpcHandler): + def service(self, details): + return rpc + + self._server = grpc.server(ThreadPoolExecutor(max_workers=4)) + self.port = self._server.add_insecure_port("127.0.0.1:0") + self._server.add_generic_rpc_handlers((_Generic(),)) + self._server.start() + + def stop(self): + self._server.stop(None) + + +class _FakeSyncDriver: + """Minimal stand-in for ydb.Driver's call interface used by _start_sync_driver.""" + + _credentials = None + + def __init__(self, channel: grpc.Channel): + self._channel = channel + + def __call__(self, request_iterator, stub, method, executor=None, settings=None, **kwargs): + multicallable = self._channel.stream_stream( + _STREAM_WRITE_METHOD, + request_serializer=lambda m: m.SerializeToString(), + response_deserializer=lambda b: b, + ) + return multicallable(request_iterator) + + +@pytest.mark.asyncio +async def test_writer_create_failure_does_not_leak_grpc_thread(): + """Regression: a failed WriterAsyncIOStream.create() must not strand a gRPC consumer thread. + + Uses a real in-process gRPC stream so the consumption thread is actually spawned; + mocked-create tests cannot catch this leak. + """ + server = _AbortingStreamServer() + channel = grpc.insecure_channel("127.0.0.1:%d" % server.port) + driver = _FakeSyncDriver(channel) + init = WriterSettings(PublicWriterSettings("/local/topic", "producer-id")).create_init_request() + + try: + attempts = 10 + for _ in range(attempts): + with pytest.raises(issues.Error): + await WriterAsyncIOStream.create(driver, init) # type: ignore[arg-type] + + # Give closed streams a moment to let their consumption threads exit. + stranded = attempts + for _ in range(30): + gc.collect() + await asyncio.sleep(0.1) + stranded = _count_stranded_consumer_threads() + if stranded == 0: + break + + assert stranded == 0, "%d gRPC consumer threads leaked after %d failed create() calls" % ( + stranded, + attempts, + ) + finally: + channel.close() + server.stop() From b2db056f7eaec6b7fe534553424e2fbd955f76ec Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 30 Jun 2026 11:50:25 +0300 Subject: [PATCH 2/4] Address review comments Set _stream at the start of _start() and guard close() (mirrors ReaderStream); match the exact __next__ code object for leak detection and wait for the test gRPC server to stop. --- ydb/_topic_writer/topic_writer_asyncio.py | 13 ++++++++----- ydb/_topic_writer/topic_writer_asyncio_test.py | 12 +++++++++--- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index f93e8a7ad..ff9c59f79 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -833,7 +833,8 @@ async def close(self): self._update_token_task.cancel() await asyncio.wait([self._update_token_task]) - self._stream.close() + if getattr(self, "_stream", None) is not None: + self._stream.close() logger.debug("writer stream %s was closed", self._id) @staticmethod @@ -858,10 +859,10 @@ async def create( await writer._start(stream, init_request) except BaseException: # If create() fails after stream.start() (e.g. the connection is lost while - # waiting for the init response), the stream is not yet assigned to the + # waiting for the init response), the stream is not yet handed to the # reconnector, so its finally cannot reach it. Close it here to avoid a # stranded gRPC consumption thread that blocks forever on the request queue. - if writer is not None and getattr(writer, "_stream", None) is not None: + if writer is not None: await writer.close() else: stream.close() @@ -887,6 +888,10 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: raise Exception("Unknown message while read writer answers: %s" % item) async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest): + # Assign before the init handshake (mirrors ReaderStream._start) so that if _start + # fails mid-handshake, close() can still reach the stream and release its gRPC thread. + self._stream = stream + logger.debug("writer stream %s send init request", self._id) stream.write(StreamWriteMessage.FromClient(init_message)) @@ -907,8 +912,6 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMes self.last_seqno, ) - self._stream = stream - if self._update_token_interval is not None: self._update_token_event.set() self._update_token_task = asyncio.create_task(self._update_token_loop()) diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index aa82c0588..949082811 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -26,7 +26,7 @@ UpdateTokenRequest, UpdateTokenResponse, ) -from .._grpc.grpcwrapper.common_utils import ServerStatus +from .._grpc.grpcwrapper.common_utils import AsyncQueueToSyncIteratorAsyncIO, ServerStatus from .topic_writer import ( InternalMessage, PublicMessage, @@ -1044,13 +1044,19 @@ async def ack_next_messages(): _STREAM_WRITE_METHOD = "/Ydb.Topic.V1.TopicService/StreamWrite" +# The exact code object the leaked gRPC consumption thread blocks in. Matching the code +# object (instead of a module filename) avoids false positives from same-named modules in +# other dependencies and survives file renames / refactors. +_CONSUMER_NEXT_CODE = AsyncQueueToSyncIteratorAsyncIO.__next__.__code__ + + def _count_stranded_consumer_threads() -> int: """Number of threads parked in AsyncQueueToSyncIteratorAsyncIO.__next__ (the leak).""" count = 0 for frame in sys._current_frames().values(): f: typing.Optional[typing.Any] = frame while f is not None: - if f.f_code.co_name == "__next__" and f.f_code.co_filename.endswith("common_utils.py"): + if f.f_code is _CONSUMER_NEXT_CODE: count += 1 break f = f.f_back @@ -1084,7 +1090,7 @@ def service(self, details): self._server.start() def stop(self): - self._server.stop(None) + self._server.stop(grace=1).wait(timeout=10) class _FakeSyncDriver: From ed64d738a05f9052ec5a07257f93e350595093b1 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 30 Jun 2026 13:14:11 +0300 Subject: [PATCH 3/4] Address review comments (round 2) Restrict the test gRPC handler to StreamWrite and assert against a baseline thread count. --- ydb/_topic_writer/topic_writer_asyncio.py | 37 ++++++------------- .../topic_writer_asyncio_test.py | 16 ++++---- 2 files changed, 20 insertions(+), 33 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index ff9c59f79..5e55917d8 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -833,8 +833,7 @@ async def close(self): self._update_token_task.cancel() await asyncio.wait([self._update_token_task]) - if getattr(self, "_stream", None) is not None: - self._stream.close() + self._stream.close() logger.debug("writer stream %s was closed", self._id) @staticmethod @@ -846,27 +845,15 @@ async def create( ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) - writer = None - try: - await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) + await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) - creds = driver._credentials - writer = WriterAsyncIOStream( - update_token_interval=update_token_interval, - get_token_function=creds.get_auth_token if creds else lambda: "", - tx_identity=tx_identity, - ) - await writer._start(stream, init_request) - except BaseException: - # If create() fails after stream.start() (e.g. the connection is lost while - # waiting for the init response), the stream is not yet handed to the - # reconnector, so its finally cannot reach it. Close it here to avoid a - # stranded gRPC consumption thread that blocks forever on the request queue. - if writer is not None: - await writer.close() - else: - stream.close() - raise + creds = driver._credentials + writer = WriterAsyncIOStream( + update_token_interval=update_token_interval, + get_token_function=creds.get_auth_token if creds else lambda: "", + tx_identity=tx_identity, + ) + await writer._start(stream, init_request) logger.debug( "writer stream %s started seqno=%s", writer._id, @@ -888,10 +875,6 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: raise Exception("Unknown message while read writer answers: %s" % item) async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest): - # Assign before the init handshake (mirrors ReaderStream._start) so that if _start - # fails mid-handshake, close() can still reach the stream and release its gRPC thread. - self._stream = stream - logger.debug("writer stream %s send init request", self._id) stream.write(StreamWriteMessage.FromClient(init_message)) @@ -912,6 +895,8 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMes self.last_seqno, ) + self._stream = stream + if self._update_token_interval is not None: self._update_token_event.set() self._update_token_task = asyncio.create_task(self._update_token_loop()) diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index 949082811..fb095dd39 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -1082,7 +1082,7 @@ def handler(request_iterator, context): class _Generic(grpc.GenericRpcHandler): def service(self, details): - return rpc + return rpc if details.method == _STREAM_WRITE_METHOD else None self._server = grpc.server(ThreadPoolExecutor(max_workers=4)) self.port = self._server.add_insecure_port("127.0.0.1:0") @@ -1123,22 +1123,24 @@ async def test_writer_create_failure_does_not_leak_grpc_thread(): init = WriterSettings(PublicWriterSettings("/local/topic", "producer-id")).create_init_request() try: + baseline = _count_stranded_consumer_threads() attempts = 10 for _ in range(attempts): with pytest.raises(issues.Error): await WriterAsyncIOStream.create(driver, init) # type: ignore[arg-type] - # Give closed streams a moment to let their consumption threads exit. - stranded = attempts + # Give closed streams a moment to let their consumption threads exit, then assert + # the count returned to the baseline (no net new stranded threads vs other tests). + leaked = attempts for _ in range(30): gc.collect() await asyncio.sleep(0.1) - stranded = _count_stranded_consumer_threads() - if stranded == 0: + leaked = _count_stranded_consumer_threads() - baseline + if leaked <= 0: break - assert stranded == 0, "%d gRPC consumer threads leaked after %d failed create() calls" % ( - stranded, + assert leaked <= 0, "%d gRPC consumer threads leaked after %d failed create() calls" % ( + leaked, attempts, ) finally: From 02228a7895aa8c3620488e158a8db969c8a128e8 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 30 Jun 2026 13:15:37 +0300 Subject: [PATCH 4/4] Restore writer fix dropped by previous commit --- ydb/_topic_writer/topic_writer_asyncio.py | 37 ++++++++++++++++------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 5e55917d8..ff9c59f79 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -833,7 +833,8 @@ async def close(self): self._update_token_task.cancel() await asyncio.wait([self._update_token_task]) - self._stream.close() + if getattr(self, "_stream", None) is not None: + self._stream.close() logger.debug("writer stream %s was closed", self._id) @staticmethod @@ -845,15 +846,27 @@ async def create( ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) - await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) + writer = None + try: + await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) - creds = driver._credentials - writer = WriterAsyncIOStream( - update_token_interval=update_token_interval, - get_token_function=creds.get_auth_token if creds else lambda: "", - tx_identity=tx_identity, - ) - await writer._start(stream, init_request) + creds = driver._credentials + writer = WriterAsyncIOStream( + update_token_interval=update_token_interval, + get_token_function=creds.get_auth_token if creds else lambda: "", + tx_identity=tx_identity, + ) + await writer._start(stream, init_request) + except BaseException: + # If create() fails after stream.start() (e.g. the connection is lost while + # waiting for the init response), the stream is not yet handed to the + # reconnector, so its finally cannot reach it. Close it here to avoid a + # stranded gRPC consumption thread that blocks forever on the request queue. + if writer is not None: + await writer.close() + else: + stream.close() + raise logger.debug( "writer stream %s started seqno=%s", writer._id, @@ -875,6 +888,10 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: raise Exception("Unknown message while read writer answers: %s" % item) async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest): + # Assign before the init handshake (mirrors ReaderStream._start) so that if _start + # fails mid-handshake, close() can still reach the stream and release its gRPC thread. + self._stream = stream + logger.debug("writer stream %s send init request", self._id) stream.write(StreamWriteMessage.FromClient(init_message)) @@ -895,8 +912,6 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMes self.last_seqno, ) - self._stream = stream - if self._update_token_interval is not None: self._update_token_event.set() self._update_token_task = asyncio.create_task(self._update_token_loop())