Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
0479790
feat: Add utilities for managing service parameters and A2A extensions.
guglielmo-san Mar 6, 2026
9e5761a
wip
guglielmo-san Mar 6, 2026
abf24a1
feat: Add utilities for managing service parameters and A2A extensions.
guglielmo-san Mar 6, 2026
506598c
wip
guglielmo-san Mar 6, 2026
a86c2f4
Merge branch 'guglielmoc/refactor_base_client' of https://github.com/…
guglielmo-san Mar 6, 2026
b0d41b9
wip refactoring
guglielmo-san Mar 6, 2026
eae38e9
fix tests
guglielmo-san Mar 6, 2026
4c23416
refactor: use `ClientCallContext` for HTTP arguments in stream reques…
guglielmo-san Mar 7, 2026
b0f2033
Refactor transport request methods to use explicit `json` and `params…
guglielmo-san Mar 7, 2026
11eecb9
refactor: Extract common HTTP argument parsing logic into a shared he…
guglielmo-san Mar 8, 2026
6186a9e
refactor: qualify ParseDict call with json_format module
guglielmo-san Mar 9, 2026
729f8b4
Merge branch '1.0-dev' into guglielmoc/refactor_base_client
guglielmo-san Mar 9, 2026
3eced82
refactor: Migrate gRPC metadata handling from direct extensions param…
guglielmo-san Mar 9, 2026
816c512
refactor: Remove `extensions` handling from `grpc_transport` by utili…
guglielmo-san Mar 9, 2026
17302b2
style: remove trailing comma from agent_card type hint in `CompatGrpc…
guglielmo-san Mar 9, 2026
36e818a
refactor: remove query parameter conversion utilities from REST trans…
guglielmo-san Mar 9, 2026
5b3d711
refactor: reformat `CompatGrpcTransport` constructor parameters for c…
guglielmo-san Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip refactoring
  • Loading branch information
guglielmo-san committed Mar 6, 2026
commit b0d41b91d6a9565ee86d2c2e6a13a10eba451e69
112 changes: 43 additions & 69 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from collections.abc import AsyncGenerator, AsyncIterator, Callable
from typing import Any

from a2a.client.client import (
Client,
ClientCallContext,
ClientConfig,
ClientEvent,
Consumer,
)
from a2a.client.client_task_manager import ClientTaskManager
from a2a.client.middleware import ClientCallInterceptor
from a2a.client.service_parameters import ServiceParameters
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.types.a2a_pb2 import (
AgentCard,
Expand All @@ -24,8 +21,6 @@
ListTaskPushNotificationConfigsResponse,
ListTasksRequest,
ListTasksResponse,
Message,
SendMessageConfiguration,
SendMessageRequest,
StreamResponse,
SubscribeToTaskRequest,
Expand All @@ -34,16 +29,6 @@
)


@dataclasses.dataclass
class RequestOptions:
"""Options for configuring A2A client requests."""

service_parameters: ServiceParameters | None = None

context: ClientCallContext | None = None



class BaseClient(Client):
"""Base implementation of the A2A client, containing transport-independent logic."""

Expand All @@ -64,7 +49,7 @@ async def send_message(
self,
request: SendMessageRequest,
*,
options: RequestOptions | None = None,
context: ClientCallContext | None = None,
) -> AsyncIterator[ClientEvent]:
"""Sends a message to the agent.

Expand All @@ -74,27 +59,32 @@ async def send_message(

Args:
request: The message to send to the agent.
configuration: Optional per-call overrides for message sending behavior.
context: The client call context.
request_metadata: Extensions Metadata attached to the request.
extensions: List of extensions to be activated.
context: Optional client call context.

Yields:
An async iterator of `ClientEvent`
"""
if request.configuration:
if not request.configuration.blocking and self._config.polling:
request.configuration.blocking = self._config.blocking
if not request.configuration.push_notification_config and self._config.push_notification_configs:
request.configuration.push_notification_config = self._config.push_notification_configs[0]
if not request.configuration.accepted_output_modes and self._config.accepted_output_modes:
request.configuration.accepted_output_modes = self._config.accepted_output_modes
if not request.configuration.history_length and self._config.history_length:
request.configuration.history_length = self._config.history_length
request.configuration.blocking = self._config.polling
if (
not request.configuration.push_notification_config
and self._config.push_notification_configs
):
request.configuration.push_notification_config = (
self._config.push_notification_configs[0]
)
if (
not request.configuration.accepted_output_modes
and self._config.accepted_output_modes
):
request.configuration.accepted_output_modes.extend(
self._config.accepted_output_modes
)

if not self._config.streaming or not self._card.capabilities.streaming:
response = await self._transport.send_message(
request, context=options.context, extensions=options.extensions
request, context=context
)

# In non-streaming case we convert to a StreamResponse so that the
Expand All @@ -116,7 +106,7 @@ async def send_message(
return

stream = self._transport.send_message_streaming(
request, context=context, extensions=extensions
request, context=context
)
async for client_event in self._process_stream(stream):
yield client_event
Expand Down Expand Up @@ -146,27 +136,24 @@ async def get_task(
self,
request: GetTaskRequest,
*,
options: RequestOptions | None = None,
context: ClientCallContext | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task.

Args:
request: The `GetTaskRequest` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `Task` object representing the current state of the task.
"""
return await self._transport.get_task(
request, context=context, extensions=extensions
)
return await self._transport.get_task(request, context=context)

async def list_tasks(
self,
request: ListTasksRequest,
*,
options: RequestOptions | None = None,
context: ClientCallContext | None = None,
) -> ListTasksResponse:
"""Retrieves tasks for an agent."""
return await self._transport.list_tasks(request, context=context)
Expand All @@ -176,113 +163,104 @@ async def cancel_task(
request: CancelTaskRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task.

Args:
request: The `CancelTaskRequest` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `Task` object containing the updated task status.
"""
return await self._transport.cancel_task(
request, context=context, extensions=extensions
)
return await self._transport.cancel_task(request, context=context)

async def create_task_push_notification_config(
self,
request: CreateTaskPushNotificationConfigRequest,
*,
options: RequestOptions | None = None,
context: ClientCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task.

Args:
request: The `TaskPushNotificationConfig` object with the new configuration.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
The created or updated `TaskPushNotificationConfig` object.
"""
return await self._transport.create_task_push_notification_config(
request, context=context, extensions=extensions
request, context=context
)

async def get_task_push_notification_config(
self,
request: GetTaskPushNotificationConfigRequest,
*,
options: RequestOptions | None = None,
context: ClientCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task.

Args:
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `TaskPushNotificationConfig` object containing the configuration.
"""
return await self._transport.get_task_push_notification_config(
request, context=context, extensions=extensions
request, context=context
)

async def list_task_push_notification_configs(
self,
request: ListTaskPushNotificationConfigsRequest,
*,
options: RequestOptions | None = None,
context: ClientCallContext | None = None,
) -> ListTaskPushNotificationConfigsResponse:
"""Lists push notification configurations for a specific task.

Args:
request: The `ListTaskPushNotificationConfigsRequest` object specifying the request.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Returns:
A `ListTaskPushNotificationConfigsResponse` object.
"""
return await self._transport.list_task_push_notification_configs(
request, context=context, extensions=extensions
request, context=context
)

async def delete_task_push_notification_config(
self,
request: DeleteTaskPushNotificationConfigRequest,
*,
options: RequestOptions | None = None,
context: ClientCallContext | None = None,
) -> None:
"""Deletes the push notification configuration for a specific task.

Args:
request: The `DeleteTaskPushNotificationConfigRequest` object specifying the request.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.
"""
await self._transport.delete_task_push_notification_config(
request, context=context, extensions=extensions
request, context=context
)

async def subscribe(
self,
request: SubscribeToTaskRequest,
*,
options: RequestOptions | None = None,
context: ClientCallContext | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream.

This is only available if both the client and server support streaming.

Args:
request: Parameters to identify the task to resubscribe to.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.

Yields:
An async iterator of `ClientEvent` objects.
Expand All @@ -298,17 +276,15 @@ async def subscribe(
# Note: resubscribe can only be called on an existing task. As such,
# we should never see Message updates, despite the typing of the service
# definition indicating it may be possible.
stream = self._transport.subscribe(
request, context=context, extensions=extensions
)
stream = self._transport.subscribe(request, context=context)
async for client_event in self._process_stream(stream):
yield client_event

async def get_extended_agent_card(
self,
request: GetExtendedAgentCardRequest,
*,
options: RequestOptions | None = None,
context: ClientCallContext | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card.
Expand All @@ -318,8 +294,7 @@ async def get_extended_agent_card(

Args:
request: The `GetExtendedAgentCardRequest` object specifying the request.
context: The client call context.
extensions: List of extensions to be activated.
context: Optional client call context.
signature_verifier: A callable used to verify the agent card's signatures.

Returns:
Expand All @@ -328,7 +303,6 @@ async def get_extended_agent_card(
card = await self._transport.get_extended_agent_card(
request,
context=context,
extensions=extensions,
signature_verifier=signature_verifier,
)
self._card = card
Expand Down
Loading