diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index ebf996a47..2a1ed95c3 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -149,15 +149,26 @@ async def _handle_streaming_request( call_context = self._build_call_context(request) - async def event_generator( - stream: AsyncIterable[Any], - ) -> AsyncIterator[str]: + # Eagerly fetch the first item from the stream so that errors raised + # before any event is yielded (e.g. validation, parsing, or handler + # failures) propagate here and are caught by + # @rest_stream_error_handler, which returns a JSONResponse with + # the correct HTTP status code instead of starting an SSE stream. + # Without this, the error would be raised after SSE headers are + # already sent, and the client would see a broken stream instead + # of a proper error response. + stream = aiter(method(request, call_context)) + try: + first_item = await anext(stream) + except StopAsyncIteration: + return EventSourceResponse(iter([])) + + async def event_generator() -> AsyncIterator[str]: + yield json.dumps(first_item) async for item in stream: yield json.dumps(item) - return EventSourceResponse( - event_generator(method(request, call_context)) - ) + return EventSourceResponse(event_generator()) async def handle_get_agent_card( self, request: Request, call_context: ServerCallContext | None = None diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 94d0313a6..2df24790b 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,5 +1,4 @@ import asyncio - from collections.abc import AsyncGenerator from typing import Any, NamedTuple from unittest.mock import ANY, AsyncMock, patch @@ -8,7 +7,6 @@ import httpx import pytest import pytest_asyncio - from cryptography.hazmat.primitives.asymmetric import ec from google.protobuf.json_format import MessageToDict from google.protobuf.timestamp_pb2 import Timestamp @@ -16,14 +14,18 @@ from a2a.client import Client, ClientConfig from a2a.client.base_client import BaseClient from a2a.client.card_resolver import A2ACardResolver -from a2a.client.client_factory import ClientFactory from a2a.client.client import ClientCallContext +from a2a.client.client_factory import ClientFactory from a2a.client.service_parameters import ( ServiceParametersFactory, with_a2a_extensions, ) from a2a.client.transports import JsonRpcTransport, RestTransport from starlette.applications import Starlette + +# Compat v0.3 imports for dedicated tests +from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc +from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler from a2a.server.apps import A2ARESTFastAPIApplication from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.request_handlers import GrpcHandler, RequestHandler @@ -52,12 +54,10 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils.constants import ( - TransportProtocol, -) +from a2a.utils.constants import TransportProtocol from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, ContentTypeNotSupportedError, + ExtendedAgentCardNotConfiguredError, ExtensionSupportRequiredError, InternalError, InvalidAgentResponseError, @@ -75,11 +75,6 @@ create_signature_verifier, ) -# Compat v0.3 imports for dedicated tests -from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc -from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler - - # --- Test Constants --- TASK_FROM_STREAM = Task( @@ -368,9 +363,9 @@ def grpc_03_setup( ) -> TransportSetup: """Sets up the CompatGrpcTransport and in-process 0.3 server.""" server_address, handler = grpc_03_server_and_handler - from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport from a2a.client.base_client import BaseClient from a2a.client.client import ClientConfig + from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport channel = grpc.aio.insecure_channel(server_address) transport = CompatGrpcTransport(channel=channel, agent_card=agent_card) @@ -926,6 +921,73 @@ async def test_client_handles_a2a_errors(transport_setups, error_cls) -> None: await client.close() +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'error_cls', + [ + TaskNotFoundError, + TaskNotCancelableError, + PushNotificationNotSupportedError, + UnsupportedOperationError, + ContentTypeNotSupportedError, + InvalidAgentResponseError, + ExtendedAgentCardNotConfiguredError, + ExtensionSupportRequiredError, + VersionNotSupportedError, + ], +) +@pytest.mark.parametrize( + 'handler_attr, client_method, request_params', + [ + pytest.param( + 'on_message_send_stream', + 'send_message', + SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg-integration-test', + parts=[Part(text='Hello, integration test!')], + ) + ), + id='stream', + ), + pytest.param( + 'on_subscribe_to_task', + 'subscribe', + SubscribeToTaskRequest(id='some-id'), + id='subscribe', + ), + ], +) +async def test_client_handles_a2a_errors_streaming( + transport_setups, error_cls, handler_attr, client_method, request_params +) -> None: + """Integration test to verify error propagation from streaming handlers to client. + + The handler raises an A2AError before yielding any events. All transports + must propagate this as the exact error_cls, not wrapped in an ExceptionGroup + or converted to a generic client error. + """ + client = transport_setups.client + handler = transport_setups.handler + + async def mock_generator(*args, **kwargs): + raise error_cls('Test error message') + yield + + getattr(handler, handler_attr).side_effect = mock_generator + + with pytest.raises(error_cls) as exc_info: + async for _ in getattr(client, client_method)(request=request_params): + pass + + assert 'Test error message' in str(exc_info.value) + + getattr(handler, handler_attr).side_effect = None + + await client.close() + + @pytest.mark.asyncio @pytest.mark.parametrize( 'request_kwargs, expected_error_code',