Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
* Query session attach stream handles `NodeShutdown` and `SessionShutdown` session hints: on `NodeShutdown` the session's node connection is pessimized and the session is retired, on `SessionShutdown` the session is retired without touching the node
* Bumped `ydb-api-protos` and regenerated gRPC/protobuf stubs (v3–v6) to include query service session hints
* Fix incompatibility with protobuf 6.30–6.31.0: regenerate v6 stubs with the lowest 6.x gencode floor (6.30.0) instead of 6.31.1
* Query session attach stream handles `NodeShutdown` and `SessionShutdown` session hints: on `NodeShutdown` the session's node connection is pessimized and the session is retired, on `SessionShutdown` the session is retired without touching the node

## 3.29.4 ##
* Fix leaked topic reader stream when close interrupts stream creation during reconnect
Expand Down
127 changes: 127 additions & 0 deletions tests/query/test_query_session_hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import asyncio
from unittest import mock

import pytest

from ydb._grpc.common.protos import ydb_query_pb2
from ydb.aio.pool import ConnectionPool as AsyncConnectionPool
from ydb.pool import ConnectionPool
from ydb.query.session import QuerySession


def _make_session(node_id=42):
driver = mock.Mock()
driver._pessimize_node = mock.Mock()
session = QuerySession(driver)
session._session_id = "test-session"
session._node_id = node_id
return session, driver


class TestQuerySessionAttachHints:
def test_node_shutdown_pessimizes_node_and_invalidates_session(self):
session, driver = _make_session(node_id=42)

session._handle_attach_session_state(
ydb_query_pb2.SessionState(
status=0,
node_shutdown=ydb_query_pb2.NodeShutdownHint(),
)
)

driver._pessimize_node.assert_called_once_with(42)
assert session._invalidated
assert session._closed

def test_session_shutdown_invalidates_without_pessimizing_node(self):
session, driver = _make_session(node_id=42)

session._handle_attach_session_state(
ydb_query_pb2.SessionState(
status=0,
session_shutdown=ydb_query_pb2.SessionShutdownHint(),
)
)

driver._pessimize_node.assert_not_called()
assert session._invalidated
assert session._closed

def test_node_shutdown_with_zero_node_id_delegates_to_driver(self):
session, driver = _make_session(node_id=0)

session._handle_attach_session_state(
ydb_query_pb2.SessionState(
status=0,
node_shutdown=ydb_query_pb2.NodeShutdownHint(),
)
)

driver._pessimize_node.assert_called_once_with(0)
assert session._invalidated

def test_node_shutdown_without_node_id_skips_pessimization(self):
session, driver = _make_session(node_id=None)

session._handle_attach_session_state(
ydb_query_pb2.SessionState(
status=0,
node_shutdown=ydb_query_pb2.NodeShutdownHint(),
)
)

driver._pessimize_node.assert_not_called()
assert session._invalidated
assert session._closed

def test_no_hint_does_not_invalidate(self):
session, driver = _make_session()

session._handle_attach_session_state(
ydb_query_pb2.SessionState(status=0),
)

driver._pessimize_node.assert_not_called()
assert not session._invalidated
assert not session._closed


class TestConnectionPoolAttachHintPessimization:
def test_sync_pool_pessimizes_node_connection(self):
pool = ConnectionPool.__new__(ConnectionPool)
connection = mock.Mock()
pool._store = mock.Mock()
pool._store.get_connection_by_node_id.return_value = connection
pool._on_disconnected = mock.Mock()

pool._pessimize_node(42)

pool._store.get_connection_by_node_id.assert_called_once_with(42)
pool._on_disconnected.assert_called_once_with(connection)

def test_sync_pool_ignores_missing_node_connection(self):
pool = ConnectionPool.__new__(ConnectionPool)
pool._store = mock.Mock()
pool._store.get_connection_by_node_id.return_value = None
pool._on_disconnected = mock.Mock()

pool._pessimize_node(42)
pool._pessimize_node(0)

pool._on_disconnected.assert_not_called()

@pytest.mark.asyncio
async def test_async_pool_pessimizes_node_connection(self):
pool = AsyncConnectionPool.__new__(AsyncConnectionPool)
connection = mock.Mock()
disconnect = mock.AsyncMock()
pool._store = mock.Mock()
pool._store.get_connection_by_node_id.return_value = connection
pool._on_disconnected = mock.Mock(return_value=disconnect)

pool._pessimize_node(42)
await asyncio.sleep(0)

pool._store.get_connection_by_node_id.assert_called_once_with(42)
pool._on_disconnected.assert_called_once_with(connection)
disconnect.assert_awaited_once_with()
11 changes: 10 additions & 1 deletion ydb/aio/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging
import random
from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING
from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING, cast

from ydb import issues
from ydb.opentelemetry.tracing import SpanName, create_ydb_span
Expand Down Expand Up @@ -311,6 +311,15 @@ async def __wrapper__() -> None:

return __wrapper__

def _pessimize_node(self, node_id: int) -> None:
"""Deprioritize the connection attached to the given YDB node."""
if node_id <= 0:
return

connection = cast(Optional[Connection], self._store.get_connection_by_node_id(node_id))
if connection is not None:
asyncio.get_running_loop().create_task(self._on_disconnected(connection)())

async def wait(self, timeout: Optional[float] = 7.0, fail_fast: bool = False) -> None: # type: ignore[override] # async override of sync method
with create_ydb_span(SpanName.DRIVER_INITIALIZE, self._driver_config, kind="internal").attach_context():
await self._store.get(fast_fail=fail_fast, wait_timeout=timeout if timeout is not None else 7.0)
Expand Down
3 changes: 1 addition & 2 deletions ydb/aio/query/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from .. import _utilities
from ... import issues
from ...settings import BaseRequestSettings
from ..._grpc.grpcwrapper import common_utils
from ..._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public

from ...query import base
Expand Down Expand Up @@ -53,7 +52,7 @@ async def _attach(self) -> None:
self._stream = await self._attach_call()
self._status_stream = _utilities.AsyncResponseIterator(
self._stream,
lambda response: common_utils.ServerStatus.from_proto(response),
self._attach_stream_wrapper,
)

try:
Expand Down
13 changes: 13 additions & 0 deletions ydb/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def remove(self, connection: Connection) -> None:
self.connections.pop(connection.endpoint, None)
self.outdated.pop(connection.endpoint, None)

def get_connection_by_node_id(self, node_id: Optional[int]) -> Optional[Connection]:
with self.lock:
return self.connections_by_node_id.get(node_id)


class Discovery(threading.Thread):
def __init__(self, store: ConnectionsCache, driver_config: "DriverConfig") -> None:
Expand Down Expand Up @@ -494,6 +498,15 @@ def _on_disconnected(self, connection: Connection) -> None:
if self._discovery_thread:
self._discovery_thread.notify_disconnected()

def _pessimize_node(self, node_id: int) -> None:
"""Deprioritize the connection attached to the given YDB node."""
if node_id <= 0:
return

connection = self._store.get_connection_by_node_id(node_id)
if connection is not None:
self._on_disconnected(connection)

def discovery_debug_details(self) -> str:
"""
Returns debug string about last errors
Expand Down
20 changes: 19 additions & 1 deletion ydb/query/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,24 @@ def _check_session_ready_to_use(self) -> None:
if self._closed:
raise RuntimeError(f"Session is not active, session_id: {self._session_id}, closed: {self._closed}")

def _attach_stream_wrapper(self, response_pb):
"""Map attach-stream protobuf frames to ServerStatus and handle session hints."""
self._handle_attach_session_state(response_pb)
return common_utils.ServerStatus.from_proto(response_pb)

def _handle_attach_session_state(self, response_pb) -> None:
"""Retire the session when the server sends a shutdown hint on the attach stream."""
if response_pb is None:
return

hint = response_pb.WhichOneof("session_hint")
if hint == "node_shutdown":
if self._node_id is not None:
self._driver._pessimize_node(self._node_id)
self._close_session(invalidate=True)
elif hint == "session_shutdown":
self._close_session(invalidate=True)

def _close_session(self, invalidate: bool = False) -> None:
if self._closed:
return
Expand Down Expand Up @@ -362,7 +380,7 @@ def _attach(self, first_resp_timeout: int = DEFAULT_INITIAL_RESPONSE_TIMEOUT) ->
self._stream = self._attach_call()
status_stream = _utilities.SyncResponseIterator(
self._stream,
lambda response: common_utils.ServerStatus.from_proto(response),
self._attach_stream_wrapper,
)

try:
Expand Down
Loading