From 4ae40aed4537ff54276a26d87835104a2b4477dc Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Mon, 23 Mar 2026 06:59:26 +0000 Subject: [PATCH] CopyingTaskStoreAdapter. --- src/a2a/server/tasks/copying_task_store.py | 61 ++++++ src/a2a/server/tasks/inmemory_task_store.py | 61 +++++- .../integration/test_copying_observability.py | 184 ++++++++++++++++++ tests/server/tasks/test_copying_task_store.py | 132 +++++++++++++ .../server/tasks/test_inmemory_task_store.py | 35 ++++ 5 files changed, 469 insertions(+), 4 deletions(-) create mode 100644 src/a2a/server/tasks/copying_task_store.py create mode 100644 tests/integration/test_copying_observability.py create mode 100644 tests/server/tasks/test_copying_task_store.py diff --git a/src/a2a/server/tasks/copying_task_store.py b/src/a2a/server/tasks/copying_task_store.py new file mode 100644 index 000000000..6bfda5e74 --- /dev/null +++ b/src/a2a/server/tasks/copying_task_store.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import logging + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from a2a.server.context import ServerCallContext +from a2a.server.tasks.task_store import TaskStore +from a2a.types.a2a_pb2 import ListTasksRequest, ListTasksResponse, Task + + +logger = logging.getLogger(__name__) + + +class CopyingTaskStoreAdapter(TaskStore): + """An adapter that ensures deep copies of tasks are passed to and returned from the underlying TaskStore. + + This prevents accidental shared mutable state bugs where code modifies a Task object + retrieved from the store without explicitly saving it, which hides missing save calls. + """ + + def __init__(self, underlying_store: TaskStore): + self._store = underlying_store + + async def save( + self, task: Task, context: ServerCallContext | None = None + ) -> None: + """Saves a copy of the task to the underlying store.""" + task_copy = Task() + task_copy.CopyFrom(task) + await self._store.save(task_copy, context) + + async def get( + self, task_id: str, context: ServerCallContext | None = None + ) -> Task | None: + """Retrieves a task from the underlying store and returns a copy.""" + task = await self._store.get(task_id, context) + if task is None: + return None + task_copy = Task() + task_copy.CopyFrom(task) + return task_copy + + async def list( + self, + params: ListTasksRequest, + context: ServerCallContext | None = None, + ) -> ListTasksResponse: + """Retrieves a list of tasks from the underlying store and returns a copy.""" + response = await self._store.list(params, context) + response_copy = ListTasksResponse() + response_copy.CopyFrom(response) + return response_copy + + async def delete( + self, task_id: str, context: ServerCallContext | None = None + ) -> None: + """Deletes a task from the underlying store.""" + await self._store.delete(task_id, context) diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index eb596ca4b..f887b77ba 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -3,6 +3,7 @@ from a2a.server.context import ServerCallContext from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope +from a2a.server.tasks.copying_task_store import CopyingTaskStoreAdapter from a2a.server.tasks.task_store import TaskStore from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import Task @@ -14,8 +15,8 @@ logger = logging.getLogger(__name__) -class InMemoryTaskStore(TaskStore): - """In-memory implementation of TaskStore. +class _InMemoryTaskStoreImpl(TaskStore): + """Internal In-memory implementation of TaskStore. Stores task objects in a nested dictionary in memory, keyed by owner then task_id. Task data is lost when the server process stops. @@ -25,8 +26,8 @@ def __init__( self, owner_resolver: OwnerResolver = resolve_user_scope, ) -> None: - """Initializes the InMemoryTaskStore.""" - logger.debug('Initializing InMemoryTaskStore') + """Initializes the internal _InMemoryTaskStoreImpl.""" + logger.debug('Initializing _InMemoryTaskStoreImpl') self.tasks: dict[str, dict[str, Task]] = {} self.lock = asyncio.Lock() self.owner_resolver = owner_resolver @@ -183,3 +184,55 @@ async def delete( if not owner_tasks: del self.tasks[owner] logger.debug('Removed empty owner %s from store.', owner) + + +class InMemoryTaskStore(TaskStore): + """In-memory implementation of TaskStore. + + Can optionally use CopyingTaskStoreAdapter to wrap the internal dictionary-based + implementation, preventing shared mutable state issues by always returning and + storing deep copies. + """ + + def __init__( + self, + owner_resolver: OwnerResolver = resolve_user_scope, + use_copying: bool = True, + ) -> None: + """Initializes the InMemoryTaskStore. + + Args: + owner_resolver: Resolver for task owners. + use_copying: If True, the store will return and save deep copies of tasks. + Copying behavior is consistent with database task stores. + """ + self._impl = _InMemoryTaskStoreImpl(owner_resolver=owner_resolver) + self._store: TaskStore = ( + CopyingTaskStoreAdapter(self._impl) if use_copying else self._impl + ) + + async def save( + self, task: Task, context: ServerCallContext | None = None + ) -> None: + """Saves or updates a task in the store.""" + await self._store.save(task, context) + + async def get( + self, task_id: str, context: ServerCallContext | None = None + ) -> Task | None: + """Retrieves a task from the store by ID.""" + return await self._store.get(task_id, context) + + async def list( + self, + params: a2a_pb2.ListTasksRequest, + context: ServerCallContext | None = None, + ) -> a2a_pb2.ListTasksResponse: + """Retrieves a list of tasks from the store.""" + return await self._store.list(params, context) + + async def delete( + self, task_id: str, context: ServerCallContext | None = None + ) -> None: + """Deletes a task from the store by ID.""" + await self._store.delete(task_id, context) diff --git a/tests/integration/test_copying_observability.py b/tests/integration/test_copying_observability.py new file mode 100644 index 000000000..9ef1c0483 --- /dev/null +++ b/tests/integration/test_copying_observability.py @@ -0,0 +1,184 @@ +import httpx +import pytest +from typing import NamedTuple + +from starlette.applications import Starlette + +from a2a.client.client import Client, ClientConfig +from a2a.client.client_factory import ClientFactory +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes +from a2a.server.events import EventQueue +from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import TaskUpdater +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentInterface, + Artifact, + GetTaskRequest, + Message, + Part, + Role, + SendMessageRequest, + TaskState, +) +from a2a.utils import TransportProtocol + + +class MockMutatingAgentExecutor(AgentExecutor): + async def execute(self, context: RequestContext, event_queue: EventQueue): + assert context.task_id is not None + assert context.context_id is not None + task_updater = TaskUpdater( + event_queue, + context.task_id, + context.context_id, + ) + + user_input = context.get_user_input() + + if user_input == 'Init task': + # Explicitly save status change to ensure task exists with some state + await task_updater.update_status( + TaskState.TASK_STATE_WORKING, + message=task_updater.new_agent_message( + [Part(text='task working')] + ), + ) + else: + # Mutate the task WITHOUT saving it properly + assert context.current_task is not None + context.current_task.artifacts.append( + Artifact( + name='leaked-artifact', + parts=[Part(text='leaked artifact')], + ) + ) + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + raise NotImplementedError('Cancellation is not supported') + + +@pytest.fixture +def agent_card() -> AgentCard: + return AgentCard( + name='Mutating Agent', + description='Real in-memory integration testing.', + version='1.0.0', + capabilities=AgentCapabilities( + streaming=True, push_notifications=False + ), + skills=[], + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + supported_interfaces=[ + AgentInterface( + protocol_binding=TransportProtocol.JSONRPC, + url='http://testserver', + ), + ], + ) + + +class ClientSetup(NamedTuple): + client: Client + task_store: InMemoryTaskStore + use_copying: bool + + +def setup_client(agent_card: AgentCard, use_copying: bool) -> ClientSetup: + task_store = InMemoryTaskStore(use_copying=use_copying) + handler = DefaultRequestHandler( + agent_executor=MockMutatingAgentExecutor(), + task_store=task_store, + queue_manager=InMemoryQueueManager(), + ) + agent_card_routes = create_agent_card_routes( + agent_card=agent_card, card_url='/' + ) + jsonrpc_routes = create_jsonrpc_routes( + agent_card=agent_card, + request_handler=handler, + extended_agent_card=agent_card, + rpc_url='/', + ) + app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url='http://testserver' + ) + factory = ClientFactory( + config=ClientConfig( + httpx_client=httpx_client, + supported_protocol_bindings=[TransportProtocol.JSONRPC], + ) + ) + client = factory.create(agent_card) + return ClientSetup( + client=client, + task_store=task_store, + use_copying=use_copying, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('use_copying', [True, False]) +async def test_mutation_observability(agent_card: AgentCard, use_copying: bool): + """Tests that task mutations are observable when copying is disabled. + + When copying is disabled, the agent mutates the task in-place and the + changes are observable by the client. When copying is enabled, the agent + mutates a copy of the task and the changes are not observable by the client. + + It is ok to remove the `use_copying` parameter from the system in the future + to make InMemoryTaskStore consistent with other task stores. + """ + client_setup = setup_client(agent_card, use_copying) + client = client_setup.client + + # 1. First message to create the task + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-mut-init', + parts=[Part(text='Init task')], + ) + + events = [ + event + async for event in client.send_message( + request=SendMessageRequest(message=message_to_send) + ) + ] + + task = events[-1][1] + assert task is not None + task_id = task.id + + # 2. Second message to mutate it + message_to_send_2 = Message( + role=Role.ROLE_USER, + message_id='msg-mut-do', + task_id=task_id, + parts=[Part(text='Update task without saving it')], + ) + + _ = [ + event + async for event in client.send_message( + request=SendMessageRequest(message=message_to_send_2) + ) + ] + + # 3. Get task via client + retrieved_task = await client.get_task(request=GetTaskRequest(id=task_id)) + + # 4. Assert behavior based on `use_copying` + if use_copying: + # The un-saved artifact IS NOT leaked to the client + assert len(retrieved_task.artifacts) == 0 + else: + # The un-saved artifact IS leaked to the client + assert len(retrieved_task.artifacts) == 1 + assert retrieved_task.artifacts[0].name == 'leaked-artifact' diff --git a/tests/server/tasks/test_copying_task_store.py b/tests/server/tasks/test_copying_task_store.py new file mode 100644 index 000000000..5e07b909b --- /dev/null +++ b/tests/server/tasks/test_copying_task_store.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import unittest +import pytest + +from unittest.mock import AsyncMock + +from a2a.server.context import ServerCallContext +from a2a.server.tasks.copying_task_store import CopyingTaskStoreAdapter +from a2a.server.tasks.task_store import TaskStore +from a2a.types.a2a_pb2 import ( + ListTasksRequest, + ListTasksResponse, + Task, + TaskState, +) + + +@pytest.mark.asyncio +async def test_copying_task_store_save(): + """Test that the adapter makes a copy of the task when saving.""" + mock_store = AsyncMock(spec=TaskStore) + adapter = CopyingTaskStoreAdapter(mock_store) + + original_task = Task( + id='test_task', status={'state': TaskState.TASK_STATE_WORKING} + ) + context = ServerCallContext() + + await adapter.save(original_task, context) + + # Verify underlying store was called + mock_store.save.assert_awaited_once() + + # Get the saved task + saved_task = mock_store.save.call_args[0][0] + saved_context = mock_store.save.call_args[0][1] + + # Verify context is passed correctly + assert saved_context is context + + # Verify content is identical + assert saved_task.id == original_task.id + assert saved_task.status.state == original_task.status.state + + # Verify it is a COPY, not the same reference + assert saved_task is not original_task + + +@pytest.mark.asyncio +async def test_copying_task_store_get(): + """Test that the adapter returns a copy of the task retrieved.""" + mock_store = AsyncMock(spec=TaskStore) + adapter = CopyingTaskStoreAdapter(mock_store) + + stored_task = Task( + id='test_task', status={'state': TaskState.TASK_STATE_WORKING} + ) + mock_store.get.return_value = stored_task + context = ServerCallContext() + + retrieved_task = await adapter.get('test_task', context) + + # Verify underlying store was called + mock_store.get.assert_awaited_once_with('test_task', context) + + # Verify retrieved task has identical content + assert retrieved_task is not None + assert retrieved_task.id == stored_task.id + assert retrieved_task.status.state == stored_task.status.state + + # Verify it is a COPY, not the same reference + assert retrieved_task is not stored_task + + +@pytest.mark.asyncio +async def test_copying_task_store_get_none(): + """Test that the adapter properly returns None when no task is found.""" + mock_store = AsyncMock(spec=TaskStore) + adapter = CopyingTaskStoreAdapter(mock_store) + + mock_store.get.return_value = None + context = ServerCallContext() + + retrieved_task = await adapter.get('test_task', context) + + # Verify underlying store was called + mock_store.get.assert_awaited_once_with('test_task', context) + assert retrieved_task is None + + +@pytest.mark.asyncio +async def test_copying_task_store_list(): + """Test that the adapter returns a copy of the list response.""" + mock_store = AsyncMock(spec=TaskStore) + adapter = CopyingTaskStoreAdapter(mock_store) + + task1 = Task(id='test_task_1') + task2 = Task(id='test_task_2') + stored_response = ListTasksResponse(tasks=[task1, task2]) + mock_store.list.return_value = stored_response + context = ServerCallContext() + request = ListTasksRequest(page_size=10) + + retrieved_response = await adapter.list(request, context) + + # Verify underlying store was called + mock_store.list.assert_awaited_once_with(request, context) + + # Verify retrieved response has identical content + assert len(retrieved_response.tasks) == 2 + assert retrieved_response.tasks[0].id == 'test_task_1' + assert retrieved_response.tasks[1].id == 'test_task_2' + + # Verify it is a COPY, not the same reference + assert retrieved_response is not stored_response + # Also verify inner tasks are copies + assert retrieved_response.tasks[0] is not task1 + assert retrieved_response.tasks[1] is not task2 + + +@pytest.mark.asyncio +async def test_copying_task_store_delete(): + """Test that the adapter calls delete on underlying store.""" + mock_store = AsyncMock(spec=TaskStore) + adapter = CopyingTaskStoreAdapter(mock_store) + context = ServerCallContext() + + await adapter.delete('test_task', context) + + # Verify underlying store was called + mock_store.delete.assert_awaited_once_with('test_task', context) diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index 2184c2116..af3531e33 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -330,3 +330,38 @@ async def test_owner_resource_scoping() -> None: # Cleanup remaining tasks await store.delete('u1-task2', context_user1) await store.delete('u2-task1', context_user2) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('use_copying', [True, False]) +async def test_inmemory_task_store_copying_behavior(use_copying: bool): + """Verify that tasks are copied (or not) based on use_copying parameter.""" + store = InMemoryTaskStore(use_copying=use_copying) + + original_task = Task( + id='test_task', status=TaskStatus(state=TaskState.TASK_STATE_WORKING) + ) + await store.save(original_task) + + # Retrieve it + retrieved_task = await store.get('test_task') + assert retrieved_task is not None + + if use_copying: + assert retrieved_task is not original_task + else: + assert retrieved_task is original_task + + # Modify retrieved task + retrieved_task.status.state = TaskState.TASK_STATE_COMPLETED + + # Retrieve it again, it should NOT be modified in the store if use_copying=True + retrieved_task_2 = await store.get('test_task') + assert retrieved_task_2 is not None + + if use_copying: + assert retrieved_task_2.status.state == TaskState.TASK_STATE_WORKING + assert retrieved_task_2 is not retrieved_task + else: + assert retrieved_task_2.status.state == TaskState.TASK_STATE_COMPLETED + assert retrieved_task_2 is retrieved_task