From fa2863b515b260e1c079be63d74880900701f574 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 5 Mar 2026 16:09:20 +0000 Subject: [PATCH 1/3] feat: add `tenant` to `ServerCallContext`, add tenant-prefixed routes for REST endpoints and introduce tenant extraction from REST API paths --- src/a2a/server/agent_execution/context.py | 5 ++ src/a2a/server/apps/rest/rest_adapter.py | 54 ++++++++++++ src/a2a/server/context.py | 1 + tests/server/apps/rest/test_rest_tenant.py | 99 ++++++++++++++++++++++ 4 files changed, 159 insertions(+) create mode 100644 tests/server/apps/rest/test_rest_tenant.py diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index ebbf74a91..73a4a9f4e 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -160,6 +160,11 @@ def add_activated_extension(self, uri: str) -> None: if self._call_context: self._call_context.activated_extensions.add(uri) + @property + def tenant(self) -> str: + """The tenant associated with this request.""" + return self._call_context.tenant if self._call_context else '' + @property def requested_extensions(self) -> set[str]: """Extensions that the client requested to activate.""" diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 113a8c47a..ddf7354b4 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -111,6 +111,9 @@ async def _handle_request( request: Request, ) -> Response: call_context = self._context_builder.build(request) + if 'tenant' in request.path_params: + call_context.tenant = request.path_params['tenant'] + response = await method(request, call_context) return JSONResponse(content=response) @@ -131,6 +134,8 @@ async def _handle_streaming_request( ) from e call_context = self._context_builder.build(request) + if 'tenant' in request.path_params: + call_context.tenant = request.path_params['tenant'] async def event_generator( stream: AsyncIterable[Any], @@ -250,10 +255,59 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: ('/tasks', 'GET'): functools.partial( self._handle_request, self.handler.list_tasks ), + # Tenant prefixed routes + ('/{tenant}/message:send', 'POST'): functools.partial( + self._handle_request, + self.handler.on_message_send, + ), + ('/{tenant}/message:stream', 'POST'): functools.partial( + self._handle_streaming_request, + self.handler.on_message_send_stream, + ), + ('/{tenant}/tasks/{id}:cancel', 'POST'): functools.partial( + self._handle_request, self.handler.on_cancel_task + ), + ('/{tenant}/tasks/{id}:subscribe', 'GET'): functools.partial( + self._handle_streaming_request, + self.handler.on_subscribe_to_task, + ), + ('/{tenant}/tasks/{id}', 'GET'): functools.partial( + self._handle_request, self.handler.on_get_task + ), + ( + '/{tenant}/tasks/{id}/pushNotificationConfigs/{push_id}', + 'GET', + ): functools.partial( + self._handle_request, self.handler.get_push_notification + ), + ( + '/{tenant}/tasks/{id}/pushNotificationConfigs/{push_id}', + 'DELETE', + ): functools.partial( + self._handle_request, self.handler.delete_push_notification + ), + ( + '/{tenant}/tasks/{id}/pushNotificationConfigs', + 'POST', + ): functools.partial( + self._handle_request, self.handler.set_push_notification + ), + ( + '/{tenant}/tasks/{id}/pushNotificationConfigs', + 'GET', + ): functools.partial( + self._handle_request, self.handler.list_push_notifications + ), + ('/{tenant}/tasks', 'GET'): functools.partial( + self._handle_request, self.handler.list_tasks + ), } if self.agent_card.capabilities.extended_agent_card: routes[('/extendedAgentCard', 'GET')] = functools.partial( self._handle_request, self.handle_authenticated_agent_card ) + routes[('/{tenant}/extendedAgentCard', 'GET')] = functools.partial( + self._handle_request, self.handle_authenticated_agent_card + ) return routes diff --git a/src/a2a/server/context.py b/src/a2a/server/context.py index 2b34cefee..c0ddd9219 100644 --- a/src/a2a/server/context.py +++ b/src/a2a/server/context.py @@ -21,5 +21,6 @@ class ServerCallContext(BaseModel): state: State = Field(default={}) user: User = Field(default=UnauthenticatedUser()) + tenant: str = Field(default='') requested_extensions: set[str] = Field(default_factory=set) activated_extensions: set[str] = Field(default_factory=set) diff --git a/tests/server/apps/rest/test_rest_tenant.py b/tests/server/apps/rest/test_rest_tenant.py new file mode 100644 index 000000000..a94dc0f7b --- /dev/null +++ b/tests/server/apps/rest/test_rest_tenant.py @@ -0,0 +1,99 @@ +import pytest +from unittest.mock import MagicMock +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient +from google.protobuf import json_format + +from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types.a2a_pb2 import ( + AgentCard, + Message, + Role, + Part, + SendMessageRequest, + SendMessageConfiguration, +) + + +@pytest.fixture +async def agent_card() -> AgentCard: + mock_agent_card = MagicMock(spec=AgentCard) + mock_agent_card.url = 'http://mockurl.com' + mock_capabilities = MagicMock() + mock_capabilities.streaming = False + mock_agent_card.capabilities = mock_capabilities + return mock_agent_card + + +@pytest.fixture +async def request_handler() -> RequestHandler: + handler = MagicMock(spec=RequestHandler) + # Return a default response so the test doesn't crash on return value expectation + handler.on_message_send.return_value = Message( + message_id='test', + role=Role.ROLE_AGENT, + parts=[Part(text='response message')], + ) + return handler + + +@pytest.fixture +async def app( + agent_card: AgentCard, request_handler: RequestHandler +) -> FastAPI: + return A2ARESTFastAPIApplication(agent_card, request_handler).build( + agent_card_url='/well-known/agent.json', rpc_url='' + ) + + +@pytest.fixture +async def client(app: FastAPI) -> AsyncClient: + return AsyncClient(transport=ASGITransport(app=app), base_url='http://test') + + +@pytest.mark.anyio +async def test_tenant_extraction_from_path( + client: AsyncClient, request_handler: MagicMock +) -> None: + request = SendMessageRequest( + message=Message(), + configuration=SendMessageConfiguration(), + ) + + # Test with tenant in URL + tenant_id = 'my-tenant-123' + response = await client.post( + f'/{tenant_id}/message:send', json=json_format.MessageToDict(request) + ) + response.raise_for_status() + + # Verify handler was called + assert request_handler.on_message_send.called + + # Verify call context has tenant + args, _ = request_handler.on_message_send.call_args + # args[0] is the request proto, args[1] is the ServerCallContext + context = args[1] + assert context.tenant == tenant_id + + +@pytest.mark.anyio +async def test_no_tenant_extraction( + client: AsyncClient, request_handler: MagicMock +) -> None: + request = SendMessageRequest( + message=Message(), + configuration=SendMessageConfiguration(), + ) + + # Test without tenant in URL + response = await client.post( + '/message:send', json=json_format.MessageToDict(request) + ) + response.raise_for_status() + + # Verify call context has empty string tenant (default) + args, _ = request_handler.on_message_send.call_args + context = args[1] + assert context.tenant == '' From 39d53517b783005a08f18458a7dfdcd537a64154 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 6 Mar 2026 09:39:01 +0000 Subject: [PATCH 2/3] refactor: add helper _build_call_context and reduce code duplication in routes --- src/a2a/server/apps/rest/rest_adapter.py | 74 ++++++------------------ 1 file changed, 17 insertions(+), 57 deletions(-) diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index ddf7354b4..ac9759172 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -110,9 +110,7 @@ async def _handle_request( method: Callable[[Request, ServerCallContext], Awaitable[Any]], request: Request, ) -> Response: - call_context = self._context_builder.build(request) - if 'tenant' in request.path_params: - call_context.tenant = request.path_params['tenant'] + call_context = self._build_call_context(request) response = await method(request, call_context) return JSONResponse(content=response) @@ -133,9 +131,7 @@ async def _handle_streaming_request( message=f'Failed to pre-consume request body: {e}' ) from e - call_context = self._context_builder.build(request) - if 'tenant' in request.path_params: - call_context.tenant = request.path_params['tenant'] + call_context = self._build_call_context(request) async def event_generator( stream: AsyncIterable[Any], @@ -210,7 +206,7 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: A dictionary where each key is a tuple of (path, http_method) and the value is the callable handler for that route. """ - routes: dict[tuple[str, str], Callable[[Request], Any]] = { + base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { ('/message:send', 'POST'): functools.partial( self._handle_request, self.handler.on_message_send ), @@ -255,59 +251,23 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: ('/tasks', 'GET'): functools.partial( self._handle_request, self.handler.list_tasks ), - # Tenant prefixed routes - ('/{tenant}/message:send', 'POST'): functools.partial( - self._handle_request, - self.handler.on_message_send, - ), - ('/{tenant}/message:stream', 'POST'): functools.partial( - self._handle_streaming_request, - self.handler.on_message_send_stream, - ), - ('/{tenant}/tasks/{id}:cancel', 'POST'): functools.partial( - self._handle_request, self.handler.on_cancel_task - ), - ('/{tenant}/tasks/{id}:subscribe', 'GET'): functools.partial( - self._handle_streaming_request, - self.handler.on_subscribe_to_task, - ), - ('/{tenant}/tasks/{id}', 'GET'): functools.partial( - self._handle_request, self.handler.on_get_task - ), - ( - '/{tenant}/tasks/{id}/pushNotificationConfigs/{push_id}', - 'GET', - ): functools.partial( - self._handle_request, self.handler.get_push_notification - ), - ( - '/{tenant}/tasks/{id}/pushNotificationConfigs/{push_id}', - 'DELETE', - ): functools.partial( - self._handle_request, self.handler.delete_push_notification - ), - ( - '/{tenant}/tasks/{id}/pushNotificationConfigs', - 'POST', - ): functools.partial( - self._handle_request, self.handler.set_push_notification - ), - ( - '/{tenant}/tasks/{id}/pushNotificationConfigs', - 'GET', - ): functools.partial( - self._handle_request, self.handler.list_push_notifications - ), - ('/{tenant}/tasks', 'GET'): functools.partial( - self._handle_request, self.handler.list_tasks - ), } + if self.agent_card.capabilities.extended_agent_card: - routes[('/extendedAgentCard', 'GET')] = functools.partial( - self._handle_request, self.handle_authenticated_agent_card - ) - routes[('/{tenant}/extendedAgentCard', 'GET')] = functools.partial( + base_routes[('/extendedAgentCard', 'GET')] = functools.partial( self._handle_request, self.handle_authenticated_agent_card ) + routes: dict[tuple[str, str], Callable[[Request], Any]] = { + (p, method): handler + for (path, method), handler in base_routes.items() + for p in (path, f'/{{tenant}}{path}') + } + return routes + + def _build_call_context(self, request: Request) -> ServerCallContext: + call_context = self._context_builder.build(request) + if 'tenant' in request.path_params: + call_context.tenant = request.path_params['tenant'] + return call_context From 2a5716643c8632d05ebb9b927be8467d81125562 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 6 Mar 2026 10:06:26 +0000 Subject: [PATCH 3/3] fix: add tenant support to `handle_authenticated_agent_card` and add more tests --- src/a2a/server/apps/rest/rest_adapter.py | 2 +- tests/server/apps/rest/test_rest_tenant.py | 167 ++++++++++++++++----- 2 files changed, 130 insertions(+), 39 deletions(-) diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index ac9759172..454a9f24b 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -186,7 +186,7 @@ async def handle_authenticated_agent_card( card_to_serve = self.agent_card if self.extended_card_modifier: - context = self._context_builder.build(request) + context = self._build_call_context(request) card_to_serve = await maybe_await( self.extended_card_modifier(card_to_serve, context) ) diff --git a/tests/server/apps/rest/test_rest_tenant.py b/tests/server/apps/rest/test_rest_tenant.py index a94dc0f7b..db1ddd5e0 100644 --- a/tests/server/apps/rest/test_rest_tenant.py +++ b/tests/server/apps/rest/test_rest_tenant.py @@ -2,17 +2,18 @@ from unittest.mock import MagicMock from fastapi import FastAPI from httpx import ASGITransport, AsyncClient -from google.protobuf import json_format from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( AgentCard, + ListTaskPushNotificationConfigsResponse, + ListTasksResponse, Message, - Role, Part, - SendMessageRequest, - SendMessageConfiguration, + Role, + Task, + TaskPushNotificationConfig, ) @@ -22,6 +23,8 @@ async def agent_card() -> AgentCard: mock_agent_card.url = 'http://mockurl.com' mock_capabilities = MagicMock() mock_capabilities.streaming = False + mock_capabilities.push_notifications = True + mock_capabilities.extended_agent_card = True mock_agent_card.capabilities = mock_capabilities return mock_agent_card @@ -29,22 +32,46 @@ async def agent_card() -> AgentCard: @pytest.fixture async def request_handler() -> RequestHandler: handler = MagicMock(spec=RequestHandler) - # Return a default response so the test doesn't crash on return value expectation + # Setup default return values for all handlers handler.on_message_send.return_value = Message( message_id='test', role=Role.ROLE_AGENT, parts=[Part(text='response message')], ) + handler.on_cancel_task.return_value = Task(id='1') + handler.on_get_task.return_value = Task(id='1') + handler.on_list_tasks.return_value = ListTasksResponse() + handler.on_create_task_push_notification_config.return_value = ( + TaskPushNotificationConfig() + ) + handler.on_get_task_push_notification_config.return_value = ( + TaskPushNotificationConfig() + ) + handler.on_list_task_push_notification_configs.return_value = ( + ListTaskPushNotificationConfigsResponse() + ) + handler.on_delete_task_push_notification_config.return_value = None return handler +@pytest.fixture +async def extended_card_modifier() -> MagicMock: + modifier = MagicMock() + modifier.return_value = AgentCard() + return modifier + + @pytest.fixture async def app( - agent_card: AgentCard, request_handler: RequestHandler + agent_card: AgentCard, + request_handler: RequestHandler, + extended_card_modifier: MagicMock, ) -> FastAPI: - return A2ARESTFastAPIApplication(agent_card, request_handler).build( - agent_card_url='/well-known/agent.json', rpc_url='' - ) + return A2ARESTFastAPIApplication( + agent_card, + request_handler, + extended_card_modifier=extended_card_modifier, + ).build(agent_card_url='/well-known/agent.json', rpc_url='') @pytest.fixture @@ -52,48 +79,112 @@ async def client(app: FastAPI) -> AsyncClient: return AsyncClient(transport=ASGITransport(app=app), base_url='http://test') +@pytest.mark.parametrize( + 'path_template, method, handler_method_name, json_body', + [ + ('/message:send', 'POST', 'on_message_send', {'message': {}}), + ('/tasks/1:cancel', 'POST', 'on_cancel_task', None), + ('/tasks/1', 'GET', 'on_get_task', None), + ('/tasks', 'GET', 'on_list_tasks', None), + ( + '/tasks/1/pushNotificationConfigs/p1', + 'GET', + 'on_get_task_push_notification_config', + None, + ), + ( + '/tasks/1/pushNotificationConfigs/p1', + 'DELETE', + 'on_delete_task_push_notification_config', + None, + ), + ( + '/tasks/1/pushNotificationConfigs', + 'POST', + 'on_create_task_push_notification_config', + {'config': {'url': 'http://foo'}}, + ), + ( + '/tasks/1/pushNotificationConfigs', + 'GET', + 'on_list_task_push_notification_configs', + None, + ), + ], +) @pytest.mark.anyio -async def test_tenant_extraction_from_path( - client: AsyncClient, request_handler: MagicMock +async def test_tenant_extraction_parametrized( + client: AsyncClient, + request_handler: MagicMock, + extended_card_modifier: MagicMock, + path_template: str, + method: str, + handler_method_name: str, + json_body: dict | None, ) -> None: - request = SendMessageRequest( - message=Message(), - configuration=SendMessageConfiguration(), - ) + """Test tenant extraction for standard REST endpoints.""" + # Test with tenant + tenant = 'my-tenant' + tenant_path = f'/{tenant}{path_template}' - # Test with tenant in URL - tenant_id = 'my-tenant-123' - response = await client.post( - f'/{tenant_id}/message:send', json=json_format.MessageToDict(request) - ) + response = await client.request(method, tenant_path, json=json_body) response.raise_for_status() - # Verify handler was called - assert request_handler.on_message_send.called + # Verify handler call + handler_mock = getattr(request_handler, handler_method_name) + + assert handler_mock.called + args, _ = handler_mock.call_args + context = args[1] + assert context.tenant == tenant + + # Reset mock for non-tenant test + handler_mock.reset_mock() + + # Test without tenant + response = await client.request(method, path_template, json=json_body) + response.raise_for_status() - # Verify call context has tenant - args, _ = request_handler.on_message_send.call_args - # args[0] is the request proto, args[1] is the ServerCallContext + # Verify context.tenant == "" + assert handler_mock.called + args, _ = handler_mock.call_args context = args[1] - assert context.tenant == tenant_id + assert context.tenant == '' @pytest.mark.anyio -async def test_no_tenant_extraction( - client: AsyncClient, request_handler: MagicMock +async def test_tenant_extraction_extended_agent_card( + client: AsyncClient, + extended_card_modifier: MagicMock, ) -> None: - request = SendMessageRequest( - message=Message(), - configuration=SendMessageConfiguration(), - ) + """Test tenant extraction specifically for extendedAgentCard endpoint. - # Test without tenant in URL - response = await client.post( - '/message:send', json=json_format.MessageToDict(request) - ) + This verifies that `extended_card_modifier` receives the correct context + including the tenant, confirming that `_build_call_context` is used correctly. + """ + # Test with tenant + tenant = 'my-tenant' + tenant_path = f'/{tenant}/extendedAgentCard' + + response = await client.get(tenant_path) + response.raise_for_status() + + # Verify extended_card_modifier called with tenant context + assert extended_card_modifier.called + args, _ = extended_card_modifier.call_args + # args[0] is card_to_serve, args[1] is context + context = args[1] + assert context.tenant == tenant + + # Reset mock for non-tenant test + extended_card_modifier.reset_mock() + + # Test without tenant + response = await client.get('/extendedAgentCard') response.raise_for_status() - # Verify call context has empty string tenant (default) - args, _ = request_handler.on_message_send.call_args + # Verify extended_card_modifier called with empty tenant context + assert extended_card_modifier.called + args, _ = extended_card_modifier.call_args context = args[1] assert context.tenant == ''