From 8c5ad710d3f36bac37e7401d65d190b35810d554 Mon Sep 17 00:00:00 2001 From: I763688 Date: Thu, 21 May 2026 08:30:03 +0100 Subject: [PATCH 1/6] feat: add configurable token cache for customer flow --- src/sap_cloud_sdk/agentgateway/_customer.py | 180 ++++++++++-- .../agentgateway/_token_cache.py | 189 ++++++++++++ src/sap_cloud_sdk/agentgateway/agw_client.py | 19 +- src/sap_cloud_sdk/agentgateway/config.py | 18 ++ tests/agentgateway/unit/test_agw_client.py | 278 +++++++++++++++++- tests/agentgateway/unit/test_customer.py | 77 ++++- tests/agentgateway/unit/test_token_cache.py | 114 +++++++ 7 files changed, 844 insertions(+), 31 deletions(-) create mode 100644 src/sap_cloud_sdk/agentgateway/_token_cache.py create mode 100644 tests/agentgateway/unit/test_token_cache.py diff --git a/src/sap_cloud_sdk/agentgateway/_customer.py b/src/sap_cloud_sdk/agentgateway/_customer.py index 0f6ffb4..6f076ce 100644 --- a/src/sap_cloud_sdk/agentgateway/_customer.py +++ b/src/sap_cloud_sdk/agentgateway/_customer.py @@ -25,6 +25,8 @@ IntegrationDependency, MCPTool, ) +from sap_cloud_sdk.agentgateway._token_cache import _TokenCache, compute_expires_at +from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError logger = logging.getLogger(__name__) @@ -211,19 +213,24 @@ def _request_token_mtls( credentials: CustomerCredentials, grant_type: str, timeout: float, + config: ClientConfig, app_tid: str | None = None, extra_data: dict | None = None, -) -> str: +) -> tuple[str, float]: """Make mTLS token request to IAS. Args: credentials: Customer credentials with certificate and private key. grant_type: OAuth2 grant type. + timeout: HTTP timeout in seconds. + config: Client configuration (used to compute cache expiry). app_tid: BTP Application Tenant ID of subscriber (optional). extra_data: Additional form data for the token request. Returns: - Access token string. + Tuple of (access_token, expires_at) where expires_at is a + time.monotonic() value indicating when the cached token should + be refreshed (already includes the configured buffer). Raises: AgentGatewaySDKError: If token request fails. @@ -282,8 +289,10 @@ def _request_token_mtls( f"Token response missing 'access_token'. Keys: {list(token_data.keys())}" ) + expires_at = compute_expires_at(token_data, config) + logger.debug("Token acquired successfully (length: %d)", len(access_token)) - return access_token + return access_token, expires_at except httpx.RequestError as e: raise AgentGatewaySDKError(f"Token request failed: {e}") @@ -292,61 +301,87 @@ def _request_token_mtls( def get_system_token_mtls( credentials: CustomerCredentials, timeout: float, + config: ClientConfig, + cache: _TokenCache, app_tid: str | None = None, ) -> str: """Get system-scoped token using mTLS client credentials flow. - Used for tool discovery where user identity is not needed. + Used for tool discovery where user identity is not needed. Returns + a cached token if still valid; otherwise acquires a fresh one. Args: credentials: Customer credentials. timeout: HTTP timeout in seconds. + config: Client configuration. + cache: Token cache to consult and update. app_tid: BTP Application Tenant ID of subscriber (optional). Returns: System-scoped access token. """ + cached = cache.get_system_token(app_tid) + if cached: + logger.debug("Using cached system token (app_tid=%s)", app_tid) + return cached + logger.info("Acquiring system token via mTLS client credentials") - return _request_token_mtls( + token, expires_at = _request_token_mtls( credentials, grant_type=_GRANT_TYPE_CLIENT_CREDENTIALS, timeout=timeout, + config=config, app_tid=app_tid, extra_data={"response_type": "token"}, ) + cache.set_system_token(token, expires_at, app_tid) + return token def exchange_user_token( credentials: CustomerCredentials, user_token: str, timeout: float, + config: ClientConfig, + cache: _TokenCache, app_tid: str | None = None, ) -> str: """Exchange user token for AGW-scoped token using jwt-bearer grant. Used for tool invocation where user identity must be preserved - for principal propagation. + for principal propagation. Returns a cached exchanged token if + still valid; otherwise acquires a fresh one. Args: credentials: Customer credentials. user_token: User's JWT token to exchange. timeout: HTTP timeout in seconds. + config: Client configuration. + cache: Token cache to consult and update. app_tid: BTP Application Tenant ID of subscriber (optional). Returns: AGW-scoped access token with user identity. """ + cached = cache.get_user_token(user_token, app_tid) + if cached: + logger.debug("Using cached user token (app_tid=%s)", app_tid) + return cached + logger.info("Exchanging user token for AGW-scoped token via jwt-bearer grant") - return _request_token_mtls( + token, expires_at = _request_token_mtls( credentials, grant_type=_GRANT_TYPE_JWT_BEARER, timeout=timeout, + config=config, app_tid=app_tid, extra_data={ "assertion": user_token, "token_format": "jwt", }, ) + cache.set_user_token(user_token, token, expires_at, app_tid) + return token def _build_mcp_url(gateway_url: str, ord_id: str, gt_id: str) -> str: @@ -433,6 +468,8 @@ async def _list_server_tools( async def get_mcp_tools_customer( credentials: CustomerCredentials, timeout: float, + config: ClientConfig, + cache: _TokenCache, app_tid: str | None = None, ) -> list[MCPTool]: """List all MCP tools from servers defined in credentials. @@ -442,6 +479,9 @@ async def get_mcp_tools_customer( Args: credentials: Customer credentials with integrationDependencies. + timeout: HTTP timeout in seconds. + config: Client configuration. + cache: Token cache shared across calls. app_tid: BTP Application Tenant ID of subscriber (optional). Returns: @@ -462,7 +502,7 @@ async def get_mcp_tools_customer( # Get system token for discovery loop = asyncio.get_running_loop() system_token = await loop.run_in_executor( - None, get_system_token_mtls, credentials, timeout, app_tid + None, get_system_token_mtls, credentials, timeout, config, cache, app_tid ) tools: list[MCPTool] = [] @@ -480,7 +520,42 @@ async def get_mcp_tools_customer( server_tools = await _list_server_tools(url, system_token, dep, timeout) tools.extend(server_tools) logger.debug("Loaded %d tool(s) from %s", len(server_tools), dep.ord_id) - except Exception: + except Exception as exc: + unwrapped = _unwrap_exception_group(exc) + if _is_unauthorized(unwrapped): + logger.info( + "401 from %s — invalidating cached system token and retrying", + dep.ord_id, + ) + cache.invalidate_system_token(app_tid) + try: + fresh_token = await loop.run_in_executor( + None, + get_system_token_mtls, + credentials, + timeout, + config, + cache, + app_tid, + ) + server_tools = await _list_server_tools( + url, fresh_token, dep, timeout + ) + tools.extend(server_tools) + # Replace stale token for remaining iterations + system_token = fresh_token + logger.debug( + "Loaded %d tool(s) from %s after retry", + len(server_tools), + dep.ord_id, + ) + continue + except Exception: + logger.exception( + "Failed to load tools from %s after retry — skipping", + dep.ord_id, + ) + continue logger.exception("Failed to load tools from %s — skipping", dep.ord_id) logger.info( @@ -494,6 +569,8 @@ async def call_mcp_tool_customer( tool: MCPTool, user_token: str | None, timeout: float, + config: ClientConfig, + cache: _TokenCache, app_tid: str | None = None, **kwargs, ) -> str: @@ -502,11 +579,16 @@ async def call_mcp_tool_customer( If user_token is provided, exchanges it for an AGW-scoped token to preserve user identity for principal propagation. Otherwise, falls back to system token. + On a 401 from the MCP server, drops the cached token and retries once. + Args: credentials: Customer credentials. tool: MCPTool to invoke. user_token: User's JWT token for principal propagation (optional). If None, system token is used instead (no principal propagation). + timeout: HTTP timeout in seconds. + config: Client configuration. + cache: Token cache shared across calls. app_tid: BTP Application Tenant ID of subscriber (optional). **kwargs: Tool input parameters. @@ -517,12 +599,18 @@ async def call_mcp_tool_customer( loop = asyncio.get_running_loop() - if user_token: - # Exchange user token for AGW-scoped token (with principal propagation) - agw_token = await loop.run_in_executor( - None, exchange_user_token, credentials, user_token, timeout, app_tid - ) - else: + async def _acquire_token() -> str: + if user_token: + return await loop.run_in_executor( + None, + exchange_user_token, + credentials, + user_token, + timeout, + config, + cache, + app_tid, + ) # TODO: IBD workaround - use system token when user_token is not available. # This bypasses principal propagation. Remove this fallback once IBD # supports proper user token flow. @@ -530,13 +618,55 @@ async def call_mcp_tool_customer( "No user_token provided - using system token for tool invocation. " "Principal propagation will NOT work." ) - agw_token = await loop.run_in_executor( - None, get_system_token_mtls, credentials, timeout, app_tid + return await loop.run_in_executor( + None, get_system_token_mtls, credentials, timeout, config, cache, app_tid ) + def _invalidate_token() -> None: + if user_token: + cache.invalidate_user_token(user_token, app_tid) + else: + cache.invalidate_system_token(app_tid) + + last_exc: Exception | None = None + for attempt in (1, 2): + agw_token = await _acquire_token() + try: + return await _invoke_tool(tool, agw_token, timeout, **kwargs) + except Exception as exc: + unwrapped = _unwrap_exception_group(exc) + if _is_unauthorized(unwrapped) and attempt == 1: + logger.info( + "401 from MCP server for tool '%s' — invalidating cached token and retrying", + tool.name, + ) + _invalidate_token() + last_exc = exc + continue + raise + + # Defensive — should not be reachable; second attempt either returns or raises. + raise AgentGatewaySDKError( + f"Tool invocation for '{tool.name}' failed after 401 retry: {last_exc}" + ) + + +async def _invoke_tool( + tool: MCPTool, + auth_token: str, + timeout: float, + **kwargs, +) -> str: + """Open an MCP session to `tool.url` and invoke `tool.name` with `kwargs`. + + Returns the first content block's text, or empty string when content is + empty. Raises whatever the MCP transport / session raises (notably + `httpx.HTTPStatusError` on 401, which the caller uses to drive cache + invalidation and retry). + """ async with httpx.AsyncClient( headers={ - "Authorization": f"Bearer {agw_token}", + "Authorization": f"Bearer {auth_token}", "x-correlation-id": str(uuid.uuid4()), }, timeout=timeout, @@ -556,3 +686,17 @@ async def call_mcp_tool_customer( first = result.content[0] return str(getattr(first, "text", "")) + + +def _unwrap_exception_group(exc: BaseException) -> BaseException: + """Unwrap nested ExceptionGroups to find the underlying cause.""" + while isinstance(exc, BaseExceptionGroup) and exc.exceptions: + exc = exc.exceptions[0] + return exc + + +def _is_unauthorized(exc: BaseException) -> bool: + """Detect a 401 response from the MCP server (httpx-based).""" + if isinstance(exc, httpx.HTTPStatusError): + return exc.response is not None and exc.response.status_code == 401 + return False diff --git a/src/sap_cloud_sdk/agentgateway/_token_cache.py b/src/sap_cloud_sdk/agentgateway/_token_cache.py new file mode 100644 index 0000000..7e1c3e4 --- /dev/null +++ b/src/sap_cloud_sdk/agentgateway/_token_cache.py @@ -0,0 +1,189 @@ +"""Token cache for Agent Gateway customer flow. + +Caches IAS tokens (system + user-exchanged) per client to avoid redundant +mTLS token requests during agentic loops. LoB flow uses BTP Destination +Service which has its own caching, so this module only serves the customer +flow. + +Keying: +- System tokens are keyed by `app_tid` (or "_default" when unset). +- User tokens are keyed by `sha256(user_jwt + "|" + (app_tid or ""))[:16]`. + +The `app_tid` component is required because `_request_token_mtls` includes +it in the form payload, producing a tenant-scoped token. Mixing tokens +across tenants would break principal propagation. + +Thread safety: +Token fetches run in the default `ThreadPoolExecutor` via +`loop.run_in_executor`. CPython GIL makes individual dict / OrderedDict +operations atomic, but compound check-then-set is not. Two concurrent +coroutines for the same key may both miss and both fetch; the race +produces redundant token requests, not corruption. +""" + +import base64 +import hashlib +import json +import logging +import time +from collections import OrderedDict +from dataclasses import dataclass + +from sap_cloud_sdk.agentgateway.config import ClientConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class _CachedToken: + """A cached token with monotonic expiry.""" + + token: str + expires_at: float # time.monotonic() value + + def is_valid(self) -> bool: + """Return True if the token has not yet reached its monotonic expiry.""" + return time.monotonic() < self.expires_at + + +def _parse_jwt_exp(jwt: str) -> int | None: + """Extract `exp` claim (seconds since epoch) from a JWT without verification. + + Returns None if the JWT is malformed or has no `exp` claim. The result + is used only as a hint for cache TTL — never for security decisions. + """ + try: + parts = jwt.split(".") + if len(parts) < 2: + return None + payload_b64 = parts[1] + # Pad base64 + payload_b64 += "=" * (-len(payload_b64) % 4) + claims = json.loads(base64.urlsafe_b64decode(payload_b64)) + exp = claims.get("exp") + return int(exp) if exp is not None else None + except (ValueError, KeyError, TypeError, json.JSONDecodeError): + return None + + +def compute_expires_at(token_data: dict, config: ClientConfig) -> float: + """Resolve the cache expiry timestamp (monotonic) for a token response. + + Resolution order: + 1. `expires_in` from the response, minus the buffer. + 2. `exp` claim from `id_token` (translated from wall clock to monotonic), + minus the buffer. + 3. Config-provided fallback TTL. + """ + now_mono = time.monotonic() + buffer = config.token_expiry_buffer_seconds + + expires_in = token_data.get("expires_in") + if expires_in is not None: + try: + return now_mono + int(expires_in) - buffer + except (ValueError, TypeError): + pass + + id_token = token_data.get("id_token") + if id_token: + exp = _parse_jwt_exp(id_token) + if exp is not None: + remaining = exp - time.time() + if remaining > buffer: + return now_mono + remaining - buffer + + return now_mono + config.fallback_token_ttl_seconds + + +class _TokenCache: + """Per-client token cache with TTL and LRU eviction. + + Both system and user tokens use OrderedDict for LRU ordering. Keys + include `app_tid` so tenant-scoped tokens never leak across tenants. + """ + + _SYSTEM_DEFAULT_KEY = "_default" + + def __init__(self, config: ClientConfig): + """Initialize empty caches bounded by sizes from `config`.""" + self._config = config + self._system_tokens: OrderedDict[str, _CachedToken] = OrderedDict() + self._user_tokens: OrderedDict[str, _CachedToken] = OrderedDict() + + # --- System Token --- + + def get_system_token(self, app_tid: str | None) -> str | None: + """Return a valid cached system token for `app_tid`, or None on miss/expiry.""" + key = app_tid or self._SYSTEM_DEFAULT_KEY + cached = self._system_tokens.get(key) + if cached and cached.is_valid(): + self._system_tokens.move_to_end(key) + return cached.token + if cached: + del self._system_tokens[key] + return None + + def set_system_token( + self, token: str, expires_at: float, app_tid: str | None + ) -> None: + """Cache a system token under `app_tid`; evict LRU once size exceeds limit.""" + key = app_tid or self._SYSTEM_DEFAULT_KEY + self._system_tokens[key] = _CachedToken(token=token, expires_at=expires_at) + self._system_tokens.move_to_end(key) + while len(self._system_tokens) > self._config.max_system_token_cache_size: + evicted, _ = self._system_tokens.popitem(last=False) + logger.debug("System token cache full — evicted '%s'", evicted) + + def invalidate_system_token(self, app_tid: str | None) -> None: + """Drop the cached system token for `app_tid` (no-op if absent).""" + key = app_tid or self._SYSTEM_DEFAULT_KEY + if self._system_tokens.pop(key, None): + logger.debug("Invalidated system token (app_tid=%s)", app_tid) + + # --- User Tokens --- + + def get_user_token(self, user_jwt: str, app_tid: str | None) -> str | None: + """Return a valid cached exchanged token for `(user_jwt, app_tid)`, or None.""" + key = self._hash_key(user_jwt, app_tid) + cached = self._user_tokens.get(key) + if cached and cached.is_valid(): + self._user_tokens.move_to_end(key) + return cached.token + if cached: + del self._user_tokens[key] + return None + + def set_user_token( + self, + user_jwt: str, + token: str, + expires_at: float, + app_tid: str | None, + ) -> None: + """Cache an exchanged user token; evict LRU once size exceeds limit.""" + key = self._hash_key(user_jwt, app_tid) + self._user_tokens[key] = _CachedToken(token=token, expires_at=expires_at) + self._user_tokens.move_to_end(key) + while len(self._user_tokens) > self._config.max_user_token_cache_size: + evicted, _ = self._user_tokens.popitem(last=False) + logger.debug("User token cache full — evicted '%s'", evicted) + + def invalidate_user_token(self, user_jwt: str, app_tid: str | None) -> None: + """Drop the cached user token for `(user_jwt, app_tid)` (no-op if absent).""" + key = self._hash_key(user_jwt, app_tid) + if self._user_tokens.pop(key, None): + logger.debug("Invalidated user token (app_tid=%s)", app_tid) + + # --- Maintenance --- + + def clear(self) -> None: + """Drop all cached tokens. Forces a fresh fetch on next access.""" + self._system_tokens.clear() + self._user_tokens.clear() + + @staticmethod + def _hash_key(user_jwt: str, app_tid: str | None) -> str: + """Derive a short, stable cache key from `(user_jwt, app_tid)` via sha256.""" + material = f"{user_jwt}|{app_tid or ''}" + return hashlib.sha256(material.encode()).hexdigest()[:16] diff --git a/src/sap_cloud_sdk/agentgateway/agw_client.py b/src/sap_cloud_sdk/agentgateway/agw_client.py index a601d88..baf4022 100644 --- a/src/sap_cloud_sdk/agentgateway/agw_client.py +++ b/src/sap_cloud_sdk/agentgateway/agw_client.py @@ -11,6 +11,7 @@ from typing import Callable from sap_cloud_sdk.agentgateway._models import MCPTool +from sap_cloud_sdk.agentgateway._token_cache import _TokenCache from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway._customer import ( detect_customer_agent_credentials, @@ -85,6 +86,16 @@ def __init__( """ self._tenant_subdomain = tenant_subdomain self._config = config or ClientConfig() + self._token_cache = _TokenCache(self._config) + + def clear_token_cache(self) -> None: + """Drop all cached tokens. Forces a fresh token fetch on the next call. + + Useful when external state (revoked credentials, tenant change) makes + cached tokens unsafe to reuse, or for testing. No-op for LoB flow, + which delegates caching to BTP Destination Service. + """ + self._token_cache.clear() @staticmethod def _resolve_value( @@ -158,7 +169,11 @@ async def list_mcp_tools( ) credentials = load_customer_credentials(credentials_path) return await get_mcp_tools_customer( - credentials, self._config.timeout, app_tid + credentials, + self._config.timeout, + self._config, + self._token_cache, + app_tid, ) # LoB flow - requires tenant_subdomain @@ -251,6 +266,8 @@ async def call_mcp_tool( tool, resolved_user_token, self._config.timeout, + self._config, + self._token_cache, app_tid, **kwargs, ) diff --git a/src/sap_cloud_sdk/agentgateway/config.py b/src/sap_cloud_sdk/agentgateway/config.py index 427f96b..b44af1f 100644 --- a/src/sap_cloud_sdk/agentgateway/config.py +++ b/src/sap_cloud_sdk/agentgateway/config.py @@ -3,6 +3,10 @@ from dataclasses import dataclass DEFAULT_TIMEOUT_SECONDS = 60.0 +DEFAULT_TOKEN_EXPIRY_BUFFER_SECONDS = 60 +DEFAULT_MAX_USER_TOKEN_CACHE_SIZE = 10 +DEFAULT_MAX_SYSTEM_TOKEN_CACHE_SIZE = 10 +DEFAULT_FALLBACK_TOKEN_TTL_SECONDS = 300 @dataclass @@ -12,6 +16,20 @@ class ClientConfig: Attributes: timeout: HTTP timeout in seconds for token requests and MCP server calls. Defaults to 60 seconds. + token_expiry_buffer_seconds: Refresh tokens this many seconds before + their reported expiry. Defaults to 60 seconds. + max_user_token_cache_size: Maximum number of user tokens cached + per client. LRU eviction once exceeded. Defaults to 10. + max_system_token_cache_size: Maximum number of system tokens cached + per client (one per app_tid). LRU eviction once exceeded. + Defaults to 10. + fallback_token_ttl_seconds: TTL applied when neither `expires_in` + nor a parseable `id_token` exp claim is available in the token + response. Defaults to 300 seconds. """ timeout: float = DEFAULT_TIMEOUT_SECONDS + token_expiry_buffer_seconds: int = DEFAULT_TOKEN_EXPIRY_BUFFER_SECONDS + max_user_token_cache_size: int = DEFAULT_MAX_USER_TOKEN_CACHE_SIZE + max_system_token_cache_size: int = DEFAULT_MAX_SYSTEM_TOKEN_CACHE_SIZE + fallback_token_ttl_seconds: int = DEFAULT_FALLBACK_TOKEN_TTL_SECONDS diff --git a/tests/agentgateway/unit/test_agw_client.py b/tests/agentgateway/unit/test_agw_client.py index c60e8bb..f54c242 100644 --- a/tests/agentgateway/unit/test_agw_client.py +++ b/tests/agentgateway/unit/test_agw_client.py @@ -1,7 +1,9 @@ """Unit tests for Agent Gateway client.""" -from unittest.mock import patch, AsyncMock +import time +from unittest.mock import patch, AsyncMock, MagicMock +import httpx import pytest from sap_cloud_sdk.agentgateway import ( @@ -10,6 +12,10 @@ MCPTool, AgentGatewaySDKError, ) +from sap_cloud_sdk.agentgateway._models import ( + CustomerCredentials, + IntegrationDependency, +) # ============================================================ @@ -411,3 +417,273 @@ async def test_returns_result_from_lob_flow(self, mock_tool): ) assert result == "Success: Order created" + + +# ============================================================ +# Test: Token cache behavior through the public API +# ============================================================ + + +def _customer_credentials() -> CustomerCredentials: + """Build a minimal CustomerCredentials fixture for cache-behavior tests.""" + return CustomerCredentials( + token_service_url="https://ias.example.com/oauth2/token", + client_id="test-client", + certificate="cert", + private_key="key", + gateway_url="https://agw.example.com", + integration_dependencies=[ + IntegrationDependency( + ord_id="sap.test:apiResource:demo:v1", + global_tenant_id="250695", + ), + ], + ) + + +def _build_streaming_mocks( + initialize_side_effect=None, + call_tool_side_effect=None, + list_tools_side_effect=None, +): + """Build the chain of mocks needed to drive customer flow MCP calls.""" + http_client = AsyncMock() + http_client.__aenter__ = AsyncMock(return_value=http_client) + http_client.__aexit__ = AsyncMock(return_value=None) + + stream_ctx = AsyncMock() + stream_ctx.__aenter__ = AsyncMock(return_value=(AsyncMock(), AsyncMock(), None)) + stream_ctx.__aexit__ = AsyncMock(return_value=None) + + session = AsyncMock() + if initialize_side_effect is not None: + session.initialize = AsyncMock(side_effect=initialize_side_effect) + else: + init_result = MagicMock() + init_result.serverInfo.name = "demo-server" + session.initialize = AsyncMock(return_value=init_result) + + if list_tools_side_effect is not None: + session.list_tools = AsyncMock(side_effect=list_tools_side_effect) + else: + list_result = MagicMock() + list_result.tools = [] + session.list_tools = AsyncMock(return_value=list_result) + + if call_tool_side_effect is not None: + session.call_tool = AsyncMock(side_effect=call_tool_side_effect) + else: + call_result = MagicMock() + content = MagicMock() + content.text = "ok" + call_result.content = [content] + session.call_tool = AsyncMock(return_value=call_result) + + session_ctx = AsyncMock() + session_ctx.__aenter__ = AsyncMock(return_value=session) + session_ctx.__aexit__ = AsyncMock(return_value=None) + + return http_client, stream_ctx, session_ctx + + +def _make_401() -> httpx.HTTPStatusError: + """Construct an httpx 401 HTTPStatusError for simulating MCP auth failures.""" + request = httpx.Request("POST", "https://example.com") + response = httpx.Response(401, request=request) + return httpx.HTTPStatusError("Unauthorized", request=request, response=response) + + +def _patch_customer_flow(token_request_side_effect): + """Patch detection/loading + IAS request + MCP transport for customer flow. + + Returns the http/stream/session mocks plus the IAS request mock so callers + can assert on call counts. + """ + http_client, stream_ctx, session_ctx = _build_streaming_mocks() + + request_mock = MagicMock(side_effect=token_request_side_effect) + + patches = [ + patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + return_value=_customer_credentials(), + ), + patch( + "sap_cloud_sdk.agentgateway._customer._request_token_mtls", + request_mock, + ), + patch("httpx.AsyncClient", return_value=http_client), + patch( + "sap_cloud_sdk.agentgateway._customer.streamable_http_client", + return_value=stream_ctx, + ), + patch( + "sap_cloud_sdk.agentgateway._customer.ClientSession", + return_value=session_ctx, + ), + ] + return patches, request_mock, session_ctx + + +class TestTokenCacheBehavior: + """Cache behavior verified through AgentGatewayClient public API.""" + + @pytest.mark.asyncio + async def test_list_mcp_tools_twice_hits_ias_once(self, mock_tool): + """Two list_mcp_tools calls share one cached system token.""" + patches, request_mock, _ = _patch_customer_flow( + token_request_side_effect=lambda *a, **kw: ( + "system-token", + time.monotonic() + 600, + ) + ) + + with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5]: + agw_client = create_client() + + await agw_client.list_mcp_tools() + await agw_client.list_mcp_tools() + + assert request_mock.call_count == 1 + + @pytest.mark.asyncio + async def test_call_mcp_tool_twice_same_user_token_hits_ias_once(self, mock_tool): + """Two call_mcp_tool calls with same user_token reuse exchanged token.""" + patches, request_mock, _ = _patch_customer_flow( + token_request_side_effect=lambda *a, **kw: ( + "exchanged-token", + time.monotonic() + 600, + ) + ) + + with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5]: + agw_client = create_client() + + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt-A") + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt-A") + + assert request_mock.call_count == 1 + + @pytest.mark.asyncio + async def test_different_user_tokens_isolated(self, mock_tool): + """Different user_tokens trigger separate exchanges.""" + patches, request_mock, _ = _patch_customer_flow( + token_request_side_effect=[ + ("tok-A", time.monotonic() + 600), + ("tok-B", time.monotonic() + 600), + ] + ) + + with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5]: + agw_client = create_client() + + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt-A") + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt-B") + + assert request_mock.call_count == 2 + + @pytest.mark.asyncio + async def test_app_tid_isolation(self, mock_tool): + """Same user_token across different app_tid values stays isolated.""" + patches, request_mock, _ = _patch_customer_flow( + token_request_side_effect=[ + ("tok-tenant-a", time.monotonic() + 600), + ("tok-tenant-b", time.monotonic() + 600), + ] + ) + + with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5]: + agw_client = create_client() + + await agw_client.call_mcp_tool( + tool=mock_tool, user_token="user-jwt", app_tid="tenant-a" + ) + await agw_client.call_mcp_tool( + tool=mock_tool, user_token="user-jwt", app_tid="tenant-b" + ) + + assert request_mock.call_count == 2 + + @pytest.mark.asyncio + async def test_clear_token_cache_forces_refetch(self, mock_tool): + """clear_token_cache drops cached tokens, next call refetches.""" + patches, request_mock, _ = _patch_customer_flow( + token_request_side_effect=lambda *a, **kw: ( + "any-token", + time.monotonic() + 600, + ) + ) + + with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5]: + agw_client = create_client() + + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt") + agw_client.clear_token_cache() + await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt") + + assert request_mock.call_count == 2 + + @pytest.mark.asyncio + async def test_401_invalidates_cache_and_retries(self, mock_tool): + """A 401 from the MCP server drops the cached token and retries once.""" + http_client, stream_ctx, _ = _build_streaming_mocks() + + # First call_tool raises 401, second returns success + success = MagicMock() + content = MagicMock() + content.text = "ok-after-retry" + success.content = [content] + + session = AsyncMock() + init_result = MagicMock() + init_result.serverInfo.name = "demo-server" + session.initialize = AsyncMock(return_value=init_result) + session.call_tool = AsyncMock(side_effect=[_make_401(), success]) + session_ctx = AsyncMock() + session_ctx.__aenter__ = AsyncMock(return_value=session) + session_ctx.__aexit__ = AsyncMock(return_value=None) + + request_mock = MagicMock( + side_effect=[ + ("stale-token", time.monotonic() + 600), + ("fresh-token", time.monotonic() + 600), + ] + ) + + with ( + patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + return_value=_customer_credentials(), + ), + patch( + "sap_cloud_sdk.agentgateway._customer._request_token_mtls", + request_mock, + ), + patch("httpx.AsyncClient", return_value=http_client), + patch( + "sap_cloud_sdk.agentgateway._customer.streamable_http_client", + return_value=stream_ctx, + ), + patch( + "sap_cloud_sdk.agentgateway._customer.ClientSession", + return_value=session_ctx, + ), + ): + agw_client = create_client() + + result = await agw_client.call_mcp_tool( + tool=mock_tool, user_token="user-jwt" + ) + + assert result == "ok-after-retry" + # Stale exchange + fresh exchange after invalidation + assert request_mock.call_count == 2 + diff --git a/tests/agentgateway/unit/test_customer.py b/tests/agentgateway/unit/test_customer.py index 4ed170b..57a3cfa 100644 --- a/tests/agentgateway/unit/test_customer.py +++ b/tests/agentgateway/unit/test_customer.py @@ -16,6 +16,8 @@ _CREDENTIALS_PATH_ENV, _CREDENTIALS_DEFAULT_PATH, ) +from sap_cloud_sdk.agentgateway._token_cache import _TokenCache +from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway._models import ( CustomerCredentials, IntegrationDependency, @@ -305,7 +307,9 @@ def test_requests_client_credentials_token(self, credentials): mock_client.post.return_value = mock_response mock_client_class.return_value = mock_client - result = get_system_token_mtls(credentials, timeout=60.0) + result = get_system_token_mtls( + credentials, timeout=60.0, config=ClientConfig(), cache=_TokenCache(ClientConfig()) + ) assert result == "system-token-123" mock_client.post.assert_called_once() @@ -332,7 +336,12 @@ def test_raises_on_failed_request(self, credentials): mock_client_class.return_value = mock_client with pytest.raises(AgentGatewaySDKError, match="Token request failed"): - get_system_token_mtls(credentials, timeout=60.0) + get_system_token_mtls( + credentials, + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + ) # ============================================================ @@ -374,7 +383,13 @@ def test_exchanges_user_token_with_jwt_bearer(self, credentials): mock_client.post.return_value = mock_response mock_client_class.return_value = mock_client - result = exchange_user_token(credentials, "user-jwt-token", timeout=60.0) + result = exchange_user_token( + credentials, + "user-jwt-token", + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + ) assert result == "exchanged-token-123" call_args = mock_client.post.call_args @@ -403,7 +418,12 @@ def test_passes_app_tid_when_provided(self, credentials): mock_client_class.return_value = mock_client result = exchange_user_token( - credentials, "user-jwt", timeout=60.0, app_tid="test-tid" + credentials, + "user-jwt", + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + app_tid="test-tid", ) assert result == "token-with-tid" @@ -451,7 +471,12 @@ async def test_raises_when_empty_dependencies(self): with pytest.raises( AgentGatewaySDKError, match="integrationDependencies is empty" ): - await get_mcp_tools_customer(credentials, timeout=60.0) + await get_mcp_tools_customer( + credentials, + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + ) @pytest.mark.asyncio async def test_discovers_tools_from_credentials(self, credentials): @@ -477,7 +502,12 @@ async def test_discovers_tools_from_credentials(self, credentials): return_value=mock_tools, ) as mock_list, ): - result = await get_mcp_tools_customer(credentials, timeout=60.0) + result = await get_mcp_tools_customer( + credentials, + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + ) assert len(result) == 1 assert result[0].name == "list_cost_centers" @@ -525,7 +555,12 @@ async def mock_list_tools(*args, **kwargs): side_effect=mock_list_tools, ), ): - result = await get_mcp_tools_customer(credentials, timeout=60.0) + result = await get_mcp_tools_customer( + credentials, + timeout=60.0, + config=ClientConfig(), + cache=_TokenCache(ClientConfig()), + ) # Should still return tools from server2 assert len(result) == 1 @@ -615,11 +650,21 @@ async def test_exchanges_user_token_before_call(self, credentials, mock_tool): mock_session_class.return_value = mock_session_ctx result = await call_mcp_tool_customer( - credentials, mock_tool, "user-jwt", 60.0, order_id="12345" + credentials, + mock_tool, + "user-jwt", + 60.0, + ClientConfig(), + _TokenCache(ClientConfig()), + order_id="12345", ) assert result == "Order created successfully" - mock_exchange.assert_called_once_with(credentials, "user-jwt", 60.0, None) + mock_exchange.assert_called_once() + args, _ = mock_exchange.call_args + assert args[0] is credentials + assert args[1] == "user-jwt" + assert args[2] == 60.0 @pytest.mark.asyncio async def test_uses_system_token_when_user_token_not_provided( @@ -671,10 +716,20 @@ async def test_uses_system_token_when_user_token_not_provided( # Call without user_token (None) result = await call_mcp_tool_customer( - credentials, mock_tool, None, 60.0, order_id="12345" + credentials, + mock_tool, + None, + 60.0, + ClientConfig(), + _TokenCache(ClientConfig()), + order_id="12345", ) assert result == "Result with system token" # Should use system token, not exchange - mock_system_token.assert_called_once_with(credentials, 60.0, None) + mock_system_token.assert_called_once() + args, _ = mock_system_token.call_args + assert args[0] is credentials + assert args[1] == 60.0 mock_exchange.assert_not_called() + diff --git a/tests/agentgateway/unit/test_token_cache.py b/tests/agentgateway/unit/test_token_cache.py new file mode 100644 index 0000000..056e196 --- /dev/null +++ b/tests/agentgateway/unit/test_token_cache.py @@ -0,0 +1,114 @@ +"""Unit tests for token cache helpers with non-trivial logic. + +Cache class behavior is tested through AgentGatewayClient (test_agw_client.py) +to keep coverage focused on observable functionality. Only `_parse_jwt_exp` +and `compute_expires_at` are exercised here directly because they contain +parsing/branching logic that is hard to drive through the public API. +""" + +import base64 +import json +import time + +from sap_cloud_sdk.agentgateway._token_cache import ( + _parse_jwt_exp, + compute_expires_at, +) +from sap_cloud_sdk.agentgateway.config import ClientConfig + + +def _make_jwt(claims: dict) -> str: + """Build a non-signed JWT for testing (header.payload.signature).""" + header = base64.urlsafe_b64encode(json.dumps({"alg": "none"}).encode()).rstrip(b"=") + payload = base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=") + return f"{header.decode()}.{payload.decode()}.signature" + + +class TestParseJwtExp: + """Tests for the unverified JWT `exp` claim parser.""" + + def test_extracts_exp(self): + """Extract `exp` claim from a well-formed JWT payload.""" + jwt = _make_jwt({"exp": 1700000000, "iat": 1699996400}) + assert _parse_jwt_exp(jwt) == 1700000000 + + def test_returns_none_when_exp_missing(self): + """Return None when payload has no `exp` claim.""" + jwt = _make_jwt({"iat": 1699996400}) + assert _parse_jwt_exp(jwt) is None + + def test_returns_none_for_malformed_jwt(self): + """Return None for strings that are not three-part JWTs.""" + assert _parse_jwt_exp("not-a-jwt") is None + assert _parse_jwt_exp("") is None + assert _parse_jwt_exp("only.two") is None + + def test_returns_none_for_garbage_payload(self): + """Return None when the payload segment is not valid base64/JSON.""" + assert _parse_jwt_exp("aaa.@@not-base64@@.bbb") is None + + +class TestComputeExpiresAt: + """Tests for cache expiry resolution from token responses.""" + + def test_uses_expires_in_when_present(self): + """Prefer `expires_in` from the response and subtract the buffer.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"expires_in": 3600}, cfg) + assert before + 3540 - 1 <= result <= before + 3540 + 1 + + def test_expires_in_equal_to_buffer_expires_immediately(self): + """Token whose `expires_in` equals the buffer is treated as already expiring now.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"expires_in": 60}, cfg) + after = time.monotonic() + assert before - 1 <= result <= after + 1 + + def test_expires_in_below_buffer_is_already_stale(self): + """Token whose `expires_in` is below the buffer resolves to a past timestamp.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"expires_in": 30}, cfg) + assert before - 31 <= result <= before - 29 + + def test_falls_back_to_id_token_exp(self): + """Fall back to the `exp` claim of `id_token` when `expires_in` is absent.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + future_exp = int(time.time()) + 600 + jwt = _make_jwt({"exp": future_exp}) + before = time.monotonic() + result = compute_expires_at({"id_token": jwt}, cfg) + assert before + 540 - 5 <= result <= before + 540 + 5 + + def test_uses_fallback_when_no_expiry_info(self): + """Use config fallback TTL when neither `expires_in` nor `id_token` is present.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"access_token": "opaque"}, cfg) + assert before + 300 - 1 <= result <= before + 300 + 1 + + def test_uses_fallback_when_id_token_malformed(self): + """Use fallback TTL when the `id_token` cannot be parsed.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"id_token": "garbage"}, cfg) + assert before + 300 - 1 <= result <= before + 300 + 1 + + def test_uses_fallback_when_id_token_exp_within_buffer(self): + """Skip the `id_token` path when remaining lifetime is below the buffer.""" + # If remaining time is below the buffer, the id_token path is skipped. + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + soon_exp = int(time.time()) + 30 + jwt = _make_jwt({"exp": soon_exp}) + before = time.monotonic() + result = compute_expires_at({"id_token": jwt}, cfg) + assert before + 300 - 1 <= result <= before + 300 + 1 + + def test_handles_invalid_expires_in_value(self): + """Use fallback TTL when `expires_in` is not coercible to int.""" + cfg = ClientConfig(token_expiry_buffer_seconds=60, fallback_token_ttl_seconds=300) + before = time.monotonic() + result = compute_expires_at({"expires_in": "not-a-number"}, cfg) + assert before + 300 - 1 <= result <= before + 300 + 1 From 48671d38c41e464a8e846404f2f5b7f725380837 Mon Sep 17 00:00:00 2001 From: I763688 Date: Thu, 21 May 2026 08:48:00 +0100 Subject: [PATCH 2/6] feat: add metrics --- src/sap_cloud_sdk/agentgateway/_token_cache.py | 1 - src/sap_cloud_sdk/agentgateway/agw_client.py | 1 + src/sap_cloud_sdk/core/telemetry/operation.py | 1 + 3 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sap_cloud_sdk/agentgateway/_token_cache.py b/src/sap_cloud_sdk/agentgateway/_token_cache.py index 7e1c3e4..324ff3f 100644 --- a/src/sap_cloud_sdk/agentgateway/_token_cache.py +++ b/src/sap_cloud_sdk/agentgateway/_token_cache.py @@ -57,7 +57,6 @@ def _parse_jwt_exp(jwt: str) -> int | None: if len(parts) < 2: return None payload_b64 = parts[1] - # Pad base64 payload_b64 += "=" * (-len(payload_b64) % 4) claims = json.loads(base64.urlsafe_b64decode(payload_b64)) exp = claims.get("exp") diff --git a/src/sap_cloud_sdk/agentgateway/agw_client.py b/src/sap_cloud_sdk/agentgateway/agw_client.py index baf4022..a9ada19 100644 --- a/src/sap_cloud_sdk/agentgateway/agw_client.py +++ b/src/sap_cloud_sdk/agentgateway/agw_client.py @@ -88,6 +88,7 @@ def __init__( self._config = config or ClientConfig() self._token_cache = _TokenCache(self._config) + @record_metrics(Module.AGENTGATEWAY, Operation.AGENTGATEWAY_CLEAR_TOKEN_CACHE) def clear_token_cache(self) -> None: """Drop all cached tokens. Forces a fresh token fetch on the next call. diff --git a/src/sap_cloud_sdk/core/telemetry/operation.py b/src/sap_cloud_sdk/core/telemetry/operation.py index 8619145..b114523 100644 --- a/src/sap_cloud_sdk/core/telemetry/operation.py +++ b/src/sap_cloud_sdk/core/telemetry/operation.py @@ -107,6 +107,7 @@ class Operation(str, Enum): # Agent Gateway Operations AGENTGATEWAY_LIST_MCP_TOOLS = "list_mcp_tools" AGENTGATEWAY_CALL_MCP_TOOL = "call_mcp_tool" + AGENTGATEWAY_CLEAR_TOKEN_CACHE = "clear_token_cache" # Agent Memory Operations AGENT_MEMORY_ADD_MEMORY = "add_memory" From 7410a0d7cdea10856fe3a00d27e4fb53aa587413 Mon Sep 17 00:00:00 2001 From: I763688 Date: Tue, 26 May 2026 10:59:38 +0100 Subject: [PATCH 3/6] chore: Remove clear_token_cache method from AgentGatewayClient and related tests --- src/sap_cloud_sdk/agentgateway/agw_client.py | 10 ---------- src/sap_cloud_sdk/core/telemetry/operation.py | 1 - tests/agentgateway/unit/test_agw_client.py | 19 ------------------- 3 files changed, 30 deletions(-) diff --git a/src/sap_cloud_sdk/agentgateway/agw_client.py b/src/sap_cloud_sdk/agentgateway/agw_client.py index a9ada19..3ecde2d 100644 --- a/src/sap_cloud_sdk/agentgateway/agw_client.py +++ b/src/sap_cloud_sdk/agentgateway/agw_client.py @@ -88,16 +88,6 @@ def __init__( self._config = config or ClientConfig() self._token_cache = _TokenCache(self._config) - @record_metrics(Module.AGENTGATEWAY, Operation.AGENTGATEWAY_CLEAR_TOKEN_CACHE) - def clear_token_cache(self) -> None: - """Drop all cached tokens. Forces a fresh token fetch on the next call. - - Useful when external state (revoked credentials, tenant change) makes - cached tokens unsafe to reuse, or for testing. No-op for LoB flow, - which delegates caching to BTP Destination Service. - """ - self._token_cache.clear() - @staticmethod def _resolve_value( value: str | Callable[[], str] | None, diff --git a/src/sap_cloud_sdk/core/telemetry/operation.py b/src/sap_cloud_sdk/core/telemetry/operation.py index b114523..8619145 100644 --- a/src/sap_cloud_sdk/core/telemetry/operation.py +++ b/src/sap_cloud_sdk/core/telemetry/operation.py @@ -107,7 +107,6 @@ class Operation(str, Enum): # Agent Gateway Operations AGENTGATEWAY_LIST_MCP_TOOLS = "list_mcp_tools" AGENTGATEWAY_CALL_MCP_TOOL = "call_mcp_tool" - AGENTGATEWAY_CLEAR_TOKEN_CACHE = "clear_token_cache" # Agent Memory Operations AGENT_MEMORY_ADD_MEMORY = "add_memory" diff --git a/tests/agentgateway/unit/test_agw_client.py b/tests/agentgateway/unit/test_agw_client.py index f54c242..a771f9e 100644 --- a/tests/agentgateway/unit/test_agw_client.py +++ b/tests/agentgateway/unit/test_agw_client.py @@ -608,25 +608,6 @@ async def test_app_tid_isolation(self, mock_tool): assert request_mock.call_count == 2 - @pytest.mark.asyncio - async def test_clear_token_cache_forces_refetch(self, mock_tool): - """clear_token_cache drops cached tokens, next call refetches.""" - patches, request_mock, _ = _patch_customer_flow( - token_request_side_effect=lambda *a, **kw: ( - "any-token", - time.monotonic() + 600, - ) - ) - - with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5]: - agw_client = create_client() - - await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt") - agw_client.clear_token_cache() - await agw_client.call_mcp_tool(tool=mock_tool, user_token="user-jwt") - - assert request_mock.call_count == 2 - @pytest.mark.asyncio async def test_401_invalidates_cache_and_retries(self, mock_tool): """A 401 from the MCP server drops the cached token and retries once.""" From 03545dd688d6b2ddf49c6ddbf613e10df54cb3ba Mon Sep 17 00:00:00 2001 From: I763688 Date: Tue, 26 May 2026 11:28:08 +0100 Subject: [PATCH 4/6] fix: update MCP tool discovery to handle token expiration during discovery loop --- src/sap_cloud_sdk/agentgateway/_customer.py | 71 ++++++++------------- 1 file changed, 28 insertions(+), 43 deletions(-) diff --git a/src/sap_cloud_sdk/agentgateway/_customer.py b/src/sap_cloud_sdk/agentgateway/_customer.py index 6f076ce..b125e6d 100644 --- a/src/sap_cloud_sdk/agentgateway/_customer.py +++ b/src/sap_cloud_sdk/agentgateway/_customer.py @@ -499,11 +499,15 @@ async def get_mcp_tools_customer( logger.info("Discovering tools from %d MCP server(s)", len(dependencies)) - # Get system token for discovery - loop = asyncio.get_running_loop() - system_token = await loop.run_in_executor( - None, get_system_token_mtls, credentials, timeout, config, cache, app_tid - ) + system_token = None + # Define a helper closure to refetch the system token on demand, since it may need to be + # refreshed during the discovery loop if any server returns a 401 + async def refetch_system_token() -> str: + loop = asyncio.get_running_loop() + new_token = await loop.run_in_executor( + None, get_system_token_mtls, credentials, timeout, config, cache, app_tid + ) + return new_token tools: list[MCPTool] = [] @@ -516,47 +520,28 @@ async def get_mcp_tools_customer( dep.global_tenant_id, ) - try: - server_tools = await _list_server_tools(url, system_token, dep, timeout) - tools.extend(server_tools) - logger.debug("Loaded %d tool(s) from %s", len(server_tools), dep.ord_id) - except Exception as exc: - unwrapped = _unwrap_exception_group(exc) - if _is_unauthorized(unwrapped): - logger.info( - "401 from %s — invalidating cached system token and retrying", - dep.ord_id, - ) - cache.invalidate_system_token(app_tid) - try: - fresh_token = await loop.run_in_executor( - None, - get_system_token_mtls, - credentials, - timeout, - config, - cache, - app_tid, - ) - server_tools = await _list_server_tools( - url, fresh_token, dep, timeout - ) - tools.extend(server_tools) - # Replace stale token for remaining iterations - system_token = fresh_token - logger.debug( - "Loaded %d tool(s) from %s after retry", - len(server_tools), - dep.ord_id, - ) - continue - except Exception: - logger.exception( - "Failed to load tools from %s after retry — skipping", + while True: + if not system_token: + # won't catch exceptions here - if token acquisition fails, + # we want the discovery to fail immediately + system_token = await refetch_system_token() + + try: + server_tools = await _list_server_tools(url, system_token, dep, timeout) + tools.extend(server_tools) + logger.debug("Loaded %d tool(s) from %s", len(server_tools), dep.ord_id) + except Exception as exc: + unwrapped = _unwrap_exception_group(exc) + if _is_unauthorized(unwrapped): + logger.info( + "401 from %s — invalidating cached system token and retrying", dep.ord_id, ) + cache.invalidate_system_token(app_tid) + system_token = None # Force refetch on next loop iteration continue - logger.exception("Failed to load tools from %s — skipping", dep.ord_id) + logger.exception("Failed to load tools from %s — skipping", dep.ord_id) + break # Success, move to next server logger.info( "Loaded %d MCP tool(s) from %d server(s)", len(tools), len(dependencies) From f9fb242f9b78d7dda816623559b977ccfe9975d2 Mon Sep 17 00:00:00 2001 From: I763688 Date: Tue, 26 May 2026 14:50:50 +0100 Subject: [PATCH 5/6] use client_id instead of app_tid for token cache keys --- src/sap_cloud_sdk/agentgateway/_customer.py | 22 +++--- .../agentgateway/_token_cache.py | 67 +++++++++---------- tests/agentgateway/unit/test_agw_client.py | 7 +- tests/agentgateway/unit/test_token_cache.py | 47 +++++++++++++ 4 files changed, 93 insertions(+), 50 deletions(-) diff --git a/src/sap_cloud_sdk/agentgateway/_customer.py b/src/sap_cloud_sdk/agentgateway/_customer.py index b125e6d..737ad02 100644 --- a/src/sap_cloud_sdk/agentgateway/_customer.py +++ b/src/sap_cloud_sdk/agentgateway/_customer.py @@ -320,9 +320,9 @@ def get_system_token_mtls( Returns: System-scoped access token. """ - cached = cache.get_system_token(app_tid) + cached = cache.get_system_token(credentials.client_id) if cached: - logger.debug("Using cached system token (app_tid=%s)", app_tid) + logger.debug("Using cached system token (client_id=%s)", credentials.client_id) return cached logger.info("Acquiring system token via mTLS client credentials") @@ -334,7 +334,7 @@ def get_system_token_mtls( app_tid=app_tid, extra_data={"response_type": "token"}, ) - cache.set_system_token(token, expires_at, app_tid) + cache.set_system_token(token, expires_at, credentials.client_id) return token @@ -363,9 +363,9 @@ def exchange_user_token( Returns: AGW-scoped access token with user identity. """ - cached = cache.get_user_token(user_token, app_tid) + cached = cache.get_user_token(user_token, credentials.client_id) if cached: - logger.debug("Using cached user token (app_tid=%s)", app_tid) + logger.debug("Using cached user token (client_id=%s)", credentials.client_id) return cached logger.info("Exchanging user token for AGW-scoped token via jwt-bearer grant") @@ -380,7 +380,7 @@ def exchange_user_token( "token_format": "jwt", }, ) - cache.set_user_token(user_token, token, expires_at, app_tid) + cache.set_user_token(user_token, token, expires_at, credentials.client_id) return token @@ -520,7 +520,7 @@ async def refetch_system_token() -> str: dep.global_tenant_id, ) - while True: + for attempt in (1, 2): if not system_token: # won't catch exceptions here - if token acquisition fails, # we want the discovery to fail immediately @@ -532,12 +532,12 @@ async def refetch_system_token() -> str: logger.debug("Loaded %d tool(s) from %s", len(server_tools), dep.ord_id) except Exception as exc: unwrapped = _unwrap_exception_group(exc) - if _is_unauthorized(unwrapped): + if _is_unauthorized(unwrapped) and attempt == 1: logger.info( "401 from %s — invalidating cached system token and retrying", dep.ord_id, ) - cache.invalidate_system_token(app_tid) + cache.invalidate_system_token(credentials.client_id) system_token = None # Force refetch on next loop iteration continue logger.exception("Failed to load tools from %s — skipping", dep.ord_id) @@ -609,9 +609,9 @@ async def _acquire_token() -> str: def _invalidate_token() -> None: if user_token: - cache.invalidate_user_token(user_token, app_tid) + cache.invalidate_user_token(user_token, credentials.client_id) else: - cache.invalidate_system_token(app_tid) + cache.invalidate_system_token(credentials.client_id) last_exc: Exception | None = None for attempt in (1, 2): diff --git a/src/sap_cloud_sdk/agentgateway/_token_cache.py b/src/sap_cloud_sdk/agentgateway/_token_cache.py index 324ff3f..0f4066d 100644 --- a/src/sap_cloud_sdk/agentgateway/_token_cache.py +++ b/src/sap_cloud_sdk/agentgateway/_token_cache.py @@ -6,12 +6,8 @@ flow. Keying: -- System tokens are keyed by `app_tid` (or "_default" when unset). -- User tokens are keyed by `sha256(user_jwt + "|" + (app_tid or ""))[:16]`. - -The `app_tid` component is required because `_request_token_mtls` includes -it in the form payload, producing a tenant-scoped token. Mixing tokens -across tenants would break principal propagation. +- System tokens are keyed by `client_id` (or "_default" when unset). +- User tokens are keyed by `sha256(user_jwt + "|" + (client_id or ""))[:16]`. Thread safety: Token fetches run in the default `ThreadPoolExecutor` via @@ -98,8 +94,7 @@ def compute_expires_at(token_data: dict, config: ClientConfig) -> float: class _TokenCache: """Per-client token cache with TTL and LRU eviction. - Both system and user tokens use OrderedDict for LRU ordering. Keys - include `app_tid` so tenant-scoped tokens never leak across tenants. + Both system and user tokens use OrderedDict for LRU ordering. """ _SYSTEM_DEFAULT_KEY = "_default" @@ -112,9 +107,9 @@ def __init__(self, config: ClientConfig): # --- System Token --- - def get_system_token(self, app_tid: str | None) -> str | None: - """Return a valid cached system token for `app_tid`, or None on miss/expiry.""" - key = app_tid or self._SYSTEM_DEFAULT_KEY + def get_system_token(self, client_id: str) -> str | None: + """Return a valid cached system token for `client_id`, or None on miss/expiry.""" + key = client_id cached = self._system_tokens.get(key) if cached and cached.is_valid(): self._system_tokens.move_to_end(key) @@ -123,28 +118,29 @@ def get_system_token(self, app_tid: str | None) -> str | None: del self._system_tokens[key] return None - def set_system_token( - self, token: str, expires_at: float, app_tid: str | None - ) -> None: - """Cache a system token under `app_tid`; evict LRU once size exceeds limit.""" - key = app_tid or self._SYSTEM_DEFAULT_KEY - self._system_tokens[key] = _CachedToken(token=token, expires_at=expires_at) + def set_system_token(self, token: str, expires_at: float, + client_id: str) -> None: + """Cache a system token under `client_id`; evict LRU once size exceeds limit.""" + key = client_id + self._system_tokens[key] = _CachedToken(token=token, + expires_at=expires_at) self._system_tokens.move_to_end(key) - while len(self._system_tokens) > self._config.max_system_token_cache_size: + while len(self._system_tokens + ) > self._config.max_system_token_cache_size: evicted, _ = self._system_tokens.popitem(last=False) logger.debug("System token cache full — evicted '%s'", evicted) - def invalidate_system_token(self, app_tid: str | None) -> None: - """Drop the cached system token for `app_tid` (no-op if absent).""" - key = app_tid or self._SYSTEM_DEFAULT_KEY + def invalidate_system_token(self, client_id: str) -> None: + """Drop the cached system token for `client_id` (no-op if absent).""" + key = client_id if self._system_tokens.pop(key, None): - logger.debug("Invalidated system token (app_tid=%s)", app_tid) + logger.debug("Invalidated system token (client_id=%s)", client_id) # --- User Tokens --- - def get_user_token(self, user_jwt: str, app_tid: str | None) -> str | None: - """Return a valid cached exchanged token for `(user_jwt, app_tid)`, or None.""" - key = self._hash_key(user_jwt, app_tid) + def get_user_token(self, user_jwt: str, client_id: str) -> str | None: + """Return a valid cached exchanged token for `(user_jwt, client_id)`, or None.""" + key = self._hash_key(user_jwt, client_id) cached = self._user_tokens.get(key) if cached and cached.is_valid(): self._user_tokens.move_to_end(key) @@ -158,21 +154,22 @@ def set_user_token( user_jwt: str, token: str, expires_at: float, - app_tid: str | None, + client_id: str, ) -> None: """Cache an exchanged user token; evict LRU once size exceeds limit.""" - key = self._hash_key(user_jwt, app_tid) - self._user_tokens[key] = _CachedToken(token=token, expires_at=expires_at) + key = self._hash_key(user_jwt, client_id) + self._user_tokens[key] = _CachedToken(token=token, + expires_at=expires_at) self._user_tokens.move_to_end(key) while len(self._user_tokens) > self._config.max_user_token_cache_size: evicted, _ = self._user_tokens.popitem(last=False) logger.debug("User token cache full — evicted '%s'", evicted) - def invalidate_user_token(self, user_jwt: str, app_tid: str | None) -> None: - """Drop the cached user token for `(user_jwt, app_tid)` (no-op if absent).""" - key = self._hash_key(user_jwt, app_tid) + def invalidate_user_token(self, user_jwt: str, client_id: str) -> None: + """Drop the cached user token for `(user_jwt, client_id)` (no-op if absent).""" + key = self._hash_key(user_jwt, client_id) if self._user_tokens.pop(key, None): - logger.debug("Invalidated user token (app_tid=%s)", app_tid) + logger.debug("Invalidated user token (client_id=%s)", client_id) # --- Maintenance --- @@ -182,7 +179,7 @@ def clear(self) -> None: self._user_tokens.clear() @staticmethod - def _hash_key(user_jwt: str, app_tid: str | None) -> str: - """Derive a short, stable cache key from `(user_jwt, app_tid)` via sha256.""" - material = f"{user_jwt}|{app_tid or ''}" + def _hash_key(user_jwt: str, client_id: str) -> str: + """Derive a short, stable cache key from `(user_jwt, client_id)` via sha256.""" + material = f"{user_jwt}|{client_id}" return hashlib.sha256(material.encode()).hexdigest()[:16] diff --git a/tests/agentgateway/unit/test_agw_client.py b/tests/agentgateway/unit/test_agw_client.py index a771f9e..6e7e47c 100644 --- a/tests/agentgateway/unit/test_agw_client.py +++ b/tests/agentgateway/unit/test_agw_client.py @@ -587,12 +587,11 @@ async def test_different_user_tokens_isolated(self, mock_tool): assert request_mock.call_count == 2 @pytest.mark.asyncio - async def test_app_tid_isolation(self, mock_tool): - """Same user_token across different app_tid values stays isolated.""" + async def test_same_user_token_different_app_tid_shares_cache(self, mock_tool): + """Same user_token + same client_id shares cache regardless of app_tid.""" patches, request_mock, _ = _patch_customer_flow( token_request_side_effect=[ ("tok-tenant-a", time.monotonic() + 600), - ("tok-tenant-b", time.monotonic() + 600), ] ) @@ -606,7 +605,7 @@ async def test_app_tid_isolation(self, mock_tool): tool=mock_tool, user_token="user-jwt", app_tid="tenant-b" ) - assert request_mock.call_count == 2 + assert request_mock.call_count == 1 @pytest.mark.asyncio async def test_401_invalidates_cache_and_retries(self, mock_tool): diff --git a/tests/agentgateway/unit/test_token_cache.py b/tests/agentgateway/unit/test_token_cache.py index 056e196..67fe1ea 100644 --- a/tests/agentgateway/unit/test_token_cache.py +++ b/tests/agentgateway/unit/test_token_cache.py @@ -11,6 +11,7 @@ import time from sap_cloud_sdk.agentgateway._token_cache import ( + _TokenCache, _parse_jwt_exp, compute_expires_at, ) @@ -112,3 +113,49 @@ def test_handles_invalid_expires_in_value(self): before = time.monotonic() result = compute_expires_at({"expires_in": "not-a-number"}, cfg) assert before + 300 - 1 <= result <= before + 300 + 1 + + +class TestTokenCacheClientIdIsolation: + """Tokens are isolated by client_id — same user JWT, different credentials, no sharing.""" + + def test_system_tokens_isolated_by_client_id(self): + cache = _TokenCache(ClientConfig()) + expires_at = time.monotonic() + 600 + + cache.set_system_token("token-a", expires_at, "client-a") + cache.set_system_token("token-b", expires_at, "client-b") + + assert cache.get_system_token("client-a") == "token-a" + assert cache.get_system_token("client-b") == "token-b" + + def test_user_tokens_isolated_by_client_id(self): + cache = _TokenCache(ClientConfig()) + expires_at = time.monotonic() + 600 + + cache.set_user_token("user-jwt", "token-a", expires_at, "client-a") + cache.set_user_token("user-jwt", "token-b", expires_at, "client-b") + + assert cache.get_user_token("user-jwt", "client-a") == "token-a" + assert cache.get_user_token("user-jwt", "client-b") == "token-b" + + def test_invalidate_system_token_does_not_affect_other_clients(self): + cache = _TokenCache(ClientConfig()) + expires_at = time.monotonic() + 600 + + cache.set_system_token("token-a", expires_at, "client-a") + cache.set_system_token("token-b", expires_at, "client-b") + cache.invalidate_system_token("client-a") + + assert cache.get_system_token("client-a") is None + assert cache.get_system_token("client-b") == "token-b" + + def test_invalidate_user_token_does_not_affect_other_clients(self): + cache = _TokenCache(ClientConfig()) + expires_at = time.monotonic() + 600 + + cache.set_user_token("user-jwt", "token-a", expires_at, "client-a") + cache.set_user_token("user-jwt", "token-b", expires_at, "client-b") + cache.invalidate_user_token("user-jwt", "client-a") + + assert cache.get_user_token("user-jwt", "client-a") is None + assert cache.get_user_token("user-jwt", "client-b") == "token-b" From 6f7f7601081811ff9d56ee4a1892a9602926506b Mon Sep 17 00:00:00 2001 From: I763688 Date: Tue, 26 May 2026 16:10:24 +0100 Subject: [PATCH 6/6] avoid passing ClientConfig to token request functions --- src/sap_cloud_sdk/agentgateway/_customer.py | 21 +++---- .../agentgateway/_token_cache.py | 61 ++++++++++--------- tests/agentgateway/unit/test_customer.py | 5 +- 3 files changed, 40 insertions(+), 47 deletions(-) diff --git a/src/sap_cloud_sdk/agentgateway/_customer.py b/src/sap_cloud_sdk/agentgateway/_customer.py index 737ad02..e25f454 100644 --- a/src/sap_cloud_sdk/agentgateway/_customer.py +++ b/src/sap_cloud_sdk/agentgateway/_customer.py @@ -25,7 +25,7 @@ IntegrationDependency, MCPTool, ) -from sap_cloud_sdk.agentgateway._token_cache import _TokenCache, compute_expires_at +from sap_cloud_sdk.agentgateway._token_cache import _TokenCache from sap_cloud_sdk.agentgateway.config import ClientConfig from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError @@ -213,7 +213,7 @@ def _request_token_mtls( credentials: CustomerCredentials, grant_type: str, timeout: float, - config: ClientConfig, + cache: _TokenCache, app_tid: str | None = None, extra_data: dict | None = None, ) -> tuple[str, float]: @@ -223,7 +223,7 @@ def _request_token_mtls( credentials: Customer credentials with certificate and private key. grant_type: OAuth2 grant type. timeout: HTTP timeout in seconds. - config: Client configuration (used to compute cache expiry). + cache: Token cache to calculate expiry (for buffer and fallback TTL). app_tid: BTP Application Tenant ID of subscriber (optional). extra_data: Additional form data for the token request. @@ -289,7 +289,7 @@ def _request_token_mtls( f"Token response missing 'access_token'. Keys: {list(token_data.keys())}" ) - expires_at = compute_expires_at(token_data, config) + expires_at = cache.compute_expires_at(token_data) logger.debug("Token acquired successfully (length: %d)", len(access_token)) return access_token, expires_at @@ -301,7 +301,6 @@ def _request_token_mtls( def get_system_token_mtls( credentials: CustomerCredentials, timeout: float, - config: ClientConfig, cache: _TokenCache, app_tid: str | None = None, ) -> str: @@ -313,7 +312,6 @@ def get_system_token_mtls( Args: credentials: Customer credentials. timeout: HTTP timeout in seconds. - config: Client configuration. cache: Token cache to consult and update. app_tid: BTP Application Tenant ID of subscriber (optional). @@ -330,7 +328,7 @@ def get_system_token_mtls( credentials, grant_type=_GRANT_TYPE_CLIENT_CREDENTIALS, timeout=timeout, - config=config, + cache=cache, app_tid=app_tid, extra_data={"response_type": "token"}, ) @@ -342,7 +340,6 @@ def exchange_user_token( credentials: CustomerCredentials, user_token: str, timeout: float, - config: ClientConfig, cache: _TokenCache, app_tid: str | None = None, ) -> str: @@ -356,7 +353,6 @@ def exchange_user_token( credentials: Customer credentials. user_token: User's JWT token to exchange. timeout: HTTP timeout in seconds. - config: Client configuration. cache: Token cache to consult and update. app_tid: BTP Application Tenant ID of subscriber (optional). @@ -373,7 +369,7 @@ def exchange_user_token( credentials, grant_type=_GRANT_TYPE_JWT_BEARER, timeout=timeout, - config=config, + cache=cache, app_tid=app_tid, extra_data={ "assertion": user_token, @@ -505,7 +501,7 @@ async def get_mcp_tools_customer( async def refetch_system_token() -> str: loop = asyncio.get_running_loop() new_token = await loop.run_in_executor( - None, get_system_token_mtls, credentials, timeout, config, cache, app_tid + None, get_system_token_mtls, credentials, timeout, cache, app_tid ) return new_token @@ -592,7 +588,6 @@ async def _acquire_token() -> str: credentials, user_token, timeout, - config, cache, app_tid, ) @@ -604,7 +599,7 @@ async def _acquire_token() -> str: "Principal propagation will NOT work." ) return await loop.run_in_executor( - None, get_system_token_mtls, credentials, timeout, config, cache, app_tid + None, get_system_token_mtls, credentials, timeout, cache, app_tid ) def _invalidate_token() -> None: diff --git a/src/sap_cloud_sdk/agentgateway/_token_cache.py b/src/sap_cloud_sdk/agentgateway/_token_cache.py index 0f4066d..4788ded 100644 --- a/src/sap_cloud_sdk/agentgateway/_token_cache.py +++ b/src/sap_cloud_sdk/agentgateway/_token_cache.py @@ -61,36 +61,6 @@ def _parse_jwt_exp(jwt: str) -> int | None: return None -def compute_expires_at(token_data: dict, config: ClientConfig) -> float: - """Resolve the cache expiry timestamp (monotonic) for a token response. - - Resolution order: - 1. `expires_in` from the response, minus the buffer. - 2. `exp` claim from `id_token` (translated from wall clock to monotonic), - minus the buffer. - 3. Config-provided fallback TTL. - """ - now_mono = time.monotonic() - buffer = config.token_expiry_buffer_seconds - - expires_in = token_data.get("expires_in") - if expires_in is not None: - try: - return now_mono + int(expires_in) - buffer - except (ValueError, TypeError): - pass - - id_token = token_data.get("id_token") - if id_token: - exp = _parse_jwt_exp(id_token) - if exp is not None: - remaining = exp - time.time() - if remaining > buffer: - return now_mono + remaining - buffer - - return now_mono + config.fallback_token_ttl_seconds - - class _TokenCache: """Per-client token cache with TTL and LRU eviction. @@ -171,6 +141,37 @@ def invalidate_user_token(self, user_jwt: str, client_id: str) -> None: if self._user_tokens.pop(key, None): logger.debug("Invalidated user token (client_id=%s)", client_id) + # --- Utility --- + + def compute_expires_at(self, token_data: dict) -> float: + """Resolve the cache expiry timestamp (monotonic) for a token response. + + Resolution order: + 1. `expires_in` from the response, minus the buffer. + 2. `exp` claim from `id_token` (translated from wall clock to monotonic), + minus the buffer. + 3. Config-provided fallback TTL. + """ + now_mono = time.monotonic() + buffer = self._config.token_expiry_buffer_seconds + + expires_in = token_data.get("expires_in") + if expires_in is not None: + try: + return now_mono + int(expires_in) - buffer + except (ValueError, TypeError): + pass + + id_token = token_data.get("id_token") + if id_token: + exp = _parse_jwt_exp(id_token) + if exp is not None: + remaining = exp - time.time() + if remaining > buffer: + return now_mono + remaining - buffer + + return now_mono + self._config.fallback_token_ttl_seconds + # --- Maintenance --- def clear(self) -> None: diff --git a/tests/agentgateway/unit/test_customer.py b/tests/agentgateway/unit/test_customer.py index 57a3cfa..b03097f 100644 --- a/tests/agentgateway/unit/test_customer.py +++ b/tests/agentgateway/unit/test_customer.py @@ -308,7 +308,7 @@ def test_requests_client_credentials_token(self, credentials): mock_client_class.return_value = mock_client result = get_system_token_mtls( - credentials, timeout=60.0, config=ClientConfig(), cache=_TokenCache(ClientConfig()) + credentials, timeout=60.0, cache=_TokenCache(ClientConfig()) ) assert result == "system-token-123" @@ -339,7 +339,6 @@ def test_raises_on_failed_request(self, credentials): get_system_token_mtls( credentials, timeout=60.0, - config=ClientConfig(), cache=_TokenCache(ClientConfig()), ) @@ -387,7 +386,6 @@ def test_exchanges_user_token_with_jwt_bearer(self, credentials): credentials, "user-jwt-token", timeout=60.0, - config=ClientConfig(), cache=_TokenCache(ClientConfig()), ) @@ -421,7 +419,6 @@ def test_passes_app_tid_when_provided(self, credentials): credentials, "user-jwt", timeout=60.0, - config=ClientConfig(), cache=_TokenCache(ClientConfig()), app_tid="test-tid", )