Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix(client): suppress stale task notifications at the start of receiv…
…e_response()

When a background task (spawned via run_in_background=True) completed
between turns, its TaskNotificationMessage sat in the message buffer
and was the first thing yielded by the next receive_response() call.
This caused the notification to appear before the actual Turn N+1
response — and in some cases caused the model to respond to the stale
task context instead of the new user prompt.

Fix: ClaudeSDKClient now tracks which turn each background task was
started in (_task_turn_map).  receive_response() defers any task
lifecycle events that arrive before the first non-task message of the
current turn.  When the first substantive message arrives, deferred
events are flushed — unless the event is a TaskNotificationMessage for
a task started in an earlier turn, in which case it is discarded as
stale cross-turn noise.

Notifications that arrive AFTER the first AssistantMessage (mid-turn)
are still yielded normally.  Notifications for tasks with no recorded
start (unknown task_id) are yielded as current-turn (safe default).
Map entries are cleaned up when a notification is processed to prevent
unbounded growth on long-lived clients.

The raw receive_messages() stream is unchanged: callers who need every
event regardless of turn boundaries should use that method instead.

Closes #788
  • Loading branch information
qozle committed Apr 5, 2026
commit 4c93ae7211c2898a1cac951efa70abb2a4365747
62 changes: 62 additions & 0 deletions src/claude_agent_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Message,
PermissionMode,
ResultMessage,
TaskNotificationMessage,
TaskStartedMessage,
)


Expand Down Expand Up @@ -74,6 +76,14 @@ def __init__(
self._transport: Transport | None = None
self._query: Any | None = None

# Turn tracking for background-task notification hygiene.
# Each ResultMessage increments _current_turn. When a TaskStartedMessage
# arrives, we record task_id → turn so that receive_response() can
# suppress TaskNotificationMessages whose task was started in a prior
# turn and would otherwise leak into the next turn's response stream.
self._current_turn: int = 0
self._task_turn_map: dict[str, int] = {}

def _convert_hooks_to_internal_format(
self, hooks: dict[HookEvent, list[HookMatcher]]
) -> dict[str, list[dict[str, Any]]]:
Expand Down Expand Up @@ -524,10 +534,62 @@ async def receive_response(self) -> AsyncIterator[Message]:
Note:
To collect all messages: `messages = [msg async for msg in client.receive_response()]`
The final message in the list will always be a ResultMessage.

Background task notifications:
If a background task (spawned via the Agent tool with run_in_background=True)
completes after a previous turn's ResultMessage but before this call returns,
its TaskNotificationMessage is suppressed from this iterator. The notification
arrived between turns and would otherwise appear before the first assistant
response, making it look like stale context from a prior conversation.

Task completions that arrive *during* the current turn (after the first
assistant message) are still yielded normally. For the full unfiltered stream
including all task events, use receive_messages() instead.
"""
if not self._query:
raise CLIConnectionError("Not connected. Call connect() first.")

# We hold any task-lifecycle events that arrive before the first
# non-task message of this turn. Once a non-task message arrives we
# know the CLI is processing our latest query, so deferred events are
# re-yielded in order. Events for tasks started in a previous turn
# are discarded at that point because they are stale cross-turn noise.
deferred: list[Message] = []
turn_started = False

async for message in self.receive_messages():
# Track task IDs so we know which turn they were spawned in.
if isinstance(message, TaskStartedMessage):
self._task_turn_map[message.task_id] = self._current_turn

if not turn_started:
if isinstance(message, (TaskStartedMessage, TaskNotificationMessage)):
# Arrival before the first non-task message: could be
# a stale notification from a previous turn. Defer.
deferred.append(message)
continue

# First non-task message — we are now inside the current turn.
turn_started = True
for deferred_msg in deferred:
if isinstance(deferred_msg, TaskNotificationMessage):
task_turn = self._task_turn_map.get(deferred_msg.task_id)
# Clean up the map entry regardless of outcome.
self._task_turn_map.pop(deferred_msg.task_id, None)
if task_turn is not None and task_turn < self._current_turn:
# Stale: started in a previous turn, completed
# between turns. Drop it.
continue
yield deferred_msg
deferred.clear()

# Clean up map entries when a notification is yielded mid-turn.
if isinstance(message, TaskNotificationMessage):
self._task_turn_map.pop(message.task_id, None)

yield message
if isinstance(message, ResultMessage):
self._current_turn += 1
return

async def disconnect(self) -> None:
Expand Down
210 changes: 210 additions & 0 deletions tests/test_streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
UserMessage,
query,
)
from claude_agent_sdk.types import TaskNotificationMessage, TaskStartedMessage
from claude_agent_sdk._internal.transport.subprocess_cli import SubprocessCLITransport


Expand Down Expand Up @@ -1312,3 +1313,212 @@ async def mock_receive():
assert isinstance(messages[-1], ResultMessage)

anyio.run(_test)


# ---------------------------------------------------------------------------
# Task notification hygiene tests (issue #788)
# ---------------------------------------------------------------------------

def _make_assistant_msg(text: str = "4") -> dict:
return {
"type": "assistant",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": text}],
"model": "claude-sonnet-4-5",
},
}


def _make_result_msg() -> dict:
return {
"type": "result",
"subtype": "success",
"duration_ms": 100,
"duration_api_ms": 80,
"is_error": False,
"num_turns": 1,
"session_id": "test",
"total_cost_usd": 0.001,
}


def _make_task_started(task_id: str = "task-1") -> dict:
return {
"type": "system",
"subtype": "task_started",
"task_id": task_id,
"description": "background work",
"uuid": f"uuid-{task_id}",
"session_id": "test",
}


def _make_task_notification(task_id: str = "task-1") -> dict:
return {
"type": "system",
"subtype": "task_notification",
"task_id": task_id,
"status": "completed",
"output_file": "/tmp/out.md",
"summary": "done",
"uuid": f"notif-{task_id}",
"session_id": "test",
}


class TestReceiveResponseTaskNotificationHygiene:
"""receive_response() must not leak between-turn task notifications (issue #788)."""

def _make_transport_with_messages(self, messages: list[dict]):
"""Build a mock transport that yields the given messages after init."""
mock_transport = AsyncMock()
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)

written_messages: list[str] = []

async def mock_write(data):
written_messages.append(data)

mock_transport.write = AsyncMock(side_effect=mock_write)

async def msg_gen():
# Respond to initialize request first
await asyncio.sleep(0.01)
for msg_str in written_messages:
try:
msg = json.loads(msg_str.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype") == "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
},
}
break
except (json.JSONDecodeError, KeyError):
pass
for m in messages:
yield m

mock_transport.read_messages = msg_gen
return mock_transport

def test_stale_notification_before_turn2_is_suppressed(self):
"""TaskNotificationMessage buffered before Turn 2 starts is NOT yielded."""

async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_cls:
# Stream: Turn 1 (task starts + result), then stale notification,
# then Turn 2 (assistant + result).
msgs = [
_make_task_started("t1"),
_make_assistant_msg("Spawning"),
_make_result_msg(), # End Turn 1
_make_task_notification("t1"), # Stale: between turns
_make_assistant_msg("4"),
_make_result_msg(), # End Turn 2
]
mock_cls.return_value = self._make_transport_with_messages(msgs)

async with ClaudeSDKClient() as client:
# Consume Turn 1
turn1 = [m async for m in client.receive_response()]
assert any(isinstance(m, TaskStartedMessage) for m in turn1)
assert isinstance(turn1[-1], ResultMessage)

# Turn 2: stale notification must not appear
turn2 = [m async for m in client.receive_response()]
assert not any(
isinstance(m, TaskNotificationMessage) for m in turn2
), "Stale TaskNotificationMessage leaked into Turn 2"
assert any(isinstance(m, AssistantMessage) for m in turn2)
assert isinstance(turn2[-1], ResultMessage)

anyio.run(_test)

def test_notification_arriving_mid_turn_is_yielded(self):
"""TaskNotificationMessage that arrives after the first AssistantMessage IS yielded."""

async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_cls:
# Stream: Turn 1 starts, notification arrives after first assistant msg.
msgs = [
_make_task_started("t2"),
_make_assistant_msg("thinking..."),
_make_task_notification("t2"), # Arrives mid-turn — should show
_make_assistant_msg("done"),
_make_result_msg(),
]
mock_cls.return_value = self._make_transport_with_messages(msgs)

async with ClaudeSDKClient() as client:
turn1 = [m async for m in client.receive_response()]

notifications = [m for m in turn1 if isinstance(m, TaskNotificationMessage)]
assert len(notifications) == 1, (
"TaskNotificationMessage that arrived mid-turn should be yielded"
)

anyio.run(_test)

def test_turn_counter_increments_and_cleans_map(self):
"""_current_turn increments per result; _task_turn_map is cleaned up."""

async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_cls:
msgs = [
_make_task_started("t3"),
_make_task_notification("t3"), # Completes in Turn 1
_make_assistant_msg("hi"),
_make_result_msg(),
]
mock_cls.return_value = self._make_transport_with_messages(msgs)

async with ClaudeSDKClient() as client:
assert client._current_turn == 0
_ = [m async for m in client.receive_response()]
assert client._current_turn == 1
# Map entry cleaned up after notification was processed
assert "t3" not in client._task_turn_map

anyio.run(_test)

def test_unknown_task_id_notification_is_yielded(self):
"""Notification for an unknown task_id (no TaskStartedMessage seen) is yielded."""

async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_cls:
# No task_started, just a notification — must not crash or suppress.
msgs = [
_make_task_notification("unknown-task"),
_make_assistant_msg("hi"),
_make_result_msg(),
]
mock_cls.return_value = self._make_transport_with_messages(msgs)

async with ClaudeSDKClient() as client:
turn1 = [m async for m in client.receive_response()]

notifications = [m for m in turn1 if isinstance(m, TaskNotificationMessage)]
assert len(notifications) == 1, (
"Notification for unknown task_id should be yielded as current-turn"
)

anyio.run(_test)