diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 33302d90c..65ae850ae 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -109,6 +109,8 @@ async def get_task( params = MessageToDict(request) if 'id' in params: del params['id'] # id is part of the URL path + if 'tenant' in params: + del params['tenant'] response_data = await self._execute_request( 'GET', @@ -127,12 +129,16 @@ async def list_tasks( context: ClientCallContext | None = None, ) -> ListTasksResponse: """Retrieves tasks for an agent.""" + params = MessageToDict(request) + if 'tenant' in params: + del params['tenant'] + response_data = await self._execute_request( 'GET', '/tasks', request.tenant, context=context, - params=MessageToDict(request), + params=params, ) response: ListTasksResponse = ParseDict( response_data, ListTasksResponse() @@ -185,8 +191,10 @@ async def get_task_push_notification_config( params = MessageToDict(request) if 'id' in params: del params['id'] - if 'task_id' in params: - del params['task_id'] + if 'taskId' in params: + del params['taskId'] + if 'tenant' in params: + del params['tenant'] response_data = await self._execute_request( 'GET', @@ -208,8 +216,10 @@ async def list_task_push_notification_configs( ) -> ListTaskPushNotificationConfigsResponse: """Lists push notification configurations for a specific task.""" params = MessageToDict(request) - if 'task_id' in params: - del params['task_id'] + if 'taskId' in params: + del params['taskId'] + if 'tenant' in params: + del params['tenant'] response_data = await self._execute_request( 'GET', @@ -233,8 +243,10 @@ async def delete_task_push_notification_config( params = MessageToDict(request) if 'id' in params: del params['id'] - if 'task_id' in params: - del params['task_id'] + if 'taskId' in params: + del params['taskId'] + if 'tenant' in params: + del params['tenant'] await self._execute_request( 'DELETE', diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 4e7d75f2e..769e457c1 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -7,7 +7,6 @@ MessageToDict, MessageToJson, Parse, - ParseDict, ) @@ -27,7 +26,6 @@ AgentCard, CancelTaskRequest, GetTaskPushNotificationConfigRequest, - GetTaskRequest, SubscribeToTaskRequest, ) from a2a.utils import proto_utils @@ -220,12 +218,11 @@ async def set_push_notification( (due to the `@validate` decorator), A2AError if processing error is found. """ - task_id = request.path_params['id'] body = await request.body() params = a2a_pb2.TaskPushNotificationConfig() Parse(body, params) # Set the parent to the task resource name format - params.task_id = task_id + params.task_id = request.path_params['id'] config = ( await self.request_handler.on_create_task_push_notification_config( params, context @@ -247,10 +244,9 @@ async def on_get_task( Returns: A `Task` object containing the Task. """ - task_id = request.path_params['id'] - history_length_str = request.query_params.get('historyLength') - history_length = int(history_length_str) if history_length_str else None - params = GetTaskRequest(id=task_id, history_length=history_length) + params = a2a_pb2.GetTaskRequest() + proto_utils.parse_params(request.query_params, params) + params.id = request.path_params['id'] task = await self.request_handler.on_get_task(params, context) if task: return MessageToDict(task) @@ -295,12 +291,8 @@ async def list_tasks( A list of `dict` representing the `Task` objects. """ params = a2a_pb2.ListTasksRequest() - # Parse query params, keeping arrays/repeated fields in mind if there are any - # Using a simple ParseDict for now, might need more robust query param parsing - # if the request structure contains nested or repeated elements - ParseDict( - dict(request.query_params), params, ignore_unknown_fields=True - ) + proto_utils.parse_params(request.query_params, params) + result = await self.request_handler.on_list_tasks(params, context) return MessageToDict(result) @@ -318,13 +310,9 @@ async def list_push_notifications( Returns: A list of `dict` representing the `TaskPushNotificationConfig` objects. """ - task_id = request.path_params['id'] - params = a2a_pb2.ListTaskPushNotificationConfigsRequest(task_id=task_id) - - # Parse query params, keeping arrays/repeated fields in mind if there are any - ParseDict( - dict(request.query_params), params, ignore_unknown_fields=True - ) + params = a2a_pb2.ListTaskPushNotificationConfigsRequest() + proto_utils.parse_params(request.query_params, params) + params.task_id = request.path_params['id'] result = ( await self.request_handler.on_list_task_push_notification_configs( diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 79238c2b1..cdfc306f4 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -17,7 +17,19 @@ This module provides helper functions for common proto type operations. """ -from typing import Any +from typing import TYPE_CHECKING, Any + +from google.protobuf.json_format import ParseDict +from google.protobuf.message import Message as ProtobufMessage + + +if TYPE_CHECKING: + from starlette.datastructures import QueryParams +else: + try: + from starlette.datastructures import QueryParams + except ImportError: + QueryParams = Any from a2a.types.a2a_pb2 import ( Message, @@ -131,3 +143,49 @@ def parse_string_integers_in_dict(value: Any, max_safe_digits: int = 15) -> Any: if stripped_value.isdigit() and len(stripped_value) > max_safe_digits: return int(value) return value + + +def parse_params(params: QueryParams, message: ProtobufMessage) -> None: + """Converts REST query parameters back into a Protobuf message. + + Handles A2A-specific pre-processing before calling ParseDict: + - Booleans: 'true'/'false' -> True/False + - Repeated: Supports BOTH repeated keys and comma-separated values. + - Others: Handles string->enum/timestamp/number conversion via ParseDict. + + See Also: + https://a2a-protocol.org/latest/specification/#115-query-parameter-naming-for-request-parameters + """ + descriptor = message.DESCRIPTOR + fields = {f.camelcase_name: f for f in descriptor.fields} + processed: dict[str, Any] = {} + + keys = params.keys() + + for k in keys: + if k not in fields: + continue + + field = fields[k] + v_list = params.getlist(k) + + if field.label == field.LABEL_REPEATED: + accumulated: list[Any] = [] + for v in v_list: + if not v: + continue + if isinstance(v, str): + accumulated.extend([x for x in v.split(',') if x]) + else: + accumulated.append(v) + processed[k] = accumulated + else: + # For non-repeated fields, the last one wins. + raw_val = v_list[-1] + if raw_val is not None: + parsed_val: Any = raw_val + if field.type == field.TYPE_BOOL and isinstance(raw_val, str): + parsed_val = raw_val.lower() == 'true' + processed[k] = parsed_val + + ParseDict(processed, message, ignore_unknown_fields=True) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 742b570a2..ec29ddc56 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -5,6 +5,7 @@ import pytest from google.protobuf import json_format +from google.protobuf.timestamp_pb2 import Timestamp from httpx_sse import EventSource, ServerSentEvent from a2a.client import create_text_message_object @@ -16,16 +17,16 @@ AgentCard, AgentInterface, CancelTaskRequest, - TaskPushNotificationConfig, DeleteTaskPushNotificationConfigRequest, GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, ListTasksRequest, - Message, SendMessageRequest, SubscribeToTaskRequest, + TaskPushNotificationConfig, + TaskState, ) from a2a.utils.constants import TransportProtocol from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP @@ -175,6 +176,47 @@ async def test_send_message_with_timeout_context( assert 'timeout' in kwargs assert kwargs['timeout'] == httpx.Timeout(10.0) + @pytest.mark.asyncio + async def test_url_serialization( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + """Test that query parameters are correctly serialized to the URL.""" + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + url='http://agent.example.com/api', + ) + + timestamp = Timestamp() + timestamp.FromJsonString('2024-03-09T16:00:00Z') + + request = ListTasksRequest( + tenant='my-tenant', + status=TaskState.TASK_STATE_WORKING, + include_artifacts=True, + status_timestamp_after=timestamp, + ) + + # Use real build_request to get actual URL serialization + mock_httpx_client.build_request.side_effect = ( + httpx.AsyncClient().build_request + ) + mock_httpx_client.send.return_value = AsyncMock( + spec=httpx.Response, status_code=200, json=lambda: {'tasks': []} + ) + + await client.list_tasks(request=request) + + mock_httpx_client.send.assert_called_once() + sent_request = mock_httpx_client.send.call_args[0][0] + + # Check decoded query parameters for spec compliance + params = sent_request.url.params + assert params['status'] == 'TASK_STATE_WORKING' + assert params['includeArtifacts'] == 'true' + assert params['statusTimestampAfter'] == '2024-03-09T16:00:00Z' + assert 'tenant' not in params + class TestRestTransportExtensions: @pytest.mark.asyncio @@ -616,7 +658,7 @@ async def test_rest_get_task_prepend_empty_tenant( # 3. Verify the URL args, _ = mock_httpx_client.build_request.call_args - assert args[1] == f'http://agent.example.com/api/tasks/task-123' + assert args[1] == 'http://agent.example.com/api/tasks/task-123' @pytest.mark.parametrize( 'method_name, request_obj, expected_path', diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 8952962b0..3376f33d7 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -8,6 +8,7 @@ import pytest import pytest_asyncio from google.protobuf.json_format import MessageToDict +from google.protobuf.timestamp_pb2 import Timestamp from grpc.aio import Channel from jwt.api_jwk import PyJWK @@ -30,35 +31,31 @@ create_agent_card_signer, create_signature_verifier, ) -from a2a.client.card_resolver import A2ACardResolver + from a2a.types.a2a_pb2 import ( AgentCapabilities, AgentCard, AgentInterface, CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTaskPushNotificationConfigsResponse, + ListTasksRequest, + ListTasksResponse, Message, Part, - TaskPushNotificationConfig, Role, SendMessageRequest, - SendMessageRequest, - TaskPushNotificationConfig, - DeleteTaskPushNotificationConfigRequest, - ListTaskPushNotificationConfigsRequest, - ListTaskPushNotificationConfigsResponse, SubscribeToTaskRequest, Task, TaskPushNotificationConfig, TaskState, TaskStatus, TaskStatusUpdateEvent, - ListTasksRequest, - ListTasksResponse, ) -from cryptography.hazmat.primitives import asymmetric from cryptography.hazmat.primitives.asymmetric import ec # --- Test Constants --- @@ -162,7 +159,9 @@ def agent_card() -> AgentCard: name='Test Agent', description='An agent for integration testing.', version='1.0.0', - capabilities=AgentCapabilities(streaming=True, push_notifications=True), + capabilities=AgentCapabilities( + streaming=True, push_notifications=True, extended_agent_card=True + ), skills=[], default_input_modes=['text/plain'], default_output_modes=['text/plain'], @@ -182,7 +181,7 @@ class TransportSetup(NamedTuple): """Holds the transport and handler for a given test.""" transport: ClientTransport - handler: AsyncMock + handler: RequestHandler | AsyncMock # --- HTTP/JSON-RPC/REST Setup --- @@ -218,7 +217,9 @@ def jsonrpc_setup(http_base_setup) -> TransportSetup: def rest_setup(http_base_setup) -> TransportSetup: """Sets up the RestTransport and in-memory server.""" mock_request_handler, agent_card = http_base_setup - app_builder = A2ARESTFastAPIApplication(agent_card, mock_request_handler) + app_builder = A2ARESTFastAPIApplication( + agent_card, mock_request_handler, extended_agent_card=agent_card + ) app = app_builder.build() httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) transport = RestTransport( @@ -229,6 +230,30 @@ def rest_setup(http_base_setup) -> TransportSetup: return TransportSetup(transport=transport, handler=mock_request_handler) +@pytest_asyncio.fixture +async def grpc_setup( + grpc_server_and_handler: tuple[str, AsyncMock], + agent_card: AgentCard, +) -> TransportSetup: + """Sets up the GrpcTransport and in-process server.""" + server_address, handler = grpc_server_and_handler + channel = grpc.aio.insecure_channel(server_address) + transport = GrpcTransport(channel=channel, agent_card=agent_card) + return TransportSetup(transport=transport, handler=handler) + + +@pytest.fixture( + params=[ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('rest_setup', id='REST'), + pytest.param('grpc_setup', id='gRPC'), + ] +) +def transport_setups(request) -> TransportSetup: + """Parametrized fixture that runs tests against all supported transports.""" + return request.getfixturevalue(request.param) + + # --- gRPC Setup --- @@ -251,24 +276,10 @@ async def grpc_server_and_handler( @pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_setup_fixture', - [ - pytest.param('jsonrpc_setup', id='JSON-RPC'), - pytest.param('rest_setup', id='REST'), - ], -) -async def test_http_transport_sends_message_streaming( - transport_setup_fixture: str, request -) -> None: - """ - Integration test for HTTP-based transports (JSON-RPC, REST) streaming. - """ - transport_setup: TransportSetup = request.getfixturevalue( - transport_setup_fixture - ) - transport = transport_setup.transport - handler = transport_setup.handler +async def test_transport_sends_message_streaming(transport_setups) -> None: + """Integration test for all transports streaming.""" + transport = transport_setups.transport + handler = transport_setups.handler message_to_send = Message( role=Role.ROLE_USER, @@ -281,85 +292,18 @@ async def test_http_transport_sends_message_streaming( events = [event async for event in stream] assert len(events) == 1 - first_event = events[0] - - # StreamResponse wraps the Task in its 'task' field - assert first_event.task.id == TASK_FROM_STREAM.id - assert first_event.task.context_id == TASK_FROM_STREAM.context_id - - handler.on_message_send_stream.assert_called_once() - call_args, _ = handler.on_message_send_stream.call_args - received_params: SendMessageRequest = call_args[0] - - assert received_params.message.message_id == message_to_send.message_id - assert ( - received_params.message.parts[0].text == message_to_send.parts[0].text - ) + assert events[0].task.id == TASK_FROM_STREAM.id - await transport.close() - - -@pytest.mark.asyncio -async def test_grpc_transport_sends_message_streaming( - grpc_server_and_handler: tuple[str, AsyncMock], - agent_card: AgentCard, -) -> None: - """ - Integration test specifically for the gRPC transport streaming. - """ - server_address, handler = grpc_server_and_handler - - def channel_factory(address: str) -> Channel: - return grpc.aio.insecure_channel(address) - - channel = channel_factory(server_address) - transport = GrpcTransport(channel=channel, agent_card=agent_card) - - message_to_send = Message( - role=Role.ROLE_USER, - message_id='msg-grpc-integration-test', - parts=[Part(text='Hello, gRPC integration test!')], - ) - params = SendMessageRequest(message=message_to_send) - - stream = transport.send_message_streaming(request=params) - first_event = await anext(stream) - - # StreamResponse wraps the Task in its 'task' field - assert first_event.task.id == TASK_FROM_STREAM.id - assert first_event.task.context_id == TASK_FROM_STREAM.context_id - - handler.on_message_send_stream.assert_called_once() - call_args, _ = handler.on_message_send_stream.call_args - received_params: SendMessageRequest = call_args[0] - - assert received_params.message.message_id == message_to_send.message_id - assert ( - received_params.message.parts[0].text == message_to_send.parts[0].text - ) + handler.on_message_send_stream.assert_called_once_with(params, ANY) await transport.close() @pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_setup_fixture', - [ - pytest.param('jsonrpc_setup', id='JSON-RPC'), - pytest.param('rest_setup', id='REST'), - ], -) -async def test_http_transport_sends_message_blocking( - transport_setup_fixture: str, request -) -> None: - """ - Integration test for HTTP-based transports (JSON-RPC, REST) blocking. - """ - transport_setup: TransportSetup = request.getfixturevalue( - transport_setup_fixture - ) - transport = transport_setup.transport - handler = transport_setup.handler +async def test_transport_sends_message_blocking(transport_setups) -> None: + """Integration test for all transports blocking.""" + transport = transport_setups.transport + handler = transport_setups.handler message_to_send = Message( role=Role.ROLE_USER, @@ -370,500 +314,155 @@ async def test_http_transport_sends_message_blocking( result = await transport.send_message(request=params) - # SendMessageResponse wraps Task in its 'task' field - assert result.task.id == TASK_FROM_BLOCKING.id - assert result.task.context_id == TASK_FROM_BLOCKING.context_id - - handler.on_message_send.assert_awaited_once() - call_args, _ = handler.on_message_send.call_args - received_params: SendMessageRequest = call_args[0] - - assert received_params.message.message_id == message_to_send.message_id - assert ( - received_params.message.parts[0].text == message_to_send.parts[0].text - ) - - if hasattr(transport, 'close'): - await transport.close() - - -@pytest.mark.asyncio -async def test_grpc_transport_sends_message_blocking( - grpc_server_and_handler: tuple[str, AsyncMock], - agent_card: AgentCard, -) -> None: - """ - Integration test specifically for the gRPC transport blocking. - """ - server_address, handler = grpc_server_and_handler - - def channel_factory(address: str) -> Channel: - return grpc.aio.insecure_channel(address) - - channel = channel_factory(server_address) - transport = GrpcTransport(channel=channel, agent_card=agent_card) - - message_to_send = Message( - role=Role.ROLE_USER, - message_id='msg-grpc-integration-test-blocking', - parts=[Part(text='Hello, gRPC blocking test!')], - ) - params = SendMessageRequest(message=message_to_send) - - result = await transport.send_message(request=params) - - # SendMessageResponse wraps Task in its 'task' field assert result.task.id == TASK_FROM_BLOCKING.id - assert result.task.context_id == TASK_FROM_BLOCKING.context_id - - handler.on_message_send.assert_awaited_once() - call_args, _ = handler.on_message_send.call_args - received_params: SendMessageRequest = call_args[0] - - assert received_params.message.message_id == message_to_send.message_id - assert ( - received_params.message.parts[0].text == message_to_send.parts[0].text - ) + handler.on_message_send.assert_awaited_once_with(params, ANY) await transport.close() @pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_setup_fixture', - [ - pytest.param('jsonrpc_setup', id='JSON-RPC'), - pytest.param('rest_setup', id='REST'), - ], -) -async def test_http_transport_get_task( - transport_setup_fixture: str, request -) -> None: - transport_setup: TransportSetup = request.getfixturevalue( - transport_setup_fixture - ) - transport = transport_setup.transport - handler = transport_setup.handler +async def test_transport_get_task(transport_setups) -> None: + transport = transport_setups.transport + handler = transport_setups.handler - # Use GetTaskRequest with name (AIP resource format) params = GetTaskRequest(id=GET_TASK_RESPONSE.id) result = await transport.get_task(request=params) assert result.id == GET_TASK_RESPONSE.id - handler.on_get_task.assert_awaited_once() - - if hasattr(transport, 'close'): - await transport.close() - - -@pytest.mark.asyncio -async def test_grpc_transport_get_task( - grpc_server_and_handler: tuple[str, AsyncMock], - agent_card: AgentCard, -) -> None: - server_address, handler = grpc_server_and_handler - - def channel_factory(address: str) -> Channel: - return grpc.aio.insecure_channel(address) - - channel = channel_factory(server_address) - transport = GrpcTransport(channel=channel, agent_card=agent_card) - - # Use GetTaskRequest with name (AIP resource format) - params = GetTaskRequest(id=f'{GET_TASK_RESPONSE.id}') - result = await transport.get_task(request=params) - - assert result.id == GET_TASK_RESPONSE.id - handler.on_get_task.assert_awaited_once() + handler.on_get_task.assert_awaited_once_with(params, ANY) await transport.close() @pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_setup_fixture', - [ - pytest.param('jsonrpc_setup', id='JSON-RPC'), - pytest.param('rest_setup', id='REST'), - ], -) -async def test_http_transport_list_tasks( - transport_setup_fixture: str, request -) -> None: - transport_setup: TransportSetup = request.getfixturevalue( - transport_setup_fixture +async def test_transport_list_tasks(transport_setups) -> None: + transport = transport_setups.transport + handler = transport_setups.handler + + t = Timestamp() + t.FromJsonString('2024-03-09T16:00:00Z') + params = ListTasksRequest( + context_id='ctx-1', + status=TaskState.TASK_STATE_WORKING, + page_size=10, + page_token='page-1', + history_length=5, + status_timestamp_after=t, + include_artifacts=True, ) - transport = transport_setup.transport - handler = transport_setup.handler - - params = ListTasksRequest(page_size=10, page_token='page-1') - result = await transport.list_tasks(request=params) - - assert len(result.tasks) == 2 - assert result.next_page_token == 'page-2' - assert result.total_size == 12 - assert result.page_size == 10 - handler.on_list_tasks.assert_awaited_once() - - if hasattr(transport, 'close'): - await transport.close() - - -@pytest.mark.asyncio -async def test_grpc_transport_list_tasks( - grpc_server_and_handler: tuple[str, AsyncMock], - agent_card: AgentCard, -) -> None: - server_address, handler = grpc_server_and_handler - - def channel_factory(address: str) -> Channel: - return grpc.aio.insecure_channel(address) - - channel = channel_factory(server_address) - transport = GrpcTransport(channel=channel, agent_card=agent_card) - - params = ListTasksRequest(page_size=10, page_token='page-1') result = await transport.list_tasks(request=params) assert len(result.tasks) == 2 assert result.next_page_token == 'page-2' - handler.on_list_tasks.assert_awaited_once() + handler.on_list_tasks.assert_awaited_once_with(params, ANY) await transport.close() @pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_setup_fixture', - [ - pytest.param('jsonrpc_setup', id='JSON-RPC'), - pytest.param('rest_setup', id='REST'), - ], -) -async def test_http_transport_cancel_task( - transport_setup_fixture: str, request -) -> None: - transport_setup: TransportSetup = request.getfixturevalue( - transport_setup_fixture - ) - transport = transport_setup.transport - handler = transport_setup.handler +async def test_transport_cancel_task(transport_setups) -> None: + transport = transport_setups.transport + handler = transport_setups.handler - # Use CancelTaskRequest with name (AIP resource format) - params = CancelTaskRequest(id=f'{CANCEL_TASK_RESPONSE.id}') + params = CancelTaskRequest(id=CANCEL_TASK_RESPONSE.id) result = await transport.cancel_task(request=params) assert result.id == CANCEL_TASK_RESPONSE.id - handler.on_cancel_task.assert_awaited_once() - - if hasattr(transport, 'close'): - await transport.close() - - -@pytest.mark.asyncio -async def test_grpc_transport_cancel_task( - grpc_server_and_handler: tuple[str, AsyncMock], - agent_card: AgentCard, -) -> None: - server_address, handler = grpc_server_and_handler - - def channel_factory(address: str) -> Channel: - return grpc.aio.insecure_channel(address) - - channel = channel_factory(server_address) - transport = GrpcTransport(channel=channel, agent_card=agent_card) - - # Use CancelTaskRequest with name (AIP resource format) - params = CancelTaskRequest(id=f'{CANCEL_TASK_RESPONSE.id}') - result = await transport.cancel_task(request=params) - - assert result.id == CANCEL_TASK_RESPONSE.id - handler.on_cancel_task.assert_awaited_once() + handler.on_cancel_task.assert_awaited_once_with(params, ANY) await transport.close() @pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_setup_fixture', - [ - pytest.param('jsonrpc_setup', id='JSON-RPC'), - pytest.param('rest_setup', id='REST'), - ], -) -async def test_http_transport_create_task_push_notification_config( - transport_setup_fixture: str, request +async def test_transport_create_task_push_notification_config( + transport_setups, ) -> None: - transport_setup: TransportSetup = request.getfixturevalue( - transport_setup_fixture - ) - transport = transport_setup.transport - handler = transport_setup.handler + transport = transport_setups.transport + handler = transport_setups.handler - # Create TaskPushNotificationConfig with required fields - params = TaskPushNotificationConfig( - task_id='task-callback-123', - ) + params = TaskPushNotificationConfig(task_id='task-callback-123') result = await transport.create_task_push_notification_config( request=params ) assert result.id == CALLBACK_CONFIG.id - assert result.id == CALLBACK_CONFIG.id - assert result.url == CALLBACK_CONFIG.url - handler.on_create_task_push_notification_config.assert_awaited_once() - - if hasattr(transport, 'close'): - await transport.close() - - -@pytest.mark.asyncio -async def test_grpc_transport_create_task_push_notification_config( - grpc_server_and_handler: tuple[str, AsyncMock], - agent_card: AgentCard, -) -> None: - server_address, handler = grpc_server_and_handler - - def channel_factory(address: str) -> Channel: - return grpc.aio.insecure_channel(address) - - channel = channel_factory(server_address) - transport = GrpcTransport(channel=channel, agent_card=agent_card) - - # Create TaskPushNotificationConfig with required fields - params = TaskPushNotificationConfig( - task_id='task-callback-123', - ) - result = await transport.create_task_push_notification_config( - request=params + handler.on_create_task_push_notification_config.assert_awaited_once_with( + params, ANY ) - assert result.id == CALLBACK_CONFIG.id - assert result.id == CALLBACK_CONFIG.id - assert result.url == CALLBACK_CONFIG.url - handler.on_create_task_push_notification_config.assert_awaited_once() - await transport.close() @pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_setup_fixture', - [ - pytest.param('jsonrpc_setup', id='JSON-RPC'), - pytest.param('rest_setup', id='REST'), - ], -) -async def test_http_transport_get_task_push_notification_config( - transport_setup_fixture: str, request +async def test_transport_get_task_push_notification_config( + transport_setups, ) -> None: - transport_setup: TransportSetup = request.getfixturevalue( - transport_setup_fixture - ) - transport = transport_setup.transport - handler = transport_setup.handler + transport = transport_setups.transport + handler = transport_setups.handler - # Use GetTaskPushNotificationConfigRequest with name field (resource name) params = GetTaskPushNotificationConfigRequest( - task_id=f'{CALLBACK_CONFIG.task_id}', + task_id=CALLBACK_CONFIG.task_id, id=CALLBACK_CONFIG.id, ) result = await transport.get_task_push_notification_config(request=params) - assert result.task_id == CALLBACK_CONFIG.task_id assert result.id == CALLBACK_CONFIG.id - assert result.url == CALLBACK_CONFIG.url - handler.on_get_task_push_notification_config.assert_awaited_once() - - if hasattr(transport, 'close'): - await transport.close() - - -@pytest.mark.asyncio -async def test_grpc_transport_get_task_push_notification_config( - grpc_server_and_handler: tuple[str, AsyncMock], - agent_card: AgentCard, -) -> None: - server_address, handler = grpc_server_and_handler - - def channel_factory(address: str) -> Channel: - return grpc.aio.insecure_channel(address) - - channel = channel_factory(server_address) - transport = GrpcTransport(channel=channel, agent_card=agent_card) - - # Use GetTaskPushNotificationConfigRequest with name field (resource name) - params = GetTaskPushNotificationConfigRequest( - task_id=f'{CALLBACK_CONFIG.task_id}', - id=CALLBACK_CONFIG.id, + handler.on_get_task_push_notification_config.assert_awaited_once_with( + params, ANY ) - result = await transport.get_task_push_notification_config(request=params) - - assert result.task_id == CALLBACK_CONFIG.task_id - assert result.id == CALLBACK_CONFIG.id - assert result.url == CALLBACK_CONFIG.url - handler.on_get_task_push_notification_config.assert_awaited_once() await transport.close() @pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_setup_fixture', - [ - pytest.param('jsonrpc_setup', id='JSON-RPC'), - pytest.param('rest_setup', id='REST'), - ], -) -async def test_http_transport_list_task_push_notification_configs( - transport_setup_fixture: str, request +async def test_transport_list_task_push_notification_configs( + transport_setups, ) -> None: - transport_setup: TransportSetup = request.getfixturevalue( - transport_setup_fixture - ) - transport = transport_setup.transport - handler = transport_setup.handler + transport = transport_setups.transport + handler = transport_setups.handler params = ListTaskPushNotificationConfigsRequest( - task_id=f'{CALLBACK_CONFIG.task_id}', + task_id=CALLBACK_CONFIG.task_id, ) result = await transport.list_task_push_notification_configs(request=params) assert len(result.configs) == 1 - assert result.configs[0].task_id == CALLBACK_CONFIG.task_id - handler.on_list_task_push_notification_configs.assert_awaited_once() - - if hasattr(transport, 'close'): - await transport.close() - - -@pytest.mark.asyncio -async def test_grpc_transport_list_task_push_notification_configs( - grpc_server_and_handler: tuple[str, AsyncMock], - agent_card: AgentCard, -) -> None: - server_address, handler = grpc_server_and_handler - - def channel_factory(address: str) -> Channel: - return grpc.aio.insecure_channel(address) - - channel = channel_factory(server_address) - transport = GrpcTransport(channel=channel, agent_card=agent_card) - - params = ListTaskPushNotificationConfigsRequest( - task_id=f'{CALLBACK_CONFIG.task_id}', + handler.on_list_task_push_notification_configs.assert_awaited_once_with( + params, ANY ) - result = await transport.list_task_push_notification_configs(request=params) - - assert len(result.configs) == 1 - assert result.configs[0].task_id == CALLBACK_CONFIG.task_id - handler.on_list_task_push_notification_configs.assert_awaited_once() await transport.close() @pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_setup_fixture', - [ - pytest.param('jsonrpc_setup', id='JSON-RPC'), - pytest.param('rest_setup', id='REST'), - ], -) -async def test_http_transport_delete_task_push_notification_config( - transport_setup_fixture: str, request +async def test_transport_delete_task_push_notification_config( + transport_setups, ) -> None: - transport_setup: TransportSetup = request.getfixturevalue( - transport_setup_fixture - ) - transport = transport_setup.transport - handler = transport_setup.handler + transport = transport_setups.transport + handler = transport_setups.handler params = DeleteTaskPushNotificationConfigRequest( - task_id=f'{CALLBACK_CONFIG.task_id}', + task_id=CALLBACK_CONFIG.task_id, id=CALLBACK_CONFIG.id, ) await transport.delete_task_push_notification_config(request=params) - handler.on_delete_task_push_notification_config.assert_awaited_once() - - if hasattr(transport, 'close'): - await transport.close() - - -@pytest.mark.asyncio -async def test_grpc_transport_delete_task_push_notification_config( - grpc_server_and_handler: tuple[str, AsyncMock], - agent_card: AgentCard, -) -> None: - server_address, handler = grpc_server_and_handler - - def channel_factory(address: str) -> Channel: - return grpc.aio.insecure_channel(address) - - channel = channel_factory(server_address) - transport = GrpcTransport(channel=channel, agent_card=agent_card) - - params = DeleteTaskPushNotificationConfigRequest( - task_id=f'{CALLBACK_CONFIG.task_id}', - id=CALLBACK_CONFIG.id, + handler.on_delete_task_push_notification_config.assert_awaited_once_with( + params, ANY ) - await transport.delete_task_push_notification_config(request=params) - - handler.on_delete_task_push_notification_config.assert_awaited_once() await transport.close() @pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_setup_fixture', - [ - pytest.param('jsonrpc_setup', id='JSON-RPC'), - pytest.param('rest_setup', id='REST'), - ], -) -async def test_http_transport_resubscribe( - transport_setup_fixture: str, request -) -> None: - transport_setup: TransportSetup = request.getfixturevalue( - transport_setup_fixture - ) - transport = transport_setup.transport - handler = transport_setup.handler - - # Use SubscribeToTaskRequest with name (AIP resource format) - params = SubscribeToTaskRequest(id=RESUBSCRIBE_EVENT.task_id) - stream = transport.subscribe(request=params) - first_event = await anext(stream) - - # StreamResponse wraps the status update in its 'status_update' field - assert first_event.status_update.task_id == RESUBSCRIBE_EVENT.task_id - handler.on_subscribe_to_task.assert_called_once() - - if hasattr(transport, 'close'): - await transport.close() - - -@pytest.mark.asyncio -async def test_grpc_transport_resubscribe( - grpc_server_and_handler: tuple[str, AsyncMock], - agent_card: AgentCard, -) -> None: - server_address, handler = grpc_server_and_handler +async def test_transport_subscribe(transport_setups) -> None: + transport = transport_setups.transport + handler = transport_setups.handler - def channel_factory(address: str) -> Channel: - return grpc.aio.insecure_channel(address) - - channel = channel_factory(server_address) - transport = GrpcTransport(channel=channel, agent_card=agent_card) - - # Use SubscribeToTaskRequest with name (AIP resource format) params = SubscribeToTaskRequest(id=RESUBSCRIBE_EVENT.task_id) stream = transport.subscribe(request=params) - first_event = await anext(stream) + first_event = await stream.__anext__() - # StreamResponse wraps the status update in its 'status_update' field assert first_event.status_update.task_id == RESUBSCRIBE_EVENT.task_id handler.on_subscribe_to_task.assert_called_once() @@ -871,83 +470,27 @@ def channel_factory(address: str) -> Channel: @pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_setup_fixture', - [ - pytest.param('jsonrpc_setup', id='JSON-RPC'), - pytest.param('rest_setup', id='REST'), - ], -) -async def test_http_transport_get_card( - transport_setup_fixture: str, request, agent_card: AgentCard -) -> None: - transport_setup: TransportSetup = request.getfixturevalue( - transport_setup_fixture - ) - transport = transport_setup.transport - # Access the base card from the agent_card property. - result = transport.agent_card # type: ignore[attr-defined] +async def test_transport_get_card(transport_setups, agent_card) -> None: + transport = transport_setups.transport + result = transport.agent_card assert result.name == agent_card.name - - if hasattr(transport, 'close'): - await transport.close() - - -@pytest.mark.asyncio -async def test_http_transport_get_authenticated_card( - agent_card: AgentCard, - mock_request_handler: AsyncMock, -) -> None: - agent_card.capabilities.extended_agent_card = True - # Create a copy of the agent card for the extended card - extended_agent_card = AgentCard() - extended_agent_card.CopyFrom(agent_card) - extended_agent_card.name = 'Extended Agent Card' - - app_builder = A2ARESTFastAPIApplication( - agent_card, - mock_request_handler, - extended_agent_card=extended_agent_card, - ) - app = app_builder.build() - httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) - - transport = RestTransport( - httpx_client=httpx_client, - agent_card=agent_card, - url=agent_card.supported_interfaces[0].url, - ) - result = await transport.get_extended_agent_card( - GetExtendedAgentCardRequest() - ) - assert result.name == extended_agent_card.name - - if hasattr(transport, 'close'): - await transport.close() + await transport.close() @pytest.mark.asyncio -async def test_grpc_transport_get_card( - grpc_server_and_handler: tuple[str, AsyncMock], - agent_card: AgentCard, +async def test_transport_get_extended_agent_card( + transport_setups, agent_card ) -> None: - server_address, _ = grpc_server_and_handler - - def channel_factory(address: str) -> Channel: - return grpc.aio.insecure_channel(address) - - channel = channel_factory(server_address) - transport = GrpcTransport(channel=channel, agent_card=agent_card) - - # The transport starts with a minimal card, get_extended_agent_card() fetches the full one - assert transport.agent_card is not None + transport = transport_setups.transport + # Ensure capabilities allow extended card transport.agent_card.capabilities.extended_agent_card = True + result = await transport.get_extended_agent_card( GetExtendedAgentCardRequest() ) - - assert result.name == agent_card.name + # The result could be the original card or a slightly modified one depending on transport + assert result.name in [agent_card.name, 'Extended Agent Card'] await transport.close() diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index 63cb2e95e..6a53541f3 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -3,9 +3,16 @@ This module tests the proto utilities including to_stream_response and dictionary normalization. """ +import httpx import pytest +from google.protobuf.json_format import MessageToDict, Parse +from google.protobuf.message import Message as ProtobufMessage +from google.protobuf.timestamp_pb2 import Timestamp from a2a.types.a2a_pb2 import ( + AgentCard, + AgentSkill, + ListTasksRequest, Message, Part, Role, @@ -16,6 +23,7 @@ TaskStatus, TaskStatusUpdateEvent, ) +from starlette.datastructures import QueryParams from a2a.utils import proto_utils @@ -172,4 +180,62 @@ def test_parse_string_integers_in_dict(self): assert result['int'] == 42 assert result['list'] == ['hello', 9999999999999999999, '123'] assert result['nested']['inner_large_string'] == 9999999999999999999 - assert result['nested']['inner_regular'] == 'value' + + +class TestRestParams: + """Unit tests for REST parameter conversion.""" + + def test_rest_params_roundtrip(self): + """Test the comprehensive roundtrip conversion for REST parameters.""" + + original = ListTasksRequest( + tenant='tenant-1', + context_id='ctx-1', + status=TaskState.TASK_STATE_WORKING, + page_size=10, + include_artifacts=True, + status_timestamp_after=Parse('"2024-03-09T16:00:00Z"', Timestamp()), + history_length=5, + ) + + query_params = self._message_to_rest_params(original) + + assert dict(query_params) == { + 'tenant': 'tenant-1', + 'contextId': 'ctx-1', + 'status': 'TASK_STATE_WORKING', + 'pageSize': '10', + 'includeArtifacts': 'true', + 'statusTimestampAfter': '2024-03-09T16:00:00Z', + 'historyLength': '5', + } + + converted = ListTasksRequest() + proto_utils.parse_params(QueryParams(query_params), converted) + + assert converted == original + + @pytest.mark.parametrize( + 'query_string', + [ + 'id=skill-1&tags=tag1&tags=tag2&tags=tag3', + 'id=skill-1&tags=tag1,tag2,tag3', + ], + ) + def test_repeated_fields_parsing(self, query_string: str): + """Test parsing of repeated fields using different query string formats.""" + query_params = QueryParams(query_string) + + converted = AgentSkill() + proto_utils.parse_params(query_params, converted) + + assert converted == AgentSkill( + id='skill-1', tags=['tag1', 'tag2', 'tag3'] + ) + + def _message_to_rest_params(self, message: ProtobufMessage) -> QueryParams: + """Converts a message to REST query parameters.""" + rest_dict = MessageToDict(message) + return httpx.Request( + 'GET', 'http://api.example.com', params=rest_dict + ).url.params