diff --git a/CHANGELOG.md b/CHANGELOG.md index 23662f0a9..725c203d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ +* Query session attach stream handles `NodeShutdown` and `SessionShutdown` session hints: on `NodeShutdown` the session's node connection is pessimized and the session is retired, on `SessionShutdown` the session is retired without touching the node +* Bumped `ydb-api-protos` and regenerated gRPC/protobuf stubs (v3–v6) to include query service session hints * Fix incompatibility with protobuf 6.30–6.31.0: regenerate v6 stubs with the lowest 6.x gencode floor (6.30.0) instead of 6.31.1 +* Query session attach stream handles `NodeShutdown` and `SessionShutdown` session hints: on `NodeShutdown` the session's node connection is pessimized and the session is retired, on `SessionShutdown` the session is retired without touching the node ## 3.29.4 ## * Fix leaked topic reader stream when close interrupts stream creation during reconnect diff --git a/tests/query/test_query_session_hints.py b/tests/query/test_query_session_hints.py new file mode 100644 index 000000000..ad9c4bce2 --- /dev/null +++ b/tests/query/test_query_session_hints.py @@ -0,0 +1,127 @@ +import asyncio +from unittest import mock + +import pytest + +from ydb._grpc.common.protos import ydb_query_pb2 +from ydb.aio.pool import ConnectionPool as AsyncConnectionPool +from ydb.pool import ConnectionPool +from ydb.query.session import QuerySession + + +def _make_session(node_id=42): + driver = mock.Mock() + driver._pessimize_node = mock.Mock() + session = QuerySession(driver) + session._session_id = "test-session" + session._node_id = node_id + return session, driver + + +class TestQuerySessionAttachHints: + def test_node_shutdown_pessimizes_node_and_invalidates_session(self): + session, driver = _make_session(node_id=42) + + session._handle_attach_session_state( + ydb_query_pb2.SessionState( + status=0, + node_shutdown=ydb_query_pb2.NodeShutdownHint(), + ) + ) + + driver._pessimize_node.assert_called_once_with(42) + assert session._invalidated + assert session._closed + + def test_session_shutdown_invalidates_without_pessimizing_node(self): + session, driver = _make_session(node_id=42) + + session._handle_attach_session_state( + ydb_query_pb2.SessionState( + status=0, + session_shutdown=ydb_query_pb2.SessionShutdownHint(), + ) + ) + + driver._pessimize_node.assert_not_called() + assert session._invalidated + assert session._closed + + def test_node_shutdown_with_zero_node_id_delegates_to_driver(self): + session, driver = _make_session(node_id=0) + + session._handle_attach_session_state( + ydb_query_pb2.SessionState( + status=0, + node_shutdown=ydb_query_pb2.NodeShutdownHint(), + ) + ) + + driver._pessimize_node.assert_called_once_with(0) + assert session._invalidated + + def test_node_shutdown_without_node_id_skips_pessimization(self): + session, driver = _make_session(node_id=None) + + session._handle_attach_session_state( + ydb_query_pb2.SessionState( + status=0, + node_shutdown=ydb_query_pb2.NodeShutdownHint(), + ) + ) + + driver._pessimize_node.assert_not_called() + assert session._invalidated + assert session._closed + + def test_no_hint_does_not_invalidate(self): + session, driver = _make_session() + + session._handle_attach_session_state( + ydb_query_pb2.SessionState(status=0), + ) + + driver._pessimize_node.assert_not_called() + assert not session._invalidated + assert not session._closed + + +class TestConnectionPoolAttachHintPessimization: + def test_sync_pool_pessimizes_node_connection(self): + pool = ConnectionPool.__new__(ConnectionPool) + connection = mock.Mock() + pool._store = mock.Mock() + pool._store.get_connection_by_node_id.return_value = connection + pool._on_disconnected = mock.Mock() + + pool._pessimize_node(42) + + pool._store.get_connection_by_node_id.assert_called_once_with(42) + pool._on_disconnected.assert_called_once_with(connection) + + def test_sync_pool_ignores_missing_node_connection(self): + pool = ConnectionPool.__new__(ConnectionPool) + pool._store = mock.Mock() + pool._store.get_connection_by_node_id.return_value = None + pool._on_disconnected = mock.Mock() + + pool._pessimize_node(42) + pool._pessimize_node(0) + + pool._on_disconnected.assert_not_called() + + @pytest.mark.asyncio + async def test_async_pool_pessimizes_node_connection(self): + pool = AsyncConnectionPool.__new__(AsyncConnectionPool) + connection = mock.Mock() + disconnect = mock.AsyncMock() + pool._store = mock.Mock() + pool._store.get_connection_by_node_id.return_value = connection + pool._on_disconnected = mock.Mock(return_value=disconnect) + + pool._pessimize_node(42) + await asyncio.sleep(0) + + pool._store.get_connection_by_node_id.assert_called_once_with(42) + pool._on_disconnected.assert_called_once_with(connection) + disconnect.assert_awaited_once_with() diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 4d952086c..8d5b0b227 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -3,7 +3,7 @@ import asyncio import logging import random -from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING +from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING, cast from ydb import issues from ydb.opentelemetry.tracing import SpanName, create_ydb_span @@ -311,6 +311,15 @@ async def __wrapper__() -> None: return __wrapper__ + def _pessimize_node(self, node_id: int) -> None: + """Deprioritize the connection attached to the given YDB node.""" + if node_id <= 0: + return + + connection = cast(Optional[Connection], self._store.get_connection_by_node_id(node_id)) + if connection is not None: + asyncio.get_running_loop().create_task(self._on_disconnected(connection)()) + async def wait(self, timeout: Optional[float] = 7.0, fail_fast: bool = False) -> None: # type: ignore[override] # async override of sync method with create_ydb_span(SpanName.DRIVER_INITIALIZE, self._driver_config, kind="internal").attach_context(): await self._store.get(fast_fail=fail_fast, wait_timeout=timeout if timeout is not None else 7.0) diff --git a/ydb/aio/query/session.py b/ydb/aio/query/session.py index b776b6382..1c0db50ad 100644 --- a/ydb/aio/query/session.py +++ b/ydb/aio/query/session.py @@ -14,7 +14,6 @@ from .. import _utilities from ... import issues from ...settings import BaseRequestSettings -from ..._grpc.grpcwrapper import common_utils from ..._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public from ...query import base @@ -53,7 +52,7 @@ async def _attach(self) -> None: self._stream = await self._attach_call() self._status_stream = _utilities.AsyncResponseIterator( self._stream, - lambda response: common_utils.ServerStatus.from_proto(response), + self._attach_stream_wrapper, ) try: diff --git a/ydb/pool.py b/ydb/pool.py index f1a4bdf94..a77359c23 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -160,6 +160,10 @@ def remove(self, connection: Connection) -> None: self.connections.pop(connection.endpoint, None) self.outdated.pop(connection.endpoint, None) + def get_connection_by_node_id(self, node_id: Optional[int]) -> Optional[Connection]: + with self.lock: + return self.connections_by_node_id.get(node_id) + class Discovery(threading.Thread): def __init__(self, store: ConnectionsCache, driver_config: "DriverConfig") -> None: @@ -494,6 +498,15 @@ def _on_disconnected(self, connection: Connection) -> None: if self._discovery_thread: self._discovery_thread.notify_disconnected() + def _pessimize_node(self, node_id: int) -> None: + """Deprioritize the connection attached to the given YDB node.""" + if node_id <= 0: + return + + connection = self._store.get_connection_by_node_id(node_id) + if connection is not None: + self._on_disconnected(connection) + def discovery_debug_details(self) -> str: """ Returns debug string about last errors diff --git a/ydb/query/session.py b/ydb/query/session.py index a9c1b4a50..661e24011 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -156,6 +156,24 @@ def _check_session_ready_to_use(self) -> None: if self._closed: raise RuntimeError(f"Session is not active, session_id: {self._session_id}, closed: {self._closed}") + def _attach_stream_wrapper(self, response_pb): + """Map attach-stream protobuf frames to ServerStatus and handle session hints.""" + self._handle_attach_session_state(response_pb) + return common_utils.ServerStatus.from_proto(response_pb) + + def _handle_attach_session_state(self, response_pb) -> None: + """Retire the session when the server sends a shutdown hint on the attach stream.""" + if response_pb is None: + return + + hint = response_pb.WhichOneof("session_hint") + if hint == "node_shutdown": + if self._node_id is not None: + self._driver._pessimize_node(self._node_id) + self._close_session(invalidate=True) + elif hint == "session_shutdown": + self._close_session(invalidate=True) + def _close_session(self, invalidate: bool = False) -> None: if self._closed: return @@ -362,7 +380,7 @@ def _attach(self, first_resp_timeout: int = DEFAULT_INITIAL_RESPONSE_TIMEOUT) -> self._stream = self._attach_call() status_stream = _utilities.SyncResponseIterator( self._stream, - lambda response: common_utils.ServerStatus.from_proto(response), + self._attach_stream_wrapper, ) try: