diff --git a/src/a2a/client/transports/http_helpers.py b/src/a2a/client/transports/http_helpers.py index 0a5721b50..301782e36 100644 --- a/src/a2a/client/transports/http_helpers.py +++ b/src/a2a/client/transports/http_helpers.py @@ -78,7 +78,22 @@ async def send_http_stream_request( async with aconnect_sse( httpx_client, method, url, **kwargs ) as event_source: - event_source.response.raise_for_status() + try: + event_source.response.raise_for_status() + except httpx.HTTPStatusError as e: + # Read upfront streaming error content immediately, otherwise lower-level handlers + # (e.g. response.json()) crash with 'ResponseNotRead' Access errors. + await event_source.response.aread() + raise e + + # If the response is not a stream, read it standardly (e.g., upfront JSON-RPC error payload) + if 'text/event-stream' not in event_source.response.headers.get( + 'content-type', '' + ): + content = await event_source.response.aread() + yield content.decode('utf-8') + return + async for sse in event_source.aiter_sse(): if not sse.data: continue diff --git a/src/a2a/compat/v0_3/grpc_handler.py b/src/a2a/compat/v0_3/grpc_handler.py index a298a6c5e..eb72cf76b 100644 --- a/src/a2a/compat/v0_3/grpc_handler.py +++ b/src/a2a/compat/v0_3/grpc_handler.py @@ -29,7 +29,7 @@ from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types.a2a_pb2 import AgentCard from a2a.utils.errors import A2AError, InvalidParamsError -from a2a.utils.helpers import maybe_await, validate, validate_async_generator +from a2a.utils.helpers import maybe_await, validate logger = logging.getLogger(__name__) @@ -170,10 +170,6 @@ async def _handler( context, _handler, a2a_v0_3_pb2.SendMessageResponse() ) - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def SendStreamingMessage( self, request: a2a_v0_3_pb2.SendMessageRequest, @@ -181,6 +177,10 @@ async def SendStreamingMessage( ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: """Handles the 'SendStreamingMessage' gRPC method (v0.3).""" + @validate( + lambda _: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) async def _handler( server_context: ServerCallContext, ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: @@ -233,10 +233,6 @@ async def _handler( return await self._handle_unary(context, _handler, a2a_v0_3_pb2.Task()) - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def TaskSubscription( self, request: a2a_v0_3_pb2.TaskSubscriptionRequest, @@ -244,6 +240,10 @@ async def TaskSubscription( ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: """Handles the 'TaskSubscription' gRPC method (v0.3).""" + @validate( + lambda _: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) async def _handler( server_context: ServerCallContext, ) -> AsyncIterable[a2a_v0_3_pb2.StreamResponse]: @@ -260,10 +260,6 @@ async def _handler( async for item in self._handle_stream(context, _handler): yield item - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) async def CreateTaskPushNotificationConfig( self, request: a2a_v0_3_pb2.CreateTaskPushNotificationConfigRequest, @@ -271,6 +267,10 @@ async def CreateTaskPushNotificationConfig( ) -> a2a_v0_3_pb2.TaskPushNotificationConfig: """Handles the 'CreateTaskPushNotificationConfig' gRPC method (v0.3).""" + @validate( + lambda _: self.agent_card.capabilities.push_notifications, + 'Push notifications are not supported by the agent', + ) async def _handler( server_context: ServerCallContext, ) -> a2a_v0_3_pb2.TaskPushNotificationConfig: diff --git a/src/a2a/compat/v0_3/rest_handler.py b/src/a2a/compat/v0_3/rest_handler.py index 8d39e9b8b..470f94b3e 100644 --- a/src/a2a/compat/v0_3/rest_handler.py +++ b/src/a2a/compat/v0_3/rest_handler.py @@ -31,7 +31,6 @@ from a2a.utils import constants from a2a.utils.helpers import ( validate, - validate_async_generator, validate_version, ) from a2a.utils.telemetry import SpanKind, trace_class @@ -85,7 +84,7 @@ async def on_message_send( return MessageToDict(pb2_v03_resp) @validate_version(constants.PROTOCOL_VERSION_0_3) - @validate_async_generator( + @validate( lambda self: self.agent_card.capabilities.streaming, 'Streaming is not supported by the agent', ) @@ -143,7 +142,7 @@ async def on_cancel_task( return MessageToDict(pb2_v03_task) @validate_version(constants.PROTOCOL_VERSION_0_3) - @validate_async_generator( + @validate( lambda self: self.agent_card.capabilities.streaming, 'Streaming is not supported by the agent', ) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 326dea236..b290fbf44 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterable, Awaitable, Callable +from typing import TypeVar try: @@ -34,8 +35,12 @@ from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import AgentCard from a2a.utils import proto_utils -from a2a.utils.errors import A2A_ERROR_REASONS, A2AError, TaskNotFoundError -from a2a.utils.helpers import maybe_await, validate, validate_async_generator +from a2a.utils.errors import ( + A2A_ERROR_REASONS, + A2AError, + TaskNotFoundError, +) +from a2a.utils.helpers import maybe_await, validate logger = logging.getLogger(__name__) @@ -101,6 +106,9 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: } +TResponse = TypeVar('TResponse') + + class GrpcHandler(a2a_grpc.A2AServiceServicer): """Maps incoming gRPC requests to the appropriate request handler method.""" @@ -128,284 +136,241 @@ def __init__( self.context_builder = context_builder or DefaultCallContextBuilder() self.card_modifier = card_modifier + async def _handle_unary( + self, + request: message.Message, + context: grpc.aio.ServicerContext, + handler_func: Callable[[ServerCallContext], Awaitable[TResponse]], + default_response: TResponse, + ) -> TResponse: + """Centralized error handling and context management for unary calls.""" + try: + server_context = self._build_call_context(context, request) + result = await handler_func(server_context) + self._set_extension_metadata(context, server_context) + except A2AError as e: + await self.abort_context(e, context) + else: + return result + return default_response + + async def _handle_stream( + self, + request: message.Message, + context: grpc.aio.ServicerContext, + handler_func: Callable[[ServerCallContext], AsyncIterable[TResponse]], + ) -> AsyncIterable[TResponse]: + """Centralized error handling and context management for streaming calls.""" + try: + server_context = self._build_call_context(context, request) + async for item in handler_func(server_context): + yield item + self._set_extension_metadata(context, server_context) + except A2AError as e: + await self.abort_context(e, context) + async def SendMessage( self, request: a2a_pb2.SendMessageRequest, context: grpc.aio.ServicerContext, ) -> a2a_pb2.SendMessageResponse: - """Handles the 'SendMessage' gRPC method. - - Args: - request: The incoming `SendMessageRequest` object. - context: Context provided by the server. + """Handles the 'SendMessage' gRPC method.""" - Returns: - A `SendMessageResponse` object containing the result (Task or - Message) or throws an error response if an A2AError is raised - by the handler. - """ - try: - # Construct the server context object - server_context = self._build_call_context(context, request) + async def _handler( + server_context: ServerCallContext, + ) -> a2a_pb2.SendMessageResponse: task_or_message = await self.request_handler.on_message_send( request, server_context ) - self._set_extension_metadata(context, server_context) if isinstance(task_or_message, a2a_pb2.Task): return a2a_pb2.SendMessageResponse(task=task_or_message) return a2a_pb2.SendMessageResponse(message=task_or_message) - except A2AError as e: - await self.abort_context(e, context) - return a2a_pb2.SendMessageResponse() - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) + return await self._handle_unary( + request, context, _handler, a2a_pb2.SendMessageResponse() + ) + async def SendStreamingMessage( self, request: a2a_pb2.SendMessageRequest, context: grpc.aio.ServicerContext, ) -> AsyncIterable[a2a_pb2.StreamResponse]: - """Handles the 'StreamMessage' gRPC method. - - Yields response objects as they are produced by the underlying handler's - stream. - - Args: - request: The incoming `SendMessageRequest` object. - context: Context provided by the server. + """Handles the 'StreamMessage' gRPC method.""" - Yields: - `StreamResponse` objects containing streaming events - (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) - or gRPC error responses if an A2AError is raised. - """ - server_context = self._build_call_context(context, request) - try: + @validate( + lambda _: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def _handler( + server_context: ServerCallContext, + ) -> AsyncIterable[a2a_pb2.StreamResponse]: async for event in self.request_handler.on_message_send_stream( request, server_context ): yield proto_utils.to_stream_response(event) - self._set_extension_metadata(context, server_context) - except A2AError as e: - await self.abort_context(e, context) - return + + async for item in self._handle_stream(request, context, _handler): + yield item async def CancelTask( self, request: a2a_pb2.CancelTaskRequest, context: grpc.aio.ServicerContext, ) -> a2a_pb2.Task: - """Handles the 'CancelTask' gRPC method. + """Handles the 'CancelTask' gRPC method.""" - Args: - request: The incoming `CancelTaskRequest` object. - context: Context provided by the server. - - Returns: - A `Task` object containing the updated Task or a gRPC error. - """ - try: - server_context = self._build_call_context(context, request) + async def _handler(server_context: ServerCallContext) -> a2a_pb2.Task: task = await self.request_handler.on_cancel_task( request, server_context ) if task: return task - await self.abort_context(TaskNotFoundError(), context) - except A2AError as e: - await self.abort_context(e, context) - return a2a_pb2.Task() + raise TaskNotFoundError + + return await self._handle_unary( + request, context, _handler, a2a_pb2.Task() + ) - @validate_async_generator( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) async def SubscribeToTask( self, request: a2a_pb2.SubscribeToTaskRequest, context: grpc.aio.ServicerContext, ) -> AsyncIterable[a2a_pb2.StreamResponse]: - """Handles the 'SubscribeToTask' gRPC method. - - Yields response objects as they are produced by the underlying handler's - stream. - - Args: - request: The incoming `SubscribeToTaskRequest` object. - context: Context provided by the server. + """Handles the 'SubscribeToTask' gRPC method.""" - Yields: - `StreamResponse` objects containing streaming events - """ - try: - server_context = self._build_call_context(context, request) + @validate( + lambda _: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def _handler( + server_context: ServerCallContext, + ) -> AsyncIterable[a2a_pb2.StreamResponse]: async for event in self.request_handler.on_subscribe_to_task( - request, - server_context, + request, server_context ): yield proto_utils.to_stream_response(event) - except A2AError as e: - await self.abort_context(e, context) + + async for item in self._handle_stream(request, context, _handler): + yield item async def GetTaskPushNotificationConfig( self, request: a2a_pb2.GetTaskPushNotificationConfigRequest, context: grpc.aio.ServicerContext, ) -> a2a_pb2.TaskPushNotificationConfig: - """Handles the 'GetTaskPushNotificationConfig' gRPC method. - - Args: - request: The incoming `GetTaskPushNotificationConfigRequest` object. - context: Context provided by the server. + """Handles the 'GetTaskPushNotificationConfig' gRPC method.""" - Returns: - A `TaskPushNotificationConfig` object containing the config. - """ - try: - server_context = self._build_call_context(context, request) + async def _handler( + server_context: ServerCallContext, + ) -> a2a_pb2.TaskPushNotificationConfig: return ( await self.request_handler.on_get_task_push_notification_config( - request, - server_context, + request, server_context ) ) - except A2AError as e: - await self.abort_context(e, context) - return a2a_pb2.TaskPushNotificationConfig() - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) + return await self._handle_unary( + request, context, _handler, a2a_pb2.TaskPushNotificationConfig() + ) + async def CreateTaskPushNotificationConfig( self, request: a2a_pb2.TaskPushNotificationConfig, context: grpc.aio.ServicerContext, ) -> a2a_pb2.TaskPushNotificationConfig: - """Handles the 'CreateTaskPushNotificationConfig' gRPC method. - - Requires the agent to support push notifications. - - Args: - request: The incoming `TaskPushNotificationConfig` object. - context: Context provided by the server. - - Returns: - A `TaskPushNotificationConfig` object + """Handles the 'CreateTaskPushNotificationConfig' gRPC method.""" - Raises: - A2AError: If push notifications are not supported by the agent - (due to the `@validate` decorator). - """ - try: - server_context = self._build_call_context(context, request) + @validate( + lambda _: self.agent_card.capabilities.push_notifications, + 'Push notifications are not supported by the agent', + ) + async def _handler( + server_context: ServerCallContext, + ) -> a2a_pb2.TaskPushNotificationConfig: return await self.request_handler.on_create_task_push_notification_config( - request, - server_context, + request, server_context ) - except A2AError as e: - await self.abort_context(e, context) - return a2a_pb2.TaskPushNotificationConfig() + + return await self._handle_unary( + request, context, _handler, a2a_pb2.TaskPushNotificationConfig() + ) async def ListTaskPushNotificationConfigs( self, request: a2a_pb2.ListTaskPushNotificationConfigsRequest, context: grpc.aio.ServicerContext, ) -> a2a_pb2.ListTaskPushNotificationConfigsResponse: - """Handles the 'ListTaskPushNotificationConfig' gRPC method. + """Handles the 'ListTaskPushNotificationConfig' gRPC method.""" - Args: - request: The incoming `ListTaskPushNotificationConfigsRequest` object. - context: Context provided by the server. - - Returns: - A `ListTaskPushNotificationConfigsResponse` object containing the configs. - """ - try: - server_context = self._build_call_context(context, request) + async def _handler( + server_context: ServerCallContext, + ) -> a2a_pb2.ListTaskPushNotificationConfigsResponse: return await self.request_handler.on_list_task_push_notification_configs( - request, - server_context, + request, server_context ) - except A2AError as e: - await self.abort_context(e, context) - return a2a_pb2.ListTaskPushNotificationConfigsResponse() + + return await self._handle_unary( + request, + context, + _handler, + a2a_pb2.ListTaskPushNotificationConfigsResponse(), + ) async def DeleteTaskPushNotificationConfig( self, request: a2a_pb2.DeleteTaskPushNotificationConfigRequest, context: grpc.aio.ServicerContext, ) -> empty_pb2.Empty: - """Handles the 'DeleteTaskPushNotificationConfig' gRPC method. - - Args: - request: The incoming `DeleteTaskPushNotificationConfigRequest` object. - context: Context provided by the server. + """Handles the 'DeleteTaskPushNotificationConfig' gRPC method.""" - Returns: - An empty `Empty` object. - """ - try: - server_context = self._build_call_context(context, request) + async def _handler( + server_context: ServerCallContext, + ) -> empty_pb2.Empty: await self.request_handler.on_delete_task_push_notification_config( - request, - server_context, + request, server_context ) return empty_pb2.Empty() - except A2AError as e: - await self.abort_context(e, context) - return empty_pb2.Empty() + + return await self._handle_unary( + request, context, _handler, empty_pb2.Empty() + ) async def GetTask( self, request: a2a_pb2.GetTaskRequest, context: grpc.aio.ServicerContext, ) -> a2a_pb2.Task: - """Handles the 'GetTask' gRPC method. + """Handles the 'GetTask' gRPC method.""" - Args: - request: The incoming `GetTaskRequest` object. - context: Context provided by the server. - - Returns: - A `Task` object. - """ - try: - server_context = self._build_call_context(context, request) + async def _handler(server_context: ServerCallContext) -> a2a_pb2.Task: task = await self.request_handler.on_get_task( request, server_context ) if task: return task - await self.abort_context(TaskNotFoundError(), context) - except A2AError as e: - await self.abort_context(e, context) - return a2a_pb2.Task() + raise TaskNotFoundError + + return await self._handle_unary( + request, context, _handler, a2a_pb2.Task() + ) async def ListTasks( self, request: a2a_pb2.ListTasksRequest, context: grpc.aio.ServicerContext, ) -> a2a_pb2.ListTasksResponse: - """Handles the 'ListTasks' gRPC method. - - Args: - request: The incoming `ListTasksRequest` object. - context: Context provided by the server. + """Handles the 'ListTasks' gRPC method.""" - Returns: - A `ListTasksResponse` object. - """ - try: - server_context = self._build_call_context(context, request) + async def _handler( + server_context: ServerCallContext, + ) -> a2a_pb2.ListTasksResponse: return await self.request_handler.on_list_tasks( request, server_context ) - except A2AError as e: - await self.abort_context(e, context) - return a2a_pb2.ListTasksResponse() + + return await self._handle_unary( + request, context, _handler, a2a_pb2.ListTasksResponse() + ) async def GetExtendedAgentCard( self, diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py index dfedd3b11..06188e412 100644 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ b/src/a2a/server/request_handlers/jsonrpc_handler.py @@ -52,7 +52,6 @@ from a2a.utils.helpers import ( maybe_await, validate, - validate_async_generator, validate_version, ) from a2a.utils.telemetry import SpanKind, trace_class @@ -178,7 +177,7 @@ async def on_message_send( return _build_error_response(request_id, e) @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate_async_generator( + @validate( lambda self: self.agent_card.capabilities.streaming, 'Streaming is not supported by the agent', ) @@ -244,7 +243,7 @@ async def on_cancel_task( return _build_error_response(request_id, TaskNotFoundError()) @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate_async_generator( + @validate( lambda self: self.agent_card.capabilities.streaming, 'Streaming is not supported by the agent', ) diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 96028115a..af889d9df 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -31,7 +31,6 @@ from a2a.utils.errors import TaskNotFoundError from a2a.utils.helpers import ( validate, - validate_async_generator, validate_version, ) from a2a.utils.telemetry import SpanKind, trace_class @@ -93,7 +92,7 @@ async def on_message_send( return MessageToDict(response) @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate_async_generator( + @validate( lambda self: self.agent_card.capabilities.streaming, 'Streaming is not supported by the agent', ) @@ -147,7 +146,7 @@ async def on_cancel_task( raise TaskNotFoundError @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate_async_generator( + @validate( lambda self: self.agent_card.capabilities.streaming, 'Streaming is not supported by the agent', ) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index d215f84d8..e5b37e5f4 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -232,91 +232,6 @@ def sync_wrapper(self: Any, *args, **kwargs) -> Any: return decorator -def validate_async_generator( - expression: Callable[[Any], bool], error_message: str | None = None -): - """Decorator that validates if a given expression evaluates to True for async generators. - - Typically used on class methods to check capabilities or configuration - before executing the method's logic. If the expression is False, - an `UnsupportedOperationError` is raised. - - Args: - expression: A callable that takes the instance (`self`) as its argument - and returns a boolean. - error_message: An optional custom error message for the `UnsupportedOperationError`. - If None, the string representation of the expression will be used. - - Examples: - Streaming capability validation with success case: - >>> import asyncio - >>> from a2a.utils.errors import UnsupportedOperationError - >>> - >>> class StreamingAgent: - ... def __init__(self, streaming_enabled: bool): - ... self.streaming_enabled = streaming_enabled - ... - ... @validate_async_generator( - ... lambda self: self.streaming_enabled, - ... 'Streaming is not supported by this agent', - ... ) - ... async def stream_messages(self, count: int): - ... for i in range(count): - ... yield f'Message {i}' - >>> - >>> async def run_streaming_test(): - ... # Successful streaming - ... agent = StreamingAgent(streaming_enabled=True) - ... async for msg in agent.stream_messages(2): - ... print(msg) - >>> - >>> asyncio.run(run_streaming_test()) - Message 0 - Message 1 - - Error case - validation fails: - >>> class FeatureAgent: - ... def __init__(self): - ... self.features = {'real_time': False} - ... - ... @validate_async_generator( - ... lambda self: self.features.get('real_time', False), - ... 'Real-time feature must be enabled to stream updates', - ... ) - ... async def real_time_updates(self): - ... yield 'This should not be yielded' - >>> - >>> async def run_error_test(): - ... agent = FeatureAgent() - ... try: - ... async for _ in agent.real_time_updates(): - ... pass - ... except UnsupportedOperationError as e: - ... print(e.message) - >>> - >>> asyncio.run(run_error_test()) - Real-time feature must be enabled to stream updates - - Note: - This decorator is specifically for async generator methods (async def with yield). - The validation happens before the generator starts yielding values. - """ - - def decorator(function): - @functools.wraps(function) - async def wrapper(self, *args, **kwargs): - if not expression(self): - final_message = error_message or str(expression) - logger.error('Unsupported Operation: %s', final_message) - raise UnsupportedOperationError(message=final_message) - async for i in function(self, *args, **kwargs): - yield i - - return wrapper - - return decorator - - def are_modalities_compatible( server_output_modes: list[str] | None, client_output_modes: list[str] | None ) -> bool: diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index b568865e6..5741aa003 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -442,6 +442,9 @@ async def test_send_message_streaming_sse_error( request = create_send_message_request() mock_event_source = AsyncMock() mock_event_source.response.raise_for_status = MagicMock() + mock_event_source.response.headers = { + 'content-type': 'text/event-stream' + } mock_event_source.aiter_sse = MagicMock( side_effect=SSEError('Simulated SSE error') ) @@ -463,6 +466,9 @@ async def test_send_message_streaming_request_error( request = create_send_message_request() mock_event_source = AsyncMock() mock_event_source.response.raise_for_status = MagicMock() + mock_event_source.response.headers = { + 'content-type': 'text/event-stream' + } mock_event_source.aiter_sse = MagicMock( side_effect=httpx.RequestError( 'Simulated request error', request=MagicMock() @@ -486,6 +492,9 @@ async def test_send_message_streaming_timeout( request = create_send_message_request() mock_event_source = AsyncMock() mock_event_source.response.raise_for_status = MagicMock() + mock_event_source.response.headers = { + 'content-type': 'text/event-stream' + } mock_event_source.aiter_sse = MagicMock( side_effect=httpx.TimeoutException('Timeout') ) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 944110a49..7648de577 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -87,6 +87,9 @@ async def test_send_message_streaming_timeout( ) mock_event_source = AsyncMock(spec=EventSource) mock_event_source.response = MagicMock(spec=httpx.Response) + mock_event_source.response.headers = { + 'content-type': 'text/event-stream' + } mock_event_source.response.raise_for_status.return_value = None mock_event_source.aiter_sse.side_effect = httpx.TimeoutException( 'Read timed out' @@ -295,6 +298,10 @@ async def test_send_message_streaming_with_new_extensions( ) mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.response = MagicMock(spec=httpx.Response) + mock_event_source.response.headers = { + 'content-type': 'text/event-stream' + } mock_event_source.aiter_sse.return_value = async_iterable_from_list([]) mock_aconnect_sse.return_value.__aenter__.return_value = ( mock_event_source @@ -708,6 +715,9 @@ async def test_rest_streaming_methods_prepend_tenant( # noqa: PLR0913 # 2. Setup mocks mock_event_source = AsyncMock(spec=EventSource) mock_event_source.response = MagicMock(spec=httpx.Response) + mock_event_source.response.headers = { + 'content-type': 'text/event-stream' + } mock_event_source.response.raise_for_status.return_value = None async def empty_aiter(): diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index e239d780f..b1013e98e 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -73,6 +73,10 @@ create_signature_verifier, ) +# Compat v0.3 imports for dedicated tests +from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc +from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler + # --- Test Constants --- @@ -292,6 +296,30 @@ def transport_setups(request) -> TransportSetup: return request.getfixturevalue(request.param) +@pytest.fixture( + params=[ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + pytest.param('grpc_setup', id='gRPC'), + pytest.param('grpc_03_setup', id='gRPC-0.3'), + ] +) +def error_handling_setups(request) -> TransportSetup: + """Parametrized fixture for error tests including compat 0.3 endpoint verification.""" + return request.getfixturevalue(request.param) + + +@pytest.fixture( + params=[ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + ] +) +def http_transport_setups(request) -> TransportSetup: + """Parametrized fixture that runs tests against HTTP-based transports only.""" + return request.getfixturevalue(request.param) + + # --- gRPC Setup --- @@ -307,7 +335,46 @@ async def grpc_server_and_handler( a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) await server.start() yield server_address, mock_request_handler - await server.stop(0) + + +@pytest_asyncio.fixture +async def grpc_03_server_and_handler( + mock_request_handler: AsyncMock, agent_card: AgentCard +) -> AsyncGenerator[tuple[str, AsyncMock], None]: + """Creates and manages an in-process v0.3 compat gRPC test server.""" + server = grpc.aio.server() + port = server.add_insecure_port('[::]:0') + server_address = f'localhost:{port}' + servicer = CompatGrpcHandler(agent_card, mock_request_handler) + a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) + await server.start() + try: + yield server_address, mock_request_handler + finally: + await server.stop(None) + + +@pytest.fixture +def grpc_03_setup( + grpc_03_server_and_handler, agent_card: AgentCard +) -> TransportSetup: + """Sets up the CompatGrpcTransport and in-process 0.3 server.""" + server_address, handler = grpc_03_server_and_handler + from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport + from a2a.client.base_client import BaseClient + from a2a.client.client import ClientConfig + + channel = grpc.aio.insecure_channel(server_address) + transport = CompatGrpcTransport(channel=channel, agent_card=agent_card) + + client = BaseClient( + card=agent_card, + config=ClientConfig(), + transport=transport, + consumers=[], + interceptors=[], + ) + return TransportSetup(client=client, handler=handler) # --- The Integration Tests --- @@ -927,3 +994,59 @@ async def test_rest_malformed_payload( assert response.status_code == 400 await transport.close() + + +@pytest.mark.asyncio +async def test_validate_version_unsupported(http_transport_setups) -> None: + """Integration test for @validate_version decorator.""" + client = http_transport_setups.client + + service_params = {'A2A-Version': '2.0.0'} + context = ClientCallContext(service_parameters=service_params) + + params = GetTaskRequest(id=GET_TASK_RESPONSE.id) + + with pytest.raises(VersionNotSupportedError) as exc_info: + await client.get_task(request=params, context=context) + + await client.close() + + +@pytest.mark.asyncio +async def test_validate_decorator_push_notifications_disabled( + error_handling_setups, agent_card: AgentCard +) -> None: + """Integration test for @validate decorator with push notifications disabled.""" + client = error_handling_setups.client + + agent_card.capabilities.push_notifications = False + + params = TaskPushNotificationConfig(task_id='123') + + with pytest.raises(UnsupportedOperationError) as exc_info: + await client.create_task_push_notification_config(request=params) + + await client.close() + + +@pytest.mark.asyncio +async def test_validate_streaming_disabled( + error_handling_setups, agent_card: AgentCard +) -> None: + """Integration test for @validate decorator when streaming is disabled.""" + client = error_handling_setups.client + transport = client._transport + + agent_card.capabilities.streaming = False + + params = SendMessageRequest( + message=Message(role=Role.ROLE_USER, parts=[Part(text='hi')]) + ) + + stream = transport.send_message_streaming(request=params) + + with pytest.raises(UnsupportedOperationError) as exc_info: + async for _ in stream: + pass + + await transport.close()