Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 151 additions & 27 deletions src/sap_cloud_sdk/agentgateway/_customer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
IntegrationDependency,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we also add token cache for LoB flow to keep consistency?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had an impression that since LoB flow goes through destination service it would handle the token request and caching. However seems like it's not the case. I'll add caching to lob flow

MCPTool,
)
from sap_cloud_sdk.agentgateway._token_cache import _TokenCache
from sap_cloud_sdk.agentgateway.config import ClientConfig
from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -211,19 +213,24 @@ def _request_token_mtls(
credentials: CustomerCredentials,
grant_type: str,
timeout: float,
cache: _TokenCache,
app_tid: str | None = None,
extra_data: dict | None = None,
) -> str:
) -> tuple[str, float]:
"""Make mTLS token request to IAS.

Args:
credentials: Customer credentials with certificate and private key.
grant_type: OAuth2 grant type.
timeout: HTTP timeout in seconds.
cache: Token cache to calculate expiry (for buffer and fallback TTL).
app_tid: BTP Application Tenant ID of subscriber (optional).
extra_data: Additional form data for the token request.

Returns:
Access token string.
Tuple of (access_token, expires_at) where expires_at is a
time.monotonic() value indicating when the cached token should
be refreshed (already includes the configured buffer).

Raises:
AgentGatewaySDKError: If token request fails.
Expand Down Expand Up @@ -282,8 +289,10 @@ def _request_token_mtls(
f"Token response missing 'access_token'. Keys: {list(token_data.keys())}"
)

expires_at = cache.compute_expires_at(token_data)

logger.debug("Token acquired successfully (length: %d)", len(access_token))
return access_token
return access_token, expires_at

except httpx.RequestError as e:
raise AgentGatewaySDKError(f"Token request failed: {e}")
Expand All @@ -292,61 +301,83 @@ def _request_token_mtls(
def get_system_token_mtls(
credentials: CustomerCredentials,
timeout: float,
cache: _TokenCache,
app_tid: str | None = None,
) -> str:
"""Get system-scoped token using mTLS client credentials flow.

Used for tool discovery where user identity is not needed.
Used for tool discovery where user identity is not needed. Returns
a cached token if still valid; otherwise acquires a fresh one.

Args:
credentials: Customer credentials.
timeout: HTTP timeout in seconds.
cache: Token cache to consult and update.
app_tid: BTP Application Tenant ID of subscriber (optional).

Returns:
System-scoped access token.
"""
cached = cache.get_system_token(credentials.client_id)
if cached:
logger.debug("Using cached system token (client_id=%s)", credentials.client_id)
return cached

logger.info("Acquiring system token via mTLS client credentials")
return _request_token_mtls(
token, expires_at = _request_token_mtls(
credentials,
grant_type=_GRANT_TYPE_CLIENT_CREDENTIALS,
timeout=timeout,
cache=cache,
app_tid=app_tid,
extra_data={"response_type": "token"},
)
cache.set_system_token(token, expires_at, credentials.client_id)
return token


def exchange_user_token(
credentials: CustomerCredentials,
user_token: str,
timeout: float,
cache: _TokenCache,
app_tid: str | None = None,
) -> str:
"""Exchange user token for AGW-scoped token using jwt-bearer grant.

Used for tool invocation where user identity must be preserved
for principal propagation.
for principal propagation. Returns a cached exchanged token if
still valid; otherwise acquires a fresh one.

Args:
credentials: Customer credentials.
user_token: User's JWT token to exchange.
timeout: HTTP timeout in seconds.
cache: Token cache to consult and update.
app_tid: BTP Application Tenant ID of subscriber (optional).

Returns:
AGW-scoped access token with user identity.
"""
cached = cache.get_user_token(user_token, credentials.client_id)
if cached:
logger.debug("Using cached user token (client_id=%s)", credentials.client_id)
return cached

logger.info("Exchanging user token for AGW-scoped token via jwt-bearer grant")
return _request_token_mtls(
token, expires_at = _request_token_mtls(
credentials,
grant_type=_GRANT_TYPE_JWT_BEARER,
timeout=timeout,
cache=cache,
app_tid=app_tid,
extra_data={
"assertion": user_token,
"token_format": "jwt",
},
)
cache.set_user_token(user_token, token, expires_at, credentials.client_id)
return token


def _build_mcp_url(gateway_url: str, ord_id: str, gt_id: str) -> str:
Expand Down Expand Up @@ -433,6 +464,8 @@ async def _list_server_tools(
async def get_mcp_tools_customer(
credentials: CustomerCredentials,
timeout: float,
config: ClientConfig,
cache: _TokenCache,
app_tid: str | None = None,
) -> list[MCPTool]:
"""List all MCP tools from servers defined in credentials.
Expand All @@ -442,6 +475,9 @@ async def get_mcp_tools_customer(

Args:
credentials: Customer credentials with integrationDependencies.
timeout: HTTP timeout in seconds.
config: Client configuration.
cache: Token cache shared across calls.
app_tid: BTP Application Tenant ID of subscriber (optional).

Returns:
Expand All @@ -459,11 +495,15 @@ async def get_mcp_tools_customer(

logger.info("Discovering tools from %d MCP server(s)", len(dependencies))

# Get system token for discovery
loop = asyncio.get_running_loop()
system_token = await loop.run_in_executor(
None, get_system_token_mtls, credentials, timeout, app_tid
)
system_token = None
# Define a helper closure to refetch the system token on demand, since it may need to be
# refreshed during the discovery loop if any server returns a 401
async def refetch_system_token() -> str:
loop = asyncio.get_running_loop()
new_token = await loop.run_in_executor(
None, get_system_token_mtls, credentials, timeout, cache, app_tid
)
return new_token

tools: list[MCPTool] = []

Expand All @@ -476,12 +516,28 @@ async def get_mcp_tools_customer(
dep.global_tenant_id,
)

try:
server_tools = await _list_server_tools(url, system_token, dep, timeout)
tools.extend(server_tools)
logger.debug("Loaded %d tool(s) from %s", len(server_tools), dep.ord_id)
except Exception:
logger.exception("Failed to load tools from %s — skipping", dep.ord_id)
for attempt in (1, 2):
if not system_token:
# won't catch exceptions here - if token acquisition fails,
# we want the discovery to fail immediately
system_token = await refetch_system_token()

try:
server_tools = await _list_server_tools(url, system_token, dep, timeout)
tools.extend(server_tools)
logger.debug("Loaded %d tool(s) from %s", len(server_tools), dep.ord_id)
except Exception as exc:
unwrapped = _unwrap_exception_group(exc)
if _is_unauthorized(unwrapped) and attempt == 1:
logger.info(
"401 from %s — invalidating cached system token and retrying",
dep.ord_id,
)
cache.invalidate_system_token(credentials.client_id)
system_token = None # Force refetch on next loop iteration
continue
logger.exception("Failed to load tools from %s — skipping", dep.ord_id)
break # Success, move to next server

logger.info(
"Loaded %d MCP tool(s) from %d server(s)", len(tools), len(dependencies)
Expand All @@ -494,6 +550,8 @@ async def call_mcp_tool_customer(
tool: MCPTool,
user_token: str | None,
timeout: float,
config: ClientConfig,
cache: _TokenCache,
app_tid: str | None = None,
**kwargs,
) -> str:
Expand All @@ -502,11 +560,16 @@ async def call_mcp_tool_customer(
If user_token is provided, exchanges it for an AGW-scoped token to preserve
user identity for principal propagation. Otherwise, falls back to system token.

On a 401 from the MCP server, drops the cached token and retries once.

Args:
credentials: Customer credentials.
tool: MCPTool to invoke.
user_token: User's JWT token for principal propagation (optional).
If None, system token is used instead (no principal propagation).
timeout: HTTP timeout in seconds.
config: Client configuration.
cache: Token cache shared across calls.
app_tid: BTP Application Tenant ID of subscriber (optional).
**kwargs: Tool input parameters.

Expand All @@ -517,26 +580,73 @@ async def call_mcp_tool_customer(

loop = asyncio.get_running_loop()

if user_token:
# Exchange user token for AGW-scoped token (with principal propagation)
agw_token = await loop.run_in_executor(
None, exchange_user_token, credentials, user_token, timeout, app_tid
)
else:
async def _acquire_token() -> str:
if user_token:
return await loop.run_in_executor(
None,
exchange_user_token,
credentials,
user_token,
timeout,
cache,
app_tid,
)
# TODO: IBD workaround - use system token when user_token is not available.
# This bypasses principal propagation. Remove this fallback once IBD
# supports proper user token flow.
logger.warning(
"No user_token provided - using system token for tool invocation. "
"Principal propagation will NOT work."
)
agw_token = await loop.run_in_executor(
None, get_system_token_mtls, credentials, timeout, app_tid
return await loop.run_in_executor(
None, get_system_token_mtls, credentials, timeout, cache, app_tid
)

def _invalidate_token() -> None:
if user_token:
cache.invalidate_user_token(user_token, credentials.client_id)
else:
cache.invalidate_system_token(credentials.client_id)

last_exc: Exception | None = None
for attempt in (1, 2):
agw_token = await _acquire_token()
try:
return await _invoke_tool(tool, agw_token, timeout, **kwargs)
except Exception as exc:
unwrapped = _unwrap_exception_group(exc)
if _is_unauthorized(unwrapped) and attempt == 1:
logger.info(
"401 from MCP server for tool '%s' — invalidating cached token and retrying",
tool.name,
)
_invalidate_token()
last_exc = exc
continue
raise

# Defensive — should not be reachable; second attempt either returns or raises.
raise AgentGatewaySDKError(
f"Tool invocation for '{tool.name}' failed after 401 retry: {last_exc}"
)


async def _invoke_tool(
tool: MCPTool,
auth_token: str,
timeout: float,
**kwargs,
) -> str:
"""Open an MCP session to `tool.url` and invoke `tool.name` with `kwargs`.

Returns the first content block's text, or empty string when content is
empty. Raises whatever the MCP transport / session raises (notably
`httpx.HTTPStatusError` on 401, which the caller uses to drive cache
invalidation and retry).
"""
async with httpx.AsyncClient(
headers={
"Authorization": f"Bearer {agw_token}",
"Authorization": f"Bearer {auth_token}",
"x-correlation-id": str(uuid.uuid4()),
},
timeout=timeout,
Expand All @@ -556,3 +666,17 @@ async def call_mcp_tool_customer(

first = result.content[0]
return str(getattr(first, "text", ""))


def _unwrap_exception_group(exc: BaseException) -> BaseException:
"""Unwrap nested ExceptionGroups to find the underlying cause."""
while isinstance(exc, BaseExceptionGroup) and exc.exceptions:
exc = exc.exceptions[0]
return exc


def _is_unauthorized(exc: BaseException) -> bool:
"""Detect a 401 response from the MCP server (httpx-based)."""
if isinstance(exc, httpx.HTTPStatusError):
return exc.response is not None and exc.response.status_code == 401
return False
Loading