-
Notifications
You must be signed in to change notification settings - Fork 411
Expand file tree
/
Copy pathclient.py
More file actions
236 lines (208 loc) · 7.32 KB
/
client.py
File metadata and controls
236 lines (208 loc) · 7.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import json
from collections.abc import AsyncGenerator
from typing import Any
from uuid import uuid4
import httpx
from httpx_sse import SSEError, aconnect_sse
from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError
from a2a.types import (
AgentCard,
CancelTaskRequest,
CancelTaskResponse,
GetTaskPushNotificationConfigRequest,
GetTaskPushNotificationConfigResponse,
GetTaskRequest,
GetTaskResponse,
SendMessageRequest,
SendMessageResponse,
SendStreamingMessageRequest,
SendStreamingMessageResponse,
SetTaskPushNotificationConfigRequest,
SetTaskPushNotificationConfigResponse,
)
from a2a.utils.telemetry import SpanKind, trace_class
class A2ACardResolver:
"""Agent Card resolver."""
def __init__(
self,
httpx_client: httpx.AsyncClient,
base_url: str,
agent_card_path: str = '/.well-known/agent.json',
):
self.base_url = base_url.rstrip('/')
self.agent_card_path = agent_card_path.lstrip('/')
self.httpx_client = httpx_client
async def get_agent_card(
self, http_kwargs: dict[str, Any] | None = None
) -> AgentCard:
try:
response = await self.httpx_client.get(
f'{self.base_url}/{self.agent_card_path}',
**(http_kwargs or {}),
)
response.raise_for_status()
return AgentCard.model_validate(response.json())
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e
@trace_class(kind=SpanKind.CLIENT)
class A2AClient:
"""A2A Client."""
def __init__(
self,
httpx_client: httpx.AsyncClient,
agent_card: AgentCard | None = None,
url: str | None = None,
):
if agent_card:
self.url = agent_card.url
elif url:
self.url = url
else:
raise ValueError('Must provide either agent_card or url')
self.httpx_client = httpx_client
@staticmethod
async def get_client_from_agent_card_url(
httpx_client: httpx.AsyncClient,
base_url: str,
agent_card_path: str = '/.well-known/agent.json',
http_kwargs: dict[str, Any] | None = None,
) -> 'A2AClient':
"""Get a A2A client for provided agent card URL."""
agent_card: AgentCard = await A2ACardResolver(
httpx_client, base_url=base_url, agent_card_path=agent_card_path
).get_agent_card(http_kwargs=http_kwargs)
return A2AClient(httpx_client=httpx_client, agent_card=agent_card)
async def send_message(
self,
request: SendMessageRequest,
*,
http_kwargs: dict[str, Any] | None = None,
) -> SendMessageResponse:
if not request.id:
request.id = str(uuid4())
return SendMessageResponse(
**await self._send_request(
request.model_dump(mode='json', exclude_none=True),
http_kwargs,
)
)
async def send_message_streaming(
self,
request: SendStreamingMessageRequest,
*,
http_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[SendStreamingMessageResponse]:
if not request.id:
request.id = str(uuid4())
# Default to no timeout for streaming, can be overridden by http_kwargs
http_kwargs_with_timeout: dict[str, Any] = {
'timeout': None,
**(http_kwargs or {}),
}
async with aconnect_sse(
self.httpx_client,
'POST',
self.url,
json=request.model_dump(mode='json', exclude_none=True),
**http_kwargs_with_timeout,
) as event_source:
try:
async for sse in event_source.aiter_sse():
yield SendStreamingMessageResponse(**json.loads(sse.data))
except SSEError as e:
raise A2AClientHTTPError(
400,
f'Invalid SSE response or protocol error: {e}',
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e
async def _send_request(
self,
rpc_request_payload: dict[str, Any],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Sends a non-streaming JSON-RPC request to the agent.
Args:
rpc_request_payload: JSON RPC payload for sending the request
**kwargs: Additional keyword arguments to pass to the httpx client.
"""
try:
response = await self.httpx_client.post(
self.url, json=rpc_request_payload, **(http_kwargs or {})
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e
async def get_task(
self,
request: GetTaskRequest,
*,
http_kwargs: dict[str, Any] | None = None,
) -> GetTaskResponse:
if not request.id:
request.id = str(uuid4())
return GetTaskResponse(
**await self._send_request(
request.model_dump(mode='json', exclude_none=True),
http_kwargs,
)
)
async def cancel_task(
self,
request: CancelTaskRequest,
*,
http_kwargs: dict[str, Any] | None = None,
) -> CancelTaskResponse:
if not request.id:
request.id = str(uuid4())
return CancelTaskResponse(
**await self._send_request(
request.model_dump(mode='json', exclude_none=True),
http_kwargs,
)
)
async def set_task_callback(
self,
request: SetTaskPushNotificationConfigRequest,
*,
http_kwargs: dict[str, Any] | None = None,
) -> SetTaskPushNotificationConfigResponse:
if not request.id:
request.id = str(uuid4())
return SetTaskPushNotificationConfigResponse(
**await self._send_request(
request.model_dump(mode='json', exclude_none=True),
http_kwargs,
)
)
async def get_task_callback(
self,
request: GetTaskPushNotificationConfigRequest,
*,
http_kwargs: dict[str, Any] | None = None,
) -> GetTaskPushNotificationConfigResponse:
if not request.id:
request.id = str(uuid4())
return GetTaskPushNotificationConfigResponse(
**await self._send_request(
request.model_dump(mode='json', exclude_none=True),
http_kwargs,
)
)