Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions ydb/_topic_writer/topic_writer_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,8 @@ async def close(self):
self._update_token_task.cancel()
await asyncio.wait([self._update_token_task])

self._stream.close()
if getattr(self, "_stream", None) is not None:
self._stream.close()
logger.debug("writer stream %s was closed", self._id)

@staticmethod
Expand All @@ -845,15 +846,27 @@ async def create(
) -> "WriterAsyncIOStream":
stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto)

await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite)
writer = None
try:
await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite)

creds = driver._credentials
writer = WriterAsyncIOStream(
update_token_interval=update_token_interval,
get_token_function=creds.get_auth_token if creds else lambda: "",
tx_identity=tx_identity,
)
await writer._start(stream, init_request)
creds = driver._credentials
writer = WriterAsyncIOStream(
update_token_interval=update_token_interval,
get_token_function=creds.get_auth_token if creds else lambda: "",
tx_identity=tx_identity,
)
await writer._start(stream, init_request)
except BaseException:
# If create() fails after stream.start() (e.g. the connection is lost while
# waiting for the init response), the stream is not yet handed to the
# reconnector, so its finally cannot reach it. Close it here to avoid a
# stranded gRPC consumption thread that blocks forever on the request queue.
if writer is not None:
await writer.close()
else:
stream.close()
raise
logger.debug(
"writer stream %s started seqno=%s",
writer._id,
Expand All @@ -875,6 +888,10 @@ async def receive(self) -> StreamWriteMessage.WriteResponse:
raise Exception("Unknown message while read writer answers: %s" % item)

async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMessage.InitRequest):
# Assign before the init handshake (mirrors ReaderStream._start) so that if _start
# fails mid-handshake, close() can still reach the stream and release its gRPC thread.
self._stream = stream

logger.debug("writer stream %s send init request", self._id)
stream.write(StreamWriteMessage.FromClient(init_message))

Expand All @@ -895,8 +912,6 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMes
self.last_seqno,
)

self._stream = stream

if self._update_token_interval is not None:
self._update_token_event.set()
self._update_token_task = asyncio.create_task(self._update_token_loop())
Expand Down
113 changes: 112 additions & 1 deletion ydb/_topic_writer/topic_writer_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
import copy
import dataclasses
import datetime
import gc
import gzip
import sys
import typing
from concurrent.futures import ThreadPoolExecutor
from queue import Queue, Empty
from typing import List, Callable, Optional
from unittest import mock

import freezegun
import grpc
import pytest

from .. import aio
Expand All @@ -22,7 +26,7 @@
UpdateTokenRequest,
UpdateTokenResponse,
)
from .._grpc.grpcwrapper.common_utils import ServerStatus
from .._grpc.grpcwrapper.common_utils import AsyncQueueToSyncIteratorAsyncIO, ServerStatus
from .topic_writer import (
InternalMessage,
PublicMessage,
Expand Down Expand Up @@ -1035,3 +1039,110 @@ async def ack_next_messages():

res = await writer.write_with_ack([PublicMessage(seqno=2, data="123"), PublicMessage(seqno=3, data="123")])
assert res == [PublicWriteResult.Written(offset=2), PublicWriteResult.Skipped()]


_STREAM_WRITE_METHOD = "/Ydb.Topic.V1.TopicService/StreamWrite"


# The exact code object the leaked gRPC consumption thread blocks in. Matching the code
# object (instead of a module filename) avoids false positives from same-named modules in
# other dependencies and survives file renames / refactors.
_CONSUMER_NEXT_CODE = AsyncQueueToSyncIteratorAsyncIO.__next__.__code__


def _count_stranded_consumer_threads() -> int:
"""Number of threads parked in AsyncQueueToSyncIteratorAsyncIO.__next__ (the leak)."""
count = 0
for frame in sys._current_frames().values():
f: typing.Optional[typing.Any] = frame
while f is not None:
if f.f_code is _CONSUMER_NEXT_CODE:
count += 1
break
f = f.f_back
return count


class _AbortingStreamServer:
"""In-process gRPC server that accepts StreamWrite then immediately drops the stream."""

def __init__(self):
def handler(request_iterator, context):
try:
next(request_iterator) # consume the client's init request
except Exception:
pass
context.abort(grpc.StatusCode.UNAVAILABLE, "simulated node down")

rpc = grpc.stream_stream_rpc_method_handler(
handler,
request_deserializer=lambda b: b,
response_serializer=lambda b: b,
)

class _Generic(grpc.GenericRpcHandler):
def service(self, details):
return rpc if details.method == _STREAM_WRITE_METHOD else None

self._server = grpc.server(ThreadPoolExecutor(max_workers=4))
self.port = self._server.add_insecure_port("127.0.0.1:0")
self._server.add_generic_rpc_handlers((_Generic(),))
self._server.start()

def stop(self):
self._server.stop(grace=1).wait(timeout=10)


class _FakeSyncDriver:
"""Minimal stand-in for ydb.Driver's call interface used by _start_sync_driver."""

_credentials = None

def __init__(self, channel: grpc.Channel):
self._channel = channel

def __call__(self, request_iterator, stub, method, executor=None, settings=None, **kwargs):
multicallable = self._channel.stream_stream(
_STREAM_WRITE_METHOD,
request_serializer=lambda m: m.SerializeToString(),
response_deserializer=lambda b: b,
)
return multicallable(request_iterator)


@pytest.mark.asyncio
async def test_writer_create_failure_does_not_leak_grpc_thread():
"""Regression: a failed WriterAsyncIOStream.create() must not strand a gRPC consumer thread.

Uses a real in-process gRPC stream so the consumption thread is actually spawned;
mocked-create tests cannot catch this leak.
"""
server = _AbortingStreamServer()
channel = grpc.insecure_channel("127.0.0.1:%d" % server.port)
driver = _FakeSyncDriver(channel)
init = WriterSettings(PublicWriterSettings("/local/topic", "producer-id")).create_init_request()

try:
baseline = _count_stranded_consumer_threads()
attempts = 10
for _ in range(attempts):
with pytest.raises(issues.Error):
await WriterAsyncIOStream.create(driver, init) # type: ignore[arg-type]

# Give closed streams a moment to let their consumption threads exit, then assert
# the count returned to the baseline (no net new stranded threads vs other tests).
leaked = attempts
for _ in range(30):
gc.collect()
await asyncio.sleep(0.1)
leaked = _count_stranded_consumer_threads() - baseline
if leaked <= 0:
break

assert leaked <= 0, "%d gRPC consumer threads leaked after %d failed create() calls" % (
leaked,
attempts,
)
finally:
channel.close()
server.stop()
Loading