diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 63d0fdc74..c777e34cc 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -555,11 +555,15 @@ async def on_subscribe_to_task( if task.status.state in TERMINAL_TASK_STATES: raise ServerError( - error=InvalidParamsError( + error=UnsupportedOperationError( message=f'Task {task.id} is in terminal state: {task.status.state}' ) ) + # The operation MUST return a Task object as the first event in the stream + # https://a2a-protocol.org/latest/specification/#316-subscribe-to-task + yield task + task_manager = TaskManager( task_id=task.id, context_id=task.context_id, diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 9a00ba6c6..350d595a4 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -1499,7 +1499,10 @@ async def exec_side_effect(_request, queue: EventQueue): # Allow producer to emit the next event allow_second_event.set() - received = await resub_gen.__anext__() + first_subscribe_event = await anext(resub_gen) + assert first_subscribe_event == task_for_resub + + received = await anext(resub_gen) assert received == second_event # Finish producer to allow cleanup paths to complete @@ -2706,7 +2709,7 @@ async def test_on_subscribe_to_task_in_terminal_state(terminal_state): async for _ in request_handler.on_subscribe_to_task(params, context): pass # pragma: no cover - assert isinstance(exc_info.value.error, InvalidParamsError) + assert isinstance(exc_info.value.error, UnsupportedOperationError) assert exc_info.value.error.message assert ( f'Task {task_id} is in terminal state: {terminal_state}' diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index fca1175af..a9e940a03 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -703,7 +703,9 @@ async def streaming_coro(): collected_events: list[Any] = [] async for event in response: collected_events.append(event) - assert len(collected_events) == len(events) + assert ( + len(collected_events) == len(events) + 1 + ) # First event is task itself assert mock_task.history is not None and len(mock_task.history) == 0 async def test_on_subscribe_no_existing_task_error(self) -> None: