From f8e3af0e6847d9fc566362e67be3df79de76795f Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Wed, 28 May 2025 22:36:24 +0000 Subject: [PATCH 1/2] fix: Correctly adapt starlette BaseUser to A2A User --- src/a2a/server/apps/starlette_app.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/a2a/server/apps/starlette_app.py b/src/a2a/server/apps/starlette_app.py index 84ef75774..fb7a99081 100644 --- a/src/a2a/server/apps/starlette_app.py +++ b/src/a2a/server/apps/starlette_app.py @@ -45,8 +45,20 @@ logger = logging.getLogger(__name__) -# Register Starlette User as an implementation of a2a.auth.user.User -A2AUser.register(BaseUser) + +class StarletteUserProxy(A2AUser): + """Adapts the Starlette User class to the A2A user representation.""" + + def __init__(self, user: BaseUser): + self._user = user + + @property + def is_authenticated(self): + return self._user.is_authenticated + + @property + def user_name(self): + return self._user.display_name class CallContextBuilder(ABC): @@ -64,7 +76,7 @@ def build(self, request: Request) -> ServerCallContext: user = UnauthenticatedUser() state = {} with contextlib.suppress(Exception): - user = request.user + user = StarletteUserProxy(request.user) state['auth'] = request.auth return ServerCallContext(user=user, state=state) From 2844180ceff7901302e2176d0c0778bdc417bf6d Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Mon, 2 Jun 2025 16:24:44 +0000 Subject: [PATCH 2/2] Add integration test for server auth --- tests/server/test_integration.py | 116 +++++++++++++++++++++++++++---- 1 file changed, 101 insertions(+), 15 deletions(-) diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index c0a54e94b..0f69fca6f 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -3,17 +3,42 @@ from unittest import mock import pytest +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + BaseUser, + SimpleUser, +) +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import HTTPConnection from starlette.responses import JSONResponse from starlette.routing import Route from starlette.testclient import TestClient from a2a.server.apps.starlette_app import A2AStarletteApplication -from a2a.types import (AgentCapabilities, AgentCard, Artifact, DataPart, - InternalError, InvalidRequestError, JSONParseError, - Part, PushNotificationConfig, Task, - TaskArtifactUpdateEvent, TaskPushNotificationConfig, - TaskState, TaskStatus, TextPart, - UnsupportedOperationError) +from a2a.types import ( + AgentCapabilities, + AgentCard, + Artifact, + DataPart, + InternalError, + InvalidRequestError, + JSONParseError, + Message, + Part, + PushNotificationConfig, + Role, + SendMessageResponse, + SendMessageSuccessResponse, + Task, + TaskArtifactUpdateEvent, + TaskPushNotificationConfig, + TaskState, + TaskStatus, + TextPart, + UnsupportedOperationError, +) from a2a.utils.errors import MethodNotImplementedError # === TEST SETUP === @@ -106,9 +131,9 @@ def app(agent_card: AgentCard, handler: mock.AsyncMock): @pytest.fixture -def client(app: A2AStarletteApplication): +def client(app: A2AStarletteApplication, **kwargs): """Create a test client with the app.""" - return TestClient(app.build()) + return TestClient(app.build(**kwargs)) # === BASIC FUNCTIONALITY TESTS === @@ -135,7 +160,7 @@ def test_authenticated_extended_agent_card_endpoint_not_supported( # So, building the app and trying to hit it should result in 404 from Starlette itself client = TestClient(app_instance.build()) response = client.get('/agent/authenticatedExtendedCard') - assert response.status_code == 404 # Starlette's default for no route + assert response.status_code == 404 # Starlette's default for no route def test_authenticated_extended_agent_card_endpoint_supported_with_specific_extended_card( @@ -144,7 +169,9 @@ def test_authenticated_extended_agent_card_endpoint_supported_with_specific_exte handler: mock.AsyncMock, ): """Test extended card endpoint returns the specific extended card when provided.""" - agent_card.supportsAuthenticatedExtendedCard = True # Main card must support it + agent_card.supportsAuthenticatedExtendedCard = ( + True # Main card must support it + ) app_instance = A2AStarletteApplication( agent_card, handler, extended_agent_card=extended_agent_card_fixture ) @@ -157,10 +184,9 @@ def test_authenticated_extended_agent_card_endpoint_supported_with_specific_exte assert data['name'] == extended_agent_card_fixture.name assert data['version'] == extended_agent_card_fixture.version assert len(data['skills']) == len(extended_agent_card_fixture.skills) - assert any( - skill['id'] == 'skill-extended' for skill in data['skills'] - ), "Extended skill not found in served card" - + assert any(skill['id'] == 'skill-extended' for skill in data['skills']), ( + 'Extended skill not found in served card' + ) def test_agent_card_custom_url( @@ -233,7 +259,6 @@ def test_send_message(client: TestClient, handler: mock.AsyncMock): mock_task = Task( id='task1', contextId='session-xyz', - state='completed', status=task_status, ) handler.on_message_send.return_value = mock_task @@ -402,6 +427,67 @@ def test_get_push_notification_config( handler.on_get_task_push_notification_config.assert_awaited_once() +def test_server_auth(app: A2AStarletteApplication, handler: mock.AsyncMock): + class TestAuthMiddleware(AuthenticationBackend): + async def authenticate( + self, conn: HTTPConnection + ) -> tuple[AuthCredentials, BaseUser] | None: + # For the purposes of this test, all requests are authenticated! + return (AuthCredentials(['authenticated']), SimpleUser('test_user')) + + client = TestClient( + app.build( + middleware=[ + Middleware( + AuthenticationMiddleware, backend=TestAuthMiddleware() + ) + ] + ) + ) + + # Set the output message to be the authenticated user name + handler.on_message_send.side_effect = lambda params, context: Message( + contextId='session-xyz', + messageId='112', + role=Role.agent, + parts=[ + Part(TextPart(text=context.user.user_name)), + ], + ) + + # Send request + response = client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'message/send', + 'params': { + 'message': { + 'role': 'agent', + 'parts': [{'kind': 'text', 'text': 'Hello'}], + 'messageId': '111', + 'kind': 'message', + 'taskId': 'task1', + 'contextId': 'session-xyz', + } + }, + }, + ) + + # Verify response + assert response.status_code == 200 + result = SendMessageResponse.model_validate(response.json()) + assert isinstance(result.root, SendMessageSuccessResponse) + assert isinstance(result.root.result, Message) + message = result.root.result + assert isinstance(message.parts[0].root, TextPart) + assert message.parts[0].root.text == 'test_user' + + # Verify handler was called + handler.on_message_send.assert_awaited_once() + + # === STREAMING TESTS ===