From a4bc9e6e75a0e00ee7ab66f6776d934fdaf4fe9d Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 10:26:38 +0000 Subject: [PATCH] fix: fix REST error handling Do one iteration to catch exceptions occurred beforehand to return an error instead of sending headers for SSE. Error handling during the execution is not defined in the spec: https://github.com/a2aproject/A2A/issues/1262. --- src/a2a/server/apps/rest/rest_adapter.py | 23 +++-- .../test_client_server_integration.py | 88 ++++++++++++++++--- 2 files changed, 92 insertions(+), 19 deletions(-) diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 6b8abb99e..e892d62f0 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -152,15 +152,26 @@ async def _handle_streaming_request( call_context = self._build_call_context(request) - async def event_generator( - stream: AsyncIterable[Any], - ) -> AsyncIterator[str]: + # Eagerly fetch the first item from the stream so that errors raised + # before any event is yielded (e.g. validation, parsing, or handler + # failures) propagate here and are caught by + # @rest_stream_error_handler, which returns a JSONResponse with + # the correct HTTP status code instead of starting an SSE stream. + # Without this, the error would be raised after SSE headers are + # already sent, and the client would see a broken stream instead + # of a proper error response. + stream = aiter(method(request, call_context)) + try: + first_item = await anext(stream) + except StopAsyncIteration: + return EventSourceResponse(iter([])) + + async def event_generator() -> AsyncIterator[str]: + yield json.dumps(first_item) async for item in stream: yield json.dumps(item) - return EventSourceResponse( - event_generator(method(request, call_context)) - ) + return EventSourceResponse(event_generator()) async def handle_get_agent_card( self, request: Request, call_context: ServerCallContext | None = None diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index b1013e98e..09aadd021 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,5 +1,4 @@ import asyncio - from collections.abc import AsyncGenerator from typing import Any, NamedTuple from unittest.mock import ANY, AsyncMock, patch @@ -8,7 +7,6 @@ import httpx import pytest import pytest_asyncio - from cryptography.hazmat.primitives.asymmetric import ec from google.protobuf.json_format import MessageToDict from google.protobuf.timestamp_pb2 import Timestamp @@ -16,13 +14,17 @@ from a2a.client import Client, ClientConfig from a2a.client.base_client import BaseClient from a2a.client.card_resolver import A2ACardResolver -from a2a.client.client_factory import ClientFactory from a2a.client.client import ClientCallContext +from a2a.client.client_factory import ClientFactory from a2a.client.service_parameters import ( ServiceParametersFactory, with_a2a_extensions, ) from a2a.client.transports import JsonRpcTransport, RestTransport + +# Compat v0.3 imports for dedicated tests +from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc +from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication from a2a.server.request_handlers import GrpcHandler, RequestHandler from a2a.types import a2a_pb2_grpc @@ -50,12 +52,10 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils.constants import ( - TransportProtocol, -) +from a2a.utils.constants import TransportProtocol from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, ContentTypeNotSupportedError, + ExtendedAgentCardNotConfiguredError, ExtensionSupportRequiredError, InternalError, InvalidAgentResponseError, @@ -73,11 +73,6 @@ create_signature_verifier, ) -# Compat v0.3 imports for dedicated tests -from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc -from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler - - # --- Test Constants --- TASK_FROM_STREAM = Task( @@ -360,9 +355,9 @@ def grpc_03_setup( ) -> TransportSetup: """Sets up the CompatGrpcTransport and in-process 0.3 server.""" server_address, handler = grpc_03_server_and_handler - from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport from a2a.client.base_client import BaseClient from a2a.client.client import ClientConfig + from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport channel = grpc.aio.insecure_channel(server_address) transport = CompatGrpcTransport(channel=channel, agent_card=agent_card) @@ -909,6 +904,73 @@ async def test_client_handles_a2a_errors(transport_setups, error_cls) -> None: await client.close() +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'error_cls', + [ + TaskNotFoundError, + TaskNotCancelableError, + PushNotificationNotSupportedError, + UnsupportedOperationError, + ContentTypeNotSupportedError, + InvalidAgentResponseError, + ExtendedAgentCardNotConfiguredError, + ExtensionSupportRequiredError, + VersionNotSupportedError, + ], +) +@pytest.mark.parametrize( + 'handler_attr, client_method, request_params', + [ + pytest.param( + 'on_message_send_stream', + 'send_message', + SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg-integration-test', + parts=[Part(text='Hello, integration test!')], + ) + ), + id='stream', + ), + pytest.param( + 'on_subscribe_to_task', + 'subscribe', + SubscribeToTaskRequest(id='some-id'), + id='subscribe', + ), + ], +) +async def test_client_handles_a2a_errors_streaming( + transport_setups, error_cls, handler_attr, client_method, request_params +) -> None: + """Integration test to verify error propagation from streaming handlers to client. + + The handler raises an A2AError before yielding any events. All transports + must propagate this as the exact error_cls, not wrapped in an ExceptionGroup + or converted to a generic client error. + """ + client = transport_setups.client + handler = transport_setups.handler + + async def mock_generator(*args, **kwargs): + raise error_cls('Test error message') + yield + + getattr(handler, handler_attr).side_effect = mock_generator + + with pytest.raises(error_cls) as exc_info: + async for _ in getattr(client, client_method)(request=request_params): + pass + + assert 'Test error message' in str(exc_info.value) + + getattr(handler, handler_attr).side_effect = None + + await client.close() + + @pytest.mark.asyncio @pytest.mark.parametrize( 'request_kwargs, expected_error_code',