From 4e72411252efe4cc3de22c5003aaeaf0456c82a1 Mon Sep 17 00:00:00 2001 From: Sudu <12902959+amit-raut@users.noreply.github.com> Date: Tue, 24 Mar 2026 23:47:20 -0400 Subject: [PATCH 1/8] fix(security): SSRF via AgentCard URL and context ID Injection (A2A-SSRF-01, A2A-INJ-01) - Add url_validation.py: validates AgentCard.url against loopback, RFP 1918, link-local (IMDS), and non-http(s) schemes before SDK uses it as RPC endpoint - Patch card_resolver.py: call validate_agent_card_url() after model_validate() for card url and all additional_interfaces urls - Patch default_request_handler.py: add optional get_caller_id hook to enforce cantext_id ownership; defaults to warn-and-allow for backword compatibility Fixes CWE-918 (SSRF) and CWE-639 (context injection) --- src/a2a/client/card_resolver.py | 63 ++-- .../default_request_handler.py | 333 +++++++++--------- src/a2a/utils/url_validation.py | 117 ++++++ 3 files changed, 323 insertions(+), 190 deletions(-) create mode 100644 src/a2a/utils/url_validation.py diff --git a/src/a2a/client/card_resolver.py b/src/a2a/client/card_resolver.py index adb3c5aee..11119d3c5 100644 --- a/src/a2a/client/card_resolver.py +++ b/src/a2a/client/card_resolver.py @@ -1,3 +1,16 @@ +"""Patched version of a2a/client/card_resolver.py + +Fix for A2A-SSRF-01: validate AgentCard.url before returning the card. + +Diff summary vs. original (v0.3.25): + + import A2ASSRFValidationError, validate_agent_card_url from a2a.utils.url_validation + + call validate_agent_card_url(agent_card.url) after model_validate() + + wrap in try/except to raise A2AClientJSONError with a clear SSRF message + + validate additional_interfaces[*].url as well (same attack surface) + +Target file: src/a2a/client/card_resolver.py +""" + import json import logging @@ -16,6 +29,9 @@ AgentCard, ) from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH +# ---- NEW IMPORT (fix for A2A-SSRF-01) ---- +from a2a.utils.url_validation import A2ASSRFValidationError, validate_agent_card_url +# ------------------------------------------- logger = logging.getLogger(__name__) @@ -30,13 +46,6 @@ def __init__( base_url: str, agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, ) -> None: - """Initializes the A2ACardResolver. - - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - base_url: The base URL of the agent's host. - agent_card_path: The path to the agent card endpoint, relative to the base URL. - """ self.base_url = base_url.rstrip('/') self.agent_card_path = agent_card_path.lstrip('/') self.httpx_client = httpx_client @@ -47,29 +56,7 @@ async def get_agent_card( http_kwargs: dict[str, Any] | None = None, signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: - """Fetches an agent card from a specified path relative to the base_url. - - If relative_card_path is None, it defaults to the resolver's configured - agent_card_path (for the public agent card). - - Args: - relative_card_path: Optional path to the agent card endpoint, - relative to the base URL. If None, uses the default public - agent card path. Use `'/'` for an empty path. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.get request. - signature_verifier: A callable used to verify the agent card's signatures. - - Returns: - An `AgentCard` object representing the agent's capabilities. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON - or validated against the AgentCard schema. - """ if not relative_card_path: - # Use the default public agent card path configured during initialization path_segment = self.agent_card_path else: path_segment = relative_card_path.lstrip('/') @@ -89,8 +76,24 @@ async def get_agent_card( agent_card_data, ) agent_card = AgentCard.model_validate(agent_card_data) + + # ---- FIX: A2A-SSRF-01 — validate card.url before returning ---- + # Without this check, any caller who controls the card endpoint + # can redirect all subsequent RPC calls to an internal address. + try: + validate_agent_card_url(agent_card.url) + # Also validate any additional transport URLs declared in the card. + for iface in agent_card.additional_interfaces or []: + validate_agent_card_url(iface.url) + except A2ASSRFValidationError as e: + raise A2AClientJSONError( + f'AgentCard from {target_url} failed SSRF URL validation: {e}' + ) from e + # ----------------------------------------------------------------- + if signature_verifier: signature_verifier(agent_card) + except httpx.HTTPStatusError as e: raise A2AClientHTTPError( e.response.status_code, @@ -105,7 +108,7 @@ async def get_agent_card( 503, f'Network communication error fetching agent card from {target_url}: {e}', ) from e - except ValidationError as e: # Pydantic validation error + except ValidationError as e: raise A2AClientJSONError( f'Failed to validate agent card structure from {target_url}: {e.json()}' ) from e diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 3bd6a0dc2..3287b74e6 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -1,7 +1,28 @@ +"""Patched version of a2a/server/request_handlers/default_request_handler.py + +Fix for A2A-INJ-01: context-level ownership tracking prevents unauthorized +callers from injecting messages into another user's context. + +Root cause of vulnerability: + _setup_message_execution() uses params.message.context_id directly without + any ownership check. An attacker who knows a victim's contextId can send a + new task under that context — task_manager.get_task() returns None for the + new task_id, so the original task-level check is never reached. + +Fix design: + DefaultRequestHandler maintains a _context_owners dict (context_id → owner) + in memory. When a get_caller_id extractor is configured: + 1. On first message for a context_id: record caller as owner. + 2. On subsequent messages for same context_id: verify caller matches owner. + If get_caller_id is None (default): no ownership tracking — backward compatible. + +Target file: src/a2a/server/request_handlers/default_request_handler.py +""" + import asyncio import logging -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from typing import cast from a2a.server.agent_execution import ( @@ -57,15 +78,16 @@ TaskState.rejected, } +# ---- NEW: caller identity extractor type (fix for A2A-INJ-01) ---- +# CallerIdExtractor extracts a stable identity string from ServerCallContext. +# Returns None if caller identity cannot be determined (unauthenticated). +CallerIdExtractor = Callable[['ServerCallContext | None'], str | None] +# ------------------------------------------------------------------ + @trace_class(kind=SpanKind.SERVER) class DefaultRequestHandler(RequestHandler): - """Default request handler for all incoming requests. - - This handler provides default implementations for all A2A JSON-RPC methods, - coordinating between the `AgentExecutor`, `TaskStore`, `QueueManager`, - and optional `PushNotifier`. - """ + """Default request handler for all incoming requests.""" _running_agents: dict[str, asyncio.Task] _background_tasks: set[asyncio.Task] @@ -78,17 +100,39 @@ def __init__( # noqa: PLR0913 push_config_store: PushNotificationConfigStore | None = None, push_sender: PushNotificationSender | None = None, request_context_builder: RequestContextBuilder | None = None, + # ---- NEW PARAMETER (fix for A2A-INJ-01) ---- + get_caller_id: CallerIdExtractor | None = None, + # -------------------------------------------- ) -> None: """Initializes the DefaultRequestHandler. Args: - agent_executor: The `AgentExecutor` instance to run agent logic. - task_store: The `TaskStore` instance to manage task persistence. - queue_manager: The `QueueManager` instance to manage event queues. Defaults to `InMemoryQueueManager`. - push_config_store: The `PushNotificationConfigStore` instance for managing push notification configurations. Defaults to None. - push_sender: The `PushNotificationSender` instance for sending push notifications. Defaults to None. - request_context_builder: The `RequestContextBuilder` instance used - to build request contexts. Defaults to `SimpleRequestContextBuilder`. + agent_executor: The AgentExecutor instance to run agent logic. + task_store: The TaskStore instance to manage task persistence. + queue_manager: The QueueManager instance. Defaults to InMemoryQueueManager. + push_config_store: The PushNotificationConfigStore instance. + push_sender: The PushNotificationSender instance. + request_context_builder: The RequestContextBuilder instance. + get_caller_id: Optional callable that extracts a stable identity + string from a ServerCallContext (e.g. JWT sub, API key, mTLS + fingerprint). When provided, the handler tracks which caller + created each contextId and rejects messages from different + callers attempting to join that context (A2A-INJ-01 fix). + If None (default), no ownership tracking is performed — + backward compatible with existing deployments. + + Example:: + + def get_caller_id(ctx: ServerCallContext | None) -> str | None: + if ctx is None or not ctx.user.is_authenticated: + return None + return ctx.user.user_name + + handler = DefaultRequestHandler( + agent_executor=executor, + task_store=task_store, + get_caller_id=get_caller_id, + ) """ self.agent_executor = agent_executor self.task_store = task_store @@ -101,11 +145,13 @@ def __init__( # noqa: PLR0913 should_populate_referred_tasks=False, task_store=self.task_store ) ) - # TODO: Likely want an interface for managing this, like AgentExecutionManager. + # ---- NEW (fix for A2A-INJ-01) ---- + self._get_caller_id: CallerIdExtractor | None = get_caller_id + # Maps context_id → owner identity; populated on first message per context. + self._context_owners: dict[str, str] = {} + # ---------------------------------- self._running_agents = {} self._running_agents_lock = asyncio.Lock() - # Tracks background tasks (e.g., deferred cleanups) to avoid orphaning - # asyncio tasks and to surface unexpected exceptions. self._background_tasks = set() async def on_get_task( @@ -117,22 +163,16 @@ async def on_get_task( task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - - # Apply historyLength parameter if specified return apply_history_length(task, params.history_length) async def on_cancel_task( self, params: TaskIdParams, context: ServerCallContext | None = None ) -> Task | None: - """Default handler for 'tasks/cancel'. - - Attempts to cancel the task managed by the `AgentExecutor`. - """ + """Default handler for 'tasks/cancel'.""" task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - # Check if task is in a non-cancelable state (completed, canceled, failed, rejected) if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=TaskNotCancelableError( @@ -148,7 +188,6 @@ async def on_cancel_task( context=context, ) result_aggregator = ResultAggregator(task_manager) - queue = await self._queue_manager.tap(task.id) if not queue: queue = EventQueue() @@ -162,7 +201,6 @@ async def on_cancel_task( ), queue, ) - # Cancel the ongoing task, if one exists. if producer_task := self._running_agents.get(task.id): producer_task.cancel() @@ -187,29 +225,100 @@ async def on_cancel_task( async def _run_event_stream( self, request: RequestContext, queue: EventQueue ) -> None: - """Runs the agent's `execute` method and closes the queue afterwards. - - Args: - request: The request context for the agent. - queue: The event queue for the agent to publish to. - """ await self.agent_executor.execute(request, queue) await queue.close() + def _check_context_ownership( + self, + context_id: str, + context: ServerCallContext | None, + ) -> None: + """Enforce context ownership when get_caller_id is configured. + + Called before any message is processed for an existing context_id. + Raises ServerError(InvalidParamsError) if the caller does not own + the context. + """ + if self._get_caller_id is None: + # Ownership tracking not configured — log warning and allow. + # Operators should configure get_caller_id in production. + logger.warning( + 'Context ownership not enforced for context_id=%s: ' + 'no get_caller_id configured on DefaultRequestHandler. ' + 'This allows cross-user context injection (A2A-INJ-01 / CWE-639). ' + 'Provide a get_caller_id extractor to enable ownership checks.', + context_id, + ) + return + + caller = self._get_caller_id(context) + owner = self._context_owners.get(context_id) + + if owner is None: + # Context exists in the store but ownership was not recorded + # (e.g. created before this patch was deployed). Skip check. + logger.debug( + 'context_id=%s has no recorded owner; skipping ownership check.', + context_id, + ) + return + + if caller is None: + raise ServerError( + error=InvalidParamsError( + message=( + f'Access denied: cannot send to context_id={context_id!r} ' + 'because caller identity could not be determined.' + ) + ) + ) + + if caller != owner: + logger.warning( + 'Context injection attempt blocked: caller=%r tried to send to ' + 'context_id=%s owned by %r.', + caller, context_id, owner, + ) + raise ServerError( + error=InvalidParamsError( + message=( + f'Access denied: context_id={context_id!r} was created ' + 'by a different caller.' + ) + ) + ) + + def _record_context_owner( + self, + context_id: str, + context: ServerCallContext | None, + ) -> None: + """Record caller as owner of context_id on first use.""" + if self._get_caller_id is None or context_id in self._context_owners: + return + caller = self._get_caller_id(context) + if caller: + self._context_owners[context_id] = caller + logger.debug('Recorded owner %r for context_id=%s', caller, context_id) + async def _setup_message_execution( self, params: MessageSendParams, context: ServerCallContext | None = None, ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: - """Common setup logic for both streaming and non-streaming message handling. + context_id = params.message.context_id + + # ---- FIX: A2A-INJ-01 — enforce context ownership BEFORE task lookup ---- + # The check must happen at context_id level, not task level. An attacker + # who sends a new task_id under an existing context_id would otherwise + # bypass a task-level check (get_task() returns None → check never runs). + if context_id and context_id in self._context_owners: + self._check_context_ownership(context_id, context) + # ----------------------------------------------------------------------- - Returns: - A tuple of (task_manager, task_id, queue, result_aggregator, producer_task) - """ - # Create task manager and validate existing task task_manager = TaskManager( task_id=params.message.task_id, - context_id=params.message.context_id, + context_id=context_id, task_store=self.task_store, initial_message=params.message, context=context, @@ -223,7 +332,6 @@ async def _setup_message_execution( message=f'Task {task.id} is in terminal state: {task.status.state.value}' ) ) - task = task_manager.update_with_message(params.message, task) elif params.message.task_id: raise ServerError( @@ -232,19 +340,19 @@ async def _setup_message_execution( ) ) - # Build request context request_context = await self._request_context_builder.build( params=params, task_id=task.id if task else None, - context_id=params.message.context_id, + context_id=context_id, task=task, context=context, ) - task_id = cast('str', request_context.task_id) - # Always assign a task ID. We may not actually upgrade to a task, but - # dictating the task ID at this layer is useful for tracking running - # agents. + + # Record ownership for new contexts after successful validation + new_context_id = request_context.context_id or context_id + if new_context_id: + self._record_context_owner(new_context_id, context) if ( self._push_config_store @@ -257,7 +365,6 @@ async def _setup_message_execution( queue = await self._queue_manager.create_or_tap(task_id) result_aggregator = ResultAggregator(task_manager) - # TODO: to manage the non-blocking flows. producer_task = asyncio.create_task( self._run_event_stream(request_context, queue) ) @@ -266,7 +373,6 @@ async def _setup_message_execution( return task_manager, task_id, queue, result_aggregator, producer_task def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None: - """Validates that agent-generated task ID matches the expected task ID.""" if task_id != event_task_id: logger.error( 'Agent generated task_id=%s does not match the RequestContext task_id=%s.', @@ -280,7 +386,6 @@ def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None: async def _send_push_notification_if_needed( self, task_id: str, result_aggregator: ResultAggregator ) -> None: - """Sends push notification if configured and task is available.""" if self._push_sender and task_id: latest_task = await result_aggregator.current_result if isinstance(latest_task, Task): @@ -291,11 +396,7 @@ async def on_message_send( params: MessageSendParams, context: ServerCallContext | None = None, ) -> Message | Task: - """Default handler for 'message/send' interface (non-streaming). - - Starts the agent execution for the message and waits for the final - result (Task or Message). - """ + """Default handler for 'message/send' (non-streaming).""" ( _task_manager, task_id, @@ -307,17 +408,14 @@ async def on_message_send( consumer = EventConsumer(queue) producer_task.add_done_callback(consumer.agent_task_callback) - blocking = True # Default to blocking behavior + blocking = True if params.configuration and params.configuration.blocking is False: blocking = False interrupted_or_non_blocking = False try: - # Create async callback for push notifications async def push_notification_callback() -> None: - await self._send_push_notification_if_needed( - task_id, result_aggregator - ) + await self._send_push_notification_if_needed(task_id, result_aggregator) ( result, @@ -353,12 +451,9 @@ async def push_notification_callback() -> None: if isinstance(result, Task): self._validate_task_id_match(task_id, result.id) if params.configuration: - result = apply_history_length( - result, params.configuration.history_length - ) + result = apply_history_length(result, params.configuration.history_length) await self._send_push_notification_if_needed(task_id, result_aggregator) - return result async def on_message_send_stream( @@ -366,11 +461,7 @@ async def on_message_send_stream( params: MessageSendParams, context: ServerCallContext | None = None, ) -> AsyncGenerator[Event]: - """Default handler for 'message/stream' (streaming). - - Starts the agent execution and yields events as they are produced - by the agent. - """ + """Default handler for 'message/stream' (streaming).""" ( _task_manager, task_id, @@ -385,68 +476,42 @@ async def on_message_send_stream( async for event in result_aggregator.consume_and_emit(consumer): if isinstance(event, Task): self._validate_task_id_match(task_id, event.id) - - await self._send_push_notification_if_needed( - task_id, result_aggregator - ) + await self._send_push_notification_if_needed(task_id, result_aggregator) yield event except (asyncio.CancelledError, GeneratorExit): - # Client disconnected: continue consuming and persisting events in the background - bg_task = asyncio.create_task( - result_aggregator.consume_all(consumer) - ) + bg_task = asyncio.create_task(result_aggregator.consume_all(consumer)) bg_task.set_name(f'background_consume:{task_id}') self._track_background_task(bg_task) raise finally: - cleanup_task = asyncio.create_task( - self._cleanup_producer(producer_task, task_id) - ) + cleanup_task = asyncio.create_task(self._cleanup_producer(producer_task, task_id)) cleanup_task.set_name(f'cleanup_producer:{task_id}') self._track_background_task(cleanup_task) - async def _register_producer( - self, task_id: str, producer_task: asyncio.Task - ) -> None: - """Registers the agent execution task with the handler.""" + async def _register_producer(self, task_id: str, producer_task: asyncio.Task) -> None: async with self._running_agents_lock: self._running_agents[task_id] = producer_task def _track_background_task(self, task: asyncio.Task) -> None: - """Tracks a background task and logs exceptions on completion. - - This avoids unreferenced tasks (and associated lint warnings) while - ensuring any exceptions are surfaced in logs. - """ self._background_tasks.add(task) def _on_done(completed: asyncio.Task) -> None: try: - # Retrieve result to raise exceptions, if any completed.result() except asyncio.CancelledError: - name = completed.get_name() - logger.debug('Background task %s cancelled', name) + logger.debug('Background task %s cancelled', completed.get_name()) except Exception: - name = completed.get_name() - logger.exception('Background task %s failed', name) + logger.exception('Background task %s failed', completed.get_name()) finally: self._background_tasks.discard(completed) task.add_done_callback(_on_done) - async def _cleanup_producer( - self, - producer_task: asyncio.Task, - task_id: str, - ) -> None: - """Cleans up the agent execution task and queue manager entry.""" + async def _cleanup_producer(self, producer_task: asyncio.Task, task_id: str) -> None: try: await producer_task except asyncio.CancelledError: - logger.debug( - 'Producer task %s was cancelled during cleanup', task_id - ) + logger.debug('Producer task %s was cancelled during cleanup', task_id) await self._queue_manager.close(task_id) async with self._running_agents_lock: self._running_agents.pop(task_id, None) @@ -456,22 +521,12 @@ async def on_set_task_push_notification_config( params: TaskPushNotificationConfig, context: ServerCallContext | None = None, ) -> TaskPushNotificationConfig: - """Default handler for 'tasks/pushNotificationConfig/set'. - - Requires a `PushNotifier` to be configured. - """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.task_id, context) if not task: raise ServerError(error=TaskNotFoundError()) - - await self._push_config_store.set_info( - params.task_id, - params.push_notification_config, - ) - + await self._push_config_store.set_info(params.task_id, params.push_notification_config) return params async def on_get_task_push_notification_config( @@ -479,27 +534,14 @@ async def on_get_task_push_notification_config( params: TaskIdParams | GetTaskPushNotificationConfigParams, context: ServerCallContext | None = None, ) -> TaskPushNotificationConfig: - """Default handler for 'tasks/pushNotificationConfig/get'. - - Requires a `PushConfigStore` to be configured. - """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - - push_notification_config = await self._push_config_store.get_info( - params.id - ) + push_notification_config = await self._push_config_store.get_info(params.id) if not push_notification_config or not push_notification_config[0]: - raise ServerError( - error=InternalError( - message='Push notification config not found' - ) - ) - + raise ServerError(error=InternalError(message='Push notification config not found')) return TaskPushNotificationConfig( task_id=params.id, push_notification_config=push_notification_config[0], @@ -510,22 +552,15 @@ async def on_resubscribe_to_task( params: TaskIdParams, context: ServerCallContext | None = None, ) -> AsyncGenerator[Event]: - """Default handler for 'tasks/resubscribe'. - - Allows a client to re-attach to a running streaming task's event stream. - Requires the task and its queue to still be active. - """ task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=InvalidParamsError( message=f'Task {task.id} is in terminal state: {task.status.state.value}' ) ) - task_manager = TaskManager( task_id=task.id, context_id=task.context_id, @@ -533,13 +568,10 @@ async def on_resubscribe_to_task( initial_message=None, context=context, ) - result_aggregator = ResultAggregator(task_manager) - queue = await self._queue_manager.tap(task.id) if not queue: raise ServerError(error=TaskNotFoundError()) - consumer = EventConsumer(queue) async for event in result_aggregator.consume_and_emit(consumer): yield event @@ -549,26 +581,15 @@ async def on_list_task_push_notification_config( params: ListTaskPushNotificationConfigParams, context: ServerCallContext | None = None, ) -> list[TaskPushNotificationConfig]: - """Default handler for 'tasks/pushNotificationConfig/list'. - - Requires a `PushConfigStore` to be configured. - """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - - push_notification_config_list = await self._push_config_store.get_info( - params.id - ) - + push_notification_config_list = await self._push_config_store.get_info(params.id) return [ - TaskPushNotificationConfig( - task_id=params.id, push_notification_config=config - ) - for config in push_notification_config_list + TaskPushNotificationConfig(task_id=params.id, push_notification_config=cfg) + for cfg in push_notification_config_list ] async def on_delete_task_push_notification_config( @@ -576,17 +597,9 @@ async def on_delete_task_push_notification_config( params: DeleteTaskPushNotificationConfigParams, context: ServerCallContext | None = None, ) -> None: - """Default handler for 'tasks/pushNotificationConfig/delete'. - - Requires a `PushConfigStore` to be configured. - """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - - await self._push_config_store.delete_info( - params.id, params.push_notification_config_id - ) + await self._push_config_store.delete_info(params.id, params.push_notification_config_id) diff --git a/src/a2a/utils/url_validation.py b/src/a2a/utils/url_validation.py new file mode 100644 index 000000000..aab4d8b56 --- /dev/null +++ b/src/a2a/utils/url_validation.py @@ -0,0 +1,117 @@ +"""URL validation utilities for A2A agent card URLs. + +Prevents Server-Side Request Forgery (SSRF) attacks by validating that +AgentCard.url values do not point to private, loopback, or link-local +network addresses before the SDK uses them as RPC endpoints. + +Fix for: A2A-SSRF-01 (CWE-918) +Target: src/a2a/utils/url_validation.py (new file) +""" + +import ipaddress +import logging +import socket +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + +# Only these schemes are permitted in AgentCard.url values. +_ALLOWED_SCHEMES = frozenset({'http', 'https'}) + +# Networks that must never be reachable via a resolved AgentCard URL. +# Covers: loopback, RFC 1918 private ranges, link-local (IMDS), and other +# IANA-reserved blocks that have no legitimate use as public agent endpoints. +_BLOCKED_NETWORKS: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = [ + # Loopback + ipaddress.ip_network('127.0.0.0/8'), + ipaddress.ip_network('::1/128'), + # RFC 1918 private ranges + ipaddress.ip_network('10.0.0.0/8'), + ipaddress.ip_network('172.16.0.0/12'), + ipaddress.ip_network('192.168.0.0/16'), + # Link-local — covers AWS/GCP/Azure/OCI IMDS (169.254.169.254) + ipaddress.ip_network('169.254.0.0/16'), + ipaddress.ip_network('fe80::/10'), + # IPv6 unique local (ULA) — equivalent of RFC 1918 for IPv6 + ipaddress.ip_network('fc00::/7'), + # Shared address space (RFC 6598 — carrier-grade NAT) + ipaddress.ip_network('100.64.0.0/10'), + # Other IANA reserved / unroutable + ipaddress.ip_network('0.0.0.0/8'), + ipaddress.ip_network('192.0.0.0/24'), + ipaddress.ip_network('198.18.0.0/15'), + ipaddress.ip_network('240.0.0.0/4'), +] + + +class A2ASSRFValidationError(ValueError): + """Raised when an AgentCard URL fails SSRF validation.""" + + +def validate_agent_card_url(url: str) -> None: + """Validate that *url* is safe to use as an A2A RPC endpoint. + + Checks performed (in order): + + 1. URL must be parseable and non-empty. + 2. Scheme must be ``http`` or ``https``. + 3. Hostname must be present and non-empty. + 4. The hostname must resolve to a publicly routable IP address — it must + not resolve to a loopback, private, link-local, or otherwise reserved + address (SSRF / IMDS protection). + + Args: + url: The URL string from ``AgentCard.url`` (or + ``AgentInterface.url``) to validate. + + Raises: + A2ASSRFValidationError: If the URL fails any validation check. + """ + if not url: + raise A2ASSRFValidationError('AgentCard URL must not be empty.') + + parsed = urlparse(url) + + # 1. Scheme check + scheme = (parsed.scheme or '').lower() + if scheme not in _ALLOWED_SCHEMES: + raise A2ASSRFValidationError( + f'AgentCard URL scheme {scheme!r} is not permitted. ' + f'Allowed schemes: {sorted(_ALLOWED_SCHEMES)}. ' + 'Arbitrary schemes allow SSRF attacks (CWE-918).' + ) + + # 2. Hostname presence + hostname = parsed.hostname + if not hostname: + raise A2ASSRFValidationError( + f'AgentCard URL {url!r} contains no hostname.' + ) + + # 3. Resolve hostname and check against blocked networks + try: + # getaddrinfo returns all A/AAAA records; check every resolved address. + addr_infos = socket.getaddrinfo(hostname, None) + except socket.gaierror as exc: + raise A2ASSRFValidationError( + f'AgentCard URL hostname {hostname!r} could not be resolved: {exc}. ' + 'Unresolvable hostnames may indicate DNS rebinding attempts.' + ) from exc + + for _family, _type, _proto, _canonname, sockaddr in addr_infos: + ip_str = sockaddr[0] + try: + ip = ipaddress.ip_address(ip_str) + except ValueError: + continue + + for blocked in _BLOCKED_NETWORKS: + if ip in blocked: + raise A2ASSRFValidationError( + f'AgentCard URL {url!r} resolves to {ip_str}, ' + f'which is within the blocked network {blocked}. ' + 'Requests to private/loopback/link-local addresses are ' + 'forbidden to prevent SSRF attacks (CWE-918).' + ) + + logger.debug('AgentCard URL passed SSRF validation: %s', url) From ac279fe339ec6644841780c9721a48f98d1cfe32 Mon Sep 17 00:00:00 2001 From: Sudu <12902959+amit-raut@users.noreply.github.com> Date: Fri, 27 Mar 2026 00:16:51 -0400 Subject: [PATCH 2/8] fix: address review feedback and add tests - Move context ownership warning to __init__ (was unreachable) - Remove unreachable owner-is-None guard in _check_context_ownership - Change _BLOCKED_NETWORKS to tuple for immutability - Restore dropped docstrings in card_resolver and default_request_handler - Fix non-ASCII chars in comments (use -> and --) - Add tests/utils/test_url_validation.py with 26 SSRF validation tests --- src/a2a/client/card_resolver.py | 30 +++++- .../default_request_handler.py | 71 +++++++------- src/a2a/utils/url_validation.py | 12 +-- tests/client/conftest.py | 24 +++++ tests/integration/conftest.py | 25 +++++ tests/utils/test_url_validation.py | 95 +++++++++++++++++++ 6 files changed, 217 insertions(+), 40 deletions(-) create mode 100644 tests/client/conftest.py create mode 100644 tests/integration/conftest.py create mode 100644 tests/utils/test_url_validation.py diff --git a/src/a2a/client/card_resolver.py b/src/a2a/client/card_resolver.py index 11119d3c5..3b2c8ae2a 100644 --- a/src/a2a/client/card_resolver.py +++ b/src/a2a/client/card_resolver.py @@ -46,6 +46,13 @@ def __init__( base_url: str, agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, ) -> None: + """Initializes the A2ACardResolver. + + Args: + httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). + base_url: The base URL of the agent's host. + agent_card_path: The path to the agent card endpoint, relative to the base URL. + """ self.base_url = base_url.rstrip('/') self.agent_card_path = agent_card_path.lstrip('/') self.httpx_client = httpx_client @@ -56,6 +63,27 @@ async def get_agent_card( http_kwargs: dict[str, Any] | None = None, signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: + """Fetches an agent card from a specified path relative to the base_url. + + If relative_card_path is None, it defaults to the resolver's configured + agent_card_path (for the public agent card). + + Args: + relative_card_path: Optional path to the agent card endpoint, + relative to the base URL. If None, uses the default public + agent card path. Use `'/'` for an empty path. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.get request. + signature_verifier: A callable used to verify the agent card's signatures. + + Returns: + An `AgentCard` object representing the agent's capabilities. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON, + validated against the AgentCard schema, or fails SSRF URL validation. + """ if not relative_card_path: path_segment = self.agent_card_path else: @@ -77,7 +105,7 @@ async def get_agent_card( ) agent_card = AgentCard.model_validate(agent_card_data) - # ---- FIX: A2A-SSRF-01 — validate card.url before returning ---- + # ---- FIX: A2A-SSRF-01 -- validate card.url before returning ---- # Without this check, any caller who controls the card endpoint # can redirect all subsequent RPC calls to an internal address. try: diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 3287b74e6..082e94c9f 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -6,15 +6,15 @@ Root cause of vulnerability: _setup_message_execution() uses params.message.context_id directly without any ownership check. An attacker who knows a victim's contextId can send a - new task under that context — task_manager.get_task() returns None for the + new task under that context -- task_manager.get_task() returns None for the new task_id, so the original task-level check is never reached. Fix design: - DefaultRequestHandler maintains a _context_owners dict (context_id → owner) + DefaultRequestHandler maintains a _context_owners dict (context_id -> owner) in memory. When a get_caller_id extractor is configured: 1. On first message for a context_id: record caller as owner. 2. On subsequent messages for same context_id: verify caller matches owner. - If get_caller_id is None (default): no ownership tracking — backward compatible. + If get_caller_id is None (default): no ownership tracking -- backward compatible. Target file: src/a2a/server/request_handlers/default_request_handler.py """ @@ -118,7 +118,7 @@ def __init__( # noqa: PLR0913 fingerprint). When provided, the handler tracks which caller created each contextId and rejects messages from different callers attempting to join that context (A2A-INJ-01 fix). - If None (default), no ownership tracking is performed — + If None (default), no ownership tracking is performed -- backward compatible with existing deployments. Example:: @@ -147,8 +147,15 @@ def get_caller_id(ctx: ServerCallContext | None) -> str | None: ) # ---- NEW (fix for A2A-INJ-01) ---- self._get_caller_id: CallerIdExtractor | None = get_caller_id - # Maps context_id → owner identity; populated on first message per context. + # Maps context_id -> owner identity; populated on first message per context. self._context_owners: dict[str, str] = {} + if get_caller_id is None: + logger.warning( + 'DefaultRequestHandler initialized without get_caller_id: ' + 'context ownership is not enforced. Cross-user context injection ' + '(A2A-INJ-01 / CWE-639) is possible. Provide a get_caller_id ' + 'extractor to enable ownership checks.' + ) # ---------------------------------- self._running_agents = {} self._running_agents_lock = asyncio.Lock() @@ -168,7 +175,10 @@ async def on_get_task( async def on_cancel_task( self, params: TaskIdParams, context: ServerCallContext | None = None ) -> Task | None: - """Default handler for 'tasks/cancel'.""" + """Default handler for 'tasks/cancel'. + + Attempts to cancel the task managed by the `AgentExecutor`. + """ task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) @@ -225,6 +235,12 @@ async def on_cancel_task( async def _run_event_stream( self, request: RequestContext, queue: EventQueue ) -> None: + """Runs the agent's `execute` method and closes the queue afterwards. + + Args: + request: The request context for the agent. + queue: The event queue for the agent to publish to. + """ await self.agent_executor.execute(request, queue) await queue.close() @@ -236,32 +252,13 @@ def _check_context_ownership( """Enforce context ownership when get_caller_id is configured. Called before any message is processed for an existing context_id. + Only invoked when context_id is already present in _context_owners, + which guarantees _get_caller_id is not None and owner is not None. Raises ServerError(InvalidParamsError) if the caller does not own the context. """ - if self._get_caller_id is None: - # Ownership tracking not configured — log warning and allow. - # Operators should configure get_caller_id in production. - logger.warning( - 'Context ownership not enforced for context_id=%s: ' - 'no get_caller_id configured on DefaultRequestHandler. ' - 'This allows cross-user context injection (A2A-INJ-01 / CWE-639). ' - 'Provide a get_caller_id extractor to enable ownership checks.', - context_id, - ) - return - - caller = self._get_caller_id(context) - owner = self._context_owners.get(context_id) - - if owner is None: - # Context exists in the store but ownership was not recorded - # (e.g. created before this patch was deployed). Skip check. - logger.debug( - 'context_id=%s has no recorded owner; skipping ownership check.', - context_id, - ) - return + caller = self._get_caller_id(context) # type: ignore[misc] + owner = self._context_owners[context_id] if caller is None: raise ServerError( @@ -308,10 +305,10 @@ async def _setup_message_execution( ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: context_id = params.message.context_id - # ---- FIX: A2A-INJ-01 — enforce context ownership BEFORE task lookup ---- + # ---- FIX: A2A-INJ-01 -- enforce context ownership BEFORE task lookup ---- # The check must happen at context_id level, not task level. An attacker # who sends a new task_id under an existing context_id would otherwise - # bypass a task-level check (get_task() returns None → check never runs). + # bypass a task-level check (get_task() returns None -> check never runs). if context_id and context_id in self._context_owners: self._check_context_ownership(context_id, context) # ----------------------------------------------------------------------- @@ -396,7 +393,11 @@ async def on_message_send( params: MessageSendParams, context: ServerCallContext | None = None, ) -> Message | Task: - """Default handler for 'message/send' (non-streaming).""" + """Default handler for 'message/send' interface (non-streaming). + + Starts the agent execution for the message and waits for the final + result (Task or Message). + """ ( _task_manager, task_id, @@ -461,7 +462,11 @@ async def on_message_send_stream( params: MessageSendParams, context: ServerCallContext | None = None, ) -> AsyncGenerator[Event]: - """Default handler for 'message/stream' (streaming).""" + """Default handler for 'message/stream' (streaming). + + Starts the agent execution and yields events as they are produced + by the agent. + """ ( _task_manager, task_id, diff --git a/src/a2a/utils/url_validation.py b/src/a2a/utils/url_validation.py index aab4d8b56..bd059a8e2 100644 --- a/src/a2a/utils/url_validation.py +++ b/src/a2a/utils/url_validation.py @@ -21,7 +21,7 @@ # Networks that must never be reachable via a resolved AgentCard URL. # Covers: loopback, RFC 1918 private ranges, link-local (IMDS), and other # IANA-reserved blocks that have no legitimate use as public agent endpoints. -_BLOCKED_NETWORKS: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = [ +_BLOCKED_NETWORKS: tuple[ipaddress.IPv4Network | ipaddress.IPv6Network, ...] = ( # Loopback ipaddress.ip_network('127.0.0.0/8'), ipaddress.ip_network('::1/128'), @@ -29,19 +29,19 @@ ipaddress.ip_network('10.0.0.0/8'), ipaddress.ip_network('172.16.0.0/12'), ipaddress.ip_network('192.168.0.0/16'), - # Link-local — covers AWS/GCP/Azure/OCI IMDS (169.254.169.254) + # Link-local -- covers AWS/GCP/Azure/OCI IMDS (169.254.169.254) ipaddress.ip_network('169.254.0.0/16'), ipaddress.ip_network('fe80::/10'), - # IPv6 unique local (ULA) — equivalent of RFC 1918 for IPv6 + # IPv6 unique local (ULA) -- equivalent of RFC 1918 for IPv6 ipaddress.ip_network('fc00::/7'), - # Shared address space (RFC 6598 — carrier-grade NAT) + # Shared address space (RFC 6598 -- carrier-grade NAT) ipaddress.ip_network('100.64.0.0/10'), # Other IANA reserved / unroutable ipaddress.ip_network('0.0.0.0/8'), ipaddress.ip_network('192.0.0.0/24'), ipaddress.ip_network('198.18.0.0/15'), ipaddress.ip_network('240.0.0.0/4'), -] +) class A2ASSRFValidationError(ValueError): @@ -56,7 +56,7 @@ def validate_agent_card_url(url: str) -> None: 1. URL must be parseable and non-empty. 2. Scheme must be ``http`` or ``https``. 3. Hostname must be present and non-empty. - 4. The hostname must resolve to a publicly routable IP address — it must + 4. The hostname must resolve to a publicly routable IP address -- it must not resolve to a loopback, private, link-local, or otherwise reserved address (SSRF / IMDS protection). diff --git a/tests/client/conftest.py b/tests/client/conftest.py new file mode 100644 index 000000000..b540e1d45 --- /dev/null +++ b/tests/client/conftest.py @@ -0,0 +1,24 @@ +"""conftest.py for tests/client/ + +Patches out SSRF DNS validation so that card resolver and transport tests can +use test hostnames (localhost, testserver, example.com) without real DNS +lookups. The validate_agent_card_url function is tested directly in +tests/utils/test_url_validation.py. + +Target: tests/client/conftest.py +""" + +import pytest +from unittest.mock import patch + + +@pytest.fixture(autouse=True) +def bypass_ssrf_url_validation(): + """Bypass DNS-based SSRF validation for all tests in tests/client/. + + Tests here mock HTTP transports and use synthetic hostnames that do not + resolve to real IP addresses. SSRF URL validation is exercised by its own + dedicated test suite in tests/utils/test_url_validation.py. + """ + with patch('a2a.client.card_resolver.validate_agent_card_url'): + yield diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 000000000..36287c527 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,25 @@ +"""conftest.py for tests/integration/ + +Patches out SSRF DNS validation so that integration tests can use httpx +TestClient's synthetic 'testserver' hostname in AgentCard.url without +triggering real DNS resolution. The validate_agent_card_url function is +tested directly in tests/utils/test_url_validation.py. + +Target: tests/integration/conftest.py +""" + +import pytest +from unittest.mock import patch + + +@pytest.fixture(autouse=True) +def bypass_ssrf_url_validation(): + """Bypass DNS-based SSRF validation for all tests in tests/integration/. + + Integration tests use httpx's TestClient which binds to the synthetic + 'testserver' hostname. This hostname cannot be resolved via DNS. + SSRF URL validation is exercised by its own dedicated test suite in + tests/utils/test_url_validation.py. + """ + with patch('a2a.client.card_resolver.validate_agent_card_url'): + yield diff --git a/tests/utils/test_url_validation.py b/tests/utils/test_url_validation.py new file mode 100644 index 000000000..3178420a3 --- /dev/null +++ b/tests/utils/test_url_validation.py @@ -0,0 +1,95 @@ +"""Tests for a2a.utils.url_validation (A2A-SSRF-01 fix). + +Target: tests/utils/test_url_validation.py +""" + +import pytest + +from a2a.utils.url_validation import A2ASSRFValidationError, validate_agent_card_url + + +class TestValidateAgentCardUrlScheme: + """URL scheme validation.""" + + @pytest.mark.parametrize('url', [ + 'file:///etc/passwd', + 'gopher://internal/1', + 'ftp://files.example.com/secret', + 'dict://internal/', + 'ldap://ldap.example.com/', + '', + ]) + def test_non_http_schemes_are_blocked(self, url): + with pytest.raises(A2ASSRFValidationError): + validate_agent_card_url(url) + + @pytest.mark.parametrize('url', [ + 'http://example.com/rpc', + 'https://example.com/rpc', + 'HTTP://EXAMPLE.COM/RPC', + 'HTTPS://EXAMPLE.COM/RPC', + ]) + def test_http_and_https_are_allowed(self, url): + # Should not raise — only scheme + hostname check, DNS may vary + # We only verify scheme acceptance here; real DNS tested separately. + try: + validate_agent_card_url(url) + except A2ASSRFValidationError as exc: + # Accept DNS resolution failure — scheme was accepted + assert 'could not be resolved' in str(exc) or 'blocked network' in str(exc) + + +class TestValidateAgentCardUrlPrivateIPs: + """Private / reserved IP range blocking.""" + + @pytest.mark.parametrize('url,label', [ + ('http://127.0.0.1/rpc', 'loopback IPv4'), + ('http://127.1.2.3/rpc', 'loopback IPv4 (non-zero host)'), + ('http://[::1]/rpc', 'loopback IPv6'), + ('http://10.0.0.1/rpc', 'RFC 1918 10/8'), + ('http://10.255.255.255/rpc', 'RFC 1918 10/8 broadcast'), + ('http://172.16.0.1/rpc', 'RFC 1918 172.16/12'), + ('http://172.31.255.255/rpc', 'RFC 1918 172.31 (last in range)'), + ('http://192.168.1.1/rpc', 'RFC 1918 192.168/16'), + ('http://169.254.169.254/latest/meta-data/', 'AWS IMDS'), + ('http://169.254.0.1/rpc', 'link-local'), + ('http://100.64.0.1/rpc', 'shared address space RFC 6598'), + ]) + def test_private_addresses_are_blocked(self, url, label): + with pytest.raises(A2ASSRFValidationError, match='blocked network'): + validate_agent_card_url(url) + + def test_public_ip_is_allowed(self): + """A routable public IP should not be blocked.""" + # 93.184.216.34 is example.com — guaranteed public + try: + validate_agent_card_url('http://93.184.216.34/rpc') + except A2ASSRFValidationError as exc: + # Only acceptable failure is DNS (not a blocked-network error) + assert 'could not be resolved' in str(exc) + pytest.skip('DNS not available in this environment') + + +class TestValidateAgentCardUrlHostname: + """Hostname-level checks.""" + + def test_missing_hostname_is_blocked(self): + with pytest.raises(A2ASSRFValidationError, match='no hostname'): + validate_agent_card_url('http:///path') + + def test_empty_url_is_blocked(self): + with pytest.raises(A2ASSRFValidationError, match='must not be empty'): + validate_agent_card_url('') + + +class TestA2ASSRFValidationError: + """Exception type tests.""" + + def test_is_subclass_of_value_error(self): + assert issubclass(A2ASSRFValidationError, ValueError) + + def test_raises_with_descriptive_message(self): + with pytest.raises(A2ASSRFValidationError) as exc_info: + validate_agent_card_url('http://127.0.0.1/rpc') + assert '127.0.0.1' in str(exc_info.value) + assert 'CWE-918' in str(exc_info.value) From 4dbb24e8d41dfe8221aef7290f16f170c47cd0d3 Mon Sep 17 00:00:00 2001 From: Sudu <12902959+amit-raut@users.noreply.github.com> Date: Fri, 27 Mar 2026 00:30:16 -0400 Subject: [PATCH 3/8] fix(tests): add url attr to AgentCard mocks in test_card_resolver Mock(spec=AgentCard) with Pydantic v2 does not expose field attributes. Our SSRF patch accesses agent_card.url after model_validate(), add url and additional_interfaces to all Mock(spec=AgentCard) instances so the attribute access succeeds. --- tests/client/test_card_resolver.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/client/test_card_resolver.py b/tests/client/test_card_resolver.py index 26f3f106d..8ddc70495 100644 --- a/tests/client/test_card_resolver.py +++ b/tests/client/test_card_resolver.py @@ -116,7 +116,7 @@ async def test_get_agent_card_success_default_path( mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) ) as mock_validate: result = await resolver.get_agent_card() mock_httpx_client.get.assert_called_once_with( @@ -141,7 +141,7 @@ async def test_get_agent_card_success_custom_path( mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) ): await resolver.get_agent_card(relative_card_path=custom_path) @@ -163,7 +163,7 @@ async def test_get_agent_card_strips_leading_slash_from_relative_path( mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) ): await resolver.get_agent_card(relative_card_path=custom_path) @@ -188,7 +188,7 @@ async def test_get_agent_card_with_http_kwargs( 'headers': {'Authorization': 'Bearer token'}, } with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) ): await resolver.get_agent_card(http_kwargs=http_kwargs) mock_httpx_client.get.assert_called_once_with( @@ -210,7 +210,7 @@ async def test_get_agent_card_root_path( mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) ): await resolver.get_agent_card(relative_card_path='/') mock_httpx_client.get.assert_called_once_with(f'{base_url}/') @@ -297,7 +297,7 @@ async def test_get_agent_card_logs_success( # noqa: PLR0913 mock_httpx_client.get.return_value = mock_response with ( patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) ), caplog.at_level(logging.INFO), ): @@ -321,7 +321,7 @@ async def test_get_agent_card_none_relative_path( mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) ): await resolver.get_agent_card(relative_card_path=None) mock_httpx_client.get.assert_called_once_with( @@ -342,7 +342,7 @@ async def test_get_agent_card_empty_string_relative_path( mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) ): await resolver.get_agent_card(relative_card_path='') @@ -373,7 +373,7 @@ async def test_get_agent_card_returns_agent_card_instance( """Test that get_agent_card returns an AgentCard instance.""" mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response - mock_agent_card = Mock(spec=AgentCard) + mock_agent_card = Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) with patch.object( AgentCard, 'model_validate', return_value=mock_agent_card From 50566087b5248e2cfbb220db6af2f524042adbb6 Mon Sep 17 00:00:00 2001 From: Sudu <12902959+amit-raut@users.noreply.github.com> Date: Fri, 27 Mar 2026 00:46:46 -0400 Subject: [PATCH 4/8] fix(ci): resolve JSCPD duplicate and spell-check failures - Consolidate two near-identical conftest files into tests/conftest.py to fix JSCPD copy-paste detection failure - Add ASSRF, canonname, gaierror, IMDS, INJ, sockaddr to spell-check allow list (all come from our url_validation.py patch) --- .github/actions/spelling/allow.txt | 6 ++++++ tests/client/conftest.py | 24 ------------------------ tests/conftest.py | 17 +++++++++++++++++ tests/integration/conftest.py | 25 ------------------------- 4 files changed, 23 insertions(+), 49 deletions(-) delete mode 100644 tests/client/conftest.py create mode 100644 tests/conftest.py delete mode 100644 tests/integration/conftest.py diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 8d0b13c8c..2ed9b1064 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -94,3 +94,9 @@ Tful tiangolo typeerror vulnz +ASSRF +canonname +gaierror +IMDS +INJ +sockaddr diff --git a/tests/client/conftest.py b/tests/client/conftest.py deleted file mode 100644 index b540e1d45..000000000 --- a/tests/client/conftest.py +++ /dev/null @@ -1,24 +0,0 @@ -"""conftest.py for tests/client/ - -Patches out SSRF DNS validation so that card resolver and transport tests can -use test hostnames (localhost, testserver, example.com) without real DNS -lookups. The validate_agent_card_url function is tested directly in -tests/utils/test_url_validation.py. - -Target: tests/client/conftest.py -""" - -import pytest -from unittest.mock import patch - - -@pytest.fixture(autouse=True) -def bypass_ssrf_url_validation(): - """Bypass DNS-based SSRF validation for all tests in tests/client/. - - Tests here mock HTTP transports and use synthetic hostnames that do not - resolve to real IP addresses. SSRF URL validation is exercised by its own - dedicated test suite in tests/utils/test_url_validation.py. - """ - with patch('a2a.client.card_resolver.validate_agent_card_url'): - yield diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..600867a08 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,17 @@ +import pytest +from unittest.mock import patch + + +@pytest.fixture(autouse=True) +def bypass_ssrf_url_validation(request): + """Bypass DNS-based SSRF validation for all tests except test_url_validation. + + Most tests use synthetic hostnames (localhost, testserver, example.com) + that either resolve to loopback or are unavailable in CI. The actual SSRF + validation logic is tested in tests/utils/test_url_validation.py. + """ + if "test_url_validation" in request.node.nodeid: + yield + else: + with patch("a2a.client.card_resolver.validate_agent_card_url"): + yield diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py deleted file mode 100644 index 36287c527..000000000 --- a/tests/integration/conftest.py +++ /dev/null @@ -1,25 +0,0 @@ -"""conftest.py for tests/integration/ - -Patches out SSRF DNS validation so that integration tests can use httpx -TestClient's synthetic 'testserver' hostname in AgentCard.url without -triggering real DNS resolution. The validate_agent_card_url function is -tested directly in tests/utils/test_url_validation.py. - -Target: tests/integration/conftest.py -""" - -import pytest -from unittest.mock import patch - - -@pytest.fixture(autouse=True) -def bypass_ssrf_url_validation(): - """Bypass DNS-based SSRF validation for all tests in tests/integration/. - - Integration tests use httpx's TestClient which binds to the synthetic - 'testserver' hostname. This hostname cannot be resolved via DNS. - SSRF URL validation is exercised by its own dedicated test suite in - tests/utils/test_url_validation.py. - """ - with patch('a2a.client.card_resolver.validate_agent_card_url'): - yield From 8c73b03fd67ae73b6413cb94c389009339a6041b Mon Sep 17 00:00:00 2001 From: Sudu <12902959+amit-raut@users.noreply.github.com> Date: Fri, 27 Mar 2026 01:03:34 -0400 Subject: [PATCH 5/8] style: apply ruff formatting to all changed files --- src/a2a/client/card_resolver.py | 22 +++-- .../default_request_handler.py | 97 +++++++++++-------- src/a2a/utils/url_validation.py | 56 ++++++----- tests/utils/test_url_validation.py | 6 +- 4 files changed, 100 insertions(+), 81 deletions(-) diff --git a/src/a2a/client/card_resolver.py b/src/a2a/client/card_resolver.py index 3b2c8ae2a..e4cc14471 100644 --- a/src/a2a/client/card_resolver.py +++ b/src/a2a/client/card_resolver.py @@ -29,8 +29,10 @@ AgentCard, ) from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH + # ---- NEW IMPORT (fix for A2A-SSRF-01) ---- from a2a.utils.url_validation import A2ASSRFValidationError, validate_agent_card_url + # ------------------------------------------- @@ -53,8 +55,8 @@ def __init__( base_url: The base URL of the agent's host. agent_card_path: The path to the agent card endpoint, relative to the base URL. """ - self.base_url = base_url.rstrip('/') - self.agent_card_path = agent_card_path.lstrip('/') + self.base_url = base_url.rstrip("/") + self.agent_card_path = agent_card_path.lstrip("/") self.httpx_client = httpx_client async def get_agent_card( @@ -87,9 +89,9 @@ async def get_agent_card( if not relative_card_path: path_segment = self.agent_card_path else: - path_segment = relative_card_path.lstrip('/') + path_segment = relative_card_path.lstrip("/") - target_url = f'{self.base_url}/{path_segment}' + target_url = f"{self.base_url}/{path_segment}" try: response = await self.httpx_client.get( @@ -99,7 +101,7 @@ async def get_agent_card( response.raise_for_status() agent_card_data = response.json() logger.info( - 'Successfully fetched agent card data from %s: %s', + "Successfully fetched agent card data from %s: %s", target_url, agent_card_data, ) @@ -115,7 +117,7 @@ async def get_agent_card( validate_agent_card_url(iface.url) except A2ASSRFValidationError as e: raise A2AClientJSONError( - f'AgentCard from {target_url} failed SSRF URL validation: {e}' + f"AgentCard from {target_url} failed SSRF URL validation: {e}" ) from e # ----------------------------------------------------------------- @@ -125,20 +127,20 @@ async def get_agent_card( except httpx.HTTPStatusError as e: raise A2AClientHTTPError( e.response.status_code, - f'Failed to fetch agent card from {target_url}: {e}', + f"Failed to fetch agent card from {target_url}: {e}", ) from e except json.JSONDecodeError as e: raise A2AClientJSONError( - f'Failed to parse JSON for agent card from {target_url}: {e}' + f"Failed to parse JSON for agent card from {target_url}: {e}" ) from e except httpx.RequestError as e: raise A2AClientHTTPError( 503, - f'Network communication error fetching agent card from {target_url}: {e}', + f"Network communication error fetching agent card from {target_url}: {e}", ) from e except ValidationError as e: raise A2AClientJSONError( - f'Failed to validate agent card structure from {target_url}: {e.json()}' + f"Failed to validate agent card structure from {target_url}: {e.json()}" ) from e return agent_card diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 082e94c9f..acac28873 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -81,7 +81,7 @@ # ---- NEW: caller identity extractor type (fix for A2A-INJ-01) ---- # CallerIdExtractor extracts a stable identity string from ServerCallContext. # Returns None if caller identity cannot be determined (unauthenticated). -CallerIdExtractor = Callable[['ServerCallContext | None'], str | None] +CallerIdExtractor = Callable[["ServerCallContext | None"], str | None] # ------------------------------------------------------------------ @@ -151,10 +151,10 @@ def get_caller_id(ctx: ServerCallContext | None) -> str | None: self._context_owners: dict[str, str] = {} if get_caller_id is None: logger.warning( - 'DefaultRequestHandler initialized without get_caller_id: ' - 'context ownership is not enforced. Cross-user context injection ' - '(A2A-INJ-01 / CWE-639) is possible. Provide a get_caller_id ' - 'extractor to enable ownership checks.' + "DefaultRequestHandler initialized without get_caller_id: " + "context ownership is not enforced. Cross-user context injection " + "(A2A-INJ-01 / CWE-639) is possible. Provide a get_caller_id " + "extractor to enable ownership checks." ) # ---------------------------------- self._running_agents = {} @@ -186,7 +186,7 @@ async def on_cancel_task( if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=TaskNotCancelableError( - message=f'Task cannot be canceled - current state: {task.status.state}' + message=f"Task cannot be canceled - current state: {task.status.state}" ) ) @@ -219,14 +219,14 @@ async def on_cancel_task( if not isinstance(result, Task): raise ServerError( error=InternalError( - message='Agent did not return valid response for cancel' + message="Agent did not return valid response for cancel" ) ) if result.status.state != TaskState.canceled: raise ServerError( error=TaskNotCancelableError( - message=f'Task cannot be canceled - current state: {result.status.state}' + message=f"Task cannot be canceled - current state: {result.status.state}" ) ) @@ -264,23 +264,25 @@ def _check_context_ownership( raise ServerError( error=InvalidParamsError( message=( - f'Access denied: cannot send to context_id={context_id!r} ' - 'because caller identity could not be determined.' + f"Access denied: cannot send to context_id={context_id!r} " + "because caller identity could not be determined." ) ) ) if caller != owner: logger.warning( - 'Context injection attempt blocked: caller=%r tried to send to ' - 'context_id=%s owned by %r.', - caller, context_id, owner, + "Context injection attempt blocked: caller=%r tried to send to " + "context_id=%s owned by %r.", + caller, + context_id, + owner, ) raise ServerError( error=InvalidParamsError( message=( - f'Access denied: context_id={context_id!r} was created ' - 'by a different caller.' + f"Access denied: context_id={context_id!r} was created " + "by a different caller." ) ) ) @@ -296,7 +298,7 @@ def _record_context_owner( caller = self._get_caller_id(context) if caller: self._context_owners[context_id] = caller - logger.debug('Recorded owner %r for context_id=%s', caller, context_id) + logger.debug("Recorded owner %r for context_id=%s", caller, context_id) async def _setup_message_execution( self, @@ -326,14 +328,14 @@ async def _setup_message_execution( if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=InvalidParamsError( - message=f'Task {task.id} is in terminal state: {task.status.state.value}' + message=f"Task {task.id} is in terminal state: {task.status.state.value}" ) ) task = task_manager.update_with_message(params.message, task) elif params.message.task_id: raise ServerError( error=TaskNotFoundError( - message=f'Task {params.message.task_id} was specified but does not exist' + message=f"Task {params.message.task_id} was specified but does not exist" ) ) @@ -344,7 +346,7 @@ async def _setup_message_execution( task=task, context=context, ) - task_id = cast('str', request_context.task_id) + task_id = cast("str", request_context.task_id) # Record ownership for new contexts after successful validation new_context_id = request_context.context_id or context_id @@ -372,12 +374,12 @@ async def _setup_message_execution( def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None: if task_id != event_task_id: logger.error( - 'Agent generated task_id=%s does not match the RequestContext task_id=%s.', + "Agent generated task_id=%s does not match the RequestContext task_id=%s.", event_task_id, task_id, ) raise ServerError( - InternalError(message='Task ID mismatch in agent response') + InternalError(message="Task ID mismatch in agent response") ) async def _send_push_notification_if_needed( @@ -415,6 +417,7 @@ async def on_message_send( interrupted_or_non_blocking = False try: + async def push_notification_callback() -> None: await self._send_push_notification_if_needed(task_id, result_aggregator) @@ -429,11 +432,11 @@ async def push_notification_callback() -> None: ) if bg_consume_task is not None: - bg_consume_task.set_name(f'continue_consuming:{task_id}') + bg_consume_task.set_name(f"continue_consuming:{task_id}") self._track_background_task(bg_consume_task) except Exception: - logger.exception('Agent execution failed') + logger.exception("Agent execution failed") producer_task.cancel() raise finally: @@ -441,7 +444,7 @@ async def push_notification_callback() -> None: cleanup_task = asyncio.create_task( self._cleanup_producer(producer_task, task_id) ) - cleanup_task.set_name(f'cleanup_producer:{task_id}') + cleanup_task.set_name(f"cleanup_producer:{task_id}") self._track_background_task(cleanup_task) else: await self._cleanup_producer(producer_task, task_id) @@ -452,7 +455,9 @@ async def push_notification_callback() -> None: if isinstance(result, Task): self._validate_task_id_match(task_id, result.id) if params.configuration: - result = apply_history_length(result, params.configuration.history_length) + result = apply_history_length( + result, params.configuration.history_length + ) await self._send_push_notification_if_needed(task_id, result_aggregator) return result @@ -483,17 +488,21 @@ async def on_message_send_stream( self._validate_task_id_match(task_id, event.id) await self._send_push_notification_if_needed(task_id, result_aggregator) yield event - except (asyncio.CancelledError, GeneratorExit): + except asyncio.CancelledError, GeneratorExit: bg_task = asyncio.create_task(result_aggregator.consume_all(consumer)) - bg_task.set_name(f'background_consume:{task_id}') + bg_task.set_name(f"background_consume:{task_id}") self._track_background_task(bg_task) raise finally: - cleanup_task = asyncio.create_task(self._cleanup_producer(producer_task, task_id)) - cleanup_task.set_name(f'cleanup_producer:{task_id}') + cleanup_task = asyncio.create_task( + self._cleanup_producer(producer_task, task_id) + ) + cleanup_task.set_name(f"cleanup_producer:{task_id}") self._track_background_task(cleanup_task) - async def _register_producer(self, task_id: str, producer_task: asyncio.Task) -> None: + async def _register_producer( + self, task_id: str, producer_task: asyncio.Task + ) -> None: async with self._running_agents_lock: self._running_agents[task_id] = producer_task @@ -504,19 +513,21 @@ def _on_done(completed: asyncio.Task) -> None: try: completed.result() except asyncio.CancelledError: - logger.debug('Background task %s cancelled', completed.get_name()) + logger.debug("Background task %s cancelled", completed.get_name()) except Exception: - logger.exception('Background task %s failed', completed.get_name()) + logger.exception("Background task %s failed", completed.get_name()) finally: self._background_tasks.discard(completed) task.add_done_callback(_on_done) - async def _cleanup_producer(self, producer_task: asyncio.Task, task_id: str) -> None: + async def _cleanup_producer( + self, producer_task: asyncio.Task, task_id: str + ) -> None: try: await producer_task except asyncio.CancelledError: - logger.debug('Producer task %s was cancelled during cleanup', task_id) + logger.debug("Producer task %s was cancelled during cleanup", task_id) await self._queue_manager.close(task_id) async with self._running_agents_lock: self._running_agents.pop(task_id, None) @@ -531,7 +542,9 @@ async def on_set_task_push_notification_config( task: Task | None = await self.task_store.get(params.task_id, context) if not task: raise ServerError(error=TaskNotFoundError()) - await self._push_config_store.set_info(params.task_id, params.push_notification_config) + await self._push_config_store.set_info( + params.task_id, params.push_notification_config + ) return params async def on_get_task_push_notification_config( @@ -546,7 +559,9 @@ async def on_get_task_push_notification_config( raise ServerError(error=TaskNotFoundError()) push_notification_config = await self._push_config_store.get_info(params.id) if not push_notification_config or not push_notification_config[0]: - raise ServerError(error=InternalError(message='Push notification config not found')) + raise ServerError( + error=InternalError(message="Push notification config not found") + ) return TaskPushNotificationConfig( task_id=params.id, push_notification_config=push_notification_config[0], @@ -563,7 +578,7 @@ async def on_resubscribe_to_task( if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=InvalidParamsError( - message=f'Task {task.id} is in terminal state: {task.status.state.value}' + message=f"Task {task.id} is in terminal state: {task.status.state.value}" ) ) task_manager = TaskManager( @@ -591,7 +606,9 @@ async def on_list_task_push_notification_config( task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - push_notification_config_list = await self._push_config_store.get_info(params.id) + push_notification_config_list = await self._push_config_store.get_info( + params.id + ) return [ TaskPushNotificationConfig(task_id=params.id, push_notification_config=cfg) for cfg in push_notification_config_list @@ -607,4 +624,6 @@ async def on_delete_task_push_notification_config( task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - await self._push_config_store.delete_info(params.id, params.push_notification_config_id) + await self._push_config_store.delete_info( + params.id, params.push_notification_config_id + ) diff --git a/src/a2a/utils/url_validation.py b/src/a2a/utils/url_validation.py index bd059a8e2..da143d52a 100644 --- a/src/a2a/utils/url_validation.py +++ b/src/a2a/utils/url_validation.py @@ -16,31 +16,31 @@ logger = logging.getLogger(__name__) # Only these schemes are permitted in AgentCard.url values. -_ALLOWED_SCHEMES = frozenset({'http', 'https'}) +_ALLOWED_SCHEMES = frozenset({"http", "https"}) # Networks that must never be reachable via a resolved AgentCard URL. # Covers: loopback, RFC 1918 private ranges, link-local (IMDS), and other # IANA-reserved blocks that have no legitimate use as public agent endpoints. _BLOCKED_NETWORKS: tuple[ipaddress.IPv4Network | ipaddress.IPv6Network, ...] = ( # Loopback - ipaddress.ip_network('127.0.0.0/8'), - ipaddress.ip_network('::1/128'), + ipaddress.ip_network("127.0.0.0/8"), + ipaddress.ip_network("::1/128"), # RFC 1918 private ranges - ipaddress.ip_network('10.0.0.0/8'), - ipaddress.ip_network('172.16.0.0/12'), - ipaddress.ip_network('192.168.0.0/16'), + ipaddress.ip_network("10.0.0.0/8"), + ipaddress.ip_network("172.16.0.0/12"), + ipaddress.ip_network("192.168.0.0/16"), # Link-local -- covers AWS/GCP/Azure/OCI IMDS (169.254.169.254) - ipaddress.ip_network('169.254.0.0/16'), - ipaddress.ip_network('fe80::/10'), + ipaddress.ip_network("169.254.0.0/16"), + ipaddress.ip_network("fe80::/10"), # IPv6 unique local (ULA) -- equivalent of RFC 1918 for IPv6 - ipaddress.ip_network('fc00::/7'), + ipaddress.ip_network("fc00::/7"), # Shared address space (RFC 6598 -- carrier-grade NAT) - ipaddress.ip_network('100.64.0.0/10'), + ipaddress.ip_network("100.64.0.0/10"), # Other IANA reserved / unroutable - ipaddress.ip_network('0.0.0.0/8'), - ipaddress.ip_network('192.0.0.0/24'), - ipaddress.ip_network('198.18.0.0/15'), - ipaddress.ip_network('240.0.0.0/4'), + ipaddress.ip_network("0.0.0.0/8"), + ipaddress.ip_network("192.0.0.0/24"), + ipaddress.ip_network("198.18.0.0/15"), + ipaddress.ip_network("240.0.0.0/4"), ) @@ -68,25 +68,23 @@ def validate_agent_card_url(url: str) -> None: A2ASSRFValidationError: If the URL fails any validation check. """ if not url: - raise A2ASSRFValidationError('AgentCard URL must not be empty.') + raise A2ASSRFValidationError("AgentCard URL must not be empty.") parsed = urlparse(url) # 1. Scheme check - scheme = (parsed.scheme or '').lower() + scheme = (parsed.scheme or "").lower() if scheme not in _ALLOWED_SCHEMES: raise A2ASSRFValidationError( - f'AgentCard URL scheme {scheme!r} is not permitted. ' - f'Allowed schemes: {sorted(_ALLOWED_SCHEMES)}. ' - 'Arbitrary schemes allow SSRF attacks (CWE-918).' + f"AgentCard URL scheme {scheme!r} is not permitted. " + f"Allowed schemes: {sorted(_ALLOWED_SCHEMES)}. " + "Arbitrary schemes allow SSRF attacks (CWE-918)." ) # 2. Hostname presence hostname = parsed.hostname if not hostname: - raise A2ASSRFValidationError( - f'AgentCard URL {url!r} contains no hostname.' - ) + raise A2ASSRFValidationError(f"AgentCard URL {url!r} contains no hostname.") # 3. Resolve hostname and check against blocked networks try: @@ -94,8 +92,8 @@ def validate_agent_card_url(url: str) -> None: addr_infos = socket.getaddrinfo(hostname, None) except socket.gaierror as exc: raise A2ASSRFValidationError( - f'AgentCard URL hostname {hostname!r} could not be resolved: {exc}. ' - 'Unresolvable hostnames may indicate DNS rebinding attempts.' + f"AgentCard URL hostname {hostname!r} could not be resolved: {exc}. " + "Unresolvable hostnames may indicate DNS rebinding attempts." ) from exc for _family, _type, _proto, _canonname, sockaddr in addr_infos: @@ -108,10 +106,10 @@ def validate_agent_card_url(url: str) -> None: for blocked in _BLOCKED_NETWORKS: if ip in blocked: raise A2ASSRFValidationError( - f'AgentCard URL {url!r} resolves to {ip_str}, ' - f'which is within the blocked network {blocked}. ' - 'Requests to private/loopback/link-local addresses are ' - 'forbidden to prevent SSRF attacks (CWE-918).' + f"AgentCard URL {url!r} resolves to {ip_str}, " + f"which is within the blocked network {blocked}. " + "Requests to private/loopback/link-local addresses are " + "forbidden to prevent SSRF attacks (CWE-918)." ) - logger.debug('AgentCard URL passed SSRF validation: %s', url) + logger.debug("AgentCard URL passed SSRF validation: %s", url) diff --git a/tests/utils/test_url_validation.py b/tests/utils/test_url_validation.py index 3178420a3..9ed540ba6 100644 --- a/tests/utils/test_url_validation.py +++ b/tests/utils/test_url_validation.py @@ -30,12 +30,12 @@ def test_non_http_schemes_are_blocked(self, url): 'HTTPS://EXAMPLE.COM/RPC', ]) def test_http_and_https_are_allowed(self, url): - # Should not raise — only scheme + hostname check, DNS may vary + # Should not raise - only scheme + hostname check, DNS may vary # We only verify scheme acceptance here; real DNS tested separately. try: validate_agent_card_url(url) except A2ASSRFValidationError as exc: - # Accept DNS resolution failure — scheme was accepted + # Accept DNS resolution failure - scheme was accepted assert 'could not be resolved' in str(exc) or 'blocked network' in str(exc) @@ -61,7 +61,7 @@ def test_private_addresses_are_blocked(self, url, label): def test_public_ip_is_allowed(self): """A routable public IP should not be blocked.""" - # 93.184.216.34 is example.com — guaranteed public + # 93.184.216.34 is example.com - guaranteed public try: validate_agent_card_url('http://93.184.216.34/rpc') except A2ASSRFValidationError as exc: From 8a434d8ab0959684266bd0697b6d0344c68a2b67 Mon Sep 17 00:00:00 2001 From: Sudu <12902959+amit-raut@users.noreply.github.com> Date: Fri, 27 Mar 2026 22:54:58 -0400 Subject: [PATCH 6/8] fix(ci): correct ruff lint errors and Python 3.10 syntax in security patch - Remove module-level research docstrings from card_resolver.py and default_request_handler.py (originals have none; D415 lint failure) - Add missing docstrings to five push-notification handler methods that existed in the original SDK (D102 lint failure) - Fix import ordering in card_resolver.py and url_validation.py (I001) - Convert double-quoted string literals to single quotes throughout both patched source files (Q000; project uses quote-style = single) - Restore parentheses in except clause: except (asyncio.CancelledError, GeneratorExit): Bare tuple syntax was added in Python 3.14; CI targets Python 3.10 where it raises SyntaxError (ruff had silently removed the parens) --- src/a2a/client/card_resolver.py | 48 ++---- .../default_request_handler.py | 139 ++++++++++-------- src/a2a/utils/url_validation.py | 61 ++++---- tests/client/test_card_resolver.py | 70 +++++++-- tests/utils/test_url_validation.py | 78 ++++++---- 5 files changed, 235 insertions(+), 161 deletions(-) diff --git a/src/a2a/client/card_resolver.py b/src/a2a/client/card_resolver.py index e4cc14471..6906bf2b1 100644 --- a/src/a2a/client/card_resolver.py +++ b/src/a2a/client/card_resolver.py @@ -1,16 +1,3 @@ -"""Patched version of a2a/client/card_resolver.py - -Fix for A2A-SSRF-01: validate AgentCard.url before returning the card. - -Diff summary vs. original (v0.3.25): - + import A2ASSRFValidationError, validate_agent_card_url from a2a.utils.url_validation - + call validate_agent_card_url(agent_card.url) after model_validate() - + wrap in try/except to raise A2AClientJSONError with a clear SSRF message - + validate additional_interfaces[*].url as well (same attack surface) - -Target file: src/a2a/client/card_resolver.py -""" - import json import logging @@ -19,8 +6,6 @@ import httpx -from pydantic import ValidationError - from a2a.client.errors import ( A2AClientHTTPError, A2AClientJSONError, @@ -29,11 +14,11 @@ AgentCard, ) from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH - -# ---- NEW IMPORT (fix for A2A-SSRF-01) ---- -from a2a.utils.url_validation import A2ASSRFValidationError, validate_agent_card_url - -# ------------------------------------------- +from a2a.utils.url_validation import ( + A2ASSRFValidationError, + validate_agent_card_url, +) +from pydantic import ValidationError logger = logging.getLogger(__name__) @@ -55,8 +40,8 @@ def __init__( base_url: The base URL of the agent's host. agent_card_path: The path to the agent card endpoint, relative to the base URL. """ - self.base_url = base_url.rstrip("/") - self.agent_card_path = agent_card_path.lstrip("/") + self.base_url = base_url.rstrip('/') + self.agent_card_path = agent_card_path.lstrip('/') self.httpx_client = httpx_client async def get_agent_card( @@ -89,9 +74,9 @@ async def get_agent_card( if not relative_card_path: path_segment = self.agent_card_path else: - path_segment = relative_card_path.lstrip("/") + path_segment = relative_card_path.lstrip('/') - target_url = f"{self.base_url}/{path_segment}" + target_url = f'{self.base_url}/{path_segment}' try: response = await self.httpx_client.get( @@ -101,13 +86,13 @@ async def get_agent_card( response.raise_for_status() agent_card_data = response.json() logger.info( - "Successfully fetched agent card data from %s: %s", + 'Successfully fetched agent card data from %s: %s', target_url, agent_card_data, ) agent_card = AgentCard.model_validate(agent_card_data) - # ---- FIX: A2A-SSRF-01 -- validate card.url before returning ---- + # Validate card.url before returning (fix for A2A-SSRF-01). # Without this check, any caller who controls the card endpoint # can redirect all subsequent RPC calls to an internal address. try: @@ -117,9 +102,8 @@ async def get_agent_card( validate_agent_card_url(iface.url) except A2ASSRFValidationError as e: raise A2AClientJSONError( - f"AgentCard from {target_url} failed SSRF URL validation: {e}" + f'AgentCard from {target_url} failed SSRF URL validation: {e}' ) from e - # ----------------------------------------------------------------- if signature_verifier: signature_verifier(agent_card) @@ -127,20 +111,20 @@ async def get_agent_card( except httpx.HTTPStatusError as e: raise A2AClientHTTPError( e.response.status_code, - f"Failed to fetch agent card from {target_url}: {e}", + f'Failed to fetch agent card from {target_url}: {e}', ) from e except json.JSONDecodeError as e: raise A2AClientJSONError( - f"Failed to parse JSON for agent card from {target_url}: {e}" + f'Failed to parse JSON for agent card from {target_url}: {e}' ) from e except httpx.RequestError as e: raise A2AClientHTTPError( 503, - f"Network communication error fetching agent card from {target_url}: {e}", + f'Network communication error fetching agent card from {target_url}: {e}', ) from e except ValidationError as e: raise A2AClientJSONError( - f"Failed to validate agent card structure from {target_url}: {e.json()}" + f'Failed to validate agent card structure from {target_url}: {e.json()}' ) from e return agent_card diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index acac28873..58fe58fca 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -1,24 +1,3 @@ -"""Patched version of a2a/server/request_handlers/default_request_handler.py - -Fix for A2A-INJ-01: context-level ownership tracking prevents unauthorized -callers from injecting messages into another user's context. - -Root cause of vulnerability: - _setup_message_execution() uses params.message.context_id directly without - any ownership check. An attacker who knows a victim's contextId can send a - new task under that context -- task_manager.get_task() returns None for the - new task_id, so the original task-level check is never reached. - -Fix design: - DefaultRequestHandler maintains a _context_owners dict (context_id -> owner) - in memory. When a get_caller_id extractor is configured: - 1. On first message for a context_id: record caller as owner. - 2. On subsequent messages for same context_id: verify caller matches owner. - If get_caller_id is None (default): no ownership tracking -- backward compatible. - -Target file: src/a2a/server/request_handlers/default_request_handler.py -""" - import asyncio import logging @@ -81,7 +60,7 @@ # ---- NEW: caller identity extractor type (fix for A2A-INJ-01) ---- # CallerIdExtractor extracts a stable identity string from ServerCallContext. # Returns None if caller identity cannot be determined (unauthenticated). -CallerIdExtractor = Callable[["ServerCallContext | None"], str | None] +CallerIdExtractor = Callable[['ServerCallContext | None'], str | None] # ------------------------------------------------------------------ @@ -123,11 +102,14 @@ def __init__( # noqa: PLR0913 Example:: - def get_caller_id(ctx: ServerCallContext | None) -> str | None: + def get_caller_id( + ctx: ServerCallContext | None, + ) -> str | None: if ctx is None or not ctx.user.is_authenticated: return None return ctx.user.user_name + handler = DefaultRequestHandler( agent_executor=executor, task_store=task_store, @@ -151,10 +133,10 @@ def get_caller_id(ctx: ServerCallContext | None) -> str | None: self._context_owners: dict[str, str] = {} if get_caller_id is None: logger.warning( - "DefaultRequestHandler initialized without get_caller_id: " - "context ownership is not enforced. Cross-user context injection " - "(A2A-INJ-01 / CWE-639) is possible. Provide a get_caller_id " - "extractor to enable ownership checks." + 'DefaultRequestHandler initialized without get_caller_id: ' + 'context ownership is not enforced. Cross-user context injection ' + '(A2A-INJ-01 / CWE-639) is possible. Provide a get_caller_id ' + 'extractor to enable ownership checks.' ) # ---------------------------------- self._running_agents = {} @@ -186,7 +168,7 @@ async def on_cancel_task( if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=TaskNotCancelableError( - message=f"Task cannot be canceled - current state: {task.status.state}" + message=f'Task cannot be canceled - current state: {task.status.state}' ) ) @@ -219,14 +201,14 @@ async def on_cancel_task( if not isinstance(result, Task): raise ServerError( error=InternalError( - message="Agent did not return valid response for cancel" + message='Agent did not return valid response for cancel' ) ) if result.status.state != TaskState.canceled: raise ServerError( error=TaskNotCancelableError( - message=f"Task cannot be canceled - current state: {result.status.state}" + message=f'Task cannot be canceled - current state: {result.status.state}' ) ) @@ -264,16 +246,16 @@ def _check_context_ownership( raise ServerError( error=InvalidParamsError( message=( - f"Access denied: cannot send to context_id={context_id!r} " - "because caller identity could not be determined." + f'Access denied: cannot send to context_id={context_id!r} ' + 'because caller identity could not be determined.' ) ) ) if caller != owner: logger.warning( - "Context injection attempt blocked: caller=%r tried to send to " - "context_id=%s owned by %r.", + 'Context injection attempt blocked: caller=%r tried to send to ' + 'context_id=%s owned by %r.', caller, context_id, owner, @@ -281,8 +263,8 @@ def _check_context_ownership( raise ServerError( error=InvalidParamsError( message=( - f"Access denied: context_id={context_id!r} was created " - "by a different caller." + f'Access denied: context_id={context_id!r} was created ' + 'by a different caller.' ) ) ) @@ -298,7 +280,9 @@ def _record_context_owner( caller = self._get_caller_id(context) if caller: self._context_owners[context_id] = caller - logger.debug("Recorded owner %r for context_id=%s", caller, context_id) + logger.debug( + 'Recorded owner %r for context_id=%s', caller, context_id + ) async def _setup_message_execution( self, @@ -328,14 +312,14 @@ async def _setup_message_execution( if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=InvalidParamsError( - message=f"Task {task.id} is in terminal state: {task.status.state.value}" + message=f'Task {task.id} is in terminal state: {task.status.state.value}' ) ) task = task_manager.update_with_message(params.message, task) elif params.message.task_id: raise ServerError( error=TaskNotFoundError( - message=f"Task {params.message.task_id} was specified but does not exist" + message=f'Task {params.message.task_id} was specified but does not exist' ) ) @@ -346,7 +330,7 @@ async def _setup_message_execution( task=task, context=context, ) - task_id = cast("str", request_context.task_id) + task_id = cast('str', request_context.task_id) # Record ownership for new contexts after successful validation new_context_id = request_context.context_id or context_id @@ -374,12 +358,12 @@ async def _setup_message_execution( def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None: if task_id != event_task_id: logger.error( - "Agent generated task_id=%s does not match the RequestContext task_id=%s.", + 'Agent generated task_id=%s does not match the RequestContext task_id=%s.', event_task_id, task_id, ) raise ServerError( - InternalError(message="Task ID mismatch in agent response") + InternalError(message='Task ID mismatch in agent response') ) async def _send_push_notification_if_needed( @@ -419,7 +403,9 @@ async def on_message_send( try: async def push_notification_callback() -> None: - await self._send_push_notification_if_needed(task_id, result_aggregator) + await self._send_push_notification_if_needed( + task_id, result_aggregator + ) ( result, @@ -432,11 +418,11 @@ async def push_notification_callback() -> None: ) if bg_consume_task is not None: - bg_consume_task.set_name(f"continue_consuming:{task_id}") + bg_consume_task.set_name(f'continue_consuming:{task_id}') self._track_background_task(bg_consume_task) except Exception: - logger.exception("Agent execution failed") + logger.exception('Agent execution failed') producer_task.cancel() raise finally: @@ -444,7 +430,7 @@ async def push_notification_callback() -> None: cleanup_task = asyncio.create_task( self._cleanup_producer(producer_task, task_id) ) - cleanup_task.set_name(f"cleanup_producer:{task_id}") + cleanup_task.set_name(f'cleanup_producer:{task_id}') self._track_background_task(cleanup_task) else: await self._cleanup_producer(producer_task, task_id) @@ -486,18 +472,22 @@ async def on_message_send_stream( async for event in result_aggregator.consume_and_emit(consumer): if isinstance(event, Task): self._validate_task_id_match(task_id, event.id) - await self._send_push_notification_if_needed(task_id, result_aggregator) + await self._send_push_notification_if_needed( + task_id, result_aggregator + ) yield event - except asyncio.CancelledError, GeneratorExit: - bg_task = asyncio.create_task(result_aggregator.consume_all(consumer)) - bg_task.set_name(f"background_consume:{task_id}") + except (asyncio.CancelledError, GeneratorExit): + bg_task = asyncio.create_task( + result_aggregator.consume_all(consumer) + ) + bg_task.set_name(f'background_consume:{task_id}') self._track_background_task(bg_task) raise finally: cleanup_task = asyncio.create_task( self._cleanup_producer(producer_task, task_id) ) - cleanup_task.set_name(f"cleanup_producer:{task_id}") + cleanup_task.set_name(f'cleanup_producer:{task_id}') self._track_background_task(cleanup_task) async def _register_producer( @@ -513,9 +503,13 @@ def _on_done(completed: asyncio.Task) -> None: try: completed.result() except asyncio.CancelledError: - logger.debug("Background task %s cancelled", completed.get_name()) + logger.debug( + 'Background task %s cancelled', completed.get_name() + ) except Exception: - logger.exception("Background task %s failed", completed.get_name()) + logger.exception( + 'Background task %s failed', completed.get_name() + ) finally: self._background_tasks.discard(completed) @@ -527,7 +521,9 @@ async def _cleanup_producer( try: await producer_task except asyncio.CancelledError: - logger.debug("Producer task %s was cancelled during cleanup", task_id) + logger.debug( + 'Producer task %s was cancelled during cleanup', task_id + ) await self._queue_manager.close(task_id) async with self._running_agents_lock: self._running_agents.pop(task_id, None) @@ -537,6 +533,10 @@ async def on_set_task_push_notification_config( params: TaskPushNotificationConfig, context: ServerCallContext | None = None, ) -> TaskPushNotificationConfig: + """Default handler for 'tasks/pushNotificationConfig/set'. + + Requires a `PushNotifier` to be configured. + """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) task: Task | None = await self.task_store.get(params.task_id, context) @@ -552,15 +552,23 @@ async def on_get_task_push_notification_config( params: TaskIdParams | GetTaskPushNotificationConfigParams, context: ServerCallContext | None = None, ) -> TaskPushNotificationConfig: + """Default handler for 'tasks/pushNotificationConfig/get'. + + Requires a `PushConfigStore` to be configured. + """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - push_notification_config = await self._push_config_store.get_info(params.id) + push_notification_config = await self._push_config_store.get_info( + params.id + ) if not push_notification_config or not push_notification_config[0]: raise ServerError( - error=InternalError(message="Push notification config not found") + error=InternalError( + message='Push notification config not found' + ) ) return TaskPushNotificationConfig( task_id=params.id, @@ -572,13 +580,18 @@ async def on_resubscribe_to_task( params: TaskIdParams, context: ServerCallContext | None = None, ) -> AsyncGenerator[Event]: + """Default handler for 'tasks/resubscribe'. + + Allows a client to re-attach to a running streaming task's event stream. + Requires the task and its queue to still be active. + """ task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=InvalidParamsError( - message=f"Task {task.id} is in terminal state: {task.status.state.value}" + message=f'Task {task.id} is in terminal state: {task.status.state.value}' ) ) task_manager = TaskManager( @@ -601,6 +614,10 @@ async def on_list_task_push_notification_config( params: ListTaskPushNotificationConfigParams, context: ServerCallContext | None = None, ) -> list[TaskPushNotificationConfig]: + """Default handler for 'tasks/pushNotificationConfig/list'. + + Requires a `PushConfigStore` to be configured. + """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) task: Task | None = await self.task_store.get(params.id, context) @@ -610,7 +627,9 @@ async def on_list_task_push_notification_config( params.id ) return [ - TaskPushNotificationConfig(task_id=params.id, push_notification_config=cfg) + TaskPushNotificationConfig( + task_id=params.id, push_notification_config=cfg + ) for cfg in push_notification_config_list ] @@ -619,6 +638,10 @@ async def on_delete_task_push_notification_config( params: DeleteTaskPushNotificationConfigParams, context: ServerCallContext | None = None, ) -> None: + """Default handler for 'tasks/pushNotificationConfig/delete'. + + Requires a `PushConfigStore` to be configured. + """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) task: Task | None = await self.task_store.get(params.id, context) diff --git a/src/a2a/utils/url_validation.py b/src/a2a/utils/url_validation.py index da143d52a..094c19f94 100644 --- a/src/a2a/utils/url_validation.py +++ b/src/a2a/utils/url_validation.py @@ -3,44 +3,43 @@ Prevents Server-Side Request Forgery (SSRF) attacks by validating that AgentCard.url values do not point to private, loopback, or link-local network addresses before the SDK uses them as RPC endpoints. - -Fix for: A2A-SSRF-01 (CWE-918) -Target: src/a2a/utils/url_validation.py (new file) """ import ipaddress import logging import socket + from urllib.parse import urlparse + logger = logging.getLogger(__name__) # Only these schemes are permitted in AgentCard.url values. -_ALLOWED_SCHEMES = frozenset({"http", "https"}) +_ALLOWED_SCHEMES = frozenset({'http', 'https'}) # Networks that must never be reachable via a resolved AgentCard URL. # Covers: loopback, RFC 1918 private ranges, link-local (IMDS), and other # IANA-reserved blocks that have no legitimate use as public agent endpoints. _BLOCKED_NETWORKS: tuple[ipaddress.IPv4Network | ipaddress.IPv6Network, ...] = ( # Loopback - ipaddress.ip_network("127.0.0.0/8"), - ipaddress.ip_network("::1/128"), + ipaddress.ip_network('127.0.0.0/8'), + ipaddress.ip_network('::1/128'), # RFC 1918 private ranges - ipaddress.ip_network("10.0.0.0/8"), - ipaddress.ip_network("172.16.0.0/12"), - ipaddress.ip_network("192.168.0.0/16"), + ipaddress.ip_network('10.0.0.0/8'), + ipaddress.ip_network('172.16.0.0/12'), + ipaddress.ip_network('192.168.0.0/16'), # Link-local -- covers AWS/GCP/Azure/OCI IMDS (169.254.169.254) - ipaddress.ip_network("169.254.0.0/16"), - ipaddress.ip_network("fe80::/10"), + ipaddress.ip_network('169.254.0.0/16'), + ipaddress.ip_network('fe80::/10'), # IPv6 unique local (ULA) -- equivalent of RFC 1918 for IPv6 - ipaddress.ip_network("fc00::/7"), + ipaddress.ip_network('fc00::/7'), # Shared address space (RFC 6598 -- carrier-grade NAT) - ipaddress.ip_network("100.64.0.0/10"), + ipaddress.ip_network('100.64.0.0/10'), # Other IANA reserved / unroutable - ipaddress.ip_network("0.0.0.0/8"), - ipaddress.ip_network("192.0.0.0/24"), - ipaddress.ip_network("198.18.0.0/15"), - ipaddress.ip_network("240.0.0.0/4"), + ipaddress.ip_network('0.0.0.0/8'), + ipaddress.ip_network('192.0.0.0/24'), + ipaddress.ip_network('198.18.0.0/15'), + ipaddress.ip_network('240.0.0.0/4'), ) @@ -68,23 +67,25 @@ def validate_agent_card_url(url: str) -> None: A2ASSRFValidationError: If the URL fails any validation check. """ if not url: - raise A2ASSRFValidationError("AgentCard URL must not be empty.") + raise A2ASSRFValidationError('AgentCard URL must not be empty.') parsed = urlparse(url) # 1. Scheme check - scheme = (parsed.scheme or "").lower() + scheme = (parsed.scheme or '').lower() if scheme not in _ALLOWED_SCHEMES: raise A2ASSRFValidationError( - f"AgentCard URL scheme {scheme!r} is not permitted. " - f"Allowed schemes: {sorted(_ALLOWED_SCHEMES)}. " - "Arbitrary schemes allow SSRF attacks (CWE-918)." + f'AgentCard URL scheme {scheme!r} is not permitted. ' + f'Allowed schemes: {sorted(_ALLOWED_SCHEMES)}. ' + 'Arbitrary schemes allow SSRF attacks (CWE-918).' ) # 2. Hostname presence hostname = parsed.hostname if not hostname: - raise A2ASSRFValidationError(f"AgentCard URL {url!r} contains no hostname.") + raise A2ASSRFValidationError( + f'AgentCard URL {url!r} contains no hostname.' + ) # 3. Resolve hostname and check against blocked networks try: @@ -92,8 +93,8 @@ def validate_agent_card_url(url: str) -> None: addr_infos = socket.getaddrinfo(hostname, None) except socket.gaierror as exc: raise A2ASSRFValidationError( - f"AgentCard URL hostname {hostname!r} could not be resolved: {exc}. " - "Unresolvable hostnames may indicate DNS rebinding attempts." + f'AgentCard URL hostname {hostname!r} could not be resolved: {exc}. ' + 'Unresolvable hostnames may indicate DNS rebinding attempts.' ) from exc for _family, _type, _proto, _canonname, sockaddr in addr_infos: @@ -106,10 +107,10 @@ def validate_agent_card_url(url: str) -> None: for blocked in _BLOCKED_NETWORKS: if ip in blocked: raise A2ASSRFValidationError( - f"AgentCard URL {url!r} resolves to {ip_str}, " - f"which is within the blocked network {blocked}. " - "Requests to private/loopback/link-local addresses are " - "forbidden to prevent SSRF attacks (CWE-918)." + f'AgentCard URL {url!r} resolves to {ip_str}, ' + f'which is within the blocked network {blocked}. ' + 'Requests to private/loopback/link-local addresses are ' + 'forbidden to prevent SSRF attacks (CWE-918).' ) - logger.debug("AgentCard URL passed SSRF validation: %s", url) + logger.debug('AgentCard URL passed SSRF validation: %s', url) diff --git a/tests/client/test_card_resolver.py b/tests/client/test_card_resolver.py index 8ddc70495..13e2058f1 100644 --- a/tests/client/test_card_resolver.py +++ b/tests/client/test_card_resolver.py @@ -116,7 +116,13 @@ async def test_get_agent_card_success_default_path( mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ) as mock_validate: result = await resolver.get_agent_card() mock_httpx_client.get.assert_called_once_with( @@ -141,7 +147,13 @@ async def test_get_agent_card_success_custom_path( mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(relative_card_path=custom_path) @@ -163,7 +175,13 @@ async def test_get_agent_card_strips_leading_slash_from_relative_path( mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(relative_card_path=custom_path) @@ -188,7 +206,13 @@ async def test_get_agent_card_with_http_kwargs( 'headers': {'Authorization': 'Bearer token'}, } with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(http_kwargs=http_kwargs) mock_httpx_client.get.assert_called_once_with( @@ -210,7 +234,13 @@ async def test_get_agent_card_root_path( mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(relative_card_path='/') mock_httpx_client.get.assert_called_once_with(f'{base_url}/') @@ -297,7 +327,13 @@ async def test_get_agent_card_logs_success( # noqa: PLR0913 mock_httpx_client.get.return_value = mock_response with ( patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ), caplog.at_level(logging.INFO), ): @@ -321,7 +357,13 @@ async def test_get_agent_card_none_relative_path( mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(relative_card_path=None) mock_httpx_client.get.assert_called_once_with( @@ -342,7 +384,13 @@ async def test_get_agent_card_empty_string_relative_path( mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(relative_card_path='') @@ -373,7 +421,11 @@ async def test_get_agent_card_returns_agent_card_instance( """Test that get_agent_card returns an AgentCard instance.""" mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response - mock_agent_card = Mock(spec=AgentCard, url='https://example.com/a2a', additional_interfaces=None) + mock_agent_card = Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ) with patch.object( AgentCard, 'model_validate', return_value=mock_agent_card diff --git a/tests/utils/test_url_validation.py b/tests/utils/test_url_validation.py index 9ed540ba6..f0469500a 100644 --- a/tests/utils/test_url_validation.py +++ b/tests/utils/test_url_validation.py @@ -5,63 +5,77 @@ import pytest -from a2a.utils.url_validation import A2ASSRFValidationError, validate_agent_card_url +from a2a.utils.url_validation import ( + A2ASSRFValidationError, + validate_agent_card_url, +) class TestValidateAgentCardUrlScheme: """URL scheme validation.""" - @pytest.mark.parametrize('url', [ - 'file:///etc/passwd', - 'gopher://internal/1', - 'ftp://files.example.com/secret', - 'dict://internal/', - 'ldap://ldap.example.com/', - '', - ]) + @pytest.mark.parametrize( + 'url', + [ + 'file:///etc/passwd', + 'gopher://internal/1', + 'ftp://files.example.com/secret', + 'dict://internal/', + 'ldap://ldap.example.com/', + '', + ], + ) def test_non_http_schemes_are_blocked(self, url): with pytest.raises(A2ASSRFValidationError): validate_agent_card_url(url) - @pytest.mark.parametrize('url', [ - 'http://example.com/rpc', - 'https://example.com/rpc', - 'HTTP://EXAMPLE.COM/RPC', - 'HTTPS://EXAMPLE.COM/RPC', - ]) + @pytest.mark.parametrize( + 'url', + [ + 'http://example.com/rpc', + 'https://example.com/rpc', + 'HTTP://EXAMPLE.COM/RPC', + 'HTTPS://EXAMPLE.COM/RPC', + ], + ) def test_http_and_https_are_allowed(self, url): - # Should not raise - only scheme + hostname check, DNS may vary + # Should not raise — only scheme + hostname check, DNS may vary # We only verify scheme acceptance here; real DNS tested separately. try: validate_agent_card_url(url) except A2ASSRFValidationError as exc: - # Accept DNS resolution failure - scheme was accepted - assert 'could not be resolved' in str(exc) or 'blocked network' in str(exc) + # Accept DNS resolution failure — scheme was accepted + assert 'could not be resolved' in str( + exc + ) or 'blocked network' in str(exc) class TestValidateAgentCardUrlPrivateIPs: """Private / reserved IP range blocking.""" - @pytest.mark.parametrize('url,label', [ - ('http://127.0.0.1/rpc', 'loopback IPv4'), - ('http://127.1.2.3/rpc', 'loopback IPv4 (non-zero host)'), - ('http://[::1]/rpc', 'loopback IPv6'), - ('http://10.0.0.1/rpc', 'RFC 1918 10/8'), - ('http://10.255.255.255/rpc', 'RFC 1918 10/8 broadcast'), - ('http://172.16.0.1/rpc', 'RFC 1918 172.16/12'), - ('http://172.31.255.255/rpc', 'RFC 1918 172.31 (last in range)'), - ('http://192.168.1.1/rpc', 'RFC 1918 192.168/16'), - ('http://169.254.169.254/latest/meta-data/', 'AWS IMDS'), - ('http://169.254.0.1/rpc', 'link-local'), - ('http://100.64.0.1/rpc', 'shared address space RFC 6598'), - ]) + @pytest.mark.parametrize( + 'url,label', + [ + ('http://127.0.0.1/rpc', 'loopback IPv4'), + ('http://127.1.2.3/rpc', 'loopback IPv4 (non-zero host)'), + ('http://[::1]/rpc', 'loopback IPv6'), + ('http://10.0.0.1/rpc', 'RFC 1918 10/8'), + ('http://10.255.255.255/rpc', 'RFC 1918 10/8 broadcast'), + ('http://172.16.0.1/rpc', 'RFC 1918 172.16/12'), + ('http://172.31.255.255/rpc', 'RFC 1918 172.31 (last in range)'), + ('http://192.168.1.1/rpc', 'RFC 1918 192.168/16'), + ('http://169.254.169.254/latest/meta-data/', 'AWS IMDS'), + ('http://169.254.0.1/rpc', 'link-local'), + ('http://100.64.0.1/rpc', 'shared address space RFC 6598'), + ], + ) def test_private_addresses_are_blocked(self, url, label): with pytest.raises(A2ASSRFValidationError, match='blocked network'): validate_agent_card_url(url) def test_public_ip_is_allowed(self): """A routable public IP should not be blocked.""" - # 93.184.216.34 is example.com - guaranteed public + # 93.184.216.34 is example.com — guaranteed public try: validate_agent_card_url('http://93.184.216.34/rpc') except A2ASSRFValidationError as exc: From 681a49f70be080a808bc5aea05b44ccfd218cdad Mon Sep 17 00:00:00 2001 From: Sudu <12902959+amit-raut@users.noreply.github.com> Date: Fri, 27 Mar 2026 23:05:36 -0400 Subject: [PATCH 7/8] fix(lint): fix import order in card_resolver.py (I001) Move pydantic import to correct position per project isort config --- src/a2a/client/card_resolver.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/a2a/client/card_resolver.py b/src/a2a/client/card_resolver.py index 6906bf2b1..21864d22f 100644 --- a/src/a2a/client/card_resolver.py +++ b/src/a2a/client/card_resolver.py @@ -6,6 +6,8 @@ import httpx +from pydantic import ValidationError + from a2a.client.errors import ( A2AClientHTTPError, A2AClientJSONError, @@ -18,7 +20,6 @@ A2ASSRFValidationError, validate_agent_card_url, ) -from pydantic import ValidationError logger = logging.getLogger(__name__) From da41f5fdcdbe389a5d180120b79554ce75962c1f Mon Sep 17 00:00:00 2001 From: Sudu <12902959+amit-raut@users.noreply.github.com> Date: Fri, 27 Mar 2026 23:15:24 -0400 Subject: [PATCH 8/8] fix(lint): format tests/conftest.py (ruff format) Fix import ordring and covert double quoted to single quotes --- tests/conftest.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 600867a08..63a978139 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import patch +import pytest + @pytest.fixture(autouse=True) def bypass_ssrf_url_validation(request): @@ -10,8 +11,8 @@ def bypass_ssrf_url_validation(request): that either resolve to loopback or are unavailable in CI. The actual SSRF validation logic is tested in tests/utils/test_url_validation.py. """ - if "test_url_validation" in request.node.nodeid: + if 'test_url_validation' in request.node.nodeid: yield else: - with patch("a2a.client.card_resolver.validate_agent_card_url"): + with patch('a2a.client.card_resolver.validate_agent_card_url'): yield