diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 5e55917d..ff9c59f7 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -833,7 +833,8 @@ async def close(self): self._update_token_task.cancel() await asyncio.wait([self._update_token_task]) - self._stream.close() + if getattr(self, "_stream", None) is not None: + self._stream.close() logger.debug("writer stream %s was closed", self._id) @staticmethod @@ -845,15 +846,27 @@ async def create( ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) - await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) + writer = None + try: + await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite) - creds = driver._credentials - writer = WriterAsyncIOStream( - update_token_interval=update_token_interval, - get_token_function=creds.get_auth_token if creds else lambda: "", - tx_identity=tx_identity, - ) - await writer._start(stream, init_request) + creds = driver._credentials + writer = WriterAsyncIOStream( + update_token_interval=update_token_interval, + get_token_function=creds.get_auth_token if creds else lambda: "", + tx_identity=tx_identity, + ) + await writer._start(stream, init_request) + except BaseException: + # If create() fails after stream.start() (e.g. the connection is lost while + # waiting for the init response), the stream is not yet handed to the + # reconnector, so its finally cannot reach it. Close it here to avoid a + # stranded gRPC consumption thread that blocks forever on the request queue. + if writer is not None: + await writer.close() + else: + stream.close() + raise logger.debug( "writer stream %s started seqno=%s", writer._id, @@ -875,6 +888,10 @@ async def receive(self) -> StreamWriteMessage.WriteResponse: raise Exception("Unknown message while read writer answers: %s" % item) async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest): + # Assign before the init handshake (mirrors ReaderStream._start) so that if _start + # fails mid-handshake, close() can still reach the stream and release its gRPC thread. + self._stream = stream + logger.debug("writer stream %s send init request", self._id) stream.write(StreamWriteMessage.FromClient(init_message)) @@ -895,8 +912,6 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMes self.last_seqno, ) - self._stream = stream - if self._update_token_interval is not None: self._update_token_event.set() self._update_token_task = asyncio.create_task(self._update_token_loop()) diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index a4ded384..fb095dd3 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -4,13 +4,17 @@ import copy import dataclasses import datetime +import gc import gzip +import sys import typing +from concurrent.futures import ThreadPoolExecutor from queue import Queue, Empty from typing import List, Callable, Optional from unittest import mock import freezegun +import grpc import pytest from .. import aio @@ -22,7 +26,7 @@ UpdateTokenRequest, UpdateTokenResponse, ) -from .._grpc.grpcwrapper.common_utils import ServerStatus +from .._grpc.grpcwrapper.common_utils import AsyncQueueToSyncIteratorAsyncIO, ServerStatus from .topic_writer import ( InternalMessage, PublicMessage, @@ -1035,3 +1039,110 @@ async def ack_next_messages(): res = await writer.write_with_ack([PublicMessage(seqno=2, data="123"), PublicMessage(seqno=3, data="123")]) assert res == [PublicWriteResult.Written(offset=2), PublicWriteResult.Skipped()] + + +_STREAM_WRITE_METHOD = "/Ydb.Topic.V1.TopicService/StreamWrite" + + +# The exact code object the leaked gRPC consumption thread blocks in. Matching the code +# object (instead of a module filename) avoids false positives from same-named modules in +# other dependencies and survives file renames / refactors. +_CONSUMER_NEXT_CODE = AsyncQueueToSyncIteratorAsyncIO.__next__.__code__ + + +def _count_stranded_consumer_threads() -> int: + """Number of threads parked in AsyncQueueToSyncIteratorAsyncIO.__next__ (the leak).""" + count = 0 + for frame in sys._current_frames().values(): + f: typing.Optional[typing.Any] = frame + while f is not None: + if f.f_code is _CONSUMER_NEXT_CODE: + count += 1 + break + f = f.f_back + return count + + +class _AbortingStreamServer: + """In-process gRPC server that accepts StreamWrite then immediately drops the stream.""" + + def __init__(self): + def handler(request_iterator, context): + try: + next(request_iterator) # consume the client's init request + except Exception: + pass + context.abort(grpc.StatusCode.UNAVAILABLE, "simulated node down") + + rpc = grpc.stream_stream_rpc_method_handler( + handler, + request_deserializer=lambda b: b, + response_serializer=lambda b: b, + ) + + class _Generic(grpc.GenericRpcHandler): + def service(self, details): + return rpc if details.method == _STREAM_WRITE_METHOD else None + + self._server = grpc.server(ThreadPoolExecutor(max_workers=4)) + self.port = self._server.add_insecure_port("127.0.0.1:0") + self._server.add_generic_rpc_handlers((_Generic(),)) + self._server.start() + + def stop(self): + self._server.stop(grace=1).wait(timeout=10) + + +class _FakeSyncDriver: + """Minimal stand-in for ydb.Driver's call interface used by _start_sync_driver.""" + + _credentials = None + + def __init__(self, channel: grpc.Channel): + self._channel = channel + + def __call__(self, request_iterator, stub, method, executor=None, settings=None, **kwargs): + multicallable = self._channel.stream_stream( + _STREAM_WRITE_METHOD, + request_serializer=lambda m: m.SerializeToString(), + response_deserializer=lambda b: b, + ) + return multicallable(request_iterator) + + +@pytest.mark.asyncio +async def test_writer_create_failure_does_not_leak_grpc_thread(): + """Regression: a failed WriterAsyncIOStream.create() must not strand a gRPC consumer thread. + + Uses a real in-process gRPC stream so the consumption thread is actually spawned; + mocked-create tests cannot catch this leak. + """ + server = _AbortingStreamServer() + channel = grpc.insecure_channel("127.0.0.1:%d" % server.port) + driver = _FakeSyncDriver(channel) + init = WriterSettings(PublicWriterSettings("/local/topic", "producer-id")).create_init_request() + + try: + baseline = _count_stranded_consumer_threads() + attempts = 10 + for _ in range(attempts): + with pytest.raises(issues.Error): + await WriterAsyncIOStream.create(driver, init) # type: ignore[arg-type] + + # Give closed streams a moment to let their consumption threads exit, then assert + # the count returned to the baseline (no net new stranded threads vs other tests). + leaked = attempts + for _ in range(30): + gc.collect() + await asyncio.sleep(0.1) + leaked = _count_stranded_consumer_threads() - baseline + if leaked <= 0: + break + + assert leaked <= 0, "%d gRPC consumer threads leaked after %d failed create() calls" % ( + leaked, + attempts, + ) + finally: + channel.close() + server.stop()