From befb4f6a47fc2477b667559bae6d5ba2e8c2bd27 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 19 Mar 2026 17:31:03 +0000 Subject: [PATCH 01/25] wip --- samples/hello_world_agent.py | 16 +- src/a2a/compat/v0_3/jsonrpc_adapter.py | 2 +- src/a2a/compat/v0_3/rest_adapter.py | 5 +- src/a2a/server/apps/__init__.py | 10 - src/a2a/server/apps/jsonrpc/__init__.py | 20 -- src/a2a/server/apps/jsonrpc/fastapi_app.py | 148 --------- src/a2a/server/apps/jsonrpc/starlette_app.py | 169 ----------- src/a2a/server/apps/rest/fastapi_app.py | 2 +- src/a2a/server/apps/rest/rest_adapter.py | 5 +- src/a2a/server/routes/__init__.py | 20 ++ src/a2a/server/routes/agent_card_routes.py | 85 ++++++ .../jsonrpc_dispatcher.py} | 53 +--- src/a2a/server/routes/jsonrpc_routes.py | 107 +++++++ tck/sut_agent.py | 22 +- tests/__init__.py | 1 + tests/compat/v0_3/test_jsonrpc_app_compat.py | 11 +- .../cross_version/client_server/server_0_3.py | 15 +- .../cross_version/client_server/server_1_0.py | 21 +- tests/integration/test_agent_card.py | 17 +- .../test_client_server_integration.py | 65 ++-- tests/integration/test_end_to_end.py | 14 +- tests/integration/test_tenant.py | 17 +- tests/integration/test_version_header.py | 17 +- tests/server/apps/jsonrpc/test_fastapi_app.py | 79 ----- .../server/apps/jsonrpc/test_serialization.py | 280 ------------------ .../server/apps/jsonrpc/test_starlette_app.py | 81 ----- tests/server/routes/test_agent_card_routes.py | 105 +++++++ .../test_jsonrpc_dispatcher.py} | 239 +++------------ tests/server/routes/test_jsonrpc_routes.py | 96 ++++++ tests/server/test_integration.py | 102 ++++--- 30 files changed, 668 insertions(+), 1156 deletions(-) delete mode 100644 src/a2a/server/apps/jsonrpc/__init__.py delete mode 100644 src/a2a/server/apps/jsonrpc/fastapi_app.py delete mode 100644 src/a2a/server/apps/jsonrpc/starlette_app.py create mode 100644 src/a2a/server/routes/__init__.py create mode 100644 src/a2a/server/routes/agent_card_routes.py rename src/a2a/server/{apps/jsonrpc/jsonrpc_app.py => routes/jsonrpc_dispatcher.py} (93%) create mode 100644 src/a2a/server/routes/jsonrpc_routes.py create mode 100644 tests/__init__.py delete mode 100644 tests/server/apps/jsonrpc/test_fastapi_app.py delete mode 100644 tests/server/apps/jsonrpc/test_serialization.py delete mode 100644 tests/server/apps/jsonrpc/test_starlette_app.py create mode 100644 tests/server/routes/test_agent_card_routes.py rename tests/server/{apps/jsonrpc/test_jsonrpc_app.py => routes/test_jsonrpc_dispatcher.py} (51%) create mode 100644 tests/server/routes/test_jsonrpc_routes.py diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index 38dfdf561..e46b9ede4 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -11,12 +11,13 @@ from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from a2a.server.apps import A2ARESTFastAPIApplication from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import GrpcHandler from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_updater import TaskUpdater from a2a.types import ( @@ -197,14 +198,17 @@ async def serve( ) rest_app = rest_app_builder.build() - jsonrpc_app_builder = A2AFastAPIApplication( + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=request_handler, + rpc_url='/a2a/jsonrpc/', + ) + agent_card_routes = AgentCardRoutes( agent_card=agent_card, - http_handler=request_handler, - enable_v0_3_compat=True, ) - app = FastAPI() - jsonrpc_app_builder.add_routes_to_app(app, rpc_url='/a2a/jsonrpc/') + app.routes.extend(jsonrpc_routes.routes) + app.routes.extend(agent_card_routes.routes) app.mount('/a2a/rest', rest_app) grpc_server = grpc.aio.server() diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index 30a04dd91..073c7854b 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -10,8 +10,8 @@ if TYPE_CHECKING: from starlette.requests import Request - from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder from a2a.server.request_handlers.request_handler import RequestHandler + from a2a.server.routes import CallContextBuilder from a2a.types.a2a_pb2 import AgentCard _package_starlette_installed = True diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index b0296e402..8cae6b630 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -33,12 +33,9 @@ from a2a.compat.v0_3 import conversions from a2a.compat.v0_3.rest_handler import REST03Handler -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, - DefaultCallContextBuilder, -) from a2a.server.apps.rest.rest_adapter import RESTAdapterInterface from a2a.server.context import ServerCallContext +from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder from a2a.utils.error_handlers import ( rest_error_handler, rest_stream_error_handler, diff --git a/src/a2a/server/apps/__init__.py b/src/a2a/server/apps/__init__.py index 579deaa54..1cdb32953 100644 --- a/src/a2a/server/apps/__init__.py +++ b/src/a2a/server/apps/__init__.py @@ -1,18 +1,8 @@ """HTTP application components for the A2A server.""" -from a2a.server.apps.jsonrpc import ( - A2AFastAPIApplication, - A2AStarletteApplication, - CallContextBuilder, - JSONRPCApplication, -) from a2a.server.apps.rest import A2ARESTFastAPIApplication __all__ = [ - 'A2AFastAPIApplication', 'A2ARESTFastAPIApplication', - 'A2AStarletteApplication', - 'CallContextBuilder', - 'JSONRPCApplication', ] diff --git a/src/a2a/server/apps/jsonrpc/__init__.py b/src/a2a/server/apps/jsonrpc/__init__.py deleted file mode 100644 index 1121fdbc3..000000000 --- a/src/a2a/server/apps/jsonrpc/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""A2A JSON-RPC Applications.""" - -from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, - DefaultCallContextBuilder, - JSONRPCApplication, - StarletteUserProxy, -) -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication - - -__all__ = [ - 'A2AFastAPIApplication', - 'A2AStarletteApplication', - 'CallContextBuilder', - 'DefaultCallContextBuilder', - 'JSONRPCApplication', - 'StarletteUserProxy', -] diff --git a/src/a2a/server/apps/jsonrpc/fastapi_app.py b/src/a2a/server/apps/jsonrpc/fastapi_app.py deleted file mode 100644 index 0ec9d1ab2..000000000 --- a/src/a2a/server/apps/jsonrpc/fastapi_app.py +++ /dev/null @@ -1,148 +0,0 @@ -import logging - -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from fastapi import FastAPI - - _package_fastapi_installed = True -else: - try: - from fastapi import FastAPI - - _package_fastapi_installed = True - except ImportError: - FastAPI = Any - - _package_fastapi_installed = False - -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, - JSONRPCApplication, -) -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - DEFAULT_RPC_URL, -) - - -logger = logging.getLogger(__name__) - - -class A2AFastAPIApplication(JSONRPCApplication): - """A FastAPI application implementing the A2A protocol server endpoints. - - Handles incoming JSON-RPC requests, routes them to the appropriate - handler methods, and manages response generation including Server-Sent Events - (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB - enable_v0_3_compat: bool = False, - ) -> None: - """Initializes the A2AFastAPIApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - max_content_length: The maximum allowed content length for incoming - requests. Defaults to 10MB. Set to None for unbounded maximum. - enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. - """ - if not _package_fastapi_installed: - raise ImportError( - 'The `fastapi` package is required to use the `A2AFastAPIApplication`.' - ' It can be added as a part of `a2a-sdk` optional dependencies,' - ' `a2a-sdk[http-server]`.' - ) - super().__init__( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - max_content_length=max_content_length, - enable_v0_3_compat=enable_v0_3_compat, - ) - - def add_routes_to_app( - self, - app: FastAPI, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - ) -> None: - """Adds the routes to the FastAPI application. - - Args: - app: The FastAPI application to add the routes to. - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A JSON-RPC endpoint. - """ - app.post( - rpc_url, - openapi_extra={ - 'requestBody': { - 'content': { - 'application/json': { - 'schema': { - '$ref': '#/components/schemas/A2ARequest' - } - } - }, - 'required': True, - 'description': 'A2ARequest', - } - }, - )(self._handle_requests) - app.get(agent_card_url)(self._handle_get_agent_card) - - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - **kwargs: Any, - ) -> FastAPI: - """Builds and returns the FastAPI application instance. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A JSON-RPC endpoint. - **kwargs: Additional keyword arguments to pass to the FastAPI constructor. - - Returns: - A configured FastAPI application instance. - """ - app = FastAPI(**kwargs) - - self.add_routes_to_app(app, agent_card_url, rpc_url) - - return app diff --git a/src/a2a/server/apps/jsonrpc/starlette_app.py b/src/a2a/server/apps/jsonrpc/starlette_app.py deleted file mode 100644 index 553fa2503..000000000 --- a/src/a2a/server/apps/jsonrpc/starlette_app.py +++ /dev/null @@ -1,169 +0,0 @@ -import logging - -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from starlette.applications import Starlette - from starlette.routing import Route - - _package_starlette_installed = True - -else: - try: - from starlette.applications import Starlette - from starlette.routing import Route - - _package_starlette_installed = True - except ImportError: - Starlette = Any - Route = Any - - _package_starlette_installed = False - -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - CallContextBuilder, - JSONRPCApplication, -) -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - DEFAULT_RPC_URL, -) - - -logger = logging.getLogger(__name__) - - -class A2AStarletteApplication(JSONRPCApplication): - """A Starlette application implementing the A2A protocol server endpoints. - - Handles incoming JSON-RPC requests, routes them to the appropriate - handler methods, and manages response generation including Server-Sent Events - (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB - enable_v0_3_compat: bool = False, - ) -> None: - """Initializes the A2AStarletteApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - max_content_length: The maximum allowed content length for incoming - requests. Defaults to 10MB. Set to None for unbounded maximum. - enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. - """ - if not _package_starlette_installed: - raise ImportError( - 'Packages `starlette` and `sse-starlette` are required to use the' - ' `A2AStarletteApplication`. It can be added as a part of `a2a-sdk`' - ' optional dependencies, `a2a-sdk[http-server]`.' - ) - super().__init__( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - max_content_length=max_content_length, - enable_v0_3_compat=enable_v0_3_compat, - ) - - def routes( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - ) -> list[Route]: - """Returns the Starlette Routes for handling A2A requests. - - Args: - agent_card_url: The URL path for the agent card endpoint. - rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests). - - Returns: - A list of Starlette Route objects. - """ - return [ - Route( - rpc_url, - self._handle_requests, - methods=['POST'], - name='a2a_handler', - ), - Route( - agent_card_url, - self._handle_get_agent_card, - methods=['GET'], - name='agent_card', - ), - ] - - def add_routes_to_app( - self, - app: Starlette, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - ) -> None: - """Adds the routes to the Starlette application. - - Args: - app: The Starlette application to add the routes to. - agent_card_url: The URL path for the agent card endpoint. - rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests). - """ - routes = self.routes( - agent_card_url=agent_card_url, - rpc_url=rpc_url, - ) - app.routes.extend(routes) - - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - **kwargs: Any, - ) -> Starlette: - """Builds and returns the Starlette application instance. - - Args: - agent_card_url: The URL path for the agent card endpoint. - rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests). - **kwargs: Additional keyword arguments to pass to the Starlette constructor. - - Returns: - A configured Starlette application instance. - """ - app = Starlette(**kwargs) - - self.add_routes_to_app(app, agent_card_url, rpc_url) - - return app diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index ea9a501b9..4feac9072 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -28,10 +28,10 @@ from a2a.compat.v0_3.rest_adapter import REST03Adapter -from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder from a2a.server.apps.rest.rest_adapter import RESTAdapter from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.routes import CallContextBuilder from a2a.types.a2a_pb2 import AgentCard from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 6b8abb99e..ebf996a47 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -33,16 +33,13 @@ _package_starlette_installed = False -from a2a.server.apps.jsonrpc import ( - CallContextBuilder, - DefaultCallContextBuilder, -) from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( agent_card_to_dict, ) from a2a.server.request_handlers.rest_handler import RESTHandler +from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder from a2a.types.a2a_pb2 import AgentCard from a2a.utils.error_handlers import ( rest_error_handler, diff --git a/src/a2a/server/routes/__init__.py b/src/a2a/server/routes/__init__.py new file mode 100644 index 000000000..ec65d8b34 --- /dev/null +++ b/src/a2a/server/routes/__init__.py @@ -0,0 +1,20 @@ +"""A2A Routes.""" + +from a2a.server.routes.agent_card_routes import AgentCardRoutes +from a2a.server.routes.jsonrpc_dispatcher import ( + CallContextBuilder, + DefaultCallContextBuilder, + JsonRpcDispatcher, + StarletteUserProxy, +) +from a2a.server.routes.jsonrpc_routes import JsonRpcRoutes + + +__all__ = [ + 'AgentCardRoutes', + 'CallContextBuilder', + 'DefaultCallContextBuilder', + 'JsonRpcDispatcher', + 'JsonRpcRoutes', + 'StarletteUserProxy', +] diff --git a/src/a2a/server/routes/agent_card_routes.py b/src/a2a/server/routes/agent_card_routes.py new file mode 100644 index 000000000..30d635f11 --- /dev/null +++ b/src/a2a/server/routes/agent_card_routes.py @@ -0,0 +1,85 @@ +import logging + +from collections.abc import Awaitable, Callable, Sequence +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from starlette.middleware import Middleware + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + from starlette.routing import Route + + _package_starlette_installed = True +else: + try: + from starlette.middleware import Middleware + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + from starlette.routing import Route + + _package_starlette_installed = True + except ImportError: + Middleware = Any + Route = Any + Request = Any + Response = Any + JSONResponse = Any + + _package_starlette_installed = False + +from a2a.server.request_handlers.response_helpers import agent_card_to_dict +from a2a.types.a2a_pb2 import AgentCard +from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH +from a2a.utils.helpers import maybe_await + + +logger = logging.getLogger(__name__) + + +class AgentCardRoutes: + """Provides the Starlette Route for the A2A protocol agent card endpoint.""" + + def __init__( + self, + agent_card: AgentCard, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + card_url: str = AGENT_CARD_WELL_KNOWN_PATH, + middleware: Sequence['Middleware'] | None = None, + ) -> None: + """Initializes the AgentCardRoute. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + card_url: The URL for the agent card endpoint. + middleware: An optional list of Starlette middleware to apply to the + agent card endpoint. + """ + if not _package_starlette_installed: + raise ImportError( + 'The `starlette` package is required to use the `AgentCardRoutes`.' + ' `a2a-sdk[http-server]`.' + ) + + self.agent_card = agent_card + self.card_modifier = card_modifier + + async def get_agent_card(request: Request) -> Response: + card_to_serve = self.agent_card + if self.card_modifier: + card_to_serve = await maybe_await( + self.card_modifier(card_to_serve) + ) + return JSONResponse(agent_card_to_dict(card_to_serve)) + + self.routes = [ + Route( + path=card_url, + endpoint=get_agent_card, + methods=['GET'], + middleware=middleware, + ) + ] diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/routes/jsonrpc_dispatcher.py similarity index 93% rename from src/a2a/server/apps/jsonrpc/jsonrpc_app.py rename to src/a2a/server/routes/jsonrpc_dispatcher.py index 219470766..14a0cc0bb 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -31,7 +31,6 @@ from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( - agent_card_to_dict, build_error_response, ) from a2a.types import A2ARequest @@ -49,14 +48,12 @@ TaskPushNotificationConfig, ) from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, - DEFAULT_RPC_URL, + DEFAULT_MAX_CONTENT_LENGTH, ) from a2a.utils.errors import ( A2AError, UnsupportedOperationError, ) -from a2a.utils.helpers import maybe_await INTERNAL_ERROR_CODE = -32603 @@ -167,7 +164,7 @@ def build(self, request: Request) -> ServerCallContext: ) -class JSONRPCApplication(ABC): +class JsonRpcDispatcher: """Base class for A2A JSONRPC applications. Handles incoming JSON-RPC requests, routes them to the appropriate @@ -204,10 +201,10 @@ def __init__( # noqa: PLR0913 [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, - max_content_length: int | None = 10 * 1024 * 1024, # 10MB enable_v0_3_compat: bool = False, + max_content_length: int | None = DEFAULT_MAX_CONTENT_LENGTH, ) -> None: - """Initializes the JSONRPCApplication. + """Initializes the JsonRpcDispatcher. Args: agent_card: The AgentCard describing the agent's capabilities. @@ -230,7 +227,7 @@ def __init__( # noqa: PLR0913 if not _package_starlette_installed: raise ImportError( 'Packages `starlette` and `sse-starlette` are required to use the' - ' `JSONRPCApplication`. They can be added as a part of `a2a-sdk`' + ' `JsonRpcDispatcher`. They can be added as a part of `a2a-sdk`' ' optional dependencies, `a2a-sdk[http-server]`.' ) @@ -600,43 +597,3 @@ async def event_generator( # handler_result is a dict (JSON-RPC response) return JSONResponse(handler_result, headers=headers) - - async def _handle_get_agent_card(self, request: Request) -> JSONResponse: - """Handles GET requests for the agent card endpoint. - - Args: - request: The incoming Starlette Request object. - - Returns: - A JSONResponse containing the agent card data. - """ - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return JSONResponse( - agent_card_to_dict( - card_to_serve, - ) - ) - - @abstractmethod - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = DEFAULT_RPC_URL, - **kwargs: Any, - ) -> FastAPI | Starlette: - """Builds and returns the JSONRPC application instance. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A JSON-RPC endpoint. - **kwargs: Additional keyword arguments to pass to the FastAPI constructor. - - Returns: - A configured JSONRPC application instance. - """ - raise NotImplementedError( - 'Subclasses must implement the build method to create the application instance.' - ) diff --git a/src/a2a/server/routes/jsonrpc_routes.py b/src/a2a/server/routes/jsonrpc_routes.py new file mode 100644 index 000000000..cc0e12612 --- /dev/null +++ b/src/a2a/server/routes/jsonrpc_routes.py @@ -0,0 +1,107 @@ +import logging + +from collections.abc import Awaitable, Callable, Sequence +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from starlette.middleware import Middleware + from starlette.routing import Route, Router + + _package_starlette_installed = True +else: + try: + from starlette.middleware import Middleware + from starlette.routing import Route, Router + + _package_starlette_installed = True + except ImportError: + Middleware = Any + Route = Any + Router = Any + + _package_starlette_installed = False + + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.routes.jsonrpc_dispatcher import ( + CallContextBuilder, + JsonRpcDispatcher, +) +from a2a.types.a2a_pb2 import AgentCard +from a2a.utils.constants import DEFAULT_RPC_URL + + +logger = logging.getLogger(__name__) + + +class JsonRpcRoutes: + """Provides the Starlette Route for the A2A protocol JSON-RPC endpoint. + + Handles incoming JSON-RPC requests, routes them to the appropriate + handler methods, and manages response generation including Server-Sent Events + (SSE). + """ + + def __init__( # noqa: PLR0913 + self, + agent_card: AgentCard, + request_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, + context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, + enable_v0_3_compat: bool = False, + rpc_url: str = DEFAULT_RPC_URL, + middleware: Sequence[Middleware] | None = None, + ) -> None: + """Initializes the JsonRpcRoute. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the request_handler. If None, no + ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. + enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. + rpc_url: The URL prefix for the RPC endpoints. + middleware: An optional list of Starlette middleware to apply to the routes. + """ + if not _package_starlette_installed: + raise ImportError( + 'The `starlette` package is required to use the `JsonRpcRoutes`.' + ' It can be added as a part of `a2a-sdk` optional dependencies,' + ' `a2a-sdk[http-server]`.' + ) + + self.dispatcher = JsonRpcDispatcher( + agent_card=agent_card, + http_handler=request_handler, + extended_agent_card=extended_agent_card, + context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, + enable_v0_3_compat=enable_v0_3_compat, + ) + + self.routes = [ + Route( + path=rpc_url, + endpoint=self.dispatcher._handle_requests, # noqa: SLF001 + methods=['POST'], + middleware=middleware, + ) + ] diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 7196b828b..955493437 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -18,13 +18,16 @@ from a2a.server.agent_execution.context import RequestContext from a2a.server.apps import ( A2ARESTFastAPIApplication, - A2AStarletteApplication, ) from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) from a2a.server.request_handlers.grpc_handler import GrpcHandler +from a2a.server.routes import ( + AgentCardRoutes, + JsonRpcRoutes, +) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_store import TaskStore from a2a.types import ( @@ -196,15 +199,22 @@ def serve(task_store: TaskStore) -> None: task_store=task_store, ) - main_app = Starlette() - # JSONRPC - jsonrpc_server = A2AStarletteApplication( + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=request_handler, + rpc_url=JSONRPC_URL, + ) + # Agent Card + agent_card_routes = AgentCardRoutes( agent_card=agent_card, - http_handler=request_handler, ) - jsonrpc_server.add_routes_to_app(main_app, rpc_url=JSONRPC_URL) + routes = [ + *jsonrpc_routes.routes, + *agent_card_routes.routes, + ] + main_app = Starlette(routes=routes) # REST rest_server = A2ARESTFastAPIApplication( agent_card=agent_card, diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..792d60054 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# diff --git a/tests/compat/v0_3/test_jsonrpc_app_compat.py b/tests/compat/v0_3/test_jsonrpc_app_compat.py index 4f09bb230..4b344c67d 100644 --- a/tests/compat/v0_3/test_jsonrpc_app_compat.py +++ b/tests/compat/v0_3/test_jsonrpc_app_compat.py @@ -6,7 +6,8 @@ import pytest from starlette.testclient import TestClient -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from starlette.applications import Starlette +from a2a.server.routes import JsonRpcRoutes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( AgentCard, @@ -50,16 +51,18 @@ def test_app(mock_handler): mock_agent_card.capabilities.streaming = False mock_agent_card.capabilities.push_notifications = True mock_agent_card.capabilities.extended_agent_card = True - return A2AStarletteApplication( + router = JsonRpcRoutes( agent_card=mock_agent_card, - http_handler=mock_handler, + request_handler=mock_handler, enable_v0_3_compat=True, + rpc_url='/', ) + return Starlette(routes=router.routes) @pytest.fixture def client(test_app): - return TestClient(test_app.build()) + return TestClient(test_app) def test_send_message_v03_compat( diff --git a/tests/integration/cross_version/client_server/server_0_3.py b/tests/integration/cross_version/client_server/server_0_3.py index 7bd5f7e75..96152c135 100644 --- a/tests/integration/cross_version/client_server/server_0_3.py +++ b/tests/integration/cross_version/client_server/server_0_3.py @@ -8,7 +8,7 @@ from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext -from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication +from a2a.server.apps.jsonrpc import A2AFastAPIApplication from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication from a2a.server.events.event_queue import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager @@ -188,12 +188,13 @@ async def main_async(http_port: int, grpc_port: int): ) app = FastAPI() - app.mount( - '/jsonrpc', - A2AFastAPIApplication( - http_handler=handler, agent_card=agent_card - ).build(), - ) + jsonrpc_app = A2AFastAPIApplication( + agent_card=agent_card, + http_handler=handler, + extended_agent_card=agent_card, + ).build() + app.mount('/jsonrpc', jsonrpc_app) + app.mount( '/rest', A2ARESTFastAPIApplication( diff --git a/tests/integration/cross_version/client_server/server_1_0.py b/tests/integration/cross_version/client_server/server_1_0.py index e079fdf21..907c010ff 100644 --- a/tests/integration/cross_version/client_server/server_1_0.py +++ b/tests/integration/cross_version/client_server/server_1_0.py @@ -5,7 +5,8 @@ import grpc from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes +from a2a.server.apps import A2ARESTFastAPIApplication from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler @@ -166,10 +167,20 @@ async def main_async(http_port: int, grpc_port: int): app = FastAPI() app.add_middleware(CustomLoggingMiddleware) - jsonrpc_app = A2AFastAPIApplication( - http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True - ).build() - app.mount('/jsonrpc', jsonrpc_app) + agent_card_routes = AgentCardRoutes( + agent_card=agent_card, card_url='/.well-known/agent-card.json' + ) + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=handler, + extended_agent_card=agent_card, + rpc_url='/', + enable_v0_3_compat=True, + ) + app.mount( + '/jsonrpc', + FastAPI(routes=jsonrpc_routes.routes + agent_card_routes.routes), + ) app.mount( '/rest', diff --git a/tests/integration/test_agent_card.py b/tests/integration/test_agent_card.py index eb7c03f4c..42aca3843 100644 --- a/tests/integration/test_agent_card.py +++ b/tests/integration/test_agent_card.py @@ -4,7 +4,9 @@ from fastapi import FastAPI from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.apps import A2ARESTFastAPIApplication +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler @@ -70,10 +72,15 @@ async def test_agent_card_integration(header_val: str | None) -> None: app = FastAPI() # Mount JSONRPC application - # In JSONRPCApplication, the default agent_card_url is AGENT_CARD_WELL_KNOWN_PATH - jsonrpc_app = A2AFastAPIApplication( - http_handler=handler, agent_card=agent_card - ).build() + jsonrpc_routes = [ + *AgentCardRoutes( + agent_card=agent_card, card_url='/.well-known/agent-card.json' + ).routes, + *JsonRpcRoutes( + agent_card=agent_card, request_handler=handler, rpc_url='/' + ).routes, + ] + jsonrpc_app = Starlette(routes=jsonrpc_routes) app.mount('/jsonrpc', jsonrpc_app) # Mount REST application diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index e239d780f..f6f1b4182 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -23,7 +23,9 @@ with_a2a_extensions, ) from a2a.client.transports import JsonRpcTransport, RestTransport -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.apps import A2ARESTFastAPIApplication +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( @@ -220,10 +222,14 @@ def http_base_setup(mock_request_handler: AsyncMock, agent_card: AgentCard): def jsonrpc_setup(http_base_setup) -> TransportSetup: """Sets up the JsonRpcTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - app_builder = A2AFastAPIApplication( - agent_card, mock_request_handler, extended_agent_card=agent_card + agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, + extended_agent_card=agent_card, + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) factory = ClientFactory( config=ClientConfig( @@ -619,12 +625,16 @@ async def test_json_transport_get_signed_base_card( }, ) - app_builder = A2AFastAPIApplication( - agent_card, - mock_request_handler, - card_modifier=signer, # Sign the base card + agent_card_routes = AgentCardRoutes( + agent_card=agent_card, card_url='/', card_modifier=signer ) - app = app_builder.build() + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, + extended_agent_card=agent_card, + rpc_url='/', + ) + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) agent_url = agent_card.supported_interfaces[0].url @@ -639,7 +649,8 @@ async def test_json_transport_get_signed_base_card( # Verification happens here result = await resolver.get_agent_card( - signature_verifier=signature_verifier + relative_card_path='/', + signature_verifier=signature_verifier, ) # Create transport with the verified card @@ -684,15 +695,15 @@ async def test_client_get_signed_extended_card( }, ) - app_builder = A2AFastAPIApplication( - agent_card, - mock_request_handler, + agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, extended_agent_card=extended_agent_card, - extended_card_modifier=lambda card, ctx: signer( - card - ), # Sign the extended card + extended_card_modifier=lambda card, ctx: signer(card), + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) transport = JsonRpcTransport( @@ -753,16 +764,17 @@ async def test_client_get_signed_base_and_extended_cards( }, ) - app_builder = A2AFastAPIApplication( - agent_card, - mock_request_handler, + agent_card_routes = AgentCardRoutes( + agent_card=agent_card, card_url='/', card_modifier=signer + ) + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_request_handler, extended_agent_card=extended_agent_card, - card_modifier=signer, # Sign the base card - extended_card_modifier=lambda card, ctx: signer( - card - ), # Sign the extended card + extended_card_modifier=lambda card, ctx: signer(card), + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) agent_url = agent_card.supported_interfaces[0].url @@ -777,7 +789,8 @@ async def test_client_get_signed_base_and_extended_cards( # 1. Fetch base card base_card = await resolver.get_agent_card( - signature_verifier=signature_verifier + relative_card_path='/', + signature_verifier=signature_verifier, ) # 2. Create transport with base card diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index ddf9edbf3..f75e8c9da 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -10,7 +10,9 @@ from a2a.client.client import ClientConfig from a2a.client.client_factory import ClientFactory from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.apps import A2ARESTFastAPIApplication +from a2a.server.routes import JsonRpcRoutes, AgentCardRoutes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler @@ -192,10 +194,14 @@ def rest_setup(agent_card, base_e2e_setup) -> ClientSetup: @pytest.fixture def jsonrpc_setup(agent_card, base_e2e_setup) -> ClientSetup: task_store, handler = base_e2e_setup - app_builder = A2AFastAPIApplication( - agent_card, handler, extended_agent_card=agent_card + agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=handler, + extended_agent_card=agent_card, + rpc_url='/', ) - app = app_builder.build() + app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) httpx_client = httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url='http://testserver' ) diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py index 903b90a29..21698b4f4 100644 --- a/tests/integration/test_tenant.py +++ b/tests/integration/test_tenant.py @@ -19,7 +19,8 @@ from a2a.client import ClientConfig, ClientFactory from a2a.utils.constants import TransportProtocol -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes +from starlette.applications import Starlette from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.context import ServerCallContext @@ -197,10 +198,18 @@ def jsonrpc_agent_card(self): @pytest.fixture def server_app(self, jsonrpc_agent_card, mock_handler): - app = A2AStarletteApplication( + agent_card_routes = AgentCardRoutes( + agent_card=jsonrpc_agent_card, card_url='/' + ) + jsonrpc_routes = JsonRpcRoutes( agent_card=jsonrpc_agent_card, - http_handler=mock_handler, - ).build(rpc_url='/jsonrpc') + request_handler=mock_handler, + extended_agent_card=jsonrpc_agent_card, + rpc_url='/jsonrpc', + ) + app = Starlette( + routes=[*agent_card_routes.routes, *jsonrpc_routes.routes] + ) return app @pytest.mark.asyncio diff --git a/tests/integration/test_version_header.py b/tests/integration/test_version_header.py index 40aa91446..754b14168 100644 --- a/tests/integration/test_version_header.py +++ b/tests/integration/test_version_header.py @@ -4,7 +4,8 @@ from starlette.testclient import TestClient from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from a2a.server.apps import A2ARESTFastAPIApplication +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler @@ -56,10 +57,16 @@ async def mock_on_message_send_stream(*args, **kwargs): handler.on_message_send_stream = mock_on_message_send_stream app = FastAPI() - jsonrpc_app = A2AFastAPIApplication( - http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True - ).build() - app.mount('/jsonrpc', jsonrpc_app) + agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=handler, + extended_agent_card=agent_card, + rpc_url='/jsonrpc', + enable_v0_3_compat=True, + ) + app.routes.extend(agent_card_routes.routes) + app.routes.extend(jsonrpc_routes.routes) rest_app = A2ARESTFastAPIApplication( http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True ).build() diff --git a/tests/server/apps/jsonrpc/test_fastapi_app.py b/tests/server/apps/jsonrpc/test_fastapi_app.py deleted file mode 100644 index 11831df57..000000000 --- a/tests/server/apps/jsonrpc/test_fastapi_app.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from a2a.server.apps.jsonrpc import fastapi_app -from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication -from a2a.server.request_handlers.request_handler import ( - RequestHandler, # For mock spec -) -from a2a.types.a2a_pb2 import AgentCard # For mock spec - - -# --- A2AFastAPIApplication Tests --- - - -class TestA2AFastAPIApplicationOptionalDeps: - # Running tests in this class requires the optional dependency fastapi to be - # present in the test environment. - - @pytest.fixture(scope='class', autouse=True) - def ensure_pkg_fastapi_is_present(self): - try: - import fastapi as _fastapi # noqa: F401 - except ImportError: - pytest.fail( - f'Running tests in {self.__class__.__name__} requires' - ' the optional dependency fastapi to be present in the test' - ' environment. Run `uv sync --dev ...` before running the test' - ' suite.' - ) - - @pytest.fixture(scope='class') - def mock_app_params(self) -> dict: - # Mock http_handler - mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in __init__ - mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed - # in __init__ - mock_agent_card.url = 'http://example.com' - # Ensure 'capabilities.extended_agent_card' attribute exists - return {'agent_card': mock_agent_card, 'http_handler': mock_handler} - - @pytest.fixture(scope='class') - def mark_pkg_fastapi_not_installed(self): - pkg_fastapi_installed_flag = fastapi_app._package_fastapi_installed - fastapi_app._package_fastapi_installed = False - yield - fastapi_app._package_fastapi_installed = pkg_fastapi_installed_flag - - def test_create_a2a_fastapi_app_with_present_deps_succeeds( - self, mock_app_params: dict - ): - try: - _app = A2AFastAPIApplication(**mock_app_params) - except ImportError: - pytest.fail( - 'With the fastapi package present, creating a' - ' A2AFastAPIApplication instance should not raise ImportError' - ) - - def test_create_a2a_fastapi_app_with_missing_deps_raises_importerror( - self, - mock_app_params: dict, - mark_pkg_fastapi_not_installed: Any, - ): - with pytest.raises( - ImportError, - match=( - 'The `fastapi` package is required to use the' - ' `A2AFastAPIApplication`' - ), - ): - _app = A2AFastAPIApplication(**mock_app_params) - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/server/apps/jsonrpc/test_serialization.py b/tests/server/apps/jsonrpc/test_serialization.py deleted file mode 100644 index 825f8e2a1..000000000 --- a/tests/server/apps/jsonrpc/test_serialization.py +++ /dev/null @@ -1,280 +0,0 @@ -"""Tests for JSON-RPC serialization behavior.""" - -from unittest import mock - -import pytest -from starlette.testclient import TestClient - -from a2a.server.apps import A2AFastAPIApplication, A2AStarletteApplication -from a2a.server.jsonrpc_models import JSONParseError -from a2a.types import ( - InvalidRequestError, -) -from a2a.types.a2a_pb2 import ( - AgentCapabilities, - AgentInterface, - AgentCard, - AgentSkill, - APIKeySecurityScheme, - Message, - Part, - Role, - SecurityRequirement, - SecurityScheme, -) - - -@pytest.fixture -def minimal_agent_card(): - """Provides a minimal AgentCard for testing.""" - return AgentCard( - name='TestAgent', - description='A test agent.', - supported_interfaces=[ - AgentInterface( - url='http://example.com/agent', protocol_binding='HTTP+JSON' - ) - ], - version='1.0.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['text/plain'], - skills=[ - AgentSkill( - id='skill-1', - name='Test Skill', - description='A test skill', - tags=['test'], - ) - ], - ) - - -@pytest.fixture -def agent_card_with_api_key(): - """Provides an AgentCard with an APIKeySecurityScheme for testing serialization.""" - api_key_scheme = APIKeySecurityScheme( - name='X-API-KEY', - location='header', - ) - - security_scheme = SecurityScheme(api_key_security_scheme=api_key_scheme) - - card = AgentCard( - name='APIKeyAgent', - description='An agent that uses API Key auth.', - supported_interfaces=[ - AgentInterface( - url='http://example.com/apikey-agent', - protocol_binding='HTTP+JSON', - ) - ], - version='1.0.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['text/plain'], - ) - # Add security scheme to the map - card.security_schemes['api_key_auth'].CopyFrom(security_scheme) - - return card - - -def test_starlette_agent_card_serialization(minimal_agent_card: AgentCard): - """Tests that the A2AStarletteApplication endpoint correctly serializes agent card.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - response = client.get('/.well-known/agent-card.json') - assert response.status_code == 200 - response_data = response.json() - - assert response_data['name'] == 'TestAgent' - assert response_data['description'] == 'A test agent.' - assert ( - response_data['supportedInterfaces'][0]['url'] - == 'http://example.com/agent' - ) - assert response_data['version'] == '1.0.0' - - -def test_starlette_agent_card_with_api_key_scheme( - agent_card_with_api_key: AgentCard, -): - """Tests that the A2AStarletteApplication endpoint correctly serializes API key schemes.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(agent_card_with_api_key, handler) - client = TestClient(app_instance.build()) - - response = client.get('/.well-known/agent-card.json') - assert response.status_code == 200 - response_data = response.json() - - # Check security schemes are serialized - assert 'securitySchemes' in response_data - assert 'api_key_auth' in response_data['securitySchemes'] - - -def test_fastapi_agent_card_serialization(minimal_agent_card: AgentCard): - """Tests that the A2AFastAPIApplication endpoint correctly serializes agent card.""" - handler = mock.AsyncMock() - app_instance = A2AFastAPIApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - response = client.get('/.well-known/agent-card.json') - assert response.status_code == 200 - response_data = response.json() - - assert response_data['name'] == 'TestAgent' - assert response_data['description'] == 'A test agent.' - - -def test_handle_invalid_json(minimal_agent_card: AgentCard): - """Test handling of malformed JSON.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - response = client.post( - '/', - content='{ "jsonrpc": "2.0", "method": "test", "id": 1, "params": { "key": "value" }', - ) - assert response.status_code == 200 - data = response.json() - assert data['error']['code'] == JSONParseError().code - - -def test_handle_oversized_payload(minimal_agent_card: AgentCard): - """Test handling of oversized JSON payloads.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - large_string = 'a' * 11 * 1_000_000 # 11MB string - payload = { - 'jsonrpc': '2.0', - 'method': 'test', - 'id': 1, - 'params': {'data': large_string}, - } - - response = client.post('/', json=payload) - assert response.status_code == 200 - data = response.json() - assert data['error']['code'] == -32600 - - -@pytest.mark.parametrize( - 'max_content_length', - [ - None, - 11 * 1024 * 1024, - 30 * 1024 * 1024, - ], -) -def test_handle_oversized_payload_with_max_content_length( - minimal_agent_card: AgentCard, - max_content_length: int | None, -): - """Test handling of JSON payloads with sizes within custom max_content_length.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication( - minimal_agent_card, handler, max_content_length=max_content_length - ) - client = TestClient(app_instance.build()) - - large_string = 'a' * 11 * 1_000_000 # 11MB string - payload = { - 'jsonrpc': '2.0', - 'method': 'test', - 'id': 1, - 'params': {'data': large_string}, - } - - response = client.post('/', json=payload) - assert response.status_code == 200 - data = response.json() - # When max_content_length is set, requests up to that size should not be - # rejected due to payload size. The request might fail for other reasons, - # but it shouldn't be an InvalidRequestError related to the content length. - if max_content_length is not None: - assert data['error']['code'] != -32600 - - -def test_handle_unicode_characters(minimal_agent_card: AgentCard): - """Test handling of unicode characters in JSON payload.""" - handler = mock.AsyncMock() - app_instance = A2AStarletteApplication(minimal_agent_card, handler) - client = TestClient(app_instance.build()) - - unicode_text = 'こんにちは世界' # "Hello world" in Japanese - - # Mock a handler response - handler.on_message_send.return_value = Message( - role=Role.ROLE_AGENT, - parts=[Part(text=f'Received: {unicode_text}')], - message_id='response-unicode', - ) - - unicode_payload = { - 'jsonrpc': '2.0', - 'method': 'SendMessage', - 'id': 'unicode_test', - 'params': { - 'message': { - 'role': 'ROLE_USER', - 'parts': [{'text': unicode_text}], - 'messageId': 'msg-unicode', - } - }, - } - - response = client.post('/', json=unicode_payload) - - # We are testing that the server can correctly deserialize the unicode payload - assert response.status_code == 200 - data = response.json() - # Check that we got a result (handler was called) - if 'result' in data: - # Response should contain the unicode text - result = data['result'] - if 'message' in result: - assert ( - result['message']['parts'][0]['text'] - == f'Received: {unicode_text}' - ) - elif 'parts' in result: - assert result['parts'][0]['text'] == f'Received: {unicode_text}' - - -def test_fastapi_sub_application(minimal_agent_card: AgentCard): - """ - Tests that the A2AFastAPIApplication endpoint correctly passes the url in sub-application. - """ - from fastapi import FastAPI - - handler = mock.AsyncMock() - sub_app_instance = A2AFastAPIApplication(minimal_agent_card, handler) - app_instance = FastAPI() - app_instance.mount('/a2a', sub_app_instance.build()) - client = TestClient(app_instance) - - response = client.get('/a2a/openapi.json') - assert response.status_code == 200 - response_data = response.json() - - # The generated a2a.json (OpenAPI 2.0 / Swagger) does not typically include a 'servers' block - # unless specifically configured or converted to OpenAPI 3.0. - # FastAPI usually generates OpenAPI 3.0 schemas which have 'servers'. - # When we inject the raw Swagger 2.0 schema, it won't have 'servers'. - # We check if it is indeed the injected schema by checking for 'swagger': '2.0' - # or by checking for 'basePath' if we want to test path correctness. - - if response_data.get('swagger') == '2.0': - # It's the injected Swagger 2.0 schema - pass - else: - # It's an auto-generated OpenAPI 3.0+ schema (fallback or otherwise) - assert 'servers' in response_data - assert response_data['servers'] == [{'url': '/a2a'}] diff --git a/tests/server/apps/jsonrpc/test_starlette_app.py b/tests/server/apps/jsonrpc/test_starlette_app.py deleted file mode 100644 index fa6868712..000000000 --- a/tests/server/apps/jsonrpc/test_starlette_app.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from a2a.server.apps.jsonrpc import starlette_app -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication -from a2a.server.request_handlers.request_handler import ( - RequestHandler, # For mock spec -) -from a2a.types.a2a_pb2 import AgentCard # For mock spec - - -# --- A2AStarletteApplication Tests --- - - -class TestA2AStarletteApplicationOptionalDeps: - # Running tests in this class requires optional dependencies starlette and - # sse-starlette to be present in the test environment. - - @pytest.fixture(scope='class', autouse=True) - def ensure_pkg_starlette_is_present(self): - try: - import sse_starlette as _sse_starlette # noqa: F401 - import starlette as _starlette # noqa: F401 - except ImportError: - pytest.fail( - f'Running tests in {self.__class__.__name__} requires' - ' optional dependencies starlette and sse-starlette to be' - ' present in the test environment. Run `uv sync --dev ...`' - ' before running the test suite.' - ) - - @pytest.fixture(scope='class') - def mock_app_params(self) -> dict: - # Mock http_handler - mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in __init__ - mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed - # in __init__ - mock_agent_card.url = 'http://example.com' - # Ensure 'capabilities.extended_agent_card' attribute exists - return {'agent_card': mock_agent_card, 'http_handler': mock_handler} - - @pytest.fixture(scope='class') - def mark_pkg_starlette_not_installed(self): - pkg_starlette_installed_flag = ( - starlette_app._package_starlette_installed - ) - starlette_app._package_starlette_installed = False - yield - starlette_app._package_starlette_installed = ( - pkg_starlette_installed_flag - ) - - def test_create_a2a_starlette_app_with_present_deps_succeeds( - self, mock_app_params: dict - ): - try: - _app = A2AStarletteApplication(**mock_app_params) - except ImportError: - pytest.fail( - 'With packages starlette and see-starlette present, creating an' - ' A2AStarletteApplication instance should not raise ImportError' - ) - - def test_create_a2a_starlette_app_with_missing_deps_raises_importerror( - self, - mock_app_params: dict, - mark_pkg_starlette_not_installed: Any, - ): - with pytest.raises( - ImportError, - match='Packages `starlette` and `sse-starlette` are required', - ): - _app = A2AStarletteApplication(**mock_app_params) - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/server/routes/test_agent_card_routes.py b/tests/server/routes/test_agent_card_routes.py new file mode 100644 index 000000000..8f86ec937 --- /dev/null +++ b/tests/server/routes/test_agent_card_routes.py @@ -0,0 +1,105 @@ +# ruff: noqa: INP001 +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from starlette.testclient import TestClient +from starlette.middleware import Middleware + +from a2a.server.routes.agent_card_routes import AgentCardRoutes +from a2a.types.a2a_pb2 import AgentCard + + +@pytest.fixture +def agent_card(): + return AgentCard() + + +def test_get_agent_card_success(agent_card): + """Tests that the agent card route returns the card correctly.""" + routes = AgentCardRoutes(agent_card=agent_card).routes + + from starlette.applications import Starlette + + app = Starlette(routes=routes) + client = TestClient(app) + + response = client.get('/.well-known/agent-card.json') + assert response.status_code == 200 + assert response.headers['content-type'] == 'application/json' + assert response.json() == {} # Empty card serializes to empty dict/json + + +def test_get_agent_card_with_modifier(agent_card): + """Tests that card_modifier is called and modifies the response.""" + + # To test modification, let's assume we can mock the dict conversion or just see if the modifier runs. + # Actually card_modifier receives AgentCard and returns AgentCard. + async def modifier(card: AgentCard) -> AgentCard: + # Clone or modify + modified = AgentCard() + # Set some field if possible, or just return a different instance to verify. + # Since Protobuf objects have fields, let's look at one we can set. + # Usually they have fields like 'url' in v0.3 or others. + # Let's just return a MagicMock or set Something that shows up in dict if we know it. + # Wait, if we return a different object, we can verify it. + # Let's try to mock the conversion or just verify it was called. + return card + + mock_modifier = AsyncMock(side_effect=modifier) + routes = AgentCardRoutes( + agent_card=agent_card, card_modifier=mock_modifier + ).routes + + from starlette.applications import Starlette + + app = Starlette(routes=routes) + client = TestClient(app) + + response = client.get('/.well-known/agent-card.json') + assert response.status_code == 200 + assert mock_modifier.called + + +def test_agent_card_custom_url(agent_card): + """Tests that custom card_url is respected.""" + custom_url = '/custom/path/agent.json' + routes = AgentCardRoutes(agent_card=agent_card, card_url=custom_url).routes + + from starlette.applications import Starlette + + app = Starlette(routes=routes) + client = TestClient(app) + + # Check that default returns 404 + assert client.get('/.well-known/agent-card.json').status_code == 404 + # Check that custom returns 200 + assert client.get(custom_url).status_code == 200 + + +def test_agent_card_with_middleware(agent_card): + """Tests that middleware is applied to the routes.""" + middleware_called = False + + class MyMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + nonlocal middleware_called + middleware_called = True + await self.app(scope, receive, send) + + routes = AgentCardRoutes( + agent_card=agent_card, middleware=[Middleware(MyMiddleware)] + ).routes + + from starlette.applications import Starlette + + app = Starlette(routes=routes) + client = TestClient(app) + + response = client.get('/.well-known/agent-card.json') + assert response.status_code == 200 + assert middleware_called is True diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/routes/test_jsonrpc_dispatcher.py similarity index 51% rename from tests/server/apps/jsonrpc/test_jsonrpc_app.py rename to tests/server/routes/test_jsonrpc_dispatcher.py index be54958b0..7241cac4b 100644 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -1,38 +1,37 @@ # ruff: noqa: INP001 +import json from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest - from starlette.responses import JSONResponse from starlette.testclient import TestClient - -# Attempt to import StarletteBaseUser, fallback to MagicMock if not available try: from starlette.authentication import BaseUser as StarletteBaseUser except ImportError: StarletteBaseUser = MagicMock() # type: ignore from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.server.apps.jsonrpc import ( - jsonrpc_app, # Keep this import for optional deps test -) -from a2a.server.apps.jsonrpc.jsonrpc_app import ( - JSONRPCApplication, - StarletteUserProxy, -) -from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import ( - RequestHandler, -) # For mock spec +from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( AgentCard, Message, Part, Role, ) +from a2a.server.routes import jsonrpc_dispatcher +from a2a.server.routes.jsonrpc_dispatcher import ( + CallContextBuilder, + DefaultCallContextBuilder, + JsonRpcDispatcher, + StarletteUserProxy, +) +from a2a.server.routes.jsonrpc_routes import JsonRpcRoutes +from a2a.server.routes.agent_card_routes import AgentCardRoutes +from a2a.server.jsonrpc_models import JSONRPCError +from a2a.utils.errors import A2AError # --- StarletteUserProxy Tests --- @@ -58,12 +57,7 @@ def test_starlette_user_proxy_user_name(self): assert proxy.user_name == 'Test User DisplayName' def test_starlette_user_proxy_user_name_raises_attribute_error(self): - """ - Tests that if the underlying starlette user object is missing the - display_name attribute, the proxy currently raises an AttributeError. - """ starlette_user_mock = MagicMock(spec=StarletteBaseUser) - # Ensure display_name is not present on the mock to trigger AttributeError del starlette_user_mock.display_name proxy = StarletteUserProxy(starlette_user_mock) @@ -71,13 +65,12 @@ def test_starlette_user_proxy_user_name_raises_attribute_error(self): _ = proxy.user_name -# --- JSONRPCApplication Tests (Selected) --- +# --- JsonRpcDispatcher Tests --- @pytest.fixture def mock_handler(): handler = AsyncMock(spec=RequestHandler) - # Return a proto Message object directly - the handler wraps it in SendMessageResponse handler.on_message_send.return_value = Message( message_id='test', role=Role.ROLE_AGENT, @@ -90,23 +83,26 @@ def mock_handler(): def test_app(mock_handler): mock_agent_card = MagicMock(spec=AgentCard) mock_agent_card.url = 'http://mockurl.com' - # Set up capabilities.streaming to avoid validation issues mock_agent_card.capabilities = MagicMock() mock_agent_card.capabilities.streaming = False - return A2AStarletteApplication( - agent_card=mock_agent_card, http_handler=mock_handler + + jsonrpc_routes = JsonRpcRoutes( + agent_card=mock_agent_card, request_handler=mock_handler, rpc_url='/' ) + from starlette.applications import Starlette + + return Starlette(routes=jsonrpc_routes.routes) + @pytest.fixture def client(test_app): - return TestClient(test_app.build(), headers={'A2A-Version': '1.0'}) + return TestClient(test_app, headers={'A2A-Version': '1.0'}) def _make_send_message_request( text: str = 'hi', tenant: str | None = None ) -> dict: - """Helper to create a JSON-RPC send message request.""" params: dict[str, Any] = { 'message': { 'messageId': '1', @@ -125,113 +121,39 @@ def _make_send_message_request( } -class TestJSONRPCApplicationSetup: # Renamed to avoid conflict - def test_jsonrpc_app_build_method_abstract_raises_typeerror( - self, - ): # Renamed test - mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in JSONRPCApplication.__init__ - mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed in __init__ - mock_agent_card.url = 'http://mockurl.com' - # Ensure 'supportsAuthenticatedExtendedCard' attribute exists - - # This will fail at definition time if an abstract method is not implemented - with pytest.raises( - TypeError, - match=r".*abstract class IncompleteJSONRPCApp .* abstract method '?build'?", - ): - - class IncompleteJSONRPCApp(JSONRPCApplication): - # Intentionally not implementing 'build' - def some_other_method(self): - pass - - IncompleteJSONRPCApp( - agent_card=mock_agent_card, http_handler=mock_handler - ) # type: ignore[abstract] - - -class TestJSONRPCApplicationOptionalDeps: - # Running tests in this class requires optional dependencies starlette and - # sse-starlette to be present in the test environment. - - @pytest.fixture(scope='class', autouse=True) - def ensure_pkg_starlette_is_present(self): - try: - import sse_starlette as _sse_starlette # noqa: F401, PLC0415 - import starlette as _starlette # noqa: F401, PLC0415 - except ImportError: - pytest.fail( - f'Running tests in {self.__class__.__name__} requires' - ' optional dependencies starlette and sse-starlette to be' - ' present in the test environment. Run `uv sync --dev ...`' - ' before running the test suite.' - ) - +class TestJsonRpcDispatcherOptionalDependencies: @pytest.fixture(scope='class') def mock_app_params(self) -> dict: - # Mock http_handler mock_handler = MagicMock(spec=RequestHandler) - # Mock agent_card with essential attributes accessed in __init__ mock_agent_card = MagicMock(spec=AgentCard) - # Ensure 'url' attribute exists on the mock_agent_card, as it's accessed - # in __init__ mock_agent_card.url = 'http://example.com' - # Ensure 'supportsAuthenticatedExtendedCard' attribute exists return {'agent_card': mock_agent_card, 'http_handler': mock_handler} @pytest.fixture(scope='class') def mark_pkg_starlette_not_installed(self): - pkg_starlette_installed_flag = jsonrpc_app._package_starlette_installed - jsonrpc_app._package_starlette_installed = False + pkg_starlette_installed_flag = ( + jsonrpc_dispatcher._package_starlette_installed + ) + jsonrpc_dispatcher._package_starlette_installed = False yield - jsonrpc_app._package_starlette_installed = pkg_starlette_installed_flag - - def test_create_jsonrpc_based_app_with_present_deps_succeeds( - self, mock_app_params: dict - ): - class MockJSONRPCApp(JSONRPCApplication): - def build( # type: ignore[override] - self, - agent_card_url='/.well-known/agent.json', - rpc_url='/', - **kwargs, - ): - return object() # type: ignore[return-value] - - try: - _app = MockJSONRPCApp(**mock_app_params) - except ImportError: - pytest.fail( - 'With packages starlette and see-starlette present, creating a' - ' JSONRPCApplication-based instance should not raise' - ' ImportError' - ) + jsonrpc_dispatcher._package_starlette_installed = ( + pkg_starlette_installed_flag + ) - def test_create_jsonrpc_based_app_with_missing_deps_raises_importerror( + def test_create_dispatcher_with_missing_deps_raises_importerror( self, mock_app_params: dict, mark_pkg_starlette_not_installed: Any ): - class MockJSONRPCApp(JSONRPCApplication): - def build( # type: ignore[override] - self, - agent_card_url='/.well-known/agent.json', - rpc_url='/', - **kwargs, - ): - return object() # type: ignore[return-value] - with pytest.raises( ImportError, match=( 'Packages `starlette` and `sse-starlette` are required to use' - ' the `JSONRPCApplication`' + ' the `JsonRpcDispatcher`' ), ): - _app = MockJSONRPCApp(**mock_app_params) + JsonRpcDispatcher(**mock_app_params) -class TestJSONRPCApplicationExtensions: +class TestJsonRpcDispatcherExtensions: def test_request_with_single_extension(self, client, mock_handler): headers = {HTTP_EXTENSION_HEADER: 'foo'} response = client.post( @@ -261,24 +183,6 @@ def test_request_with_comma_separated_extensions( call_context = mock_handler.on_message_send.call_args[0][1] assert call_context.requested_extensions == {'foo', 'bar'} - def test_request_with_comma_separated_extensions_no_space( - self, client, mock_handler - ): - headers = [ - (HTTP_EXTENSION_HEADER, 'foo, bar'), - (HTTP_EXTENSION_HEADER, 'baz'), - ] - response = client.post( - '/', - headers=headers, - json=_make_send_message_request(), - ) - response.raise_for_status() - - mock_handler.on_message_send.assert_called_once() - call_context = mock_handler.on_message_send.call_args[0][1] - assert call_context.requested_extensions == {'foo', 'bar', 'baz'} - def test_method_added_to_call_context_state(self, client, mock_handler): response = client.post( '/', @@ -290,29 +194,10 @@ def test_method_added_to_call_context_state(self, client, mock_handler): call_context = mock_handler.on_message_send.call_args[0][1] assert call_context.state['method'] == 'SendMessage' - def test_request_with_multiple_extension_headers( - self, client, mock_handler - ): - headers = [ - (HTTP_EXTENSION_HEADER, 'foo'), - (HTTP_EXTENSION_HEADER, 'bar'), - ] - response = client.post( - '/', - headers=headers, - json=_make_send_message_request(), - ) - response.raise_for_status() - - mock_handler.on_message_send.assert_called_once() - call_context = mock_handler.on_message_send.call_args[0][1] - assert call_context.requested_extensions == {'foo', 'bar'} - def test_response_with_activated_extensions(self, client, mock_handler): def side_effect(request, context: ServerCallContext): context.activated_extensions.add('foo') context.activated_extensions.add('baz') - # Return a proto Message object directly return Message( message_id='test', role=Role.ROLE_AGENT, @@ -335,7 +220,7 @@ def side_effect(request, context: ServerCallContext): } -class TestJSONRPCApplicationTenant: +class TestJsonRpcDispatcherTenant: def test_tenant_extraction_from_params(self, client, mock_handler): tenant_id = 'my-tenant-123' response = client.post( @@ -362,20 +247,23 @@ def test_no_tenant_extraction(self, client, mock_handler): assert call_context.tenant == '' -class TestJSONRPCApplicationV03Compat: +class TestJsonRpcDispatcherV03Compat: def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): mock_agent_card = MagicMock(spec=AgentCard) mock_agent_card.url = 'http://mockurl.com' mock_agent_card.capabilities = MagicMock() mock_agent_card.capabilities.streaming = False - app = A2AStarletteApplication( + from starlette.applications import Starlette + + jsonrpc_routes = JsonRpcRoutes( agent_card=mock_agent_card, - http_handler=mock_handler, + request_handler=mock_handler, enable_v0_3_compat=True, + rpc_url='/', ) - - client = TestClient(app.build()) + app = Starlette(routes=jsonrpc_routes.routes) + client = TestClient(app) request_data = { 'jsonrpc': '2.0', @@ -390,8 +278,11 @@ def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): }, } + dispatcher_instance = jsonrpc_routes.dispatcher with patch.object( - app._v03_adapter, 'handle_request', new_callable=AsyncMock + dispatcher_instance._v03_adapter, + 'handle_request', + new_callable=AsyncMock, ) as mock_handle: mock_handle.return_value = JSONResponse( {'jsonrpc': '2.0', 'id': '1', 'result': {}} @@ -403,42 +294,6 @@ def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): assert mock_handle.called assert mock_handle.call_args[1]['method'] == 'message/send' - def test_v0_3_compat_flag_disabled_rejects_v0_3_method(self, mock_handler): - mock_agent_card = MagicMock(spec=AgentCard) - mock_agent_card.url = 'http://mockurl.com' - mock_agent_card.capabilities = MagicMock() - mock_agent_card.capabilities.streaming = False - - app = A2AStarletteApplication( - agent_card=mock_agent_card, - http_handler=mock_handler, - enable_v0_3_compat=False, - ) - - client = TestClient(app.build()) - - request_data = { - 'jsonrpc': '2.0', - 'id': '1', - 'method': 'message/send', - 'params': { - 'message': { - 'messageId': 'msg-1', - 'role': 'ROLE_USER', - 'parts': [{'text': 'Hello'}], - } - }, - } - - response = client.post('/', json=request_data) - - assert response.status_code == 200 - # Should return MethodNotFoundError because the v0.3 method is not recognized - # without the adapter enabled. - resp_json = response.json() - assert 'error' in resp_json - assert resp_json['error']['code'] == -32601 - if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/server/routes/test_jsonrpc_routes.py b/tests/server/routes/test_jsonrpc_routes.py new file mode 100644 index 000000000..b4cd2f2bb --- /dev/null +++ b/tests/server/routes/test_jsonrpc_routes.py @@ -0,0 +1,96 @@ +# ruff: noqa: INP001 +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from starlette.testclient import TestClient +from starlette.middleware import Middleware + +from a2a.server.routes.jsonrpc_routes import JsonRpcRoutes +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types.a2a_pb2 import AgentCard + + +@pytest.fixture +def agent_card(): + return AgentCard() + + +@pytest.fixture +def mock_handler(): + return AsyncMock(spec=RequestHandler) + + +def test_routes_creation(agent_card, mock_handler): + """Tests that JsonRpcRoutes creates Route objects list.""" + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, request_handler=mock_handler + ) + + assert hasattr(jsonrpc_routes, 'routes') + assert isinstance(jsonrpc_routes.routes, list) + assert len(jsonrpc_routes.routes) == 1 + + from starlette.routing import Route + + assert isinstance(jsonrpc_routes.routes[0], Route) + assert jsonrpc_routes.routes[0].methods == {'POST'} + + +def test_jsonrpc_custom_url(agent_card, mock_handler): + """Tests that custom rpc_url is respected for routing.""" + custom_url = '/custom/api/jsonrpc' + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, request_handler=mock_handler, rpc_url=custom_url + ) + + from starlette.applications import Starlette + + app = Starlette(routes=jsonrpc_routes.routes) + client = TestClient(app) + + # Check that default path returns 404 + assert client.post('/a2a/jsonrpc', json={}).status_code == 404 + + # Check that custom path routes to dispatcher (which will return JSON-RPC response, even if error) + response = client.post( + custom_url, json={'jsonrpc': '2.0', 'id': '1', 'method': 'foo'} + ) + assert response.status_code == 200 + resp_json = response.json() + assert 'error' in resp_json + # Method not found error from dispatcher + assert resp_json['error']['code'] == -32601 + + +def test_jsonrpc_with_middleware(agent_card, mock_handler): + """Tests that middleware is applied to the route.""" + middleware_called = False + + class MyMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + nonlocal middleware_called + middleware_called = True + await self.app(scope, receive, send) + + jsonrpc_routes = JsonRpcRoutes( + agent_card=agent_card, + request_handler=mock_handler, + middleware=[Middleware(MyMiddleware)], + rpc_url='/', + ) + + from starlette.applications import Starlette + + app = Starlette(routes=jsonrpc_routes.routes) + client = TestClient(app) + + # Call to trigger middleware + # Empty JSON might raise error, let's send a base valid format for dispatcher + client.post( + '/', json={'jsonrpc': '2.0', 'id': '1', 'method': 'SendMessage'} + ) + assert middleware_called is True diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 525c8e127..0cc5524d2 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -18,10 +18,8 @@ from starlette.routing import Route from starlette.testclient import TestClient -from a2a.server.apps import ( - A2AFastAPIApplication, - A2AStarletteApplication, -) +from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes + from a2a.server.context import ServerCallContext from a2a.server.jsonrpc_models import ( InternalError, @@ -148,14 +146,48 @@ def handler(): return handler +class AppBuilder: + def __init__(self, agent_card, handler, card_modifier=None): + self.agent_card = agent_card + self.handler = handler + self.card_modifier = card_modifier + + def build( + self, + rpc_url='/', + agent_card_url=AGENT_CARD_WELL_KNOWN_PATH, + middleware=None, + routes=None, + ): + from starlette.applications import Starlette + + app_instance = Starlette(middleware=middleware, routes=routes or []) + + # Agent card router + card_routes = AgentCardRoutes( + self.agent_card, + card_url=agent_card_url, + card_modifier=self.card_modifier, + ) + app_instance.routes.extend(card_routes.routes) + + # JSON-RPC router + rpc_routes = JsonRpcRoutes( + self.agent_card, self.handler, rpc_url=rpc_url + ) + app_instance.routes.extend(rpc_routes.routes) + + return app_instance + + @pytest.fixture def app(agent_card: AgentCard, handler: mock.AsyncMock): - return A2AStarletteApplication(agent_card, handler) + return AppBuilder(agent_card, handler) @pytest.fixture -def client(app: A2AStarletteApplication, **kwargs): - """Create a test client with the Starlette app.""" +def client(app, **kwargs): + """Create a test client with the app builder.""" return TestClient(app.build(**kwargs), headers={'A2A-Version': '1.0'}) @@ -172,9 +204,7 @@ def test_agent_card_endpoint(client: TestClient, agent_card: AgentCard): assert 'streaming' in data['capabilities'] -def test_agent_card_custom_url( - app: A2AStarletteApplication, agent_card: AgentCard -): +def test_agent_card_custom_url(app, agent_card: AgentCard): """Test the agent card endpoint with a custom URL.""" client = TestClient(app.build(agent_card_url='/my-agent')) response = client.get('/my-agent') @@ -183,9 +213,7 @@ def test_agent_card_custom_url( assert data['name'] == agent_card.name -def test_starlette_rpc_endpoint_custom_url( - app: A2AStarletteApplication, handler: mock.AsyncMock -): +def test_starlette_rpc_endpoint_custom_url(app, handler: mock.AsyncMock): """Test the RPC endpoint with a custom URL.""" # Provide a valid Task object as the return value task_status = MINIMAL_TASK_STATUS @@ -208,9 +236,7 @@ def test_starlette_rpc_endpoint_custom_url( assert data['result']['id'] == 'task1' -def test_fastapi_rpc_endpoint_custom_url( - app: A2AFastAPIApplication, handler: mock.AsyncMock -): +def test_fastapi_rpc_endpoint_custom_url(app, handler: mock.AsyncMock): """Test the RPC endpoint with a custom URL.""" # Provide a valid Task object as the return value task_status = MINIMAL_TASK_STATUS @@ -233,9 +259,7 @@ def test_fastapi_rpc_endpoint_custom_url( assert data['result']['id'] == 'task1' -def test_starlette_build_with_extra_routes( - app: A2AStarletteApplication, agent_card: AgentCard -): +def test_starlette_build_with_extra_routes(app, agent_card: AgentCard): """Test building the app with additional routes.""" def custom_handler(request): @@ -243,7 +267,7 @@ def custom_handler(request): extra_route = Route('/hello', custom_handler, methods=['GET']) test_app = app.build(routes=[extra_route]) - client = TestClient(test_app) + client = TestClient(test_app, headers={'A2A-Version': '1.0'}) # Test the added route response = client.get('/hello') @@ -257,9 +281,7 @@ def custom_handler(request): assert data['name'] == agent_card.name -def test_fastapi_build_with_extra_routes( - app: A2AFastAPIApplication, agent_card: AgentCard -): +def test_fastapi_build_with_extra_routes(app, agent_card: AgentCard): """Test building the app with additional routes.""" def custom_handler(request): @@ -281,9 +303,7 @@ def custom_handler(request): assert data['name'] == agent_card.name -def test_fastapi_build_custom_agent_card_path( - app: A2AFastAPIApplication, agent_card: AgentCard -): +def test_fastapi_build_custom_agent_card_path(app, agent_card: AgentCard): """Test building the app with a custom agent card path.""" test_app = app.build(agent_card_url='/agent-card') @@ -471,7 +491,7 @@ def test_get_push_notification_config( handler.on_get_task_push_notification_config.assert_awaited_once() -def test_server_auth(app: A2AStarletteApplication, handler: mock.AsyncMock): +def test_server_auth(app, handler: mock.AsyncMock): class TestAuthMiddleware(AuthenticationBackend): async def authenticate( self, conn: HTTPConnection @@ -534,9 +554,7 @@ async def authenticate( @pytest.mark.asyncio -async def test_message_send_stream( - app: A2AStarletteApplication, handler: mock.AsyncMock -) -> None: +async def test_message_send_stream(app, handler: mock.AsyncMock) -> None: """Test streaming message sending.""" # Setup mock streaming response @@ -614,9 +632,7 @@ async def stream_generator(): @pytest.mark.asyncio -async def test_task_resubscription( - app: A2AStarletteApplication, handler: mock.AsyncMock -) -> None: +async def test_task_resubscription(app, handler: mock.AsyncMock) -> None: """Test task resubscription streaming.""" # Setup mock streaming response @@ -751,9 +767,7 @@ async def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AStarletteApplication( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) @@ -776,9 +790,7 @@ def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AStarletteApplication( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) @@ -801,9 +813,7 @@ async def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AFastAPIApplication( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) @@ -823,9 +833,7 @@ def modifier(card: AgentCard) -> AgentCard: modified_card.name = 'Dynamically Modified Agent' return modified_card - app_instance = A2AFastAPIApplication( - agent_card, handler, card_modifier=modifier - ) + app_instance = AppBuilder(agent_card, handler, card_modifier=modifier) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) @@ -937,7 +945,7 @@ def test_agent_card_backward_compatibility_supports_extended_card( ): """Test that supportsAuthenticatedExtendedCard is injected when extended_agent_card is True.""" agent_card.capabilities.extended_agent_card = True - app_instance = A2AStarletteApplication(agent_card, handler) + app_instance = AppBuilder(agent_card, handler) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) assert response.status_code == 200 @@ -950,7 +958,7 @@ def test_agent_card_backward_compatibility_no_extended_card( ): """Test that supportsAuthenticatedExtendedCard is absent when extended_agent_card is False.""" agent_card.capabilities.extended_agent_card = False - app_instance = A2AStarletteApplication(agent_card, handler) + app_instance = AppBuilder(agent_card, handler) client = TestClient(app_instance.build()) response = client.get(AGENT_CARD_WELL_KNOWN_PATH) assert response.status_code == 200 From 71a9285df9e94d754cee684ef9c476997223823b Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Thu, 19 Mar 2026 21:10:39 +0100 Subject: [PATCH 02/25] Update src/a2a/server/routes/agent_card_routes.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/a2a/server/routes/agent_card_routes.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/a2a/server/routes/agent_card_routes.py b/src/a2a/server/routes/agent_card_routes.py index 30d635f11..7b63b5c01 100644 --- a/src/a2a/server/routes/agent_card_routes.py +++ b/src/a2a/server/routes/agent_card_routes.py @@ -60,8 +60,9 @@ def __init__( """ if not _package_starlette_installed: raise ImportError( - 'The `starlette` package is required to use the `AgentCardRoutes`.' - ' `a2a-sdk[http-server]`.' + 'The `starlette` package is required to use `AgentCardRoutes`. ' + 'It can be installed as part of `a2a-sdk` optional dependencies, `a2a-sdk[http-server]`.' + ) ) self.agent_card = agent_card From fee5d5e67a8e441bf13abe31ba07d8830e92e3fd Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 19 Mar 2026 20:29:36 +0000 Subject: [PATCH 03/25] fix suggestions --- src/a2a/server/routes/agent_card_routes.py | 17 +++++------ src/a2a/server/routes/jsonrpc_dispatcher.py | 29 ------------------- src/a2a/utils/constants.py | 1 - tests/server/routes/test_agent_card_routes.py | 9 +----- tests/server/routes/test_jsonrpc_routes.py | 5 +--- 5 files changed, 9 insertions(+), 52 deletions(-) diff --git a/src/a2a/server/routes/agent_card_routes.py b/src/a2a/server/routes/agent_card_routes.py index 7b63b5c01..067477bd1 100644 --- a/src/a2a/server/routes/agent_card_routes.py +++ b/src/a2a/server/routes/agent_card_routes.py @@ -63,24 +63,21 @@ def __init__( 'The `starlette` package is required to use `AgentCardRoutes`. ' 'It can be installed as part of `a2a-sdk` optional dependencies, `a2a-sdk[http-server]`.' ) - ) self.agent_card = agent_card self.card_modifier = card_modifier - async def get_agent_card(request: Request) -> Response: - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await( - self.card_modifier(card_to_serve) - ) - return JSONResponse(agent_card_to_dict(card_to_serve)) - self.routes = [ Route( path=card_url, - endpoint=get_agent_card, + endpoint=self._get_agent_card, methods=['GET'], middleware=middleware, ) ] + + async def _get_agent_card(self, request: Request) -> Response: + card_to_serve = self.agent_card + if self.card_modifier: + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) + return JSONResponse(agent_card_to_dict(card_to_serve)) diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index 14a0cc0bb..970d0620b 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -47,9 +47,6 @@ SubscribeToTaskRequest, TaskPushNotificationConfig, ) -from a2a.utils.constants import ( - DEFAULT_MAX_CONTENT_LENGTH, -) from a2a.utils.errors import ( A2AError, UnsupportedOperationError, @@ -202,7 +199,6 @@ def __init__( # noqa: PLR0913 ] | None = None, enable_v0_3_compat: bool = False, - max_content_length: int | None = DEFAULT_MAX_CONTENT_LENGTH, ) -> None: """Initializes the JsonRpcDispatcher. @@ -220,8 +216,6 @@ def __init__( # noqa: PLR0913 extended_card_modifier: An optional callback to dynamically modify the extended agent card before it is served. It receives the call context. - max_content_length: The maximum allowed content length for incoming - requests. Defaults to 10MB. Set to None for unbounded maximum. enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. """ if not _package_starlette_installed: @@ -242,7 +236,6 @@ def __init__( # noqa: PLR0913 extended_card_modifier=extended_card_modifier, ) self._context_builder = context_builder or DefaultCallContextBuilder() - self._max_content_length = max_content_length self.enable_v0_3_compat = enable_v0_3_compat self._v03_adapter: JSONRPC03Adapter | None = None @@ -298,22 +291,6 @@ def _generate_error_response( status_code=200, ) - def _allowed_content_length(self, request: Request) -> bool: - """Checks if the request content length is within the allowed maximum. - - Args: - request: The incoming Starlette Request object. - - Returns: - False if the content length is larger than the allowed maximum, True otherwise. - """ - if self._max_content_length is not None: - with contextlib.suppress(ValueError): - content_length = int(request.headers.get('content-length', '0')) - if content_length and content_length > self._max_content_length: - return False - return True - async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911, PLR0912 """Handles incoming POST requests to the main A2A endpoint. @@ -344,12 +321,6 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911 request_id, str | int ): request_id = None - # Treat payloads lager than allowed as invalid request (-32600) before routing - if not self._allowed_content_length(request): - return self._generate_error_response( - request_id, - InvalidRequestError(message='Payload too large'), - ) logger.debug('Request body: %s', body) # 1) Validate base JSON-RPC structure only (-32600 on failure) try: diff --git a/src/a2a/utils/constants.py b/src/a2a/utils/constants.py index 6cee2a05c..5497d8a24 100644 --- a/src/a2a/utils/constants.py +++ b/src/a2a/utils/constants.py @@ -20,7 +20,6 @@ class TransportProtocol(str, Enum): GRPC = 'GRPC' -DEFAULT_MAX_CONTENT_LENGTH = 10 * 1024 * 1024 # 10MB JSONRPC_PARSE_ERROR_CODE = -32700 VERSION_HEADER = 'A2A-Version' diff --git a/tests/server/routes/test_agent_card_routes.py b/tests/server/routes/test_agent_card_routes.py index 8f86ec937..01ccce8c6 100644 --- a/tests/server/routes/test_agent_card_routes.py +++ b/tests/server/routes/test_agent_card_routes.py @@ -6,6 +6,7 @@ import pytest from starlette.testclient import TestClient from starlette.middleware import Middleware +from starlette.applications import Starlette from a2a.server.routes.agent_card_routes import AgentCardRoutes from a2a.types.a2a_pb2 import AgentCard @@ -20,8 +21,6 @@ def test_get_agent_card_success(agent_card): """Tests that the agent card route returns the card correctly.""" routes = AgentCardRoutes(agent_card=agent_card).routes - from starlette.applications import Starlette - app = Starlette(routes=routes) client = TestClient(app) @@ -52,8 +51,6 @@ async def modifier(card: AgentCard) -> AgentCard: agent_card=agent_card, card_modifier=mock_modifier ).routes - from starlette.applications import Starlette - app = Starlette(routes=routes) client = TestClient(app) @@ -67,8 +64,6 @@ def test_agent_card_custom_url(agent_card): custom_url = '/custom/path/agent.json' routes = AgentCardRoutes(agent_card=agent_card, card_url=custom_url).routes - from starlette.applications import Starlette - app = Starlette(routes=routes) client = TestClient(app) @@ -95,8 +90,6 @@ async def __call__(self, scope, receive, send): agent_card=agent_card, middleware=[Middleware(MyMiddleware)] ).routes - from starlette.applications import Starlette - app = Starlette(routes=routes) client = TestClient(app) diff --git a/tests/server/routes/test_jsonrpc_routes.py b/tests/server/routes/test_jsonrpc_routes.py index b4cd2f2bb..5d1b01d98 100644 --- a/tests/server/routes/test_jsonrpc_routes.py +++ b/tests/server/routes/test_jsonrpc_routes.py @@ -5,6 +5,7 @@ import pytest from starlette.testclient import TestClient from starlette.middleware import Middleware +from starlette.applications import Starlette from a2a.server.routes.jsonrpc_routes import JsonRpcRoutes from a2a.server.request_handlers.request_handler import RequestHandler @@ -44,8 +45,6 @@ def test_jsonrpc_custom_url(agent_card, mock_handler): agent_card=agent_card, request_handler=mock_handler, rpc_url=custom_url ) - from starlette.applications import Starlette - app = Starlette(routes=jsonrpc_routes.routes) client = TestClient(app) @@ -83,8 +82,6 @@ async def __call__(self, scope, receive, send): rpc_url='/', ) - from starlette.applications import Starlette - app = Starlette(routes=jsonrpc_routes.routes) client = TestClient(app) From 23e149daef73add8387fdff1fd8985e31818e680 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 20 Mar 2026 08:19:10 +0000 Subject: [PATCH 04/25] revert test --- .../cross_version/client_server/server_0_3.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/integration/cross_version/client_server/server_0_3.py b/tests/integration/cross_version/client_server/server_0_3.py index 96152c135..7bd5f7e75 100644 --- a/tests/integration/cross_version/client_server/server_0_3.py +++ b/tests/integration/cross_version/client_server/server_0_3.py @@ -8,7 +8,7 @@ from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext -from a2a.server.apps.jsonrpc import A2AFastAPIApplication +from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication from a2a.server.events.event_queue import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager @@ -188,13 +188,12 @@ async def main_async(http_port: int, grpc_port: int): ) app = FastAPI() - jsonrpc_app = A2AFastAPIApplication( - agent_card=agent_card, - http_handler=handler, - extended_agent_card=agent_card, - ).build() - app.mount('/jsonrpc', jsonrpc_app) - + app.mount( + '/jsonrpc', + A2AFastAPIApplication( + http_handler=handler, agent_card=agent_card + ).build(), + ) app.mount( '/rest', A2ARESTFastAPIApplication( From b3c201e5a00d77d1ac2a9c9e1fe28da3d3780641 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 20 Mar 2026 08:26:49 +0000 Subject: [PATCH 05/25] revert wrong changes --- tests/__init__.py | 1 - tests/compat/v0_3/test_jsonrpc_app_compat.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 792d60054..e69de29bb 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +0,0 @@ -# diff --git a/tests/compat/v0_3/test_jsonrpc_app_compat.py b/tests/compat/v0_3/test_jsonrpc_app_compat.py index 4b344c67d..f95818456 100644 --- a/tests/compat/v0_3/test_jsonrpc_app_compat.py +++ b/tests/compat/v0_3/test_jsonrpc_app_compat.py @@ -51,13 +51,13 @@ def test_app(mock_handler): mock_agent_card.capabilities.streaming = False mock_agent_card.capabilities.push_notifications = True mock_agent_card.capabilities.extended_agent_card = True - router = JsonRpcRoutes( + jsonrpc_routes = JsonRpcRoutes( agent_card=mock_agent_card, request_handler=mock_handler, enable_v0_3_compat=True, rpc_url='/', ) - return Starlette(routes=router.routes) + return Starlette(routes=jsonrpc_routes.routes) @pytest.fixture From 3c962e2263076a1d68a9679b9bea1215047b7df6 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 20 Mar 2026 08:52:36 +0000 Subject: [PATCH 06/25] make method public --- src/a2a/server/routes/jsonrpc_dispatcher.py | 2 +- src/a2a/server/routes/jsonrpc_routes.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index 970d0620b..1ce5f0fe8 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -291,7 +291,7 @@ def _generate_error_response( status_code=200, ) - async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911, PLR0912 + async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, PLR0912 """Handles incoming POST requests to the main A2A endpoint. Parses the request body as JSON, validates it against A2A request types, diff --git a/src/a2a/server/routes/jsonrpc_routes.py b/src/a2a/server/routes/jsonrpc_routes.py index cc0e12612..73bca8280 100644 --- a/src/a2a/server/routes/jsonrpc_routes.py +++ b/src/a2a/server/routes/jsonrpc_routes.py @@ -100,7 +100,7 @@ def __init__( # noqa: PLR0913 self.routes = [ Route( path=rpc_url, - endpoint=self.dispatcher._handle_requests, # noqa: SLF001 + endpoint=self.dispatcher.handle_requests, methods=['POST'], middleware=middleware, ) From 921651b83d15b9c8962b7861b7e04eed38f2ecb1 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 20 Mar 2026 16:35:20 +0000 Subject: [PATCH 07/25] fix --- samples/hello_world_agent.py | 10 +- src/a2a/server/routes/__init__.py | 12 +- src/a2a/server/routes/agent_card_routes.py | 68 ++++------ src/a2a/server/routes/jsonrpc_routes.py | 116 ++++++++---------- tck/sut_agent.py | 12 +- tests/compat/v0_3/test_jsonrpc_app_compat.py | 6 +- .../cross_version/client_server/server_1_0.py | 6 +- tests/integration/test_agent_card.py | 10 +- .../test_client_server_integration.py | 26 ++-- tests/integration/test_end_to_end.py | 8 +- tests/integration/test_tenant.py | 8 +- tests/integration/test_version_header.py | 10 +- tests/server/routes/test_agent_card_routes.py | 33 +---- .../server/routes/test_jsonrpc_dispatcher.py | 14 +-- tests/server/routes/test_jsonrpc_routes.py | 49 ++------ tests/server/test_integration.py | 10 +- 16 files changed, 157 insertions(+), 241 deletions(-) diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index e46b9ede4..fa9ab3c2b 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -17,7 +17,7 @@ from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) -from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_updater import TaskUpdater from a2a.types import ( @@ -198,17 +198,17 @@ async def serve( ) rest_app = rest_app_builder.build() - jsonrpc_routes = JsonRpcRoutes( + jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=request_handler, rpc_url='/a2a/jsonrpc/', ) - agent_card_routes = AgentCardRoutes( + agent_card_routes = create_agent_card_routes( agent_card=agent_card, ) app = FastAPI() - app.routes.extend(jsonrpc_routes.routes) - app.routes.extend(agent_card_routes.routes) + app.routes.extend(jsonrpc_routes) + app.routes.extend(agent_card_routes) app.mount('/a2a/rest', rest_app) grpc_server = grpc.aio.server() diff --git a/src/a2a/server/routes/__init__.py b/src/a2a/server/routes/__init__.py index ec65d8b34..a559480a4 100644 --- a/src/a2a/server/routes/__init__.py +++ b/src/a2a/server/routes/__init__.py @@ -1,20 +1,16 @@ """A2A Routes.""" -from a2a.server.routes.agent_card_routes import AgentCardRoutes +from a2a.server.routes.agent_card_routes import create_agent_card_routes from a2a.server.routes.jsonrpc_dispatcher import ( CallContextBuilder, DefaultCallContextBuilder, - JsonRpcDispatcher, - StarletteUserProxy, ) -from a2a.server.routes.jsonrpc_routes import JsonRpcRoutes +from a2a.server.routes.jsonrpc_routes import create_jsonrpc_routes __all__ = [ - 'AgentCardRoutes', + 'create_agent_card_routes', 'CallContextBuilder', 'DefaultCallContextBuilder', - 'JsonRpcDispatcher', - 'JsonRpcRoutes', - 'StarletteUserProxy', + 'create_jsonrpc_routes', ] diff --git a/src/a2a/server/routes/agent_card_routes.py b/src/a2a/server/routes/agent_card_routes.py index 067477bd1..a521c8d2c 100644 --- a/src/a2a/server/routes/agent_card_routes.py +++ b/src/a2a/server/routes/agent_card_routes.py @@ -37,47 +37,29 @@ logger = logging.getLogger(__name__) -class AgentCardRoutes: - """Provides the Starlette Route for the A2A protocol agent card endpoint.""" - - def __init__( - self, - agent_card: AgentCard, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - middleware: Sequence['Middleware'] | None = None, - ) -> None: - """Initializes the AgentCardRoute. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - card_url: The URL for the agent card endpoint. - middleware: An optional list of Starlette middleware to apply to the - agent card endpoint. - """ - if not _package_starlette_installed: - raise ImportError( - 'The `starlette` package is required to use `AgentCardRoutes`. ' - 'It can be installed as part of `a2a-sdk` optional dependencies, `a2a-sdk[http-server]`.' - ) - - self.agent_card = agent_card - self.card_modifier = card_modifier - - self.routes = [ - Route( - path=card_url, - endpoint=self._get_agent_card, - methods=['GET'], - middleware=middleware, - ) - ] - - async def _get_agent_card(self, request: Request) -> Response: - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) +def create_agent_card_routes( + agent_card: AgentCard, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + card_url: str = AGENT_CARD_WELL_KNOWN_PATH, +) -> list['Route']: + """Creates the Starlette Route for the A2A protocol agent card endpoint.""" + if not _package_starlette_installed: + raise ImportError( + 'The `starlette` package is required to use `create_agent_card_routes`. ' + 'It can be installed as part of `a2a-sdk` optional dependencies, `a2a-sdk[http-server]`.' + ) + + async def _get_agent_card(request: Request) -> Response: + card_to_serve = agent_card + if card_modifier: + card_to_serve = await maybe_await(card_modifier(card_to_serve)) return JSONResponse(agent_card_to_dict(card_to_serve)) + + return [ + Route( + path=card_url, + endpoint=_get_agent_card, + methods=['GET'], + ) + ] diff --git a/src/a2a/server/routes/jsonrpc_routes.py b/src/a2a/server/routes/jsonrpc_routes.py index 73bca8280..2aeaf295a 100644 --- a/src/a2a/server/routes/jsonrpc_routes.py +++ b/src/a2a/server/routes/jsonrpc_routes.py @@ -36,72 +36,64 @@ logger = logging.getLogger(__name__) -class JsonRpcRoutes: - """Provides the Starlette Route for the A2A protocol JSON-RPC endpoint. +def create_jsonrpc_routes( # noqa: PLR0913 + agent_card: AgentCard, + request_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, + context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, + enable_v0_3_compat: bool = False, + rpc_url: str = DEFAULT_RPC_URL, +) -> list['Route']: + """Creates the Starlette Route for the A2A protocol JSON-RPC endpoint. Handles incoming JSON-RPC requests, routes them to the appropriate handler methods, and manages response generation including Server-Sent Events (SSE). - """ - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - request_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - enable_v0_3_compat: bool = False, - rpc_url: str = DEFAULT_RPC_URL, - middleware: Sequence[Middleware] | None = None, - ) -> None: - """Initializes the JsonRpcRoute. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - request_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the request_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. - rpc_url: The URL prefix for the RPC endpoints. - middleware: An optional list of Starlette middleware to apply to the routes. - """ - if not _package_starlette_installed: - raise ImportError( - 'The `starlette` package is required to use the `JsonRpcRoutes`.' - ' It can be added as a part of `a2a-sdk` optional dependencies,' - ' `a2a-sdk[http-server]`.' - ) - - self.dispatcher = JsonRpcDispatcher( - agent_card=agent_card, - http_handler=request_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - enable_v0_3_compat=enable_v0_3_compat, + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the request_handler. If None, no + ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. + enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. + rpc_url: The URL prefix for the RPC endpoints. + """ + if not _package_starlette_installed: + raise ImportError( + 'The `starlette` package is required to use `create_jsonrpc_routes`.' + ' It can be added as a part of `a2a-sdk` optional dependencies,' + ' `a2a-sdk[http-server]`.' ) - self.routes = [ - Route( - path=rpc_url, - endpoint=self.dispatcher.handle_requests, - methods=['POST'], - middleware=middleware, - ) - ] + dispatcher = JsonRpcDispatcher( + agent_card=agent_card, + http_handler=request_handler, + extended_agent_card=extended_agent_card, + context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, + enable_v0_3_compat=enable_v0_3_compat, + ) + + return [ + Route( + path=rpc_url, + endpoint=dispatcher.handle_requests, + methods=['POST'], + ) + ] diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 955493437..d133e257a 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -25,8 +25,8 @@ ) from a2a.server.request_handlers.grpc_handler import GrpcHandler from a2a.server.routes import ( - AgentCardRoutes, - JsonRpcRoutes, + create_agent_card_routes, + create_jsonrpc_routes, ) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_store import TaskStore @@ -200,18 +200,18 @@ def serve(task_store: TaskStore) -> None: ) # JSONRPC - jsonrpc_routes = JsonRpcRoutes( + jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=request_handler, rpc_url=JSONRPC_URL, ) # Agent Card - agent_card_routes = AgentCardRoutes( + agent_card_routes = create_agent_card_routes( agent_card=agent_card, ) routes = [ - *jsonrpc_routes.routes, - *agent_card_routes.routes, + *jsonrpc_routes, + *agent_card_routes, ] main_app = Starlette(routes=routes) diff --git a/tests/compat/v0_3/test_jsonrpc_app_compat.py b/tests/compat/v0_3/test_jsonrpc_app_compat.py index f95818456..8120e322f 100644 --- a/tests/compat/v0_3/test_jsonrpc_app_compat.py +++ b/tests/compat/v0_3/test_jsonrpc_app_compat.py @@ -7,7 +7,7 @@ from starlette.testclient import TestClient from starlette.applications import Starlette -from a2a.server.routes import JsonRpcRoutes +from a2a.server.routes import create_jsonrpc_routes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( AgentCard, @@ -51,13 +51,13 @@ def test_app(mock_handler): mock_agent_card.capabilities.streaming = False mock_agent_card.capabilities.push_notifications = True mock_agent_card.capabilities.extended_agent_card = True - jsonrpc_routes = JsonRpcRoutes( + jsonrpc_routes = create_jsonrpc_routes( agent_card=mock_agent_card, request_handler=mock_handler, enable_v0_3_compat=True, rpc_url='/', ) - return Starlette(routes=jsonrpc_routes.routes) + return Starlette(routes=jsonrpc_routes) @pytest.fixture diff --git a/tests/integration/cross_version/client_server/server_1_0.py b/tests/integration/cross_version/client_server/server_1_0.py index 907c010ff..f6121a337 100644 --- a/tests/integration/cross_version/client_server/server_1_0.py +++ b/tests/integration/cross_version/client_server/server_1_0.py @@ -5,7 +5,7 @@ import grpc from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes +from a2a.server.routes import AgentCardRoutes, create_jsonrpc_routes from a2a.server.apps import A2ARESTFastAPIApplication from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager @@ -170,7 +170,7 @@ async def main_async(http_port: int, grpc_port: int): agent_card_routes = AgentCardRoutes( agent_card=agent_card, card_url='/.well-known/agent-card.json' ) - jsonrpc_routes = JsonRpcRoutes( + jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=handler, extended_agent_card=agent_card, @@ -179,7 +179,7 @@ async def main_async(http_port: int, grpc_port: int): ) app.mount( '/jsonrpc', - FastAPI(routes=jsonrpc_routes.routes + agent_card_routes.routes), + FastAPI(routes=jsonrpc_routes + agent_card_routes.routes), ) app.mount( diff --git a/tests/integration/test_agent_card.py b/tests/integration/test_agent_card.py index 42aca3843..719b7be9f 100644 --- a/tests/integration/test_agent_card.py +++ b/tests/integration/test_agent_card.py @@ -6,7 +6,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from starlette.applications import Starlette from a2a.server.apps import A2ARESTFastAPIApplication -from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler @@ -73,12 +73,12 @@ async def test_agent_card_integration(header_val: str | None) -> None: # Mount JSONRPC application jsonrpc_routes = [ - *AgentCardRoutes( + *create_agent_card_routes( agent_card=agent_card, card_url='/.well-known/agent-card.json' - ).routes, - *JsonRpcRoutes( + ), + *create_jsonrpc_routes( agent_card=agent_card, request_handler=handler, rpc_url='/' - ).routes, + ), ] jsonrpc_app = Starlette(routes=jsonrpc_routes) app.mount('/jsonrpc', jsonrpc_app) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index f6f1b4182..6df44a58d 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -25,7 +25,7 @@ from a2a.client.transports import JsonRpcTransport, RestTransport from starlette.applications import Starlette from a2a.server.apps import A2ARESTFastAPIApplication -from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( @@ -222,14 +222,14 @@ def http_base_setup(mock_request_handler: AsyncMock, agent_card: AgentCard): def jsonrpc_setup(http_base_setup) -> TransportSetup: """Sets up the JsonRpcTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') - jsonrpc_routes = JsonRpcRoutes( + agent_card_routes = create_agent_card_routes(agent_card=agent_card, card_url='/') + jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=mock_request_handler, extended_agent_card=agent_card, rpc_url='/', ) - app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) + app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) factory = ClientFactory( config=ClientConfig( @@ -625,16 +625,16 @@ async def test_json_transport_get_signed_base_card( }, ) - agent_card_routes = AgentCardRoutes( + agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/', card_modifier=signer ) - jsonrpc_routes = JsonRpcRoutes( + jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=mock_request_handler, extended_agent_card=agent_card, rpc_url='/', ) - app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) + app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) agent_url = agent_card.supported_interfaces[0].url @@ -695,15 +695,15 @@ async def test_client_get_signed_extended_card( }, ) - agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') - jsonrpc_routes = JsonRpcRoutes( + agent_card_routes = create_agent_card_routes(agent_card=agent_card, card_url='/') + jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=mock_request_handler, extended_agent_card=extended_agent_card, extended_card_modifier=lambda card, ctx: signer(card), rpc_url='/', ) - app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) + app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) transport = JsonRpcTransport( @@ -764,17 +764,17 @@ async def test_client_get_signed_base_and_extended_cards( }, ) - agent_card_routes = AgentCardRoutes( + agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/', card_modifier=signer ) - jsonrpc_routes = JsonRpcRoutes( + jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=mock_request_handler, extended_agent_card=extended_agent_card, extended_card_modifier=lambda card, ctx: signer(card), rpc_url='/', ) - app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) + app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) agent_url = agent_card.supported_interfaces[0].url diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index f75e8c9da..2c0ea8e52 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -12,7 +12,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from starlette.applications import Starlette from a2a.server.apps import A2ARESTFastAPIApplication -from a2a.server.routes import JsonRpcRoutes, AgentCardRoutes +from a2a.server.routes import create_jsonrpc_routes, create_agent_card_routes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler @@ -194,14 +194,14 @@ def rest_setup(agent_card, base_e2e_setup) -> ClientSetup: @pytest.fixture def jsonrpc_setup(agent_card, base_e2e_setup) -> ClientSetup: task_store, handler = base_e2e_setup - agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') - jsonrpc_routes = JsonRpcRoutes( + agent_card_routes = create_agent_card_routes(agent_card=agent_card, card_url='/') + jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=handler, extended_agent_card=agent_card, rpc_url='/', ) - app = Starlette(routes=[*agent_card_routes.routes, *jsonrpc_routes.routes]) + app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) httpx_client = httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url='http://testserver' ) diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py index 21698b4f4..ffc21306b 100644 --- a/tests/integration/test_tenant.py +++ b/tests/integration/test_tenant.py @@ -19,7 +19,7 @@ from a2a.client import ClientConfig, ClientFactory from a2a.utils.constants import TransportProtocol -from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from starlette.applications import Starlette from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.context import ServerCallContext @@ -198,17 +198,17 @@ def jsonrpc_agent_card(self): @pytest.fixture def server_app(self, jsonrpc_agent_card, mock_handler): - agent_card_routes = AgentCardRoutes( + agent_card_routes = create_agent_card_routes( agent_card=jsonrpc_agent_card, card_url='/' ) - jsonrpc_routes = JsonRpcRoutes( + jsonrpc_routes = create_jsonrpc_routes( agent_card=jsonrpc_agent_card, request_handler=mock_handler, extended_agent_card=jsonrpc_agent_card, rpc_url='/jsonrpc', ) app = Starlette( - routes=[*agent_card_routes.routes, *jsonrpc_routes.routes] + routes=[*agent_card_routes, *jsonrpc_routes] ) return app diff --git a/tests/integration/test_version_header.py b/tests/integration/test_version_header.py index 754b14168..7dd79adf4 100644 --- a/tests/integration/test_version_header.py +++ b/tests/integration/test_version_header.py @@ -5,7 +5,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2ARESTFastAPIApplication -from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler @@ -57,16 +57,16 @@ async def mock_on_message_send_stream(*args, **kwargs): handler.on_message_send_stream = mock_on_message_send_stream app = FastAPI() - agent_card_routes = AgentCardRoutes(agent_card=agent_card, card_url='/') - jsonrpc_routes = JsonRpcRoutes( + agent_card_routes = create_agent_card_routes(agent_card=agent_card, card_url='/') + jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=handler, extended_agent_card=agent_card, rpc_url='/jsonrpc', enable_v0_3_compat=True, ) - app.routes.extend(agent_card_routes.routes) - app.routes.extend(jsonrpc_routes.routes) + app.routes.extend(agent_card_routes) + app.routes.extend(jsonrpc_routes) rest_app = A2ARESTFastAPIApplication( http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True ).build() diff --git a/tests/server/routes/test_agent_card_routes.py b/tests/server/routes/test_agent_card_routes.py index 01ccce8c6..de028f5f4 100644 --- a/tests/server/routes/test_agent_card_routes.py +++ b/tests/server/routes/test_agent_card_routes.py @@ -8,7 +8,7 @@ from starlette.middleware import Middleware from starlette.applications import Starlette -from a2a.server.routes.agent_card_routes import AgentCardRoutes +from a2a.server.routes.agent_card_routes import create_agent_card_routes from a2a.types.a2a_pb2 import AgentCard @@ -19,7 +19,7 @@ def agent_card(): def test_get_agent_card_success(agent_card): """Tests that the agent card route returns the card correctly.""" - routes = AgentCardRoutes(agent_card=agent_card).routes + routes = create_agent_card_routes(agent_card=agent_card) app = Starlette(routes=routes) client = TestClient(app) @@ -47,9 +47,9 @@ async def modifier(card: AgentCard) -> AgentCard: return card mock_modifier = AsyncMock(side_effect=modifier) - routes = AgentCardRoutes( + routes = create_agent_card_routes( agent_card=agent_card, card_modifier=mock_modifier - ).routes + ) app = Starlette(routes=routes) client = TestClient(app) @@ -62,7 +62,7 @@ async def modifier(card: AgentCard) -> AgentCard: def test_agent_card_custom_url(agent_card): """Tests that custom card_url is respected.""" custom_url = '/custom/path/agent.json' - routes = AgentCardRoutes(agent_card=agent_card, card_url=custom_url).routes + routes = create_agent_card_routes(agent_card=agent_card, card_url=custom_url) app = Starlette(routes=routes) client = TestClient(app) @@ -73,26 +73,3 @@ def test_agent_card_custom_url(agent_card): assert client.get(custom_url).status_code == 200 -def test_agent_card_with_middleware(agent_card): - """Tests that middleware is applied to the routes.""" - middleware_called = False - - class MyMiddleware: - def __init__(self, app): - self.app = app - - async def __call__(self, scope, receive, send): - nonlocal middleware_called - middleware_called = True - await self.app(scope, receive, send) - - routes = AgentCardRoutes( - agent_card=agent_card, middleware=[Middleware(MyMiddleware)] - ).routes - - app = Starlette(routes=routes) - client = TestClient(app) - - response = client.get('/.well-known/agent-card.json') - assert response.status_code == 200 - assert middleware_called is True diff --git a/tests/server/routes/test_jsonrpc_dispatcher.py b/tests/server/routes/test_jsonrpc_dispatcher.py index 7241cac4b..4fb398660 100644 --- a/tests/server/routes/test_jsonrpc_dispatcher.py +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -28,8 +28,8 @@ JsonRpcDispatcher, StarletteUserProxy, ) -from a2a.server.routes.jsonrpc_routes import JsonRpcRoutes -from a2a.server.routes.agent_card_routes import AgentCardRoutes +from a2a.server.routes.jsonrpc_routes import create_jsonrpc_routes +from a2a.server.routes.agent_card_routes import create_agent_card_routes from a2a.server.jsonrpc_models import JSONRPCError from a2a.utils.errors import A2AError @@ -86,13 +86,13 @@ def test_app(mock_handler): mock_agent_card.capabilities = MagicMock() mock_agent_card.capabilities.streaming = False - jsonrpc_routes = JsonRpcRoutes( + jsonrpc_routes = create_jsonrpc_routes( agent_card=mock_agent_card, request_handler=mock_handler, rpc_url='/' ) from starlette.applications import Starlette - return Starlette(routes=jsonrpc_routes.routes) + return Starlette(routes=jsonrpc_routes) @pytest.fixture @@ -256,13 +256,13 @@ def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): from starlette.applications import Starlette - jsonrpc_routes = JsonRpcRoutes( + jsonrpc_routes = create_jsonrpc_routes( agent_card=mock_agent_card, request_handler=mock_handler, enable_v0_3_compat=True, rpc_url='/', ) - app = Starlette(routes=jsonrpc_routes.routes) + app = Starlette(routes=jsonrpc_routes) client = TestClient(app) request_data = { @@ -278,7 +278,7 @@ def test_v0_3_compat_flag_routes_to_adapter(self, mock_handler): }, } - dispatcher_instance = jsonrpc_routes.dispatcher + dispatcher_instance = jsonrpc_routes[0].endpoint.__self__ with patch.object( dispatcher_instance._v03_adapter, 'handle_request', diff --git a/tests/server/routes/test_jsonrpc_routes.py b/tests/server/routes/test_jsonrpc_routes.py index 5d1b01d98..bf75b0521 100644 --- a/tests/server/routes/test_jsonrpc_routes.py +++ b/tests/server/routes/test_jsonrpc_routes.py @@ -7,7 +7,7 @@ from starlette.middleware import Middleware from starlette.applications import Starlette -from a2a.server.routes.jsonrpc_routes import JsonRpcRoutes +from a2a.server.routes.jsonrpc_routes import create_jsonrpc_routes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import AgentCard @@ -23,29 +23,28 @@ def mock_handler(): def test_routes_creation(agent_card, mock_handler): - """Tests that JsonRpcRoutes creates Route objects list.""" - jsonrpc_routes = JsonRpcRoutes( + """Tests that create_jsonrpc_routes creates Route objects list.""" + routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=mock_handler ) - assert hasattr(jsonrpc_routes, 'routes') - assert isinstance(jsonrpc_routes.routes, list) - assert len(jsonrpc_routes.routes) == 1 + assert isinstance(routes, list) + assert len(routes) == 1 from starlette.routing import Route - assert isinstance(jsonrpc_routes.routes[0], Route) - assert jsonrpc_routes.routes[0].methods == {'POST'} + assert isinstance(routes[0], Route) + assert routes[0].methods == {'POST'} def test_jsonrpc_custom_url(agent_card, mock_handler): """Tests that custom rpc_url is respected for routing.""" custom_url = '/custom/api/jsonrpc' - jsonrpc_routes = JsonRpcRoutes( + routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=mock_handler, rpc_url=custom_url ) - app = Starlette(routes=jsonrpc_routes.routes) + app = Starlette(routes=routes) client = TestClient(app) # Check that default path returns 404 @@ -61,33 +60,3 @@ def test_jsonrpc_custom_url(agent_card, mock_handler): # Method not found error from dispatcher assert resp_json['error']['code'] == -32601 - -def test_jsonrpc_with_middleware(agent_card, mock_handler): - """Tests that middleware is applied to the route.""" - middleware_called = False - - class MyMiddleware: - def __init__(self, app): - self.app = app - - async def __call__(self, scope, receive, send): - nonlocal middleware_called - middleware_called = True - await self.app(scope, receive, send) - - jsonrpc_routes = JsonRpcRoutes( - agent_card=agent_card, - request_handler=mock_handler, - middleware=[Middleware(MyMiddleware)], - rpc_url='/', - ) - - app = Starlette(routes=jsonrpc_routes.routes) - client = TestClient(app) - - # Call to trigger middleware - # Empty JSON might raise error, let's send a base valid format for dispatcher - client.post( - '/', json={'jsonrpc': '2.0', 'id': '1', 'method': 'SendMessage'} - ) - assert middleware_called is True diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 0cc5524d2..bdbfe62a7 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -18,7 +18,7 @@ from starlette.routing import Route from starlette.testclient import TestClient -from a2a.server.routes import AgentCardRoutes, JsonRpcRoutes +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.context import ServerCallContext from a2a.server.jsonrpc_models import ( @@ -164,18 +164,18 @@ def build( app_instance = Starlette(middleware=middleware, routes=routes or []) # Agent card router - card_routes = AgentCardRoutes( + card_routes = create_agent_card_routes( self.agent_card, card_url=agent_card_url, card_modifier=self.card_modifier, ) - app_instance.routes.extend(card_routes.routes) + app_instance.routes.extend(card_routes) # JSON-RPC router - rpc_routes = JsonRpcRoutes( + rpc_routes = create_jsonrpc_routes( self.agent_card, self.handler, rpc_url=rpc_url ) - app_instance.routes.extend(rpc_routes.routes) + app_instance.routes.extend(rpc_routes) return app_instance From d48a3c23564c820782e4136c00966c8017237d11 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 20 Mar 2026 16:48:02 +0000 Subject: [PATCH 08/25] linter --- src/a2a/server/routes/__init__.py | 2 +- src/a2a/server/routes/agent_card_routes.py | 2 +- src/a2a/server/routes/jsonrpc_routes.py | 2 +- .../integration/cross_version/client_server/server_1_0.py | 6 +++--- tests/integration/test_client_server_integration.py | 8 ++++++-- tests/integration/test_end_to_end.py | 4 +++- tests/integration/test_tenant.py | 4 +--- tests/integration/test_version_header.py | 4 +++- tests/server/routes/test_agent_card_routes.py | 6 +++--- tests/server/routes/test_jsonrpc_routes.py | 1 - 10 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/a2a/server/routes/__init__.py b/src/a2a/server/routes/__init__.py index a559480a4..cf7ed1cdc 100644 --- a/src/a2a/server/routes/__init__.py +++ b/src/a2a/server/routes/__init__.py @@ -9,8 +9,8 @@ __all__ = [ - 'create_agent_card_routes', 'CallContextBuilder', 'DefaultCallContextBuilder', + 'create_agent_card_routes', 'create_jsonrpc_routes', ] diff --git a/src/a2a/server/routes/agent_card_routes.py b/src/a2a/server/routes/agent_card_routes.py index a521c8d2c..c1f7ecffe 100644 --- a/src/a2a/server/routes/agent_card_routes.py +++ b/src/a2a/server/routes/agent_card_routes.py @@ -1,6 +1,6 @@ import logging -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any diff --git a/src/a2a/server/routes/jsonrpc_routes.py b/src/a2a/server/routes/jsonrpc_routes.py index 2aeaf295a..e55254f2f 100644 --- a/src/a2a/server/routes/jsonrpc_routes.py +++ b/src/a2a/server/routes/jsonrpc_routes.py @@ -1,6 +1,6 @@ import logging -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any diff --git a/tests/integration/cross_version/client_server/server_1_0.py b/tests/integration/cross_version/client_server/server_1_0.py index f6121a337..5b9cba9b2 100644 --- a/tests/integration/cross_version/client_server/server_1_0.py +++ b/tests/integration/cross_version/client_server/server_1_0.py @@ -5,7 +5,7 @@ import grpc from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.routes import AgentCardRoutes, create_jsonrpc_routes +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.apps import A2ARESTFastAPIApplication from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager @@ -167,7 +167,7 @@ async def main_async(http_port: int, grpc_port: int): app = FastAPI() app.add_middleware(CustomLoggingMiddleware) - agent_card_routes = AgentCardRoutes( + agent_card_routes = create_agent_card_routes( agent_card=agent_card, card_url='/.well-known/agent-card.json' ) jsonrpc_routes = create_jsonrpc_routes( @@ -179,7 +179,7 @@ async def main_async(http_port: int, grpc_port: int): ) app.mount( '/jsonrpc', - FastAPI(routes=jsonrpc_routes + agent_card_routes.routes), + FastAPI(routes=jsonrpc_routes + agent_card_routes), ) app.mount( diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 6df44a58d..19a86ad58 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -222,7 +222,9 @@ def http_base_setup(mock_request_handler: AsyncMock, agent_card: AgentCard): def jsonrpc_setup(http_base_setup) -> TransportSetup: """Sets up the JsonRpcTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - agent_card_routes = create_agent_card_routes(agent_card=agent_card, card_url='/') + agent_card_routes = create_agent_card_routes( + agent_card=agent_card, card_url='/' + ) jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=mock_request_handler, @@ -695,7 +697,9 @@ async def test_client_get_signed_extended_card( }, ) - agent_card_routes = create_agent_card_routes(agent_card=agent_card, card_url='/') + agent_card_routes = create_agent_card_routes( + agent_card=agent_card, card_url='/' + ) jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=mock_request_handler, diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 2c0ea8e52..a6f8f866a 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -194,7 +194,9 @@ def rest_setup(agent_card, base_e2e_setup) -> ClientSetup: @pytest.fixture def jsonrpc_setup(agent_card, base_e2e_setup) -> ClientSetup: task_store, handler = base_e2e_setup - agent_card_routes = create_agent_card_routes(agent_card=agent_card, card_url='/') + agent_card_routes = create_agent_card_routes( + agent_card=agent_card, card_url='/' + ) jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=handler, diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py index ffc21306b..6ceb1e070 100644 --- a/tests/integration/test_tenant.py +++ b/tests/integration/test_tenant.py @@ -207,9 +207,7 @@ def server_app(self, jsonrpc_agent_card, mock_handler): extended_agent_card=jsonrpc_agent_card, rpc_url='/jsonrpc', ) - app = Starlette( - routes=[*agent_card_routes, *jsonrpc_routes] - ) + app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) return app @pytest.mark.asyncio diff --git a/tests/integration/test_version_header.py b/tests/integration/test_version_header.py index 7dd79adf4..383d536c7 100644 --- a/tests/integration/test_version_header.py +++ b/tests/integration/test_version_header.py @@ -57,7 +57,9 @@ async def mock_on_message_send_stream(*args, **kwargs): handler.on_message_send_stream = mock_on_message_send_stream app = FastAPI() - agent_card_routes = create_agent_card_routes(agent_card=agent_card, card_url='/') + agent_card_routes = create_agent_card_routes( + agent_card=agent_card, card_url='/' + ) jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=handler, diff --git a/tests/server/routes/test_agent_card_routes.py b/tests/server/routes/test_agent_card_routes.py index de028f5f4..435921d60 100644 --- a/tests/server/routes/test_agent_card_routes.py +++ b/tests/server/routes/test_agent_card_routes.py @@ -62,7 +62,9 @@ async def modifier(card: AgentCard) -> AgentCard: def test_agent_card_custom_url(agent_card): """Tests that custom card_url is respected.""" custom_url = '/custom/path/agent.json' - routes = create_agent_card_routes(agent_card=agent_card, card_url=custom_url) + routes = create_agent_card_routes( + agent_card=agent_card, card_url=custom_url + ) app = Starlette(routes=routes) client = TestClient(app) @@ -71,5 +73,3 @@ def test_agent_card_custom_url(agent_card): assert client.get('/.well-known/agent-card.json').status_code == 404 # Check that custom returns 200 assert client.get(custom_url).status_code == 200 - - diff --git a/tests/server/routes/test_jsonrpc_routes.py b/tests/server/routes/test_jsonrpc_routes.py index bf75b0521..1d3fb5909 100644 --- a/tests/server/routes/test_jsonrpc_routes.py +++ b/tests/server/routes/test_jsonrpc_routes.py @@ -59,4 +59,3 @@ def test_jsonrpc_custom_url(agent_card, mock_handler): assert 'error' in resp_json # Method not found error from dispatcher assert resp_json['error']['code'] == -32601 - From 234eb58f12429dff9f28bc9adc1c495a169a5047 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Sat, 21 Mar 2026 13:46:21 +0000 Subject: [PATCH 09/25] feat: implement Starlette REST routes for A2A protocol, supporting message/task handling, agent cards, and optional v0.3 compatibility. --- src/a2a/server/routes/rest_routes.py | 242 +++++++++++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 src/a2a/server/routes/rest_routes.py diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py new file mode 100644 index 000000000..8a69c27ff --- /dev/null +++ b/src/a2a/server/routes/rest_routes.py @@ -0,0 +1,242 @@ +import functools +import json +import logging +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from typing import TYPE_CHECKING, Any + +from google.protobuf.json_format import MessageToDict + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_handlers.response_helpers import agent_card_to_dict +from a2a.server.request_handlers.rest_handler import RESTHandler +from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder +from a2a.types.a2a_pb2 import AgentCard +from a2a.utils.error_handlers import ( + rest_error_handler, + rest_stream_error_handler, +) +from a2a.utils.errors import ( + ExtendedAgentCardNotConfiguredError, + InvalidRequestError, +) +from a2a.utils.helpers import maybe_await + +if TYPE_CHECKING: + from sse_starlette.sse import EventSourceResponse + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + from starlette.routing import Route + + _package_starlette_installed = True +else: + try: + from sse_starlette.sse import EventSourceResponse + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + from starlette.routing import Route + + _package_starlette_installed = True + except ImportError: + EventSourceResponse = Any + Request = Any + JSONResponse = Any + Response = Any + Route = Any + + _package_starlette_installed = False + +logger = logging.getLogger(__name__) + + +def create_rest_routes( # noqa: PLR0913 + agent_card: AgentCard, + request_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, + context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, + enable_v0_3_compat: bool = False, +) -> list['Route']: + """Creates the Starlette Routes for the A2A protocol REST endpoint. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the request_handler. If None, no + ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. + enable_v0_3_compat: If True, mounts backward-compatible v0.3 protocol + endpoints using REST03Adapter. + """ + if not _package_starlette_installed: + raise ImportError( + 'Packages `starlette` and `sse-starlette` are required to use' + ' the `create_rest_routes`. They can be added as a part of `a2a-sdk` ' + 'optional dependencies, `a2a-sdk[http-server]`.' + ) + + handler = RESTHandler( + agent_card=agent_card, request_handler=request_handler + ) + _context_builder = context_builder or DefaultCallContextBuilder() + + def _build_call_context(request: 'Request') -> ServerCallContext: + call_context = _context_builder.build(request) + if 'tenant' in request.path_params: + call_context.tenant = request.path_params['tenant'] + return call_context + + @rest_error_handler + async def _handle_request( + method: Callable[['Request', ServerCallContext], Awaitable[Any]], + request: 'Request', + ) -> 'Response': + from starlette.responses import JSONResponse + + call_context = _build_call_context(request) + response = await method(request, call_context) + return JSONResponse(content=response) + + @rest_stream_error_handler + async def _handle_streaming_request( + method: Callable[['Request', ServerCallContext], AsyncIterable[Any]], + request: 'Request', + ) -> 'EventSourceResponse': + from sse_starlette.sse import EventSourceResponse + + try: + await request.body() + except (ValueError, RuntimeError, OSError) as e: + raise InvalidRequestError( + message=f'Failed to pre-consume request body: {e}' + ) from e + + call_context = _build_call_context(request) + + async def event_generator( + stream: AsyncIterable[Any], + ) -> AsyncIterator[str]: + async for item in stream: + yield json.dumps(item) + + return EventSourceResponse( + event_generator(method(request, call_context)) + ) + + async def _handle_authenticated_agent_card( + request: 'Request', call_context: ServerCallContext | None = None + ) -> dict[str, Any]: + if not agent_card.capabilities.extended_agent_card: + raise ExtendedAgentCardNotConfiguredError( + message='Authenticated card not supported' + ) + card_to_serve = extended_agent_card or agent_card + + if extended_card_modifier: + # Re-generate context if none passed to replicate RESTAdapter exact logic + context = call_context or _build_call_context(request) + card_to_serve = await maybe_await( + extended_card_modifier(card_to_serve, context) + ) + elif card_modifier: + card_to_serve = await maybe_await(card_modifier(card_to_serve)) + + return MessageToDict(card_to_serve, preserving_proto_field_name=True) + + # Dictionary of routes, mapping to bound helper methods + base_routes: dict[tuple[str, str], Callable[['Request'], Any]] = { + ('/message:send', 'POST'): functools.partial( + _handle_request, handler.on_message_send + ), + ('/message:stream', 'POST'): functools.partial( + _handle_streaming_request, + handler.on_message_send_stream, + ), + ('/tasks/{id}:cancel', 'POST'): functools.partial( + _handle_request, handler.on_cancel_task + ), + ('/tasks/{id}:subscribe', 'GET'): functools.partial( + _handle_streaming_request, + handler.on_subscribe_to_task, + ), + ('/tasks/{id}:subscribe', 'POST'): functools.partial( + _handle_streaming_request, + handler.on_subscribe_to_task, + ), + ('/tasks/{id}', 'GET'): functools.partial( + _handle_request, handler.on_get_task + ), + ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): functools.partial( + _handle_request, handler.get_push_notification + ), + ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'DELETE'): functools.partial( + _handle_request, handler.delete_push_notification + ), + ('/tasks/{id}/pushNotificationConfigs', 'POST'): functools.partial( + _handle_request, handler.set_push_notification + ), + ('/tasks/{id}/pushNotificationConfigs', 'GET'): functools.partial( + _handle_request, handler.list_push_notifications + ), + ('/tasks', 'GET'): functools.partial( + _handle_request, handler.list_tasks + ), + ('/extendedAgentCard', 'GET'): functools.partial( + _handle_request, _handle_authenticated_agent_card + ) + } + + v03_routes = {} + if enable_v0_3_compat: + from a2a.compat.v0_3.rest_adapter import REST03Adapter + + v03_adapter = REST03Adapter( + agent_card=agent_card, + http_handler=request_handler, + extended_agent_card=extended_agent_card, + context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, + ) + v03_routes = v03_adapter.routes() + + routes: list['Route'] = [] + for (path, method), endpoint in base_routes.items(): + routes.append( + Route( + path=path, + endpoint=endpoint, + methods=[method], + ) + ) + routes.append( + Route( + path=f'/{{tenant}}{path}', + endpoint=endpoint, + methods=[method], + ) + ) + + for (path, method), endpoint in v03_routes.items(): + routes.append( + Route( + path=path, + endpoint=endpoint, + methods=[method], + ) + ) + + return routes From dd881bdf3b33229ce31a97863fc1a6613e2210de Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Sun, 22 Mar 2026 09:46:32 +0000 Subject: [PATCH 10/25] wip --- src/a2a/server/apps/__init__.py | 8 - src/a2a/server/apps/rest/__init__.py | 8 - src/a2a/server/apps/rest/fastapi_app.py | 194 ------------ src/a2a/server/apps/rest/rest_adapter.py | 293 ------------------ src/a2a/server/routes/__init__.py | 2 + src/a2a/server/routes/rest_routes.py | 18 +- ...p_compat.py => test_rest_routes_compat.py} | 21 +- tests/e2e/push_notifications/agent_app.py | 34 +- .../cross_version/client_server/server_1_0.py | 11 +- tests/integration/test_agent_card.py | 10 +- .../test_client_server_integration.py | 10 +- tests/integration/test_end_to_end.py | 11 +- tests/integration/test_version_header.py | 10 +- 13 files changed, 75 insertions(+), 555 deletions(-) delete mode 100644 src/a2a/server/apps/__init__.py delete mode 100644 src/a2a/server/apps/rest/__init__.py delete mode 100644 src/a2a/server/apps/rest/fastapi_app.py delete mode 100644 src/a2a/server/apps/rest/rest_adapter.py rename tests/compat/v0_3/{test_rest_fastapi_app_compat.py => test_rest_routes_compat.py} (90%) diff --git a/src/a2a/server/apps/__init__.py b/src/a2a/server/apps/__init__.py deleted file mode 100644 index 1cdb32953..000000000 --- a/src/a2a/server/apps/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""HTTP application components for the A2A server.""" - -from a2a.server.apps.rest import A2ARESTFastAPIApplication - - -__all__ = [ - 'A2ARESTFastAPIApplication', -] diff --git a/src/a2a/server/apps/rest/__init__.py b/src/a2a/server/apps/rest/__init__.py deleted file mode 100644 index bafe4cb60..000000000 --- a/src/a2a/server/apps/rest/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""A2A REST Applications.""" - -from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication - - -__all__ = [ - 'A2ARESTFastAPIApplication', -] diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py deleted file mode 100644 index 4feac9072..000000000 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ /dev/null @@ -1,194 +0,0 @@ -import logging - -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from fastapi import APIRouter, FastAPI, Request, Response - from fastapi.responses import JSONResponse - from starlette.exceptions import HTTPException as StarletteHTTPException - - _package_fastapi_installed = True -else: - try: - from fastapi import APIRouter, FastAPI, Request, Response - from fastapi.responses import JSONResponse - from starlette.exceptions import HTTPException as StarletteHTTPException - - _package_fastapi_installed = True - except ImportError: - APIRouter = Any - FastAPI = Any - Request = Any - Response = Any - StarletteHTTPException = Any - - _package_fastapi_installed = False - - -from a2a.compat.v0_3.rest_adapter import REST03Adapter -from a2a.server.apps.rest.rest_adapter import RESTAdapter -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.routes import CallContextBuilder -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH - - -logger = logging.getLogger(__name__) - - -_HTTP_TO_GRPC_STATUS_MAP = { - 400: 'INVALID_ARGUMENT', - 401: 'UNAUTHENTICATED', - 403: 'PERMISSION_DENIED', - 404: 'NOT_FOUND', - 405: 'UNIMPLEMENTED', - 409: 'ALREADY_EXISTS', - 415: 'INVALID_ARGUMENT', - 422: 'INVALID_ARGUMENT', - 500: 'INTERNAL', - 501: 'UNIMPLEMENTED', - 502: 'INTERNAL', - 503: 'UNAVAILABLE', - 504: 'DEADLINE_EXCEEDED', -} - - -class A2ARESTFastAPIApplication: - """A FastAPI application implementing the A2A protocol server REST endpoints. - - Handles incoming REST requests, routes them to the appropriate - handler methods, and manages response generation including Server-Sent Events - (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - enable_v0_3_compat: bool = False, - ): - """Initializes the A2ARESTFastAPIApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - enable_v0_3_compat: If True, mounts backward-compatible v0.3 protocol - endpoints under the '/v0.3' path prefix using REST03Adapter. - """ - if not _package_fastapi_installed: - raise ImportError( - 'The `fastapi` package is required to use the' - ' `A2ARESTFastAPIApplication`. It can be added as a part of' - ' `a2a-sdk` optional dependencies, `a2a-sdk[http-server]`.' - ) - self._adapter = RESTAdapter( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - ) - self.enable_v0_3_compat = enable_v0_3_compat - self._v03_adapter = None - - if self.enable_v0_3_compat: - self._v03_adapter = REST03Adapter( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - ) - - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = '', - **kwargs: Any, - ) -> FastAPI: - """Builds and returns the FastAPI application instance. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A REST endpoint base path. - **kwargs: Additional keyword arguments to pass to the FastAPI constructor. - - Returns: - A configured FastAPI application instance. - """ - app = FastAPI(**kwargs) - - @app.exception_handler(StarletteHTTPException) - async def http_exception_handler( - request: Request, exc: StarletteHTTPException - ) -> Response: - """Catches framework-level HTTP exceptions. - - For example, 404 Not Found for bad routes, 422 Unprocessable Entity - for schema validation, and formats them into the A2A standard - google.rpc.Status JSON format (AIP-193). - """ - grpc_status = _HTTP_TO_GRPC_STATUS_MAP.get( - exc.status_code, 'UNKNOWN' - ) - return JSONResponse( - status_code=exc.status_code, - content={ - 'error': { - 'code': exc.status_code, - 'status': grpc_status, - 'message': str(exc.detail) - if hasattr(exc, 'detail') - else 'HTTP Exception', - } - }, - media_type='application/json', - ) - - if self.enable_v0_3_compat and self._v03_adapter: - v03_adapter = self._v03_adapter - v03_router = APIRouter() - for route, callback in v03_adapter.routes().items(): - v03_router.add_api_route( - f'{rpc_url}{route[0]}', callback, methods=[route[1]] - ) - app.include_router(v03_router) - - router = APIRouter() - for route, callback in self._adapter.routes().items(): - router.add_api_route( - f'{rpc_url}{route[0]}', callback, methods=[route[1]] - ) - - @router.get(f'{rpc_url}{agent_card_url}') - async def get_agent_card(request: Request) -> Response: - card = await self._adapter.handle_get_agent_card(request) - return JSONResponse(card) - - app.include_router(router) - - return app diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py deleted file mode 100644 index ebf996a47..000000000 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ /dev/null @@ -1,293 +0,0 @@ -import functools -import json -import logging - -from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable -from typing import TYPE_CHECKING, Any - -from google.protobuf.json_format import MessageToDict - -from a2a.utils.helpers import maybe_await - - -if TYPE_CHECKING: - from sse_starlette.sse import EventSourceResponse - from starlette.requests import Request - from starlette.responses import JSONResponse, Response - - _package_starlette_installed = True - -else: - try: - from sse_starlette.sse import EventSourceResponse - from starlette.requests import Request - from starlette.responses import JSONResponse, Response - - _package_starlette_installed = True - except ImportError: - EventSourceResponse = Any - Request = Any - JSONResponse = Any - Response = Any - - _package_starlette_installed = False - -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.response_helpers import ( - agent_card_to_dict, -) -from a2a.server.request_handlers.rest_handler import RESTHandler -from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.error_handlers import ( - rest_error_handler, - rest_stream_error_handler, -) -from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, - InvalidRequestError, -) - - -logger = logging.getLogger(__name__) - - -class RESTAdapterInterface(ABC): - """Interface for RESTAdapter.""" - - @abstractmethod - async def handle_get_agent_card( - self, request: 'Request', call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Handles GET requests for the agent card endpoint.""" - - @abstractmethod - def routes(self) -> dict[tuple[str, str], Callable[['Request'], Any]]: - """Constructs a dictionary of API routes and their corresponding handlers.""" - - -class RESTAdapter(RESTAdapterInterface): - """Adapter to make RequestHandler work with RESTful API. - - Defines REST requests processors and the routes to attach them too, as well as - manages response generation including Server-Sent Events (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - ): - """Initializes the RESTApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - """ - if not _package_starlette_installed: - raise ImportError( - 'Packages `starlette` and `sse-starlette` are required to use' - ' the `RESTAdapter`. They can be added as a part of `a2a-sdk`' - ' optional dependencies, `a2a-sdk[http-server]`.' - ) - self.agent_card = agent_card - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier - self.handler = RESTHandler( - agent_card=agent_card, request_handler=http_handler - ) - self._context_builder = context_builder or DefaultCallContextBuilder() - - @rest_error_handler - async def _handle_request( - self, - method: Callable[[Request, ServerCallContext], Awaitable[Any]], - request: Request, - ) -> Response: - call_context = self._build_call_context(request) - - response = await method(request, call_context) - return JSONResponse(content=response) - - @rest_stream_error_handler - async def _handle_streaming_request( - self, - method: Callable[[Request, ServerCallContext], AsyncIterable[Any]], - request: Request, - ) -> EventSourceResponse: - # Pre-consume and cache the request body to prevent deadlock in streaming context - # This is required because Starlette's request.body() can only be consumed once, - # and attempting to consume it after EventSourceResponse starts causes deadlock - try: - await request.body() - except (ValueError, RuntimeError, OSError) as e: - raise InvalidRequestError( - message=f'Failed to pre-consume request body: {e}' - ) from e - - call_context = self._build_call_context(request) - - async def event_generator( - stream: AsyncIterable[Any], - ) -> AsyncIterator[str]: - async for item in stream: - yield json.dumps(item) - - return EventSourceResponse( - event_generator(method(request, call_context)) - ) - - async def handle_get_agent_card( - self, request: Request, call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Handles GET requests for the agent card endpoint. - - Args: - request: The incoming Starlette Request object. - call_context: ServerCallContext - - Returns: - A JSONResponse containing the agent card data. - """ - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return agent_card_to_dict(card_to_serve) - - async def _handle_authenticated_agent_card( - self, request: Request, call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Hook for per credential agent card response. - - If a dynamic card is needed based on the credentials provided in the request - override this method and return the customized content. - - Args: - request: The incoming Starlette Request object. - call_context: ServerCallContext - - Returns: - A JSONResponse containing the authenticated card. - """ - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - card_to_serve = self.extended_agent_card - - if not card_to_serve: - card_to_serve = self.agent_card - - if self.extended_card_modifier: - context = self._build_call_context(request) - card_to_serve = await maybe_await( - self.extended_card_modifier(card_to_serve, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return MessageToDict(card_to_serve, preserving_proto_field_name=True) - - def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: - """Constructs a dictionary of API routes and their corresponding handlers. - - This method maps URL paths and HTTP methods to the appropriate handler - functions from the RESTHandler. It can be used by a web framework - (like Starlette or FastAPI) to set up the application's endpoints. - - Returns: - A dictionary where each key is a tuple of (path, http_method) and - the value is the callable handler for that route. - """ - base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { - ('/message:send', 'POST'): functools.partial( - self._handle_request, self.handler.on_message_send - ), - ('/message:stream', 'POST'): functools.partial( - self._handle_streaming_request, - self.handler.on_message_send_stream, - ), - ('/tasks/{id}:cancel', 'POST'): functools.partial( - self._handle_request, self.handler.on_cancel_task - ), - ('/tasks/{id}:subscribe', 'GET'): functools.partial( - self._handle_streaming_request, - self.handler.on_subscribe_to_task, - ), - ('/tasks/{id}:subscribe', 'POST'): functools.partial( - self._handle_streaming_request, - self.handler.on_subscribe_to_task, - ), - ('/tasks/{id}', 'GET'): functools.partial( - self._handle_request, self.handler.on_get_task - ), - ( - '/tasks/{id}/pushNotificationConfigs/{push_id}', - 'GET', - ): functools.partial( - self._handle_request, self.handler.get_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs/{push_id}', - 'DELETE', - ): functools.partial( - self._handle_request, self.handler.delete_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs', - 'POST', - ): functools.partial( - self._handle_request, self.handler.set_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs', - 'GET', - ): functools.partial( - self._handle_request, self.handler.list_push_notifications - ), - ('/tasks', 'GET'): functools.partial( - self._handle_request, self.handler.list_tasks - ), - } - - if self.agent_card.capabilities.extended_agent_card: - base_routes[('/extendedAgentCard', 'GET')] = functools.partial( - self._handle_request, self._handle_authenticated_agent_card - ) - - routes: dict[tuple[str, str], Callable[[Request], Any]] = { - (p, method): handler - for (path, method), handler in base_routes.items() - for p in (path, f'/{{tenant}}{path}') - } - - return routes - - def _build_call_context(self, request: Request) -> ServerCallContext: - call_context = self._context_builder.build(request) - if 'tenant' in request.path_params: - call_context.tenant = request.path_params['tenant'] - return call_context diff --git a/src/a2a/server/routes/__init__.py b/src/a2a/server/routes/__init__.py index cf7ed1cdc..bb6ae0ba1 100644 --- a/src/a2a/server/routes/__init__.py +++ b/src/a2a/server/routes/__init__.py @@ -6,6 +6,7 @@ DefaultCallContextBuilder, ) from a2a.server.routes.jsonrpc_routes import create_jsonrpc_routes +from a2a.server.routes.rest_routes import create_rest_routes __all__ = [ @@ -13,4 +14,5 @@ 'DefaultCallContextBuilder', 'create_agent_card_routes', 'create_jsonrpc_routes', + 'create_rest_routes', ] diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 8a69c27ff..c03ca24a0 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -179,12 +179,14 @@ async def _handle_authenticated_agent_card( ('/tasks/{id}', 'GET'): functools.partial( _handle_request, handler.on_get_task ), - ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): functools.partial( - _handle_request, handler.get_push_notification - ), - ('/tasks/{id}/pushNotificationConfigs/{push_id}', 'DELETE'): functools.partial( - _handle_request, handler.delete_push_notification - ), + ( + '/tasks/{id}/pushNotificationConfigs/{push_id}', + 'GET', + ): functools.partial(_handle_request, handler.get_push_notification), + ( + '/tasks/{id}/pushNotificationConfigs/{push_id}', + 'DELETE', + ): functools.partial(_handle_request, handler.delete_push_notification), ('/tasks/{id}/pushNotificationConfigs', 'POST'): functools.partial( _handle_request, handler.set_push_notification ), @@ -196,8 +198,8 @@ async def _handle_authenticated_agent_card( ), ('/extendedAgentCard', 'GET'): functools.partial( _handle_request, _handle_authenticated_agent_card - ) - } + ), + } v03_routes = {} if enable_v0_3_compat: diff --git a/tests/compat/v0_3/test_rest_fastapi_app_compat.py b/tests/compat/v0_3/test_rest_routes_compat.py similarity index 90% rename from tests/compat/v0_3/test_rest_fastapi_app_compat.py rename to tests/compat/v0_3/test_rest_routes_compat.py index 8625b7e0f..5ee0f60ca 100644 --- a/tests/compat/v0_3/test_rest_fastapi_app_compat.py +++ b/tests/compat/v0_3/test_rest_routes_compat.py @@ -8,8 +8,9 @@ from fastapi import FastAPI from google.protobuf import json_format from httpx import ASGITransport, AsyncClient - -from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication +from starlette.applications import Starlette +from a2a.server.routes.rest_routes import create_rest_routes +from a2a.server.routes import create_agent_card_routes from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import ( AgentCard, @@ -50,13 +51,15 @@ async def request_handler() -> RequestHandler: async def app( agent_card: AgentCard, request_handler: RequestHandler, -) -> FastAPI: - """Builds the FastAPI application for testing.""" - return A2ARESTFastAPIApplication( - agent_card, - request_handler, - enable_v0_3_compat=True, - ).build(agent_card_url='/well-known/agent.json', rpc_url='') +) -> Starlette: + """Builds the Starlette application for testing.""" + rest_routes = create_rest_routes( + agent_card, request_handler, enable_v0_3_compat=True + ) + agent_card_routes = create_agent_card_routes( + agent_card=agent_card, card_url='/well-known/agent.json' + ) + return Starlette(routes=rest_routes + agent_card_routes) @pytest.fixture diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index ca1a234bc..94ccae03a 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -3,9 +3,11 @@ from fastapi import FastAPI from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2ARESTFastAPIApplication from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue +from starlette.applications import Starlette +from a2a.server.routes.rest_routes import create_rest_routes +from a2a.server.routes import create_agent_card_routes from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import ( BasePushNotificationSender, @@ -136,20 +138,22 @@ async def cancel( def create_agent_app( url: str, notification_client: httpx.AsyncClient -) -> FastAPI: - """Creates a new HTTP+REST FastAPI application for the test agent.""" +) -> Starlette: + """Creates a new HTTP+REST Starlette application for the test agent.""" push_config_store = InMemoryPushNotificationConfigStore() - app = A2ARESTFastAPIApplication( - agent_card=test_agent_card(url), - http_handler=DefaultRequestHandler( - agent_executor=TestAgentExecutor(), - task_store=InMemoryTaskStore(), - push_config_store=push_config_store, - push_sender=BasePushNotificationSender( - httpx_client=notification_client, - config_store=push_config_store, - context=ServerCallContext(), - ), + card = test_agent_card(url) + handler = DefaultRequestHandler( + agent_executor=TestAgentExecutor(), + task_store=InMemoryTaskStore(), + push_config_store=push_config_store, + push_sender=BasePushNotificationSender( + httpx_client=notification_client, + config_store=push_config_store, + context=ServerCallContext(), ), ) - return app.build() + rest_routes = create_rest_routes(agent_card=card, request_handler=handler) + agent_card_routes = create_agent_card_routes( + agent_card=card, card_url='/.well-known/agent-card.json' + ) + return Starlette(routes=[*rest_routes, *agent_card_routes]) diff --git a/tests/integration/cross_version/client_server/server_1_0.py b/tests/integration/cross_version/client_server/server_1_0.py index 5b9cba9b2..74e0bc23b 100644 --- a/tests/integration/cross_version/client_server/server_1_0.py +++ b/tests/integration/cross_version/client_server/server_1_0.py @@ -6,7 +6,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes -from a2a.server.apps import A2ARESTFastAPIApplication +from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler @@ -182,11 +182,14 @@ async def main_async(http_port: int, grpc_port: int): FastAPI(routes=jsonrpc_routes + agent_card_routes), ) + rest_routes = create_rest_routes( + agent_card=agent_card, + request_handler=handler, + enable_v0_3_compat=True, + ) app.mount( '/rest', - A2ARESTFastAPIApplication( - http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True - ).build(), + FastAPI(routes=rest_routes + agent_card_routes), ) # Start gRPC Server diff --git a/tests/integration/test_agent_card.py b/tests/integration/test_agent_card.py index 719b7be9f..fb31beb73 100644 --- a/tests/integration/test_agent_card.py +++ b/tests/integration/test_agent_card.py @@ -5,7 +5,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from starlette.applications import Starlette -from a2a.server.apps import A2ARESTFastAPIApplication +from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager @@ -83,10 +83,10 @@ async def test_agent_card_integration(header_val: str | None) -> None: jsonrpc_app = Starlette(routes=jsonrpc_routes) app.mount('/jsonrpc', jsonrpc_app) - # Mount REST application - rest_app = A2ARESTFastAPIApplication( - http_handler=handler, agent_card=agent_card - ).build() + rest_routes = create_rest_routes( + agent_card=agent_card, request_handler=handler + ) + rest_app = Starlette(routes=rest_routes) app.mount('/rest', rest_app) expected_content = { diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 94d0313a6..223fe5c0c 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -24,8 +24,7 @@ ) from a2a.client.transports import JsonRpcTransport, RestTransport from starlette.applications import Starlette -from a2a.server.apps import A2ARESTFastAPIApplication -from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes, create_rest_routes from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( @@ -251,10 +250,13 @@ def jsonrpc_setup(http_base_setup) -> TransportSetup: def rest_setup(http_base_setup) -> TransportSetup: """Sets up the RestTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - app_builder = A2ARESTFastAPIApplication( + rest_routes = create_rest_routes( agent_card, mock_request_handler, extended_agent_card=agent_card ) - app = app_builder.build() + agent_card_routes = create_agent_card_routes( + agent_card=agent_card, card_url='/' + ) + app = Starlette(routes=[*rest_routes, *agent_card_routes]) httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) factory = ClientFactory( config=ClientConfig( diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index a6f8f866a..d6fe41070 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -10,8 +10,8 @@ from a2a.client.client import ClientConfig from a2a.client.client_factory import ClientFactory from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.routes.rest_routes import create_rest_routes from starlette.applications import Starlette -from a2a.server.apps import A2ARESTFastAPIApplication from a2a.server.routes import create_jsonrpc_routes, create_agent_card_routes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager @@ -173,8 +173,13 @@ def base_e2e_setup(): @pytest.fixture def rest_setup(agent_card, base_e2e_setup) -> ClientSetup: task_store, handler = base_e2e_setup - app_builder = A2ARESTFastAPIApplication(agent_card, handler) - app = app_builder.build() + rest_routes = create_rest_routes( + agent_card=agent_card, request_handler=handler + ) + agent_card_routes = create_agent_card_routes( + agent_card=agent_card, card_url='/' + ) + app = Starlette(routes=[*rest_routes, *agent_card_routes]) httpx_client = httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url='http://testserver' ) diff --git a/tests/integration/test_version_header.py b/tests/integration/test_version_header.py index 383d536c7..96435b38f 100644 --- a/tests/integration/test_version_header.py +++ b/tests/integration/test_version_header.py @@ -4,7 +4,7 @@ from starlette.testclient import TestClient from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.apps import A2ARESTFastAPIApplication +from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager @@ -69,9 +69,11 @@ async def mock_on_message_send_stream(*args, **kwargs): ) app.routes.extend(agent_card_routes) app.routes.extend(jsonrpc_routes) - rest_app = A2ARESTFastAPIApplication( - http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True - ).build() + + rest_routes = create_rest_routes( + agent_card=agent_card, request_handler=handler, enable_v0_3_compat=True + ) + rest_app = Starlette(routes=rest_routes) app.mount('/rest', rest_app) return app From b26802a38e10b1e291b078fd301b9b910db3762c Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Sun, 22 Mar 2026 10:21:07 +0000 Subject: [PATCH 11/25] wip --- samples/hello_world_agent.py | 2 +- src/a2a/compat/v0_3/rest_adapter.py | 3 +- src/a2a/server/routes/rest_routes.py | 23 +- tck/sut_agent.py | 6 +- tests/integration/test_agent_card.py | 11 +- .../test_client_server_integration.py | 6 +- tests/integration/test_version_header.py | 6 +- tests/server/apps/rest/__init__.py | 0 .../server/apps/rest/test_rest_fastapi_app.py | 728 ------------------ tests/server/routes/test_agent_card_routes.py | 4 - .../server/routes/test_jsonrpc_dispatcher.py | 1 - tests/server/routes/test_jsonrpc_routes.py | 1 - tests/server/routes/test_rest_routes.py | 97 +++ 13 files changed, 128 insertions(+), 760 deletions(-) delete mode 100644 tests/server/apps/rest/__init__.py delete mode 100644 tests/server/apps/rest/test_rest_fastapi_app.py create mode 100644 tests/server/routes/test_rest_routes.py diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index fa9ab3c2b..20b9804ba 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -5,13 +5,13 @@ import grpc import uvicorn +from a2a.server.apps import A2ARESTFastAPIApplication from fastapi import FastAPI from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext -from a2a.server.apps import A2ARESTFastAPIApplication from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import GrpcHandler from a2a.server.request_handlers.default_request_handler import ( diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index 8cae6b630..3d1e9cb77 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -33,7 +33,6 @@ from a2a.compat.v0_3 import conversions from a2a.compat.v0_3.rest_handler import REST03Handler -from a2a.server.apps.rest.rest_adapter import RESTAdapterInterface from a2a.server.context import ServerCallContext from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder from a2a.utils.error_handlers import ( @@ -50,7 +49,7 @@ logger = logging.getLogger(__name__) -class REST03Adapter(RESTAdapterInterface): +class REST03Adapter: """Adapter to make RequestHandler work with v0.3 RESTful API. Defines v0.3 REST request processors and their routes, as well as managing response generation including Server-Sent Events (SSE). diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index c03ca24a0..c0049c5a6 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -1,14 +1,16 @@ import functools import json import logging + from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from typing import TYPE_CHECKING, Any from google.protobuf.json_format import MessageToDict +from a2a.compat.v0_3.rest_adapter import REST03Adapter +from a2a.utils.constants import DEFAULT_RPC_URL from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.response_helpers import agent_card_to_dict from a2a.server.request_handlers.rest_handler import RESTHandler from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder from a2a.types.a2a_pb2 import AgentCard @@ -22,6 +24,7 @@ ) from a2a.utils.helpers import maybe_await + if TYPE_CHECKING: from sse_starlette.sse import EventSourceResponse from starlette.requests import Request @@ -61,7 +64,8 @@ def create_rest_routes( # noqa: PLR0913 ] | None = None, enable_v0_3_compat: bool = False, -) -> list['Route']: + rpc_url: str = DEFAULT_RPC_URL, +) -> list[Route]: """Creates the Starlette Routes for the A2A protocol REST endpoint. Args: @@ -104,7 +108,6 @@ async def _handle_request( method: Callable[['Request', ServerCallContext], Awaitable[Any]], request: 'Request', ) -> 'Response': - from starlette.responses import JSONResponse call_context = _build_call_context(request) response = await method(request, call_context) @@ -115,8 +118,6 @@ async def _handle_streaming_request( method: Callable[['Request', ServerCallContext], AsyncIterable[Any]], request: 'Request', ) -> 'EventSourceResponse': - from sse_starlette.sse import EventSourceResponse - try: await request.body() except (ValueError, RuntimeError, OSError) as e: @@ -157,7 +158,7 @@ async def _handle_authenticated_agent_card( return MessageToDict(card_to_serve, preserving_proto_field_name=True) # Dictionary of routes, mapping to bound helper methods - base_routes: dict[tuple[str, str], Callable[['Request'], Any]] = { + base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { ('/message:send', 'POST'): functools.partial( _handle_request, handler.on_message_send ), @@ -203,8 +204,6 @@ async def _handle_authenticated_agent_card( v03_routes = {} if enable_v0_3_compat: - from a2a.compat.v0_3.rest_adapter import REST03Adapter - v03_adapter = REST03Adapter( agent_card=agent_card, http_handler=request_handler, @@ -215,18 +214,18 @@ async def _handle_authenticated_agent_card( ) v03_routes = v03_adapter.routes() - routes: list['Route'] = [] + routes: list[Route] = [] for (path, method), endpoint in base_routes.items(): routes.append( Route( - path=path, + path=f'{rpc_url}{path}', endpoint=endpoint, methods=[method], ) ) routes.append( Route( - path=f'/{{tenant}}{path}', + path=f'/{{tenant}}{rpc_url}{path}', endpoint=endpoint, methods=[method], ) @@ -235,7 +234,7 @@ async def _handle_authenticated_agent_card( for (path, method), endpoint in v03_routes.items(): routes.append( Route( - path=path, + path=f'{rpc_url}{path}', endpoint=endpoint, methods=[method], ) diff --git a/tck/sut_agent.py b/tck/sut_agent.py index d133e257a..a25922db1 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -8,6 +8,9 @@ import grpc.aio import uvicorn +from a2a.server.apps import ( + A2ARESTFastAPIApplication, +) from starlette.applications import Starlette import a2a.compat.v0_3.a2a_v0_3_pb2_grpc as a2a_v0_3_grpc @@ -16,9 +19,6 @@ from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext -from a2a.server.apps import ( - A2ARESTFastAPIApplication, -) from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, diff --git a/tests/integration/test_agent_card.py b/tests/integration/test_agent_card.py index fb31beb73..85a282a9e 100644 --- a/tests/integration/test_agent_card.py +++ b/tests/integration/test_agent_card.py @@ -83,9 +83,14 @@ async def test_agent_card_integration(header_val: str | None) -> None: jsonrpc_app = Starlette(routes=jsonrpc_routes) app.mount('/jsonrpc', jsonrpc_app) - rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=handler - ) + rest_routes = [ + *create_agent_card_routes( + agent_card=agent_card, card_url='/.well-known/agent-card.json' + ), + *create_rest_routes( + agent_card=agent_card, request_handler=handler + ), + ] rest_app = Starlette(routes=rest_routes) app.mount('/rest', rest_app) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 223fe5c0c..ac7a51797 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -24,7 +24,11 @@ ) from a2a.client.transports import JsonRpcTransport, RestTransport from starlette.applications import Starlette -from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes, create_rest_routes +from a2a.server.routes import ( + create_agent_card_routes, + create_jsonrpc_routes, + create_rest_routes, +) from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( diff --git a/tests/integration/test_version_header.py b/tests/integration/test_version_header.py index 96435b38f..56331302d 100644 --- a/tests/integration/test_version_header.py +++ b/tests/integration/test_version_header.py @@ -63,7 +63,6 @@ async def mock_on_message_send_stream(*args, **kwargs): jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=handler, - extended_agent_card=agent_card, rpc_url='/jsonrpc', enable_v0_3_compat=True, ) @@ -71,10 +70,9 @@ async def mock_on_message_send_stream(*args, **kwargs): app.routes.extend(jsonrpc_routes) rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=handler, enable_v0_3_compat=True + agent_card=agent_card, request_handler=handler, rpc_url='/rest', enable_v0_3_compat=True ) - rest_app = Starlette(routes=rest_routes) - app.mount('/rest', rest_app) + app.routes.extend(rest_routes) return app diff --git a/tests/server/apps/rest/__init__.py b/tests/server/apps/rest/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py deleted file mode 100644 index 1c976c94b..000000000 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ /dev/null @@ -1,728 +0,0 @@ -import logging -import json - -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from fastapi import FastAPI -from google.protobuf import json_format -from httpx import ASGITransport, AsyncClient - -from a2a.server.apps.rest import fastapi_app, rest_adapter -from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication -from a2a.server.apps.rest.rest_adapter import RESTAdapter -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types import a2a_pb2 -from a2a.types.a2a_pb2 import ( - AgentCard, - ListTaskPushNotificationConfigsResponse, - ListTasksResponse, - Message, - Part, - Role, - Task, - TaskPushNotificationConfig, - TaskState, - TaskStatus, -) - - -logger = logging.getLogger(__name__) - - -@pytest.fixture -async def agent_card() -> AgentCard: - mock_agent_card = MagicMock(spec=AgentCard) - mock_agent_card.url = 'http://mockurl.com' - - # Mock the capabilities object with streaming enabled - mock_capabilities = MagicMock() - mock_capabilities.streaming = True - mock_capabilities.push_notifications = True - mock_capabilities.extended_agent_card = True - mock_agent_card.capabilities = mock_capabilities - - return mock_agent_card - - -@pytest.fixture -async def streaming_agent_card() -> AgentCard: - """Agent card that supports streaming for testing streaming endpoints.""" - mock_agent_card = MagicMock(spec=AgentCard) - mock_agent_card.url = 'http://mockurl.com' - - # Mock the capabilities object with streaming enabled - mock_capabilities = MagicMock() - mock_capabilities.streaming = True - mock_agent_card.capabilities = mock_capabilities - - return mock_agent_card - - -@pytest.fixture -async def request_handler() -> RequestHandler: - return MagicMock(spec=RequestHandler) - - -@pytest.fixture -async def extended_card_modifier() -> MagicMock | None: - return None - - -@pytest.fixture -async def streaming_app( - streaming_agent_card: AgentCard, request_handler: RequestHandler -) -> FastAPI: - """Builds the FastAPI application for testing streaming endpoints.""" - - return A2ARESTFastAPIApplication( - streaming_agent_card, request_handler - ).build(agent_card_url='/well-known/agent-card.json', rpc_url='') - - -@pytest.fixture -async def streaming_client(streaming_app: FastAPI) -> AsyncClient: - """HTTP client for the streaming FastAPI application.""" - return AsyncClient( - transport=ASGITransport(app=streaming_app), - base_url='http://test', - headers={'A2A-Version': '1.0'}, - ) - - -@pytest.fixture -async def app( - agent_card: AgentCard, - request_handler: RequestHandler, - extended_card_modifier: MagicMock | None, -) -> FastAPI: - """Builds the FastAPI application for testing.""" - - return A2ARESTFastAPIApplication( - agent_card, - request_handler, - extended_card_modifier=extended_card_modifier, - ).build(agent_card_url='/well-known/agent.json', rpc_url='') - - -@pytest.fixture -async def client(app: FastAPI) -> AsyncClient: - return AsyncClient( - transport=ASGITransport(app=app), - base_url='http://testapp', - headers={'A2A-Version': '1.0'}, - ) - - -@pytest.fixture -def mark_pkg_starlette_not_installed(): - pkg_starlette_installed_flag = rest_adapter._package_starlette_installed - rest_adapter._package_starlette_installed = False - yield - rest_adapter._package_starlette_installed = pkg_starlette_installed_flag - - -@pytest.fixture -def mark_pkg_fastapi_not_installed(): - pkg_fastapi_installed_flag = fastapi_app._package_fastapi_installed - fastapi_app._package_fastapi_installed = False - yield - fastapi_app._package_fastapi_installed = pkg_fastapi_installed_flag - - -@pytest.mark.anyio -async def test_create_rest_adapter_with_present_deps_succeeds( - agent_card: AgentCard, request_handler: RequestHandler -): - try: - _app = RESTAdapter(agent_card, request_handler) - except ImportError: - pytest.fail( - 'With packages starlette and see-starlette present, creating an' - ' RESTAdapter instance should not raise ImportError' - ) - - -@pytest.mark.anyio -async def test_create_rest_adapter_with_missing_deps_raises_importerror( - agent_card: AgentCard, - request_handler: RequestHandler, - mark_pkg_starlette_not_installed: Any, -): - with pytest.raises( - ImportError, - match=( - r'Packages `starlette` and `sse-starlette` are required to use' - r' the `RESTAdapter`.' - ), - ): - _app = RESTAdapter(agent_card, request_handler) - - -@pytest.mark.anyio -async def test_create_a2a_rest_fastapi_app_with_present_deps_succeeds( - agent_card: AgentCard, request_handler: RequestHandler -): - try: - _app = A2ARESTFastAPIApplication(agent_card, request_handler).build( - agent_card_url='/well-known/agent.json', rpc_url='' - ) - except ImportError: - pytest.fail( - 'With the fastapi package present, creating a' - ' A2ARESTFastAPIApplication instance should not raise ImportError' - ) - - -@pytest.mark.anyio -async def test_create_a2a_rest_fastapi_app_with_missing_deps_raises_importerror( - agent_card: AgentCard, - request_handler: RequestHandler, - mark_pkg_fastapi_not_installed: Any, -): - with pytest.raises( - ImportError, - match=( - 'The `fastapi` package is required to use the' - ' `A2ARESTFastAPIApplication`' - ), - ): - _app = A2ARESTFastAPIApplication(agent_card, request_handler).build( - agent_card_url='/well-known/agent.json', rpc_url='' - ) - - -@pytest.mark.anyio -async def test_create_a2a_rest_fastapi_app_with_v0_3_compat( - agent_card: AgentCard, request_handler: RequestHandler -): - app = A2ARESTFastAPIApplication( - agent_card, request_handler, enable_v0_3_compat=True - ).build(agent_card_url='/well-known/agent.json', rpc_url='') - - routes = [getattr(route, 'path', '') for route in app.routes] - assert '/v1/message:send' in routes - - -@pytest.mark.anyio -async def test_send_message_success_message( - client: AsyncClient, request_handler: MagicMock -) -> None: - expected_response = a2a_pb2.SendMessageResponse( - message=a2a_pb2.Message( - message_id='test', - role=a2a_pb2.Role.ROLE_AGENT, - parts=[ - a2a_pb2.Part(text='response message'), - ], - ), - ) - request_handler.on_message_send.return_value = Message( - message_id='test', - role=Role.ROLE_AGENT, - parts=[Part(text='response message')], - ) - - request = a2a_pb2.SendMessageRequest( - message=a2a_pb2.Message(), - configuration=a2a_pb2.SendMessageConfiguration(), - ) - # To see log output, run pytest with '--log-cli=true --log-cli-level=INFO' - response = await client.post( - '/message:send', json=json_format.MessageToDict(request) - ) - # request should always be successful - response.raise_for_status() - - actual_response = a2a_pb2.SendMessageResponse() - json_format.Parse(response.text, actual_response) - assert expected_response == actual_response - - -@pytest.mark.anyio -async def test_send_message_success_task( - client: AsyncClient, request_handler: MagicMock -) -> None: - expected_response = a2a_pb2.SendMessageResponse( - task=a2a_pb2.Task( - id='test_task_id', - context_id='test_context_id', - status=a2a_pb2.TaskStatus( - state=a2a_pb2.TaskState.TASK_STATE_COMPLETED, - message=a2a_pb2.Message( - message_id='test', - role=a2a_pb2.Role.ROLE_AGENT, - parts=[ - a2a_pb2.Part(text='response task message'), - ], - ), - ), - ), - ) - request_handler.on_message_send.return_value = Task( - id='test_task_id', - context_id='test_context_id', - status=TaskStatus( - state=TaskState.TASK_STATE_COMPLETED, - message=Message( - message_id='test', - role=Role.ROLE_AGENT, - parts=[Part(text='response task message')], - ), - ), - ) - - request = a2a_pb2.SendMessageRequest( - message=a2a_pb2.Message(), - configuration=a2a_pb2.SendMessageConfiguration(), - ) - # To see log output, run pytest with '--log-cli=true --log-cli-level=INFO' - response = await client.post( - '/message:send', json=json_format.MessageToDict(request) - ) - # request should always be successful - response.raise_for_status() - - actual_response = a2a_pb2.SendMessageResponse() - json_format.Parse(response.text, actual_response) - assert expected_response == actual_response - - -@pytest.mark.anyio -async def test_streaming_message_request_body_consumption( - streaming_client: AsyncClient, request_handler: MagicMock -) -> None: - """Test that streaming endpoint properly handles request body consumption. - - This test verifies the fix for the deadlock issue where request.body() - was being consumed inside the EventSourceResponse context, causing - the application to hang indefinitely. - """ - - # Mock the async generator response from the request handler - async def mock_stream_response(): - """Mock streaming response generator.""" - yield Message( - message_id='stream_msg_1', - role=Role.ROLE_AGENT, - parts=[Part(text='First streaming response')], - ) - yield Message( - message_id='stream_msg_2', - role=Role.ROLE_AGENT, - parts=[Part(text='Second streaming response')], - ) - - request_handler.on_message_send_stream.return_value = mock_stream_response() - - # Create a valid streaming request - request = a2a_pb2.SendMessageRequest( - message=a2a_pb2.Message( - message_id='test_stream_msg', - role=a2a_pb2.ROLE_USER, - parts=[a2a_pb2.Part(text='Test streaming message')], - ), - configuration=a2a_pb2.SendMessageConfiguration(), - ) - - # This should not hang indefinitely (previously it would due to the deadlock) - response = await streaming_client.post( - '/message:stream', - json=json_format.MessageToDict(request), - headers={'Accept': 'text/event-stream'}, - timeout=10.0, # Reasonable timeout to prevent hanging in tests - ) - - # The response should be successful - response.raise_for_status() - assert response.status_code == 200 - assert 'text/event-stream' in response.headers.get('content-type', '') - - # Verify that the request handler was called - request_handler.on_message_send_stream.assert_called_once() - - -@pytest.mark.anyio -async def test_streaming_content_verification( - streaming_client: AsyncClient, request_handler: MagicMock -) -> None: - """Test that streaming endpoint returns correct SSE content.""" - - async def mock_stream_response(): - yield Message( - message_id='stream_msg_1', - role=Role.ROLE_AGENT, - parts=[Part(text='First chunk')], - ) - yield Message( - message_id='stream_msg_2', - role=Role.ROLE_AGENT, - parts=[Part(text='Second chunk')], - ) - - request_handler.on_message_send_stream.return_value = mock_stream_response() - - request = a2a_pb2.SendMessageRequest( - message=a2a_pb2.Message( - message_id='test_stream_msg', - role=a2a_pb2.ROLE_USER, - parts=[a2a_pb2.Part(text='Test message')], - ), - ) - - response = await streaming_client.post( - '/message:stream', - headers={'Accept': 'text/event-stream'}, - json=json_format.MessageToDict(request), - ) - - response.raise_for_status() - - # Read the response content - lines = [line async for line in response.aiter_lines()] - - # SSE format is "data: \n\n" - # httpx.aiter_lines() will give us each line. - data_lines = [ - json.loads(line[6:]) for line in lines if line.startswith('data: ') - ] - - expected_data_lines = [ - { - 'message': { - 'messageId': 'stream_msg_1', - 'role': 'ROLE_AGENT', - 'parts': [{'text': 'First chunk'}], - } - }, - { - 'message': { - 'messageId': 'stream_msg_2', - 'role': 'ROLE_AGENT', - 'parts': [{'text': 'Second chunk'}], - } - }, - ] - - assert data_lines == expected_data_lines - - -@pytest.mark.anyio -async def test_subscribe_to_task_get( - streaming_client: AsyncClient, request_handler: MagicMock -) -> None: - """Test that GET /tasks/{id}:subscribe works.""" - - async def mock_stream_response(): - yield Task( - id='task-1', - context_id='ctx-1', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - - request_handler.on_subscribe_to_task.return_value = mock_stream_response() - - response = await streaming_client.get( - '/tasks/task-1:subscribe', - headers={'Accept': 'text/event-stream'}, - ) - - response.raise_for_status() - assert response.status_code == 200 - - # Verify handler call - request_handler.on_subscribe_to_task.assert_called_once() - args, _ = request_handler.on_subscribe_to_task.call_args - assert args[0].id == 'task-1' - - -@pytest.mark.anyio -async def test_subscribe_to_task_post( - streaming_client: AsyncClient, request_handler: MagicMock -) -> None: - """Test that POST /tasks/{id}:subscribe works.""" - - async def mock_stream_response(): - yield Task( - id='task-1', - context_id='ctx-1', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - - request_handler.on_subscribe_to_task.return_value = mock_stream_response() - - response = await streaming_client.post( - '/tasks/task-1:subscribe', - headers={'Accept': 'text/event-stream'}, - ) - - response.raise_for_status() - assert response.status_code == 200 - - # Verify handler call - request_handler.on_subscribe_to_task.assert_called_once() - args, _ = request_handler.on_subscribe_to_task.call_args - assert args[0].id == 'task-1' - - -@pytest.mark.anyio -async def test_streaming_endpoint_with_invalid_content_type( - streaming_client: AsyncClient, request_handler: MagicMock -) -> None: - """Test streaming endpoint behavior with invalid content type.""" - - async def mock_stream_response(): - yield Message( - message_id='stream_msg_1', - role=Role.ROLE_AGENT, - parts=[Part(text='Response')], - ) - - request_handler.on_message_send_stream.return_value = mock_stream_response() - - request = a2a_pb2.SendMessageRequest( - message=a2a_pb2.Message( - message_id='test_stream_msg', - role=a2a_pb2.ROLE_USER, - parts=[a2a_pb2.Part(text='Test message')], - ), - configuration=a2a_pb2.SendMessageConfiguration(), - ) - - # Send request without proper event-stream headers - response = await streaming_client.post( - '/message:stream', - json=json_format.MessageToDict(request), - timeout=10.0, - ) - - # Should still succeed (the adapter handles content-type internally) - response.raise_for_status() - assert response.status_code == 200 - - -@pytest.mark.anyio -async def test_send_message_rejected_task( - client: AsyncClient, request_handler: MagicMock -) -> None: - expected_response = a2a_pb2.SendMessageResponse( - task=a2a_pb2.Task( - id='test_task_id', - context_id='test_context_id', - status=a2a_pb2.TaskStatus( - state=a2a_pb2.TaskState.TASK_STATE_REJECTED, - message=a2a_pb2.Message( - message_id='test', - role=a2a_pb2.Role.ROLE_AGENT, - parts=[ - a2a_pb2.Part(text="I don't want to work"), - ], - ), - ), - ), - ) - request_handler.on_message_send.return_value = Task( - id='test_task_id', - context_id='test_context_id', - status=TaskStatus( - state=TaskState.TASK_STATE_REJECTED, - message=Message( - message_id='test', - role=Role.ROLE_AGENT, - parts=[Part(text="I don't want to work")], - ), - ), - ) - request = a2a_pb2.SendMessageRequest( - message=a2a_pb2.Message(), - configuration=a2a_pb2.SendMessageConfiguration(), - ) - - response = await client.post( - '/message:send', json=json_format.MessageToDict(request) - ) - - response.raise_for_status() - actual_response = a2a_pb2.SendMessageResponse() - json_format.Parse(response.text, actual_response) - assert expected_response == actual_response - - -@pytest.mark.anyio -class TestTenantExtraction: - @pytest.fixture(autouse=True) - def configure_mocks(self, request_handler: MagicMock) -> None: - # Setup default return values for all handlers - async def mock_stream(*args, **kwargs): - if False: - yield - - request_handler.on_subscribe_to_task.side_effect = ( - lambda *args, **kwargs: mock_stream() - ) - - request_handler.on_message_send.return_value = Message( - message_id='test', - role=Role.ROLE_AGENT, - parts=[Part(text='response message')], - ) - request_handler.on_cancel_task.return_value = Task(id='1') - request_handler.on_get_task.return_value = Task(id='1') - request_handler.on_list_tasks.return_value = ListTasksResponse() - request_handler.on_create_task_push_notification_config.return_value = ( - TaskPushNotificationConfig() - ) - request_handler.on_get_task_push_notification_config.return_value = ( - TaskPushNotificationConfig() - ) - request_handler.on_list_task_push_notification_configs.return_value = ( - ListTaskPushNotificationConfigsResponse() - ) - request_handler.on_delete_task_push_notification_config.return_value = ( - None - ) - - @pytest.fixture - def extended_card_modifier(self) -> MagicMock: - modifier = MagicMock() - modifier.return_value = AgentCard() - return modifier - - @pytest.mark.parametrize( - 'path_template, method, handler_method_name, json_body', - [ - ('/message:send', 'POST', 'on_message_send', {'message': {}}), - ('/tasks/1:cancel', 'POST', 'on_cancel_task', None), - ('/tasks/1:subscribe', 'GET', 'on_subscribe_to_task', None), - ('/tasks/1:subscribe', 'POST', 'on_subscribe_to_task', None), - ('/tasks/1', 'GET', 'on_get_task', None), - ('/tasks', 'GET', 'on_list_tasks', None), - ( - '/tasks/1/pushNotificationConfigs/p1', - 'GET', - 'on_get_task_push_notification_config', - None, - ), - ( - '/tasks/1/pushNotificationConfigs/p1', - 'DELETE', - 'on_delete_task_push_notification_config', - None, - ), - ( - '/tasks/1/pushNotificationConfigs', - 'POST', - 'on_create_task_push_notification_config', - {'url': 'http://foo'}, - ), - ( - '/tasks/1/pushNotificationConfigs', - 'GET', - 'on_list_task_push_notification_configs', - None, - ), - ], - ) - async def test_tenant_extraction_parametrized( # noqa: PLR0913 # Test parametrization requires many arguments - self, - client: AsyncClient, - request_handler: MagicMock, - path_template: str, - method: str, - handler_method_name: str, - json_body: dict | None, - ) -> None: - """Test tenant extraction for standard REST endpoints.""" - # Test with tenant - tenant = 'my-tenant' - tenant_path = f'/{tenant}{path_template}' - - response = await client.request(method, tenant_path, json=json_body) - response.raise_for_status() - - # Verify handler call - handler_mock = getattr(request_handler, handler_method_name) - - assert handler_mock.called - args, _ = handler_mock.call_args - context = args[1] - assert context.tenant == tenant - - # Reset mock for non-tenant test - handler_mock.reset_mock() - - # Test without tenant - response = await client.request(method, path_template, json=json_body) - response.raise_for_status() - - # Verify context.tenant == "" - assert handler_mock.called - args, _ = handler_mock.call_args - context = args[1] - assert context.tenant == '' - - async def test_tenant_extraction_extended_agent_card( - self, - client: AsyncClient, - extended_card_modifier: MagicMock, - ) -> None: - """Test tenant extraction specifically for extendedAgentCard endpoint.""" - # Test with tenant - tenant = 'my-tenant' - tenant_path = f'/{tenant}/extendedAgentCard' - - response = await client.get(tenant_path) - response.raise_for_status() - - # Verify extended_card_modifier called with tenant context - assert extended_card_modifier.called - args, _ = extended_card_modifier.call_args - context = args[1] - assert context.tenant == tenant - - # Reset mock for non-tenant test - extended_card_modifier.reset_mock() - - # Test without tenant - response = await client.get('/extendedAgentCard') - response.raise_for_status() - - # Verify extended_card_modifier called with empty tenant context - assert extended_card_modifier.called - args, _ = extended_card_modifier.call_args - context = args[1] - assert context.tenant == '' - - -@pytest.mark.anyio -async def test_global_http_exception_handler_returns_rpc_status( - client: AsyncClient, -) -> None: - """Test that a standard FastAPI 404 is transformed into the A2A google.rpc.Status format.""" - - # Send a request to an endpoint that does not exist - response = await client.get('/non-existent-route') - - # Verify it returns a 404 with standard application/json - assert response.status_code == 404 - assert response.headers.get('content-type') == 'application/json' - - data = response.json() - - # Assert the payload is wrapped in the "error" envelope - assert 'error' in data - error_payload = data['error'] - - # Assert it has the correct AIP-193 format - assert error_payload['code'] == 404 - assert error_payload['status'] == 'NOT_FOUND' - assert 'Not Found' in error_payload['message'] - - # Standard HTTP errors shouldn't leak details - assert 'details' not in error_payload - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/server/routes/test_agent_card_routes.py b/tests/server/routes/test_agent_card_routes.py index 435921d60..b24438a57 100644 --- a/tests/server/routes/test_agent_card_routes.py +++ b/tests/server/routes/test_agent_card_routes.py @@ -1,11 +1,7 @@ -# ruff: noqa: INP001 -import asyncio -from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest from starlette.testclient import TestClient -from starlette.middleware import Middleware from starlette.applications import Starlette from a2a.server.routes.agent_card_routes import create_agent_card_routes diff --git a/tests/server/routes/test_jsonrpc_dispatcher.py b/tests/server/routes/test_jsonrpc_dispatcher.py index 4fb398660..586486b01 100644 --- a/tests/server/routes/test_jsonrpc_dispatcher.py +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -1,4 +1,3 @@ -# ruff: noqa: INP001 import json from typing import Any from unittest.mock import AsyncMock, MagicMock, patch diff --git a/tests/server/routes/test_jsonrpc_routes.py b/tests/server/routes/test_jsonrpc_routes.py index 1d3fb5909..5bfa931ee 100644 --- a/tests/server/routes/test_jsonrpc_routes.py +++ b/tests/server/routes/test_jsonrpc_routes.py @@ -1,4 +1,3 @@ -# ruff: noqa: INP001 from typing import Any from unittest.mock import AsyncMock, MagicMock diff --git a/tests/server/routes/test_rest_routes.py b/tests/server/routes/test_rest_routes.py new file mode 100644 index 000000000..0c1b5a2a6 --- /dev/null +++ b/tests/server/routes/test_rest_routes.py @@ -0,0 +1,97 @@ +from unittest.mock import AsyncMock + +import pytest +from starlette.applications import Starlette +from starlette.testclient import TestClient +from starlette.routing import Route + +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.routes.rest_routes import create_rest_routes +from a2a.types.a2a_pb2 import AgentCard, Task, ListTasksResponse + + +@pytest.fixture +def agent_card(): + return AgentCard() + + +@pytest.fixture +def mock_handler(): + return AsyncMock(spec=RequestHandler) + + +def test_routes_creation(agent_card, mock_handler): + """Tests that create_rest_routes creates Route objects list.""" + routes = create_rest_routes( + agent_card=agent_card, request_handler=mock_handler + ) + + assert isinstance(routes, list) + assert len(routes) > 0 + assert all(isinstance(r, Route) for r in routes) + + +def test_routes_creation_v03_compat(agent_card, mock_handler): + """Tests that create_rest_routes creates more routes with enable_v0_3_compat.""" + routes_without_compat = create_rest_routes( + agent_card=agent_card, request_handler=mock_handler, enable_v0_3_compat=False + ) + routes_with_compat = create_rest_routes( + agent_card=agent_card, request_handler=mock_handler, enable_v0_3_compat=True + ) + + assert len(routes_with_compat) > len(routes_without_compat) + + +def test_rest_endpoints_routing(agent_card, mock_handler): + """Tests that mounted routes route to the handler endpoints.""" + mock_handler.on_message_send.return_value = Task(id='123') + + routes = create_rest_routes( + agent_card=agent_card, request_handler=mock_handler + ) + app = Starlette(routes=routes) + client = TestClient(app) + + # Test POST /message:send + response = client.post('/message:send', json={}, headers={'A2A-Version': '1.0'}) + assert response.status_code == 200 + assert response.json()['task']['id'] == '123' + assert mock_handler.on_message_send.called + + +def test_rest_endpoints_routing_tenant(agent_card, mock_handler): + """Tests that mounted routes with {tenant} route to the handler endpoints.""" + mock_handler.on_message_send.return_value = Task(id='123') + + routes = create_rest_routes( + agent_card=agent_card, request_handler=mock_handler + ) + app = Starlette(routes=routes) + client = TestClient(app) + + # Test POST /{tenant}/message:send + response = client.post('/my-tenant/message:send', json={}, headers={'A2A-Version': '1.0'}) + assert response.status_code == 200 + + # Verify that tenant was set in call context + call_args = mock_handler.on_message_send.call_args + assert call_args is not None + # call_args[0] is positional args. In on_message_send(params, context): + context = call_args[0][1] + assert context.tenant == 'my-tenant' + + +def test_rest_list_tasks(agent_card, mock_handler): + """Tests that list tasks endpoint is routed to the handler.""" + mock_handler.on_list_tasks.return_value = ListTasksResponse() + + routes = create_rest_routes( + agent_card=agent_card, request_handler=mock_handler + ) + app = Starlette(routes=routes) + client = TestClient(app) + + response = client.get('/tasks', headers={'A2A-Version': '1.0'}) + assert response.status_code == 200 + assert mock_handler.on_list_tasks.called From fa2169ef5e169a3132c8b4630400a3e1f47f2fde Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Sun, 22 Mar 2026 10:38:37 +0000 Subject: [PATCH 12/25] wip --- src/a2a/server/routes/rest_routes.py | 48 ++++++++++++------------ tests/integration/test_agent_card.py | 4 +- tests/integration/test_version_header.py | 5 ++- tests/server/routes/test_rest_routes.py | 18 ++++++--- 4 files changed, 42 insertions(+), 33 deletions(-) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index c0049c5a6..f81a28bd8 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -8,7 +8,6 @@ from google.protobuf.json_format import MessageToDict from a2a.compat.v0_3.rest_adapter import REST03Adapter -from a2a.utils.constants import DEFAULT_RPC_URL from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.rest_handler import RESTHandler @@ -64,7 +63,7 @@ def create_rest_routes( # noqa: PLR0913 ] | None = None, enable_v0_3_compat: bool = False, - rpc_url: str = DEFAULT_RPC_URL, + rpc_url: str = '', ) -> list[Route]: """Creates the Starlette Routes for the A2A protocol REST endpoint. @@ -84,6 +83,7 @@ def create_rest_routes( # noqa: PLR0913 call context. enable_v0_3_compat: If True, mounts backward-compatible v0.3 protocol endpoints using REST03Adapter. + rpc_url: The URL prefix for the REST endpoints. """ if not _package_starlette_installed: raise ImportError( @@ -92,6 +92,28 @@ def create_rest_routes( # noqa: PLR0913 'optional dependencies, `a2a-sdk[http-server]`.' ) + v03_routes = {} + if enable_v0_3_compat: + v03_adapter = REST03Adapter( + agent_card=agent_card, + http_handler=request_handler, + extended_agent_card=extended_agent_card, + context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, + ) + v03_routes = v03_adapter.routes() + + routes: list[Route] = [] + for (path, method), endpoint in v03_routes.items(): + routes.append( + Route( + path=f'{rpc_url}{path}', + endpoint=endpoint, + methods=[method], + ) + ) + handler = RESTHandler( agent_card=agent_card, request_handler=request_handler ) @@ -202,19 +224,6 @@ async def _handle_authenticated_agent_card( ), } - v03_routes = {} - if enable_v0_3_compat: - v03_adapter = REST03Adapter( - agent_card=agent_card, - http_handler=request_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - ) - v03_routes = v03_adapter.routes() - - routes: list[Route] = [] for (path, method), endpoint in base_routes.items(): routes.append( Route( @@ -231,13 +240,4 @@ async def _handle_authenticated_agent_card( ) ) - for (path, method), endpoint in v03_routes.items(): - routes.append( - Route( - path=f'{rpc_url}{path}', - endpoint=endpoint, - methods=[method], - ) - ) - return routes diff --git a/tests/integration/test_agent_card.py b/tests/integration/test_agent_card.py index 85a282a9e..494fd151c 100644 --- a/tests/integration/test_agent_card.py +++ b/tests/integration/test_agent_card.py @@ -87,9 +87,7 @@ async def test_agent_card_integration(header_val: str | None) -> None: *create_agent_card_routes( agent_card=agent_card, card_url='/.well-known/agent-card.json' ), - *create_rest_routes( - agent_card=agent_card, request_handler=handler - ), + *create_rest_routes(agent_card=agent_card, request_handler=handler), ] rest_app = Starlette(routes=rest_routes) app.mount('/rest', rest_app) diff --git a/tests/integration/test_version_header.py b/tests/integration/test_version_header.py index 56331302d..05310ec37 100644 --- a/tests/integration/test_version_header.py +++ b/tests/integration/test_version_header.py @@ -70,7 +70,10 @@ async def mock_on_message_send_stream(*args, **kwargs): app.routes.extend(jsonrpc_routes) rest_routes = create_rest_routes( - agent_card=agent_card, request_handler=handler, rpc_url='/rest', enable_v0_3_compat=True + agent_card=agent_card, + request_handler=handler, + rpc_url='/rest', + enable_v0_3_compat=True, ) app.routes.extend(rest_routes) return app diff --git a/tests/server/routes/test_rest_routes.py b/tests/server/routes/test_rest_routes.py index 0c1b5a2a6..1d9c91b46 100644 --- a/tests/server/routes/test_rest_routes.py +++ b/tests/server/routes/test_rest_routes.py @@ -34,10 +34,14 @@ def test_routes_creation(agent_card, mock_handler): def test_routes_creation_v03_compat(agent_card, mock_handler): """Tests that create_rest_routes creates more routes with enable_v0_3_compat.""" routes_without_compat = create_rest_routes( - agent_card=agent_card, request_handler=mock_handler, enable_v0_3_compat=False + agent_card=agent_card, + request_handler=mock_handler, + enable_v0_3_compat=False, ) routes_with_compat = create_rest_routes( - agent_card=agent_card, request_handler=mock_handler, enable_v0_3_compat=True + agent_card=agent_card, + request_handler=mock_handler, + enable_v0_3_compat=True, ) assert len(routes_with_compat) > len(routes_without_compat) @@ -54,7 +58,9 @@ def test_rest_endpoints_routing(agent_card, mock_handler): client = TestClient(app) # Test POST /message:send - response = client.post('/message:send', json={}, headers={'A2A-Version': '1.0'}) + response = client.post( + '/message:send', json={}, headers={'A2A-Version': '1.0'} + ) assert response.status_code == 200 assert response.json()['task']['id'] == '123' assert mock_handler.on_message_send.called @@ -71,9 +77,11 @@ def test_rest_endpoints_routing_tenant(agent_card, mock_handler): client = TestClient(app) # Test POST /{tenant}/message:send - response = client.post('/my-tenant/message:send', json={}, headers={'A2A-Version': '1.0'}) + response = client.post( + '/my-tenant/message:send', json={}, headers={'A2A-Version': '1.0'} + ) assert response.status_code == 200 - + # Verify that tenant was set in call context call_args = mock_handler.on_message_send.call_args assert call_args is not None From 214ead038e469b28ab2d394db050b8dbc3eba379 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 23 Mar 2026 15:35:00 +0000 Subject: [PATCH 13/25] update samples --- samples/hello_world_agent.py | 6 ++---- tck/sut_agent.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index 20b9804ba..cf3c85056 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -191,13 +191,11 @@ async def serve( agent_executor=SampleAgentExecutor(), task_store=task_store ) - rest_app_builder = A2ARESTFastAPIApplication( + rest_routes = create_rest_routes( agent_card=agent_card, http_handler=request_handler, enable_v0_3_compat=True, ) - rest_app = rest_app_builder.build() - jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=request_handler, @@ -209,7 +207,7 @@ async def serve( app = FastAPI() app.routes.extend(jsonrpc_routes) app.routes.extend(agent_card_routes) - app.mount('/a2a/rest', rest_app) + app.routes.extend(rest_routes) grpc_server = grpc.aio.server() grpc_server.add_insecure_port(f'{host}:{grpc_port}') diff --git a/tck/sut_agent.py b/tck/sut_agent.py index a25922db1..473187075 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -27,6 +27,7 @@ from a2a.server.routes import ( create_agent_card_routes, create_jsonrpc_routes, + create_rest_routes, ) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_store import TaskStore @@ -209,19 +210,19 @@ def serve(task_store: TaskStore) -> None: agent_card_routes = create_agent_card_routes( agent_card=agent_card, ) + # REST + rest_routes = create_rest_routes( + agent_card=agent_card, + http_handler=request_handler, + rpc_url=REST_URL, + ) + routes = [ *jsonrpc_routes, *agent_card_routes, + *rest_routes, ] - main_app = Starlette(routes=routes) - # REST - rest_server = A2ARESTFastAPIApplication( - agent_card=agent_card, - http_handler=request_handler, - ) - rest_app = rest_server.build(rpc_url=REST_URL) - main_app.mount('', rest_app) config = uvicorn.Config( main_app, host='127.0.0.1', port=http_port, log_level='info' From 5a977c48ef726223bfeeade0daf1b05783a0a502 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 23 Mar 2026 15:35:52 +0000 Subject: [PATCH 14/25] fox --- samples/hello_world_agent.py | 2 +- tck/sut_agent.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index cf3c85056..4fb905333 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -193,7 +193,7 @@ async def serve( rest_routes = create_rest_routes( agent_card=agent_card, - http_handler=request_handler, + request_handler=request_handler, enable_v0_3_compat=True, ) jsonrpc_routes = create_jsonrpc_routes( diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 473187075..d08c7305d 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -213,7 +213,7 @@ def serve(task_store: TaskStore) -> None: # REST rest_routes = create_rest_routes( agent_card=agent_card, - http_handler=request_handler, + request_handler=request_handler, rpc_url=REST_URL, ) From d1257d7d35f740d39d9571dbbd810075a3254502 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 23 Mar 2026 16:02:35 +0000 Subject: [PATCH 15/25] rename --- samples/hello_world_agent.py | 14 ++++++++------ src/a2a/server/routes/rest_routes.py | 10 +++++----- tck/sut_agent.py | 2 +- tests/integration/test_version_header.py | 2 +- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index 4fb905333..40a363ec5 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -17,7 +17,7 @@ from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) -from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes, create_rest_routes from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_updater import TaskUpdater from a2a.types import ( @@ -166,22 +166,22 @@ async def serve( AgentInterface( protocol_binding='JSONRPC', protocol_version='1.0', - url=f'http://{host}:{port}/a2a/jsonrpc/', + url=f'http://{host}:{port}/a2a/jsonrpc', ), AgentInterface( protocol_binding='JSONRPC', protocol_version='0.3', - url=f'http://{host}:{port}/a2a/jsonrpc/', + url=f'http://{host}:{port}/a2a/jsonrpc', ), AgentInterface( protocol_binding='HTTP+JSON', protocol_version='1.0', - url=f'http://{host}:{port}/a2a/rest/', + url=f'http://{host}:{port}/a2a/rest', ), AgentInterface( protocol_binding='HTTP+JSON', protocol_version='0.3', - url=f'http://{host}:{port}/a2a/rest/', + url=f'http://{host}:{port}/a2a/rest', ), ], ) @@ -194,12 +194,14 @@ async def serve( rest_routes = create_rest_routes( agent_card=agent_card, request_handler=request_handler, + path_prefix='/a2a/rest', enable_v0_3_compat=True, ) jsonrpc_routes = create_jsonrpc_routes( agent_card=agent_card, request_handler=request_handler, - rpc_url='/a2a/jsonrpc/', + rpc_url='/a2a/jsonrpc', + enable_v0_3_compat=True, ) agent_card_routes = create_agent_card_routes( agent_card=agent_card, diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index f81a28bd8..da2db878b 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -63,7 +63,7 @@ def create_rest_routes( # noqa: PLR0913 ] | None = None, enable_v0_3_compat: bool = False, - rpc_url: str = '', + path_prefix: str = '', ) -> list[Route]: """Creates the Starlette Routes for the A2A protocol REST endpoint. @@ -83,7 +83,7 @@ def create_rest_routes( # noqa: PLR0913 call context. enable_v0_3_compat: If True, mounts backward-compatible v0.3 protocol endpoints using REST03Adapter. - rpc_url: The URL prefix for the REST endpoints. + path_prefix: The URL prefix for the REST endpoints. """ if not _package_starlette_installed: raise ImportError( @@ -108,7 +108,7 @@ def create_rest_routes( # noqa: PLR0913 for (path, method), endpoint in v03_routes.items(): routes.append( Route( - path=f'{rpc_url}{path}', + path=f'{path_prefix}{path}', endpoint=endpoint, methods=[method], ) @@ -227,14 +227,14 @@ async def _handle_authenticated_agent_card( for (path, method), endpoint in base_routes.items(): routes.append( Route( - path=f'{rpc_url}{path}', + path=f'{path_prefix}{path}', endpoint=endpoint, methods=[method], ) ) routes.append( Route( - path=f'/{{tenant}}{rpc_url}{path}', + path=f'/{{tenant}}{path_prefix}{path}', endpoint=endpoint, methods=[method], ) diff --git a/tck/sut_agent.py b/tck/sut_agent.py index d08c7305d..2dcbe2196 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -214,7 +214,7 @@ def serve(task_store: TaskStore) -> None: rest_routes = create_rest_routes( agent_card=agent_card, request_handler=request_handler, - rpc_url=REST_URL, + path_prefix=REST_URL, ) routes = [ diff --git a/tests/integration/test_version_header.py b/tests/integration/test_version_header.py index 05310ec37..8e4c4a57d 100644 --- a/tests/integration/test_version_header.py +++ b/tests/integration/test_version_header.py @@ -72,7 +72,7 @@ async def mock_on_message_send_stream(*args, **kwargs): rest_routes = create_rest_routes( agent_card=agent_card, request_handler=handler, - rpc_url='/rest', + path_prefix='/rest', enable_v0_3_compat=True, ) app.routes.extend(rest_routes) From 8c230f2bebfdd935d8504ecaaf30a8eee4eaa459 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 23 Mar 2026 16:06:56 +0000 Subject: [PATCH 16/25] linter --- samples/hello_world_agent.py | 7 +++++-- tck/sut_agent.py | 3 --- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py index 40a363ec5..e286fa130 100644 --- a/samples/hello_world_agent.py +++ b/samples/hello_world_agent.py @@ -5,7 +5,6 @@ import grpc import uvicorn -from a2a.server.apps import A2ARESTFastAPIApplication from fastapi import FastAPI from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc @@ -17,7 +16,11 @@ from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) -from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes, create_rest_routes +from a2a.server.routes import ( + create_agent_card_routes, + create_jsonrpc_routes, + create_rest_routes, +) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_updater import TaskUpdater from a2a.types import ( diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 2dcbe2196..259b16a5d 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -8,9 +8,6 @@ import grpc.aio import uvicorn -from a2a.server.apps import ( - A2ARESTFastAPIApplication, -) from starlette.applications import Starlette import a2a.compat.v0_3.a2a_v0_3_pb2_grpc as a2a_v0_3_grpc From 9e7d74aeec665d6fd6b4e7839c781d867117aa26 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 23 Mar 2026 16:19:22 +0000 Subject: [PATCH 17/25] remove middleware --- src/a2a/server/routes/agent_card_routes.py | 3 --- src/a2a/server/routes/jsonrpc_routes.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/src/a2a/server/routes/agent_card_routes.py b/src/a2a/server/routes/agent_card_routes.py index c1f7ecffe..680a632dc 100644 --- a/src/a2a/server/routes/agent_card_routes.py +++ b/src/a2a/server/routes/agent_card_routes.py @@ -5,7 +5,6 @@ if TYPE_CHECKING: - from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route @@ -13,14 +12,12 @@ _package_starlette_installed = True else: try: - from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route _package_starlette_installed = True except ImportError: - Middleware = Any Route = Any Request = Any Response = Any diff --git a/src/a2a/server/routes/jsonrpc_routes.py b/src/a2a/server/routes/jsonrpc_routes.py index e55254f2f..24592d9df 100644 --- a/src/a2a/server/routes/jsonrpc_routes.py +++ b/src/a2a/server/routes/jsonrpc_routes.py @@ -5,18 +5,15 @@ if TYPE_CHECKING: - from starlette.middleware import Middleware from starlette.routing import Route, Router _package_starlette_installed = True else: try: - from starlette.middleware import Middleware from starlette.routing import Route, Router _package_starlette_installed = True except ImportError: - Middleware = Any Route = Any Router = Any From 33ded8b9d473115f17f13456810568d83b425935 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 23 Mar 2026 16:23:40 +0000 Subject: [PATCH 18/25] revert unwanted changes --- .../contrib/tasks/vertex_task_converter.py | 51 +-- src/a2a/server/apps/__init__.py | 8 - src/a2a/server/apps/rest/fastapi_app.py | 194 ------------ src/a2a/server/apps/rest/rest_adapter.py | 293 ------------------ src/a2a/server/routes/jsonrpc_routes.py | 2 +- 5 files changed, 27 insertions(+), 521 deletions(-) delete mode 100644 src/a2a/server/apps/__init__.py delete mode 100644 src/a2a/server/apps/rest/fastapi_app.py delete mode 100644 src/a2a/server/apps/rest/rest_adapter.py diff --git a/src/a2a/contrib/tasks/vertex_task_converter.py b/src/a2a/contrib/tasks/vertex_task_converter.py index 71ccbc288..6f23dad2e 100644 --- a/src/a2a/contrib/tasks/vertex_task_converter.py +++ b/src/a2a/contrib/tasks/vertex_task_converter.py @@ -1,4 +1,5 @@ try: + from google.genai import types as genai_types from vertexai import types as vertexai_types except ImportError as e: raise ImportError( @@ -25,40 +26,40 @@ _TO_SDK_TASK_STATE = { - vertexai_types.State.STATE_UNSPECIFIED: TaskState.unknown, - vertexai_types.State.SUBMITTED: TaskState.submitted, - vertexai_types.State.WORKING: TaskState.working, - vertexai_types.State.COMPLETED: TaskState.completed, - vertexai_types.State.CANCELLED: TaskState.canceled, - vertexai_types.State.FAILED: TaskState.failed, - vertexai_types.State.REJECTED: TaskState.rejected, - vertexai_types.State.INPUT_REQUIRED: TaskState.input_required, - vertexai_types.State.AUTH_REQUIRED: TaskState.auth_required, + vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown, + vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted, + vertexai_types.A2aTaskState.WORKING: TaskState.working, + vertexai_types.A2aTaskState.COMPLETED: TaskState.completed, + vertexai_types.A2aTaskState.CANCELLED: TaskState.canceled, + vertexai_types.A2aTaskState.FAILED: TaskState.failed, + vertexai_types.A2aTaskState.REJECTED: TaskState.rejected, + vertexai_types.A2aTaskState.INPUT_REQUIRED: TaskState.input_required, + vertexai_types.A2aTaskState.AUTH_REQUIRED: TaskState.auth_required, } _SDK_TO_STORED_TASK_STATE = {v: k for k, v in _TO_SDK_TASK_STATE.items()} -def to_sdk_task_state(stored_state: vertexai_types.State) -> TaskState: +def to_sdk_task_state(stored_state: vertexai_types.A2aTaskState) -> TaskState: """Converts a proto A2aTask.State to a TaskState enum.""" return _TO_SDK_TASK_STATE.get(stored_state, TaskState.unknown) -def to_stored_task_state(task_state: TaskState) -> vertexai_types.State: +def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState: """Converts a TaskState enum to a proto A2aTask.State enum value.""" return _SDK_TO_STORED_TASK_STATE.get( - task_state, vertexai_types.State.STATE_UNSPECIFIED + task_state, vertexai_types.A2aTaskState.STATE_UNSPECIFIED ) -def to_stored_part(part: Part) -> vertexai_types.Part: +def to_stored_part(part: Part) -> genai_types.Part: """Converts a SDK Part to a proto Part.""" if isinstance(part.root, TextPart): - return vertexai_types.Part(text=part.root.text) + return genai_types.Part(text=part.root.text) if isinstance(part.root, DataPart): data_bytes = json.dumps(part.root.data).encode('utf-8') - return vertexai_types.Part( - inline_data=vertexai_types.Blob( + return genai_types.Part( + inline_data=genai_types.Blob( mime_type='application/json', data=data_bytes ) ) @@ -66,14 +67,14 @@ def to_stored_part(part: Part) -> vertexai_types.Part: file_content = part.root.file if isinstance(file_content, FileWithBytes): decoded_bytes = base64.b64decode(file_content.bytes) - return vertexai_types.Part( - inline_data=vertexai_types.Blob( + return genai_types.Part( + inline_data=genai_types.Blob( mime_type=file_content.mime_type or '', data=decoded_bytes ) ) if isinstance(file_content, FileWithUri): - return vertexai_types.Part( - file_data=vertexai_types.FileData( + return genai_types.Part( + file_data=genai_types.FileData( mime_type=file_content.mime_type or '', file_uri=file_content.uri, ) @@ -81,14 +82,14 @@ def to_stored_part(part: Part) -> vertexai_types.Part: raise ValueError(f'Unsupported part type: {type(part.root)}') -def to_sdk_part(stored_part: vertexai_types.Part) -> Part: +def to_sdk_part(stored_part: genai_types.Part) -> Part: """Converts a proto Part to a SDK Part.""" if stored_part.text: return Part(root=TextPart(text=stored_part.text)) if stored_part.inline_data: - encoded_bytes = base64.b64encode(stored_part.inline_data.data).decode( - 'utf-8' - ) + encoded_bytes = base64.b64encode( + stored_part.inline_data.data or b'' + ).decode('utf-8') return Part( root=FilePart( file=FileWithBytes( @@ -97,7 +98,7 @@ def to_sdk_part(stored_part: vertexai_types.Part) -> Part: ) ) ) - if stored_part.file_data: + if stored_part.file_data and stored_part.file_data.file_uri: return Part( root=FilePart( file=FileWithUri( diff --git a/src/a2a/server/apps/__init__.py b/src/a2a/server/apps/__init__.py deleted file mode 100644 index 1cdb32953..000000000 --- a/src/a2a/server/apps/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""HTTP application components for the A2A server.""" - -from a2a.server.apps.rest import A2ARESTFastAPIApplication - - -__all__ = [ - 'A2ARESTFastAPIApplication', -] diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py deleted file mode 100644 index 4feac9072..000000000 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ /dev/null @@ -1,194 +0,0 @@ -import logging - -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from fastapi import APIRouter, FastAPI, Request, Response - from fastapi.responses import JSONResponse - from starlette.exceptions import HTTPException as StarletteHTTPException - - _package_fastapi_installed = True -else: - try: - from fastapi import APIRouter, FastAPI, Request, Response - from fastapi.responses import JSONResponse - from starlette.exceptions import HTTPException as StarletteHTTPException - - _package_fastapi_installed = True - except ImportError: - APIRouter = Any - FastAPI = Any - Request = Any - Response = Any - StarletteHTTPException = Any - - _package_fastapi_installed = False - - -from a2a.compat.v0_3.rest_adapter import REST03Adapter -from a2a.server.apps.rest.rest_adapter import RESTAdapter -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.routes import CallContextBuilder -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH - - -logger = logging.getLogger(__name__) - - -_HTTP_TO_GRPC_STATUS_MAP = { - 400: 'INVALID_ARGUMENT', - 401: 'UNAUTHENTICATED', - 403: 'PERMISSION_DENIED', - 404: 'NOT_FOUND', - 405: 'UNIMPLEMENTED', - 409: 'ALREADY_EXISTS', - 415: 'INVALID_ARGUMENT', - 422: 'INVALID_ARGUMENT', - 500: 'INTERNAL', - 501: 'UNIMPLEMENTED', - 502: 'INTERNAL', - 503: 'UNAVAILABLE', - 504: 'DEADLINE_EXCEEDED', -} - - -class A2ARESTFastAPIApplication: - """A FastAPI application implementing the A2A protocol server REST endpoints. - - Handles incoming REST requests, routes them to the appropriate - handler methods, and manages response generation including Server-Sent Events - (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - enable_v0_3_compat: bool = False, - ): - """Initializes the A2ARESTFastAPIApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - enable_v0_3_compat: If True, mounts backward-compatible v0.3 protocol - endpoints under the '/v0.3' path prefix using REST03Adapter. - """ - if not _package_fastapi_installed: - raise ImportError( - 'The `fastapi` package is required to use the' - ' `A2ARESTFastAPIApplication`. It can be added as a part of' - ' `a2a-sdk` optional dependencies, `a2a-sdk[http-server]`.' - ) - self._adapter = RESTAdapter( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - ) - self.enable_v0_3_compat = enable_v0_3_compat - self._v03_adapter = None - - if self.enable_v0_3_compat: - self._v03_adapter = REST03Adapter( - agent_card=agent_card, - http_handler=http_handler, - extended_agent_card=extended_agent_card, - context_builder=context_builder, - card_modifier=card_modifier, - extended_card_modifier=extended_card_modifier, - ) - - def build( - self, - agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH, - rpc_url: str = '', - **kwargs: Any, - ) -> FastAPI: - """Builds and returns the FastAPI application instance. - - Args: - agent_card_url: The URL for the agent card endpoint. - rpc_url: The URL for the A2A REST endpoint base path. - **kwargs: Additional keyword arguments to pass to the FastAPI constructor. - - Returns: - A configured FastAPI application instance. - """ - app = FastAPI(**kwargs) - - @app.exception_handler(StarletteHTTPException) - async def http_exception_handler( - request: Request, exc: StarletteHTTPException - ) -> Response: - """Catches framework-level HTTP exceptions. - - For example, 404 Not Found for bad routes, 422 Unprocessable Entity - for schema validation, and formats them into the A2A standard - google.rpc.Status JSON format (AIP-193). - """ - grpc_status = _HTTP_TO_GRPC_STATUS_MAP.get( - exc.status_code, 'UNKNOWN' - ) - return JSONResponse( - status_code=exc.status_code, - content={ - 'error': { - 'code': exc.status_code, - 'status': grpc_status, - 'message': str(exc.detail) - if hasattr(exc, 'detail') - else 'HTTP Exception', - } - }, - media_type='application/json', - ) - - if self.enable_v0_3_compat and self._v03_adapter: - v03_adapter = self._v03_adapter - v03_router = APIRouter() - for route, callback in v03_adapter.routes().items(): - v03_router.add_api_route( - f'{rpc_url}{route[0]}', callback, methods=[route[1]] - ) - app.include_router(v03_router) - - router = APIRouter() - for route, callback in self._adapter.routes().items(): - router.add_api_route( - f'{rpc_url}{route[0]}', callback, methods=[route[1]] - ) - - @router.get(f'{rpc_url}{agent_card_url}') - async def get_agent_card(request: Request) -> Response: - card = await self._adapter.handle_get_agent_card(request) - return JSONResponse(card) - - app.include_router(router) - - return app diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py deleted file mode 100644 index ebf996a47..000000000 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ /dev/null @@ -1,293 +0,0 @@ -import functools -import json -import logging - -from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable -from typing import TYPE_CHECKING, Any - -from google.protobuf.json_format import MessageToDict - -from a2a.utils.helpers import maybe_await - - -if TYPE_CHECKING: - from sse_starlette.sse import EventSourceResponse - from starlette.requests import Request - from starlette.responses import JSONResponse, Response - - _package_starlette_installed = True - -else: - try: - from sse_starlette.sse import EventSourceResponse - from starlette.requests import Request - from starlette.responses import JSONResponse, Response - - _package_starlette_installed = True - except ImportError: - EventSourceResponse = Any - Request = Any - JSONResponse = Any - Response = Any - - _package_starlette_installed = False - -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.response_helpers import ( - agent_card_to_dict, -) -from a2a.server.request_handlers.rest_handler import RESTHandler -from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.error_handlers import ( - rest_error_handler, - rest_stream_error_handler, -) -from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, - InvalidRequestError, -) - - -logger = logging.getLogger(__name__) - - -class RESTAdapterInterface(ABC): - """Interface for RESTAdapter.""" - - @abstractmethod - async def handle_get_agent_card( - self, request: 'Request', call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Handles GET requests for the agent card endpoint.""" - - @abstractmethod - def routes(self) -> dict[tuple[str, str], Callable[['Request'], Any]]: - """Constructs a dictionary of API routes and their corresponding handlers.""" - - -class RESTAdapter(RESTAdapterInterface): - """Adapter to make RequestHandler work with RESTful API. - - Defines REST requests processors and the routes to attach them too, as well as - manages response generation including Server-Sent Events (SSE). - """ - - def __init__( # noqa: PLR0913 - self, - agent_card: AgentCard, - http_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - ): - """Initializes the RESTApplication. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A - requests via http. - extended_agent_card: An optional, distinct AgentCard to be served - at the authenticated extended card endpoint. - context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no - ServerCallContext is passed. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - """ - if not _package_starlette_installed: - raise ImportError( - 'Packages `starlette` and `sse-starlette` are required to use' - ' the `RESTAdapter`. They can be added as a part of `a2a-sdk`' - ' optional dependencies, `a2a-sdk[http-server]`.' - ) - self.agent_card = agent_card - self.extended_agent_card = extended_agent_card - self.card_modifier = card_modifier - self.extended_card_modifier = extended_card_modifier - self.handler = RESTHandler( - agent_card=agent_card, request_handler=http_handler - ) - self._context_builder = context_builder or DefaultCallContextBuilder() - - @rest_error_handler - async def _handle_request( - self, - method: Callable[[Request, ServerCallContext], Awaitable[Any]], - request: Request, - ) -> Response: - call_context = self._build_call_context(request) - - response = await method(request, call_context) - return JSONResponse(content=response) - - @rest_stream_error_handler - async def _handle_streaming_request( - self, - method: Callable[[Request, ServerCallContext], AsyncIterable[Any]], - request: Request, - ) -> EventSourceResponse: - # Pre-consume and cache the request body to prevent deadlock in streaming context - # This is required because Starlette's request.body() can only be consumed once, - # and attempting to consume it after EventSourceResponse starts causes deadlock - try: - await request.body() - except (ValueError, RuntimeError, OSError) as e: - raise InvalidRequestError( - message=f'Failed to pre-consume request body: {e}' - ) from e - - call_context = self._build_call_context(request) - - async def event_generator( - stream: AsyncIterable[Any], - ) -> AsyncIterator[str]: - async for item in stream: - yield json.dumps(item) - - return EventSourceResponse( - event_generator(method(request, call_context)) - ) - - async def handle_get_agent_card( - self, request: Request, call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Handles GET requests for the agent card endpoint. - - Args: - request: The incoming Starlette Request object. - call_context: ServerCallContext - - Returns: - A JSONResponse containing the agent card data. - """ - card_to_serve = self.agent_card - if self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return agent_card_to_dict(card_to_serve) - - async def _handle_authenticated_agent_card( - self, request: Request, call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - """Hook for per credential agent card response. - - If a dynamic card is needed based on the credentials provided in the request - override this method and return the customized content. - - Args: - request: The incoming Starlette Request object. - call_context: ServerCallContext - - Returns: - A JSONResponse containing the authenticated card. - """ - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - card_to_serve = self.extended_agent_card - - if not card_to_serve: - card_to_serve = self.agent_card - - if self.extended_card_modifier: - context = self._build_call_context(request) - card_to_serve = await maybe_await( - self.extended_card_modifier(card_to_serve, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) - - return MessageToDict(card_to_serve, preserving_proto_field_name=True) - - def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: - """Constructs a dictionary of API routes and their corresponding handlers. - - This method maps URL paths and HTTP methods to the appropriate handler - functions from the RESTHandler. It can be used by a web framework - (like Starlette or FastAPI) to set up the application's endpoints. - - Returns: - A dictionary where each key is a tuple of (path, http_method) and - the value is the callable handler for that route. - """ - base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { - ('/message:send', 'POST'): functools.partial( - self._handle_request, self.handler.on_message_send - ), - ('/message:stream', 'POST'): functools.partial( - self._handle_streaming_request, - self.handler.on_message_send_stream, - ), - ('/tasks/{id}:cancel', 'POST'): functools.partial( - self._handle_request, self.handler.on_cancel_task - ), - ('/tasks/{id}:subscribe', 'GET'): functools.partial( - self._handle_streaming_request, - self.handler.on_subscribe_to_task, - ), - ('/tasks/{id}:subscribe', 'POST'): functools.partial( - self._handle_streaming_request, - self.handler.on_subscribe_to_task, - ), - ('/tasks/{id}', 'GET'): functools.partial( - self._handle_request, self.handler.on_get_task - ), - ( - '/tasks/{id}/pushNotificationConfigs/{push_id}', - 'GET', - ): functools.partial( - self._handle_request, self.handler.get_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs/{push_id}', - 'DELETE', - ): functools.partial( - self._handle_request, self.handler.delete_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs', - 'POST', - ): functools.partial( - self._handle_request, self.handler.set_push_notification - ), - ( - '/tasks/{id}/pushNotificationConfigs', - 'GET', - ): functools.partial( - self._handle_request, self.handler.list_push_notifications - ), - ('/tasks', 'GET'): functools.partial( - self._handle_request, self.handler.list_tasks - ), - } - - if self.agent_card.capabilities.extended_agent_card: - base_routes[('/extendedAgentCard', 'GET')] = functools.partial( - self._handle_request, self._handle_authenticated_agent_card - ) - - routes: dict[tuple[str, str], Callable[[Request], Any]] = { - (p, method): handler - for (path, method), handler in base_routes.items() - for p in (path, f'/{{tenant}}{path}') - } - - return routes - - def _build_call_context(self, request: Request) -> ServerCallContext: - call_context = self._context_builder.build(request) - if 'tenant' in request.path_params: - call_context.tenant = request.path_params['tenant'] - return call_context diff --git a/src/a2a/server/routes/jsonrpc_routes.py b/src/a2a/server/routes/jsonrpc_routes.py index 24592d9df..7c1663260 100644 --- a/src/a2a/server/routes/jsonrpc_routes.py +++ b/src/a2a/server/routes/jsonrpc_routes.py @@ -36,6 +36,7 @@ def create_jsonrpc_routes( # noqa: PLR0913 agent_card: AgentCard, request_handler: RequestHandler, + rpc_url: str, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] @@ -45,7 +46,6 @@ def create_jsonrpc_routes( # noqa: PLR0913 ] | None = None, enable_v0_3_compat: bool = False, - rpc_url: str = DEFAULT_RPC_URL, ) -> list['Route']: """Creates the Starlette Route for the A2A protocol JSON-RPC endpoint. From 02ebab66c8c2d45d8ac996da8affd9bde45789a5 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 23 Mar 2026 16:26:02 +0000 Subject: [PATCH 19/25] revert changes --- src/a2a/server/routes/agent_card_routes.py | 5 ----- src/a2a/server/routes/jsonrpc_routes.py | 13 +++---------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/src/a2a/server/routes/agent_card_routes.py b/src/a2a/server/routes/agent_card_routes.py index 680a632dc..9b850ff4f 100644 --- a/src/a2a/server/routes/agent_card_routes.py +++ b/src/a2a/server/routes/agent_card_routes.py @@ -1,5 +1,3 @@ -import logging - from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any @@ -31,9 +29,6 @@ from a2a.utils.helpers import maybe_await -logger = logging.getLogger(__name__) - - def create_agent_card_routes( agent_card: AgentCard, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] diff --git a/src/a2a/server/routes/jsonrpc_routes.py b/src/a2a/server/routes/jsonrpc_routes.py index 7c1663260..9138ed8ea 100644 --- a/src/a2a/server/routes/jsonrpc_routes.py +++ b/src/a2a/server/routes/jsonrpc_routes.py @@ -1,21 +1,18 @@ -import logging - from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from starlette.routing import Route, Router + from starlette.routing import Route _package_starlette_installed = True else: try: - from starlette.routing import Route, Router + from starlette.routing import Route _package_starlette_installed = True except ImportError: Route = Any - Router = Any _package_starlette_installed = False @@ -27,10 +24,6 @@ JsonRpcDispatcher, ) from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.constants import DEFAULT_RPC_URL - - -logger = logging.getLogger(__name__) def create_jsonrpc_routes( # noqa: PLR0913 @@ -57,6 +50,7 @@ def create_jsonrpc_routes( # noqa: PLR0913 agent_card: The AgentCard describing the agent's capabilities. request_handler: The handler instance responsible for processing A2A requests via http. + rpc_url: The URL prefix for the RPC endpoints. extended_agent_card: An optional, distinct AgentCard to be served at the authenticated extended card endpoint. context_builder: The CallContextBuilder used to construct the @@ -68,7 +62,6 @@ def create_jsonrpc_routes( # noqa: PLR0913 the extended agent card before it is served. It receives the call context. enable_v0_3_compat: Whether to enable v0.3 backward compatibility on the same endpoint. - rpc_url: The URL prefix for the RPC endpoints. """ if not _package_starlette_installed: raise ImportError( From 83fa7b892554d7555f8438af5cf5cf5623351829 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 23 Mar 2026 16:32:34 +0000 Subject: [PATCH 20/25] revert changes --- .../tasks/test_vertex_task_converter.py | 76 +++++++++++-------- tests/server/routes/test_jsonrpc_routes.py | 3 +- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/tests/contrib/tasks/test_vertex_task_converter.py b/tests/contrib/tasks/test_vertex_task_converter.py index d71f764b7..de6ae8cd6 100644 --- a/tests/contrib/tasks/test_vertex_task_converter.py +++ b/tests/contrib/tasks/test_vertex_task_converter.py @@ -7,7 +7,7 @@ 'vertexai', reason='Vertex Task Converter tests require vertexai' ) from vertexai import types as vertexai_types - +from google.genai import types as genai_types from a2a.contrib.tasks.vertex_task_converter import ( to_sdk_artifact, to_sdk_part, @@ -18,7 +18,7 @@ to_stored_task, to_stored_task_state, ) -from a2a.compat.v0_3.types import ( +from a2a.types import ( Artifact, DataPart, FilePart, @@ -34,29 +34,39 @@ def test_to_sdk_task_state() -> None: assert ( - to_sdk_task_state(vertexai_types.State.STATE_UNSPECIFIED) + to_sdk_task_state(vertexai_types.A2aTaskState.STATE_UNSPECIFIED) == TaskState.unknown ) assert ( - to_sdk_task_state(vertexai_types.State.SUBMITTED) == TaskState.submitted + to_sdk_task_state(vertexai_types.A2aTaskState.SUBMITTED) + == TaskState.submitted + ) + assert ( + to_sdk_task_state(vertexai_types.A2aTaskState.WORKING) + == TaskState.working ) - assert to_sdk_task_state(vertexai_types.State.WORKING) == TaskState.working assert ( - to_sdk_task_state(vertexai_types.State.COMPLETED) == TaskState.completed + to_sdk_task_state(vertexai_types.A2aTaskState.COMPLETED) + == TaskState.completed ) assert ( - to_sdk_task_state(vertexai_types.State.CANCELLED) == TaskState.canceled + to_sdk_task_state(vertexai_types.A2aTaskState.CANCELLED) + == TaskState.canceled ) - assert to_sdk_task_state(vertexai_types.State.FAILED) == TaskState.failed assert ( - to_sdk_task_state(vertexai_types.State.REJECTED) == TaskState.rejected + to_sdk_task_state(vertexai_types.A2aTaskState.FAILED) + == TaskState.failed ) assert ( - to_sdk_task_state(vertexai_types.State.INPUT_REQUIRED) + to_sdk_task_state(vertexai_types.A2aTaskState.REJECTED) + == TaskState.rejected + ) + assert ( + to_sdk_task_state(vertexai_types.A2aTaskState.INPUT_REQUIRED) == TaskState.input_required ) assert ( - to_sdk_task_state(vertexai_types.State.AUTH_REQUIRED) + to_sdk_task_state(vertexai_types.A2aTaskState.AUTH_REQUIRED) == TaskState.auth_required ) assert to_sdk_task_state(999) == TaskState.unknown # type: ignore @@ -65,35 +75,39 @@ def test_to_sdk_task_state() -> None: def test_to_stored_task_state() -> None: assert ( to_stored_task_state(TaskState.unknown) - == vertexai_types.State.STATE_UNSPECIFIED + == vertexai_types.A2aTaskState.STATE_UNSPECIFIED ) assert ( to_stored_task_state(TaskState.submitted) - == vertexai_types.State.SUBMITTED + == vertexai_types.A2aTaskState.SUBMITTED ) assert ( - to_stored_task_state(TaskState.working) == vertexai_types.State.WORKING + to_stored_task_state(TaskState.working) + == vertexai_types.A2aTaskState.WORKING ) assert ( to_stored_task_state(TaskState.completed) - == vertexai_types.State.COMPLETED + == vertexai_types.A2aTaskState.COMPLETED ) assert ( to_stored_task_state(TaskState.canceled) - == vertexai_types.State.CANCELLED + == vertexai_types.A2aTaskState.CANCELLED + ) + assert ( + to_stored_task_state(TaskState.failed) + == vertexai_types.A2aTaskState.FAILED ) - assert to_stored_task_state(TaskState.failed) == vertexai_types.State.FAILED assert ( to_stored_task_state(TaskState.rejected) - == vertexai_types.State.REJECTED + == vertexai_types.A2aTaskState.REJECTED ) assert ( to_stored_task_state(TaskState.input_required) - == vertexai_types.State.INPUT_REQUIRED + == vertexai_types.A2aTaskState.INPUT_REQUIRED ) assert ( to_stored_task_state(TaskState.auth_required) - == vertexai_types.State.AUTH_REQUIRED + == vertexai_types.A2aTaskState.AUTH_REQUIRED ) @@ -155,15 +169,15 @@ class BadPart: def test_to_sdk_part_text() -> None: - stored_part = vertexai_types.Part(text='hello back') + stored_part = genai_types.Part(text='hello back') sdk_part = to_sdk_part(stored_part) assert isinstance(sdk_part.root, TextPart) assert sdk_part.root.text == 'hello back' def test_to_sdk_part_inline_data() -> None: - stored_part = vertexai_types.Part( - inline_data=vertexai_types.Blob( + stored_part = genai_types.Part( + inline_data=genai_types.Blob( mime_type='application/json', data=b'{"key": "val"}', ) @@ -177,8 +191,8 @@ def test_to_sdk_part_inline_data() -> None: def test_to_sdk_part_file_data() -> None: - stored_part = vertexai_types.Part( - file_data=vertexai_types.FileData( + stored_part = genai_types.Part( + file_data=genai_types.FileData( mime_type='image/jpeg', file_uri='gs://bucket/image.jpg', ) @@ -191,7 +205,7 @@ def test_to_sdk_part_file_data() -> None: def test_to_sdk_part_unsupported() -> None: - stored_part = vertexai_types.Part() + stored_part = genai_types.Part() with pytest.raises(ValueError, match='Unsupported part:'): to_sdk_part(stored_part) @@ -210,7 +224,7 @@ def test_to_stored_artifact() -> None: def test_to_sdk_artifact() -> None: stored_artifact = vertexai_types.TaskArtifact( artifact_id='art-456', - parts=[vertexai_types.Part(text='part_2')], + parts=[genai_types.Part(text='part_2')], ) sdk_artifact = to_sdk_artifact(stored_artifact) assert sdk_artifact.artifact_id == 'art-456' @@ -236,7 +250,7 @@ def test_to_stored_task() -> None: stored_task = to_stored_task(sdk_task) assert stored_task.context_id == 'ctx-1' assert stored_task.metadata == {'foo': 'bar'} - assert stored_task.state == vertexai_types.State.WORKING + assert stored_task.state == vertexai_types.A2aTaskState.WORKING assert stored_task.output is not None assert stored_task.output.artifacts is not None assert len(stored_task.output.artifacts) == 1 @@ -247,13 +261,13 @@ def test_to_sdk_task() -> None: stored_task = vertexai_types.A2aTask( name='projects/123/locations/us-central1/agentEngines/456/tasks/task-2', context_id='ctx-2', - state=vertexai_types.State.COMPLETED, + state=vertexai_types.A2aTaskState.COMPLETED, metadata={'a': 'b'}, output=vertexai_types.TaskOutput( artifacts=[ vertexai_types.TaskArtifact( artifact_id='art-2', - parts=[vertexai_types.Part(text='result')], + parts=[genai_types.Part(text='result')], ) ] ), @@ -275,7 +289,7 @@ def test_to_sdk_task_no_output() -> None: stored_task = vertexai_types.A2aTask( name='tasks/task-3', context_id='ctx-3', - state=vertexai_types.State.SUBMITTED, + state=vertexai_types.A2aTaskState.SUBMITTED, metadata=None, ) sdk_task = to_sdk_task(stored_task) diff --git a/tests/server/routes/test_jsonrpc_routes.py b/tests/server/routes/test_jsonrpc_routes.py index 5bfa931ee..0b56ebf4d 100644 --- a/tests/server/routes/test_jsonrpc_routes.py +++ b/tests/server/routes/test_jsonrpc_routes.py @@ -3,7 +3,6 @@ import pytest from starlette.testclient import TestClient -from starlette.middleware import Middleware from starlette.applications import Starlette from a2a.server.routes.jsonrpc_routes import create_jsonrpc_routes @@ -24,7 +23,7 @@ def mock_handler(): def test_routes_creation(agent_card, mock_handler): """Tests that create_jsonrpc_routes creates Route objects list.""" routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=mock_handler + agent_card=agent_card, request_handler=mock_handler, rpc_url='/a2a/jsonrpc' ) assert isinstance(routes, list) From 50a45cbb05dcd94647f157b1befec588d1f8d8d6 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 23 Mar 2026 16:34:45 +0000 Subject: [PATCH 21/25] revert changes --- tests/contrib/tasks/test_vertex_task_converter.py | 2 +- tests/server/routes/test_jsonrpc_routes.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/contrib/tasks/test_vertex_task_converter.py b/tests/contrib/tasks/test_vertex_task_converter.py index de6ae8cd6..a060bc451 100644 --- a/tests/contrib/tasks/test_vertex_task_converter.py +++ b/tests/contrib/tasks/test_vertex_task_converter.py @@ -18,7 +18,7 @@ to_stored_task, to_stored_task_state, ) -from a2a.types import ( +from a2a.compat.v0_3.types import ( Artifact, DataPart, FilePart, diff --git a/tests/server/routes/test_jsonrpc_routes.py b/tests/server/routes/test_jsonrpc_routes.py index 0b56ebf4d..3330d14c8 100644 --- a/tests/server/routes/test_jsonrpc_routes.py +++ b/tests/server/routes/test_jsonrpc_routes.py @@ -23,7 +23,9 @@ def mock_handler(): def test_routes_creation(agent_card, mock_handler): """Tests that create_jsonrpc_routes creates Route objects list.""" routes = create_jsonrpc_routes( - agent_card=agent_card, request_handler=mock_handler, rpc_url='/a2a/jsonrpc' + agent_card=agent_card, + request_handler=mock_handler, + rpc_url='/a2a/jsonrpc', ) assert isinstance(routes, list) From 998eb28be7b932844dba12e00ce3911a9c062a62 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Tue, 24 Mar 2026 15:26:13 +0000 Subject: [PATCH 22/25] refactor: use Mount for tenant-scoped routing and update type annotations to BaseRoute --- src/a2a/server/routes/rest_routes.py | 22 ++++++++++------------ tests/integration/test_version_header.py | 2 +- tests/server/routes/test_rest_routes.py | 4 ++-- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index da2db878b..fc68807be 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -28,7 +28,7 @@ from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response - from starlette.routing import Route + from starlette.routing import BaseRoute, Route _package_starlette_installed = True else: @@ -36,7 +36,7 @@ from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response - from starlette.routing import Route + from starlette.routing import BaseRoute, Mount, Route _package_starlette_installed = True except ImportError: @@ -45,6 +45,8 @@ JSONResponse = Any Response = Any Route = Any + Mount = Any + BaseRoute = Any _package_starlette_installed = False @@ -64,7 +66,7 @@ def create_rest_routes( # noqa: PLR0913 | None = None, enable_v0_3_compat: bool = False, path_prefix: str = '', -) -> list[Route]: +) -> list['BaseRoute']: """Creates the Starlette Routes for the A2A protocol REST endpoint. Args: @@ -104,7 +106,7 @@ def create_rest_routes( # noqa: PLR0913 ) v03_routes = v03_adapter.routes() - routes: list[Route] = [] + routes: list['BaseRoute'] = [] for (path, method), endpoint in v03_routes.items(): routes.append( Route( @@ -224,20 +226,16 @@ async def _handle_authenticated_agent_card( ), } + base_route_objects = [] for (path, method), endpoint in base_routes.items(): - routes.append( + base_route_objects.append( Route( path=f'{path_prefix}{path}', endpoint=endpoint, methods=[method], ) ) - routes.append( - Route( - path=f'/{{tenant}}{path_prefix}{path}', - endpoint=endpoint, - methods=[method], - ) - ) + routes.extend(base_route_objects) + routes.append(Mount(path='/{tenant}', routes=base_route_objects)) return routes diff --git a/tests/integration/test_version_header.py b/tests/integration/test_version_header.py index 8e4c4a57d..683c56833 100644 --- a/tests/integration/test_version_header.py +++ b/tests/integration/test_version_header.py @@ -153,7 +153,7 @@ def test_version_header_integration( # noqa: PLR0912, PLR0913, PLR0915 assert response.status_code == 400, response.text else: - url = '/jsonrpc/' + url = '/jsonrpc' if endpoint_ver == '0.3': payload = { 'jsonrpc': '2.0', diff --git a/tests/server/routes/test_rest_routes.py b/tests/server/routes/test_rest_routes.py index 1d9c91b46..98bf4130d 100644 --- a/tests/server/routes/test_rest_routes.py +++ b/tests/server/routes/test_rest_routes.py @@ -3,7 +3,7 @@ import pytest from starlette.applications import Starlette from starlette.testclient import TestClient -from starlette.routing import Route +from starlette.routing import BaseRoute, Route from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.routes.rest_routes import create_rest_routes @@ -28,7 +28,7 @@ def test_routes_creation(agent_card, mock_handler): assert isinstance(routes, list) assert len(routes) > 0 - assert all(isinstance(r, Route) for r in routes) + assert all(isinstance(r, BaseRoute) for r in routes) def test_routes_creation_v03_compat(agent_card, mock_handler): From a8b3d0d96d3fe8b2380826172bce46ea8d0e0aa1 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Tue, 24 Mar 2026 15:29:27 +0000 Subject: [PATCH 23/25] linter --- src/a2a/server/routes/rest_routes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index fc68807be..f8caa7fdb 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -28,7 +28,7 @@ from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response - from starlette.routing import BaseRoute, Route + from starlette.routing import BaseRoute, Mount, Route _package_starlette_installed = True else: @@ -106,7 +106,7 @@ def create_rest_routes( # noqa: PLR0913 ) v03_routes = v03_adapter.routes() - routes: list['BaseRoute'] = [] + routes: list[BaseRoute] = [] for (path, method), endpoint in v03_routes.items(): routes.append( Route( From c6ae4494c4f807c6fb10ae781074a0c2aac5dce9 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Tue, 24 Mar 2026 16:01:56 +0000 Subject: [PATCH 24/25] revert --- src/a2a/server/routes/rest_routes.py | 35 ++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index f8caa7fdb..e332512c4 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -139,9 +139,13 @@ async def _handle_request( @rest_stream_error_handler async def _handle_streaming_request( - method: Callable[['Request', ServerCallContext], AsyncIterable[Any]], - request: 'Request', - ) -> 'EventSourceResponse': + self, + method: Callable[[Request, ServerCallContext], AsyncIterable[Any]], + request: Request, + ) -> EventSourceResponse: + # Pre-consume and cache the request body to prevent deadlock in streaming context + # This is required because Starlette's request.body() can only be consumed once, + # and attempting to consume it after EventSourceResponse starts causes deadlock try: await request.body() except (ValueError, RuntimeError, OSError) as e: @@ -149,17 +153,28 @@ async def _handle_streaming_request( message=f'Failed to pre-consume request body: {e}' ) from e - call_context = _build_call_context(request) + call_context = self._build_call_context(request) - async def event_generator( - stream: AsyncIterable[Any], - ) -> AsyncIterator[str]: + # Eagerly fetch the first item from the stream so that errors raised + # before any event is yielded (e.g. validation, parsing, or handler + # failures) propagate here and are caught by + # @rest_stream_error_handler, which returns a JSONResponse with + # the correct HTTP status code instead of starting an SSE stream. + # Without this, the error would be raised after SSE headers are + # already sent, and the client would see a broken stream instead + # of a proper error response. + stream = aiter(method(request, call_context)) + try: + first_item = await anext(stream) + except StopAsyncIteration: + return EventSourceResponse(iter([])) + + async def event_generator() -> AsyncIterator[str]: + yield json.dumps(first_item) async for item in stream: yield json.dumps(item) - return EventSourceResponse( - event_generator(method(request, call_context)) - ) + return EventSourceResponse(event_generator()) async def _handle_authenticated_agent_card( request: 'Request', call_context: ServerCallContext | None = None From fb4a107356326f48e8c92c2edcf5f424d041e516 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Tue, 24 Mar 2026 16:09:07 +0000 Subject: [PATCH 25/25] fix --- src/a2a/server/routes/rest_routes.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index e332512c4..1923f038a 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -139,7 +139,6 @@ async def _handle_request( @rest_stream_error_handler async def _handle_streaming_request( - self, method: Callable[[Request, ServerCallContext], AsyncIterable[Any]], request: Request, ) -> EventSourceResponse: @@ -153,7 +152,7 @@ async def _handle_streaming_request( message=f'Failed to pre-consume request body: {e}' ) from e - call_context = self._build_call_context(request) + call_context = _build_call_context(request) # Eagerly fetch the first item from the stream so that errors raised # before any event is yielded (e.g. validation, parsing, or handler