Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/a2a/server/apps/rest/rest_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,26 @@ async def _handle_streaming_request(

call_context = self._build_call_context(request)

async def event_generator(
stream: AsyncIterable[Any],
) -> AsyncIterator[str]:
# Eagerly fetch the first item from the stream so that errors raised
# before any event is yielded (e.g. validation, parsing, or handler
# failures) propagate here and are caught by
# @rest_stream_error_handler, which returns a JSONResponse with
# the correct HTTP status code instead of starting an SSE stream.
# Without this, the error would be raised after SSE headers are
# already sent, and the client would see a broken stream instead
# of a proper error response.
stream = aiter(method(request, call_context))
try:
first_item = await anext(stream)
except StopAsyncIteration:
return EventSourceResponse(iter([]))

async def event_generator() -> AsyncIterator[str]:
yield json.dumps(first_item)
async for item in stream:
yield json.dumps(item)

return EventSourceResponse(
event_generator(method(request, call_context))
)
return EventSourceResponse(event_generator())

async def handle_get_agent_card(
self, request: Request, call_context: ServerCallContext | None = None
Expand Down
88 changes: 75 additions & 13 deletions tests/integration/test_client_server_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio

from collections.abc import AsyncGenerator
from typing import Any, NamedTuple
from unittest.mock import ANY, AsyncMock, patch
Expand All @@ -8,22 +7,25 @@
import httpx
import pytest
import pytest_asyncio

from cryptography.hazmat.primitives.asymmetric import ec
from google.protobuf.json_format import MessageToDict
from google.protobuf.timestamp_pb2 import Timestamp

from a2a.client import Client, ClientConfig
from a2a.client.base_client import BaseClient
from a2a.client.card_resolver import A2ACardResolver
from a2a.client.client_factory import ClientFactory
from a2a.client.client import ClientCallContext
from a2a.client.client_factory import ClientFactory
from a2a.client.service_parameters import (
ServiceParametersFactory,
with_a2a_extensions,
)
from a2a.client.transports import JsonRpcTransport, RestTransport
from starlette.applications import Starlette

# Compat v0.3 imports for dedicated tests
from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
from a2a.server.apps import A2ARESTFastAPIApplication
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
from a2a.server.request_handlers import GrpcHandler, RequestHandler
Expand Down Expand Up @@ -52,12 +54,10 @@
TaskStatus,
TaskStatusUpdateEvent,
)
from a2a.utils.constants import (
TransportProtocol,
)
from a2a.utils.constants import TransportProtocol
from a2a.utils.errors import (
ExtendedAgentCardNotConfiguredError,
ContentTypeNotSupportedError,
ExtendedAgentCardNotConfiguredError,
ExtensionSupportRequiredError,
InternalError,
InvalidAgentResponseError,
Expand All @@ -75,11 +75,6 @@
create_signature_verifier,
)

# Compat v0.3 imports for dedicated tests
from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler


# --- Test Constants ---

TASK_FROM_STREAM = Task(
Expand Down Expand Up @@ -368,9 +363,9 @@ def grpc_03_setup(
) -> TransportSetup:
"""Sets up the CompatGrpcTransport and in-process 0.3 server."""
server_address, handler = grpc_03_server_and_handler
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport
from a2a.client.base_client import BaseClient
from a2a.client.client import ClientConfig
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport

channel = grpc.aio.insecure_channel(server_address)
transport = CompatGrpcTransport(channel=channel, agent_card=agent_card)
Expand Down Expand Up @@ -926,6 +921,73 @@ async def test_client_handles_a2a_errors(transport_setups, error_cls) -> None:
await client.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
'error_cls',
[
TaskNotFoundError,
TaskNotCancelableError,
PushNotificationNotSupportedError,
UnsupportedOperationError,
ContentTypeNotSupportedError,
InvalidAgentResponseError,
ExtendedAgentCardNotConfiguredError,
ExtensionSupportRequiredError,
VersionNotSupportedError,
],
)
@pytest.mark.parametrize(
'handler_attr, client_method, request_params',
[
pytest.param(
'on_message_send_stream',
'send_message',
SendMessageRequest(
message=Message(
role=Role.ROLE_USER,
message_id='msg-integration-test',
parts=[Part(text='Hello, integration test!')],
)
),
id='stream',
),
pytest.param(
'on_subscribe_to_task',
'subscribe',
SubscribeToTaskRequest(id='some-id'),
id='subscribe',
),
],
)
async def test_client_handles_a2a_errors_streaming(
transport_setups, error_cls, handler_attr, client_method, request_params
) -> None:
"""Integration test to verify error propagation from streaming handlers to client.

The handler raises an A2AError before yielding any events. All transports
must propagate this as the exact error_cls, not wrapped in an ExceptionGroup
or converted to a generic client error.
"""
client = transport_setups.client
handler = transport_setups.handler

async def mock_generator(*args, **kwargs):
raise error_cls('Test error message')
yield

getattr(handler, handler_attr).side_effect = mock_generator

with pytest.raises(error_cls) as exc_info:
async for _ in getattr(client, client_method)(request=request_params):
pass

assert 'Test error message' in str(exc_info.value)

getattr(handler, handler_attr).side_effect = None

await client.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
'request_kwargs, expected_error_code',
Expand Down
Loading