From f6ba91b120b97e38a4cdb3e61b226cef348679eb Mon Sep 17 00:00:00 2001 From: Adrian Date: Mon, 18 May 2026 16:53:20 -0700 Subject: [PATCH 1/2] Runtime handling updates (#3451) ## Summary - Refresh runtime handling around session and tool-call flows. - Adjust model configuration metadata used by runtime integrations. - Add focused coverage for the updated behavior. ## Validation - .venv/bin/python -m pytest tests/model_settings/test_serialization.py tests/models/test_trace_config.py tests/mcp/test_streamable_http_client_factory.py tests/test_run_context_approvals.py tests/test_run_state.py::TestRunState::test_trace_api_key_serialization_is_opt_in tests/realtime/test_session.py - .venv/bin/ruff check - .venv/bin/ruff format --check - git diff --check --- src/agents/extensions/models/any_llm_model.py | 23 +- src/agents/extensions/models/litellm_model.py | 15 +- src/agents/mcp/server.py | 7 +- src/agents/model_settings.py | 26 ++ src/agents/models/_trace.py | 31 ++ src/agents/models/openai_chatcompletions.py | 5 +- src/agents/realtime/session.py | 377 ++++++++++++------ src/agents/run_context.py | 19 +- .../run_internal/agent_runner_helpers.py | 2 - tests/mcp/test_mcp_auth_params.py | 3 + .../test_streamable_http_client_factory.py | 8 +- tests/model_settings/test_serialization.py | 23 ++ tests/models/test_trace_config.py | 32 ++ tests/realtime/test_session.py | 227 +++++++++++ tests/test_run_context_approvals.py | 23 ++ tests/test_run_state.py | 15 +- 16 files changed, 679 insertions(+), 157 deletions(-) create mode 100644 src/agents/models/_trace.py create mode 100644 tests/models/test_trace_config.py diff --git a/src/agents/extensions/models/any_llm_model.py b/src/agents/extensions/models/any_llm_model.py index 02f58dcf68..ca4ecf10ad 100644 --- a/src/agents/extensions/models/any_llm_model.py +++ b/src/agents/extensions/models/any_llm_model.py @@ -34,6 +34,7 @@ response_terminal_failure_error, ) from ...models._retry_runtime import should_disable_provider_managed_retries +from ...models._trace import model_config_for_trace from ...models.chatcmpl_converter import Converter from ...models.chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE, ChatCmplHelpers from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler @@ -467,12 +468,11 @@ async def _get_response_via_chat( ) -> ModelResponse: with generation_span( model=str(self.model), - model_config=model_settings.to_json_dict() - | { - "base_url": str(self.base_url or ""), - "provider": self._provider_name, - "model_impl": "any-llm", - }, + model_config=model_config_for_trace( + model_settings, + base_url=self.base_url or "", + extra_config={"provider": self._provider_name, "model_impl": "any-llm"}, + ), disabled=tracing.is_disabled(), ) as span_generation: response = await self._fetch_chat_response( @@ -570,12 +570,11 @@ async def _stream_response_via_chat( ) -> AsyncIterator[TResponseStreamEvent]: with generation_span( model=str(self.model), - model_config=model_settings.to_json_dict() - | { - "base_url": str(self.base_url or ""), - "provider": self._provider_name, - "model_impl": "any-llm", - }, + model_config=model_config_for_trace( + model_settings, + base_url=self.base_url or "", + extra_config={"provider": self._provider_name, "model_impl": "any-llm"}, + ), disabled=tracing.is_disabled(), ) as span_generation: response, stream = await self._fetch_chat_response( diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index f078336741..4e649feb08 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -43,6 +43,7 @@ from ...model_settings import ModelSettings from ...models._openai_retry import get_openai_retry_advice from ...models._retry_runtime import should_disable_provider_managed_retries +from ...models._trace import model_config_for_trace from ...models.chatcmpl_converter import Converter from ...models.chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE, ChatCmplHelpers from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler @@ -213,8 +214,11 @@ async def get_response( ) -> ModelResponse: with generation_span( model=str(self.model), - model_config=model_settings.to_json_dict() - | {"base_url": str(self.base_url or ""), "model_impl": "litellm"}, + model_config=model_config_for_trace( + model_settings, + base_url=self.base_url or "", + extra_config={"model_impl": "litellm"}, + ), disabled=tracing.is_disabled(), ) as span_generation: response = await self._fetch_response( @@ -327,8 +331,11 @@ async def stream_response( ) -> AsyncIterator[TResponseStreamEvent]: with generation_span( model=str(self.model), - model_config=model_settings.to_json_dict() - | {"base_url": str(self.base_url or ""), "model_impl": "litellm"}, + model_config=model_config_for_trace( + model_settings, + base_url=self.base_url or "", + extra_config={"model_impl": "litellm"}, + ), disabled=tracing.is_disabled(), ) as span_generation: response, stream = await self._fetch_response( diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 268b0893da..be595f11c5 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -99,7 +99,7 @@ def _create_default_streamable_http_client( timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None, ) -> httpx.AsyncClient: - kwargs: dict[str, Any] = {"follow_redirects": True} + kwargs: dict[str, Any] = {"follow_redirects": False} if timeout is not None: kwargs["timeout"] = timeout if headers is not None: @@ -1441,8 +1441,9 @@ def create_streams( auth=self.params.get("auth"), transport_factory=_InitializedNotificationTolerantStreamableHTTPTransport, ) - if httpx_client_factory is not None: - kwargs["httpx_client_factory"] = httpx_client_factory + kwargs["httpx_client_factory"] = ( + httpx_client_factory or _create_default_streamable_http_client + ) if "auth" in self.params: kwargs["auth"] = self.params["auth"] return streamablehttp_client(**kwargs) diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 3b2c93ddf8..1ef9822f52 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -60,6 +60,27 @@ class MCPToolChoice: Headers: TypeAlias = Mapping[str, str | Omit] ToolChoice: TypeAlias = Literal["auto", "required", "none"] | str | MCPToolChoice | None +_TRACEABLE_MODEL_SETTING_FIELDS = ( + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + "tool_choice", + "parallel_tool_calls", + "truncation", + "max_tokens", + "reasoning", + "verbosity", + "metadata", + "store", + "prompt_cache_retention", + "include_usage", + "response_include", + "top_logprobs", + "retry", + "context_management", +) + @dataclass class ModelSettings: @@ -199,6 +220,11 @@ def resolve(self, override: ModelSettings | None) -> ModelSettings: def to_json_dict(self) -> dict[str, Any]: return cast(dict[str, Any], TypeAdapter(ModelSettings).dump_python(self, mode="json")) + def to_traceable_dict(self) -> dict[str, Any]: + """Serialize settings for tracing without provider-specific request extras.""" + payload = self.to_json_dict() + return {key: payload[key] for key in _TRACEABLE_MODEL_SETTING_FIELDS if key in payload} + def _merge_retry_settings( inherited: ModelRetrySettings | None, diff --git a/src/agents/models/_trace.py b/src/agents/models/_trace.py new file mode 100644 index 0000000000..30026ebb50 --- /dev/null +++ b/src/agents/models/_trace.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Any +from urllib.parse import urlsplit, urlunsplit + +from ..model_settings import ModelSettings + + +def sanitize_url_for_trace(url: object) -> str: + """Return a URL safe for tracing by removing auth material and request parameters.""" + try: + parts = urlsplit(str(url)) + except ValueError: + return "" + + netloc = parts.netloc.rsplit("@", 1)[-1] + return urlunsplit((parts.scheme, netloc, parts.path, "", "")) + + +def model_config_for_trace( + model_settings: ModelSettings, + *, + base_url: object | None = None, + extra_config: dict[str, Any] | None = None, +) -> dict[str, Any]: + config = model_settings.to_traceable_dict() + if base_url is not None: + config["base_url"] = sanitize_url_for_trace(base_url) + if extra_config: + config.update(extra_config) + return config diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index dbf4045b53..cba01163e9 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -33,6 +33,7 @@ from ..util._json import _to_dump_compatible from ._openai_retry import get_openai_retry_advice from ._retry_runtime import should_disable_provider_managed_retries +from ._trace import model_config_for_trace from .chatcmpl_converter import Converter from .chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE, ChatCmplHelpers from .chatcmpl_stream_handler import ChatCmplStreamHandler @@ -147,7 +148,7 @@ async def get_response( with generation_span( model=str(self.model), - model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, + model_config=model_config_for_trace(model_settings, base_url=self._client.base_url), disabled=tracing.is_disabled(), ) as span_generation: response = await self._fetch_response( @@ -281,7 +282,7 @@ async def stream_response( with generation_span( model=str(self.model), - model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, + model_config=model_config_for_trace(model_settings, base_url=self._client.base_url), disabled=tracing.is_disabled(), ) as span_generation: response, stream = await self._fetch_response( diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 2b45ccaed6..ca809dd9c4 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -104,6 +104,21 @@ def _serialize_tool_output(output: Any) -> str: return str(output) +@dataclasses.dataclass +class _PendingToolOutput: + tool_call: RealtimeModelToolCallEvent + output: str + start_response: bool + tool_end_event: RealtimeToolEnd | None = None + session_update: RealtimeModelSendSessionUpdate | None = None + + +class _PendingToolOutputSendError(RuntimeError): + def __init__(self, call_id: str, cause: BaseException) -> None: + super().__init__(str(cause)) + self.call_id = call_id + + class RealtimeSession(RealtimeModelListener): """A connection to a realtime model. It streams events from the model to you, and allows you to send messages and audio to the model. @@ -163,6 +178,9 @@ def __init__( self._pending_tool_calls: dict[ str, tuple[RealtimeModelToolCallEvent, RealtimeAgent, FunctionTool, ToolApprovalItem] ] = {} + self._active_tool_call_ids: set[str] = set() + self._completed_tool_call_ids: set[str] = set() + self._pending_tool_outputs: dict[str, _PendingToolOutput] = {} # Guardrails state tracking self._interrupted_response_ids: set[str] = set() @@ -556,23 +574,42 @@ async def _send_tool_rejection( tool=tool, call_id=event.call_id, ) - await self._model.send_event( - RealtimeModelSendToolOutput( + await self._send_tool_output_completion( + _PendingToolOutput( tool_call=event, output=rejection_message, start_response=True, + tool_end_event=RealtimeToolEnd( + info=self._event_info, + tool=tool, + output=rejection_message, + agent=agent, + arguments=event.arguments, + ), ) ) - await self._put_event( - RealtimeToolEnd( - info=self._event_info, - tool=tool, - output=rejection_message, - agent=agent, - arguments=event.arguments, + async def _send_tool_output_completion(self, pending_output: _PendingToolOutput) -> None: + call_id = pending_output.tool_call.call_id + self._pending_tool_outputs[call_id] = pending_output + try: + await self._send_pending_tool_output(pending_output) + except Exception as exc: + raise _PendingToolOutputSendError(call_id, exc) from exc + self._pending_tool_outputs.pop(call_id, None) + + async def _send_pending_tool_output(self, pending_output: _PendingToolOutput) -> None: + if pending_output.session_update is not None: + await self._model.send_event(pending_output.session_update) + await self._model.send_event( + RealtimeModelSendToolOutput( + tool_call=pending_output.tool_call, + output=pending_output.output, + start_response=pending_output.start_response, ) ) + if pending_output.tool_end_event is not None: + await self._put_event(pending_output.tool_end_event) async def _resolve_approval_rejection_message(self, *, tool: FunctionTool, call_id: str) -> str: """Resolve model-visible output text for approval rejections.""" @@ -624,12 +661,30 @@ async def approve_tool_call(self, call_id: str, *, always: bool = False) -> None return tool_call, agent_snapshot, function_tool, approval_item = pending - self._context_wrapper.approve_tool(approval_item, always_approve=always) + if not self._begin_tool_call(call_id, from_pending_approval=True): + return - if self._async_tool_calls: - self._enqueue_tool_call_task(tool_call, agent_snapshot) - else: - await self._handle_tool_call(tool_call, agent_snapshot=agent_snapshot) + try: + self._context_wrapper.approve_tool(approval_item, always_approve=always) + + if self._async_tool_calls: + self._enqueue_tool_call_task( + tool_call, + agent_snapshot, + from_pending_approval=True, + call_id_reserved=True, + ) + else: + await self._handle_tool_call( + tool_call, + agent_snapshot=agent_snapshot, + from_pending_approval=True, + call_id_reserved=True, + ) + except Exception: + if call_id in self._active_tool_call_ids: + self._finish_tool_call(call_id, mark_completed=False) + raise async def reject_tool_call( self, @@ -643,148 +698,185 @@ async def reject_tool_call( if pending is None: return + if not self._begin_tool_call(call_id, from_pending_approval=True): + return + + mark_completed = False tool_call, agent_snapshot, function_tool, approval_item = pending - self._context_wrapper.reject_tool( - approval_item, - always_reject=always, - rejection_message=rejection_message, - ) - await self._send_tool_rejection(tool_call, tool=function_tool, agent=agent_snapshot) + try: + self._context_wrapper.reject_tool( + approval_item, + always_reject=always, + rejection_message=rejection_message, + ) + await self._send_tool_rejection(tool_call, tool=function_tool, agent=agent_snapshot) + mark_completed = True + finally: + self._finish_tool_call(call_id, mark_completed=mark_completed) async def _handle_tool_call( self, event: RealtimeModelToolCallEvent, *, agent_snapshot: RealtimeAgent | None = None, + from_pending_approval: bool = False, + call_id_reserved: bool = False, ) -> None: """Handle a tool call event.""" - agent = agent_snapshot or self._current_agent - tools, handoffs = await asyncio.gather( - agent.get_all_tools(self._context_wrapper), - self._get_handoffs(agent, self._context_wrapper), - ) - function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)} - handoff_map = {handoff.tool_name: handoff for handoff in handoffs} + mark_completed = False + if not call_id_reserved and not self._begin_tool_call( + event.call_id, from_pending_approval=from_pending_approval + ): + return - if event.name in function_map: - func_tool = function_map[event.name] - approval_status = await self._maybe_request_tool_approval( - event, function_tool=func_tool, agent=agent - ) - if approval_status is False: - await self._send_tool_rejection(event, tool=func_tool, agent=agent) - return - if approval_status is None: + agent = agent_snapshot or self._current_agent + try: + pending_output = self._pending_tool_outputs.get(event.call_id) + if pending_output is not None: + await self._send_tool_output_completion(pending_output) + mark_completed = True return - await self._put_event( - RealtimeToolStart( - info=self._event_info, - tool=func_tool, - agent=agent, - arguments=event.arguments, - ) + tools, handoffs = await asyncio.gather( + agent.get_all_tools(self._context_wrapper), + self._get_handoffs(agent, self._context_wrapper), ) + function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)} + handoff_map = {handoff.tool_name: handoff for handoff in handoffs} - tool_context = ToolContext( - context=self._context_wrapper.context, - usage=self._context_wrapper.usage, - tool_name=event.name, - tool_call_id=event.call_id, - tool_arguments=event.arguments, - agent=agent, - ) - result = await invoke_function_tool( - function_tool=func_tool, - context=tool_context, - arguments=event.arguments, - ) + if event.name in function_map: + func_tool = function_map[event.name] + approval_status = await self._maybe_request_tool_approval( + event, function_tool=func_tool, agent=agent + ) + if approval_status is False: + await self._send_tool_rejection(event, tool=func_tool, agent=agent) + mark_completed = True + return + if approval_status is None: + return - await self._model.send_event( - RealtimeModelSendToolOutput( - tool_call=event, - output=_serialize_tool_output(result), - start_response=True, + await self._put_event( + RealtimeToolStart( + info=self._event_info, + tool=func_tool, + agent=agent, + arguments=event.arguments, + ) ) - ) - await self._put_event( - RealtimeToolEnd( - info=self._event_info, - tool=func_tool, - output=result, + tool_context = ToolContext( + context=self._context_wrapper.context, + usage=self._context_wrapper.usage, + tool_name=event.name, + tool_call_id=event.call_id, + tool_arguments=event.arguments, agent=agent, + ) + result = await invoke_function_tool( + function_tool=func_tool, + context=tool_context, arguments=event.arguments, ) - ) - elif event.name in handoff_map: - handoff = handoff_map[event.name] - tool_context = ToolContext( - context=self._context_wrapper.context, - usage=self._context_wrapper.usage, - tool_name=event.name, - tool_call_id=event.call_id, - tool_arguments=event.arguments, - agent=agent, - ) - # Execute the handoff to get the new agent - result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments) - if not isinstance(result, RealtimeAgent): - raise UserError( - f"Handoff {handoff.tool_name} returned invalid result: {type(result)}" + await self._send_tool_output_completion( + _PendingToolOutput( + tool_call=event, + output=_serialize_tool_output(result), + start_response=True, + tool_end_event=RealtimeToolEnd( + info=self._event_info, + tool=func_tool, + output=result, + agent=agent, + arguments=event.arguments, + ), + ) + ) + mark_completed = True + elif event.name in handoff_map: + handoff = handoff_map[event.name] + tool_context = ToolContext( + context=self._context_wrapper.context, + usage=self._context_wrapper.usage, + tool_name=event.name, + tool_call_id=event.call_id, + tool_arguments=event.arguments, + agent=agent, ) - # Store previous agent for event - previous_agent = agent + # Execute the handoff to get the new agent + result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments) + if not isinstance(result, RealtimeAgent): + raise UserError( + f"Handoff {handoff.tool_name} returned invalid result: {type(result)}" + ) - # Update current agent - self._current_agent = result + # Store previous agent for event + previous_agent = agent - # Get updated model settings from new agent - updated_settings = await self._get_updated_model_settings_from_agent( - starting_settings=None, - agent=self._current_agent, - ) + # Update current agent + self._current_agent = result - # Send handoff event - await self._put_event( - RealtimeHandoffEvent( - from_agent=previous_agent, - to_agent=self._current_agent, - info=self._event_info, + # Get updated model settings from new agent + updated_settings = await self._get_updated_model_settings_from_agent( + starting_settings=None, + agent=self._current_agent, ) - ) - # First, send the session update so the model receives the new instructions - await self._model.send_event( - RealtimeModelSendSessionUpdate(session_settings=updated_settings) - ) + # Send handoff event + await self._put_event( + RealtimeHandoffEvent( + from_agent=previous_agent, + to_agent=self._current_agent, + info=self._event_info, + ) + ) - # Then send tool output to complete the handoff (this triggers a new response) - transfer_message = handoff.get_transfer_message(result) - await self._model.send_event( - RealtimeModelSendToolOutput( - tool_call=event, - output=transfer_message, - start_response=True, + # Send the session update before the tool output that triggers a new response. + transfer_message = handoff.get_transfer_message(result) + await self._send_tool_output_completion( + _PendingToolOutput( + tool_call=event, + output=transfer_message, + start_response=True, + session_update=RealtimeModelSendSessionUpdate( + session_settings=updated_settings + ), + ) ) - ) - else: - error_message = f"Tool {event.name} not found" - await self._model.send_event( - RealtimeModelSendToolOutput( - tool_call=event, - output=error_message, - start_response=False, + mark_completed = True + else: + error_message = f"Tool {event.name} not found" + await self._send_tool_output_completion( + _PendingToolOutput( + tool_call=event, + output=error_message, + start_response=False, + ) ) - ) - await self._put_event( - RealtimeError( - info=self._event_info, - error={"message": error_message}, + mark_completed = True + await self._put_event( + RealtimeError( + info=self._event_info, + error={"message": error_message}, + ) ) - ) + finally: + self._finish_tool_call(event.call_id, mark_completed=mark_completed) + + def _begin_tool_call(self, call_id: str, *, from_pending_approval: bool) -> bool: + if call_id in self._active_tool_call_ids or call_id in self._completed_tool_call_ids: + return False + if not from_pending_approval and call_id in self._pending_tool_calls: + return False + self._active_tool_call_ids.add(call_id) + return True + + def _finish_tool_call(self, call_id: str, *, mark_completed: bool) -> None: + self._active_tool_call_ids.discard(call_id) + if mark_completed: + self._completed_tool_call_ids.add(call_id) @classmethod def _get_new_history( @@ -1064,10 +1156,21 @@ def _cleanup_guardrail_tasks(self) -> None: self._guardrail_tasks.clear() def _enqueue_tool_call_task( - self, event: RealtimeModelToolCallEvent, agent_snapshot: RealtimeAgent + self, + event: RealtimeModelToolCallEvent, + agent_snapshot: RealtimeAgent, + *, + from_pending_approval: bool = False, + call_id_reserved: bool = False, ) -> None: """Run tool calls in the background to avoid blocking realtime transport.""" - task = asyncio.create_task(self._handle_tool_call(event, agent_snapshot=agent_snapshot)) + handle_kwargs: dict[str, Any] = {"agent_snapshot": agent_snapshot} + if from_pending_approval: + handle_kwargs["from_pending_approval"] = True + if call_id_reserved: + handle_kwargs["call_id_reserved"] = True + + task = asyncio.create_task(self._handle_tool_call(event, **handle_kwargs)) self._tool_call_tasks.add(task) task.add_done_callback(self._on_tool_call_task_done) @@ -1081,6 +1184,27 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None: if exception is None: return + if isinstance(exception, _PendingToolOutputSendError): + logger.warning( + "Realtime tool output send failed for call %s; cached output will be retried", + exception.call_id, + exc_info=exception, + ) + asyncio.create_task( + self._put_event( + RealtimeError( + info=self._event_info, + error={ + "message": ( + "Tool output send failed; cached output will be retried: " + f"{exception}" + ) + }, + ) + ) + ) + return + logger.exception("Realtime tool call task failed", exc_info=exception) if self._stored_exception is None: @@ -1123,6 +1247,7 @@ async def _cleanup(self) -> None: # Clear pending approval tracking self._pending_tool_calls.clear() + self._pending_tool_outputs.clear() # Mark as closed self._closed = True diff --git a/src/agents/run_context.py b/src/agents/run_context.py index df7047eb38..1dd74a6040 100644 --- a/src/agents/run_context.py +++ b/src/agents/run_context.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Mapping from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic @@ -225,6 +226,14 @@ def _get_rejection_message_for_key(record: _ApprovalRecord, call_id: str) -> str return record.rejection_messages.get(call_id) return None + @staticmethod + def _restore_approval_value(value: Any) -> bool | list[str]: + if isinstance(value, bool): + return value + if isinstance(value, list): + return [item for item in value if isinstance(item, str)] + return [] + def get_rejection_message( self, tool_name: str, @@ -435,13 +444,17 @@ def get_approval_status( break return status - def _rebuild_approvals(self, approvals: dict[str, dict[str, Any]]) -> None: + def _rebuild_approvals(self, approvals: Any) -> None: """Restore approvals from serialized state.""" self._approvals = {} + if not isinstance(approvals, Mapping): + return for tool_name, record_dict in approvals.items(): + if not isinstance(tool_name, str) or not isinstance(record_dict, dict): + continue record = _ApprovalRecord() - record.approved = record_dict.get("approved", []) - record.rejected = record_dict.get("rejected", []) + record.approved = self._restore_approval_value(record_dict.get("approved", [])) + record.rejected = self._restore_approval_value(record_dict.get("rejected", [])) rejection_messages = record_dict.get("rejection_messages", {}) if isinstance(rejection_messages, dict): record.rejection_messages = { diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index f60c78227a..84c67d6b8f 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -262,8 +262,6 @@ def resolve_trace_settings( group_id = trace_state.group_id if metadata is None and trace_state.metadata is not None: metadata = dict(trace_state.metadata) - if tracing is None and trace_state.tracing_api_key: - tracing = {"api_key": trace_state.tracing_api_key} metadata = add_openai_harness_id_to_metadata( metadata, diff --git a/tests/mcp/test_mcp_auth_params.py b/tests/mcp/test_mcp_auth_params.py index 92b6760b88..14f52faf1e 100644 --- a/tests/mcp/test_mcp_auth_params.py +++ b/tests/mcp/test_mcp_auth_params.py @@ -8,6 +8,7 @@ import pytest from agents.mcp import MCPServerSse, MCPServerStreamableHttp +from agents.mcp.server import _create_default_streamable_http_client class TestMCPServerSseAuthAndFactory: @@ -120,6 +121,7 @@ async def test_streamable_http_default_no_auth(self): timeout=5, sse_read_timeout=300, terminate_on_close=True, + httpx_client_factory=_create_default_streamable_http_client, ) @pytest.mark.asyncio @@ -138,6 +140,7 @@ async def test_streamable_http_with_auth(self): timeout=5, sse_read_timeout=300, terminate_on_close=True, + httpx_client_factory=_create_default_streamable_http_client, auth=auth, ) diff --git a/tests/mcp/test_streamable_http_client_factory.py b/tests/mcp/test_streamable_http_client_factory.py index 068407a2fd..32f258b0f9 100644 --- a/tests/mcp/test_streamable_http_client_factory.py +++ b/tests/mcp/test_streamable_http_client_factory.py @@ -37,17 +37,16 @@ async def test_default_httpx_client_factory(self): } ) - # Create streams should not pass httpx_client_factory when not provided server.create_streams() - # Verify streamablehttp_client was called with correct parameters + # Verify streamablehttp_client was called with the hardened default factory. mock_client.assert_called_once_with( url="http://localhost:8000/mcp", headers={"Authorization": "Bearer token"}, timeout=10, sse_read_timeout=300, # Default value terminate_on_close=True, # Default value - # httpx_client_factory should not be passed when not provided + httpx_client_factory=_create_default_streamable_http_client, ) @pytest.mark.asyncio @@ -334,6 +333,7 @@ async def test_streamable_http_server_passes_ignore_initialized_notification_fai assert kwargs["timeout"] == 5 assert kwargs["sse_read_timeout"] == 300 assert kwargs["terminate_on_close"] is True + assert kwargs["httpx_client_factory"] is _create_default_streamable_http_client assert ( kwargs["transport_factory"] is _InitializedNotificationTolerantStreamableHTTPTransport ) @@ -437,6 +437,6 @@ async def test_default_streamable_http_client_matches_expected_defaults(): assert client.timeout.write == timeout.write assert client.timeout.pool == timeout.pool assert client.auth is auth - assert client.follow_redirects is True + assert client.follow_redirects is False finally: await client.aclose() diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py index ee7f64bfcf..2e1cde6466 100644 --- a/tests/model_settings/test_serialization.py +++ b/tests/model_settings/test_serialization.py @@ -106,6 +106,29 @@ def test_extra_args_serialization() -> None: verify_serialization(model_settings) +def test_traceable_serialization_omits_request_extras() -> None: + model_settings = ModelSettings( + temperature=0.5, + extra_headers={"Authorization": "Bearer provider-token"}, + extra_query={"api-key": "query-token"}, + extra_body={"secret": "body-token"}, + extra_args={"api_key": "arg-token"}, + ) + + json_dict = model_settings.to_json_dict() + assert json_dict["extra_headers"] == {"Authorization": "Bearer provider-token"} + assert json_dict["extra_query"] == {"api-key": "query-token"} + assert json_dict["extra_body"] == {"secret": "body-token"} + assert json_dict["extra_args"] == {"api_key": "arg-token"} + + traceable = model_settings.to_traceable_dict() + assert traceable["temperature"] == 0.5 + assert "extra_headers" not in traceable + assert "extra_query" not in traceable + assert "extra_body" not in traceable + assert "extra_args" not in traceable + + def test_extra_args_resolve() -> None: """Test that extra_args are properly merged in the resolve method.""" base_settings = ModelSettings( diff --git a/tests/models/test_trace_config.py b/tests/models/test_trace_config.py new file mode 100644 index 0000000000..f28a0f701a --- /dev/null +++ b/tests/models/test_trace_config.py @@ -0,0 +1,32 @@ +from agents.model_settings import ModelSettings +from agents.models._trace import model_config_for_trace, sanitize_url_for_trace + + +def test_sanitize_url_for_trace_strips_auth_query_and_fragment() -> None: + assert ( + sanitize_url_for_trace("https://user:pass@example.com/v1?api-key=secret#fragment") + == "https://example.com/v1" + ) + assert sanitize_url_for_trace("https://example.com/v1?token=secret") == "https://example.com/v1" + + +def test_model_config_for_trace_sanitizes_base_url_and_omits_request_extras() -> None: + config = model_config_for_trace( + ModelSettings( + temperature=0.5, + extra_headers={"Authorization": "Bearer provider-token"}, + extra_query={"api-key": "query-token"}, + extra_body={"secret": "body-token"}, + extra_args={"api_key": "arg-token"}, + ), + base_url="https://user:pass@example.com/v1?api-key=secret#fragment", + extra_config={"model_impl": "test-model"}, + ) + + assert config["temperature"] == 0.5 + assert config["base_url"] == "https://example.com/v1" + assert config["model_impl"] == "test-model" + assert "extra_headers" not in config + assert "extra_query" not in config + assert "extra_body" not in config + assert "extra_args" not in config diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 67cf717aa5..e289bc3c9e 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -57,6 +57,7 @@ RealtimeModelSendAudio, RealtimeModelSendInterrupt, RealtimeModelSendSessionUpdate, + RealtimeModelSendToolOutput, RealtimeModelSendUserInput, ) from agents.realtime.session import REJECTION_MESSAGE, RealtimeSession, _serialize_tool_output @@ -1053,6 +1054,105 @@ async def test_function_tool_execution_success( assert tool_end_event.agent == mock_agent assert tool_end_event.arguments == '{"param": "value"}' + @pytest.mark.asyncio + async def test_duplicate_function_tool_call_id_is_ignored( + self, mock_model, mock_agent, mock_function_tool + ): + """Duplicate function call IDs should not re-run side-effecting tools.""" + mock_agent.get_all_tools.return_value = [mock_function_tool] + session = RealtimeSession(mock_model, mock_agent, None) + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_duplicate", arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + await session._handle_tool_call(tool_call_event) + + mock_function_tool.on_invoke_tool.assert_called_once() + assert len(mock_model.sent_tool_outputs) == 1 + + @pytest.mark.asyncio + async def test_function_tool_send_failure_retries_cached_output_without_rerun( + self, mock_agent, mock_function_tool + ): + """A post-execution send failure should retry output without rerunning the tool.""" + + class FailingToolOutputModel(MockRealtimeModel): + def __init__(self): + super().__init__() + self.fail_next_tool_output = True + + async def send_event(self, event): + if isinstance(event, RealtimeModelSendToolOutput) and self.fail_next_tool_output: + self.fail_next_tool_output = False + raise RuntimeError("send failed") + await super().send_event(event) + + mock_agent.get_all_tools.return_value = [mock_function_tool] + mock_model = FailingToolOutputModel() + session = RealtimeSession(mock_model, mock_agent, None) + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_retry_output", arguments="{}" + ) + + with pytest.raises(RuntimeError, match="send failed"): + await session._handle_tool_call(tool_call_event) + + mock_function_tool.on_invoke_tool.assert_called_once() + assert len(mock_model.sent_tool_outputs) == 0 + + await session._handle_tool_call(tool_call_event) + + mock_function_tool.on_invoke_tool.assert_called_once() + assert len(mock_model.sent_tool_outputs) == 1 + + @pytest.mark.asyncio + async def test_async_function_tool_send_failure_retries_cached_output_without_rerun( + self, mock_agent, mock_function_tool + ): + """The async task path should keep cached outputs retryable after send failure.""" + + class FailingToolOutputModel(MockRealtimeModel): + def __init__(self): + super().__init__() + self.fail_next_tool_output = True + + async def send_event(self, event): + if isinstance(event, RealtimeModelSendToolOutput) and self.fail_next_tool_output: + self.fail_next_tool_output = False + raise RuntimeError("send failed") + await super().send_event(event) + + mock_agent.get_all_tools.return_value = [mock_function_tool] + mock_model = FailingToolOutputModel() + session = RealtimeSession(mock_model, mock_agent, None) + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_async_retry_output", arguments="{}" + ) + + await session.on_event(tool_call_event) + tool_call_tasks = list(session._tool_call_tasks) + assert len(tool_call_tasks) == 1 + task_results = await asyncio.gather(*tool_call_tasks, return_exceptions=True) + await asyncio.sleep(0) + + assert len(task_results) == 1 + assert isinstance(task_results[0], RuntimeError) + assert session._stored_exception is None + assert tool_call_event.call_id in session._pending_tool_outputs + mock_function_tool.on_invoke_tool.assert_called_once() + assert len(mock_model.sent_tool_outputs) == 0 + + await session.on_event(tool_call_event) + tool_call_tasks = list(session._tool_call_tasks) + assert len(tool_call_tasks) == 1 + await asyncio.gather(*tool_call_tasks) + + assert session._stored_exception is None + assert tool_call_event.call_id not in session._pending_tool_outputs + mock_function_tool.on_invoke_tool.assert_called_once() + assert len(mock_model.sent_tool_outputs) == 1 + @pytest.mark.asyncio async def test_function_tool_timeout_returns_result_message(self, mock_model, mock_agent): async def invoke_slow_tool(_ctx: ToolContext[Any], _arguments: str) -> str: @@ -1342,6 +1442,40 @@ async def test_function_tool_needs_approval_emits_event( assert approval_event.call_id == tool_call_event.call_id assert approval_event.tool == mock_function_tool + @pytest.mark.asyncio + async def test_duplicate_pending_approval_call_id_is_ignored_and_approval_runs_once( + self, mock_model, mock_agent, mock_function_tool + ): + """A duplicate approval-gated call should not enqueue another approval or run twice.""" + mock_function_tool.needs_approval = True + mock_agent.get_all_tools.return_value = [mock_function_tool] + session = RealtimeSession( + mock_model, + mock_agent, + None, + run_config={"async_tool_calls": False}, + ) + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_duplicate_approval", arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + await session._handle_tool_call(tool_call_event) + + assert list(session._pending_tool_calls) == [tool_call_event.call_id] + approval_events = [] + while not session._event_queue.empty(): + event = await session._event_queue.get() + if isinstance(event, RealtimeToolApprovalRequired): + approval_events.append(event) + assert len(approval_events) == 1 + + await session.approve_tool_call(tool_call_event.call_id) + await session._handle_tool_call(tool_call_event) + + mock_function_tool.on_invoke_tool.assert_called_once() + assert len(mock_model.sent_tool_outputs) == 1 + @pytest.mark.asyncio async def test_approve_pending_tool_call_runs_tool( self, mock_model, mock_agent, mock_function_tool @@ -1375,6 +1509,59 @@ async def test_approve_pending_tool_call_runs_tool( assert any(isinstance(ev, RealtimeToolStart) for ev in events) assert any(isinstance(ev, RealtimeToolEnd) for ev in events) + @pytest.mark.asyncio + async def test_async_approve_pending_tool_call_reserves_call_id_before_task_runs( + self, mock_model + ): + """A duplicate event after approval should not outrun the approved async task.""" + approved_calls: list[str] = [] + duplicate_calls: list[str] = [] + + async def invoke_approved_tool(_ctx: ToolContext[Any], _arguments: str) -> str: + approved_calls.append("approved") + return "approved_result" + + async def invoke_duplicate_tool(_ctx: ToolContext[Any], _arguments: str) -> str: + duplicate_calls.append("duplicate") + return "duplicate_result" + + approved_tool = FunctionTool( + name="test_function", + description="approved", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=invoke_approved_tool, + needs_approval=True, + ) + duplicate_tool = FunctionTool( + name="test_function", + description="duplicate", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=invoke_duplicate_tool, + needs_approval=False, + ) + approved_agent = RealtimeAgent(name="approved_agent", tools=[approved_tool]) + duplicate_agent = RealtimeAgent(name="duplicate_agent", tools=[duplicate_tool]) + session = RealtimeSession(mock_model, approved_agent, None) + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_async_approval_race", arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + await session.approve_tool_call(tool_call_event.call_id) + + assert tool_call_event.call_id in session._active_tool_call_ids + await session._handle_tool_call(tool_call_event, agent_snapshot=duplicate_agent) + + tool_call_tasks = list(session._tool_call_tasks) + assert len(tool_call_tasks) == 1 + await asyncio.gather(*tool_call_tasks) + + assert approved_calls == ["approved"] + assert duplicate_calls == [] + assert len(mock_model.sent_tool_outputs) == 1 + _sent_call, sent_output, _start_response = mock_model.sent_tool_outputs[0] + assert sent_output == "approved_result" + @pytest.mark.asyncio async def test_always_approve_namespaced_tool_call_does_not_approve_bare_tool(self, mock_model): """Always approval should stay scoped to the namespaced tool key.""" @@ -1454,6 +1641,7 @@ async def test_reject_pending_tool_call_sends_rejection_output( await session._handle_tool_call(tool_call_event) await session.reject_tool_call(tool_call_event.call_id) + await session._handle_tool_call(tool_call_event) assert mock_function_tool.on_invoke_tool.call_count == 0 assert len(mock_model.sent_tool_outputs) == 1 @@ -1470,6 +1658,45 @@ async def test_reject_pending_tool_call_sends_rejection_output( isinstance(ev, RealtimeToolEnd) and ev.output == REJECTION_MESSAGE for ev in events ) + @pytest.mark.asyncio + async def test_reject_pending_tool_call_reserves_call_id_before_sending( + self, mock_agent, mock_function_tool + ): + """A duplicate event during rejection output sending should not emit a second output.""" + + class BlockingToolOutputModel(MockRealtimeModel): + def __init__(self): + super().__init__() + self.started = asyncio.Event() + self.release = asyncio.Event() + self.block_next_tool_output = True + + async def send_event(self, event): + if isinstance(event, RealtimeModelSendToolOutput) and self.block_next_tool_output: + self.block_next_tool_output = False + self.started.set() + await self.release.wait() + await super().send_event(event) + + mock_function_tool.needs_approval = True + mock_agent.get_all_tools.return_value = [mock_function_tool] + mock_model = BlockingToolOutputModel() + session = RealtimeSession(mock_model, mock_agent, None) + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_reject_race", arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + reject_task = asyncio.create_task(session.reject_tool_call(tool_call_event.call_id)) + await asyncio.wait_for(mock_model.started.wait(), timeout=1) + + await session._handle_tool_call(tool_call_event) + + mock_model.release.set() + await reject_task + + assert len(mock_model.sent_tool_outputs) == 1 + @pytest.mark.asyncio async def test_reject_pending_tool_call_uses_run_level_formatter( self, mock_model, mock_agent, mock_function_tool diff --git a/tests/test_run_context_approvals.py b/tests/test_run_context_approvals.py index 4acf8bdde1..79b34ac2ba 100644 --- a/tests/test_run_context_approvals.py +++ b/tests/test_run_context_approvals.py @@ -161,6 +161,29 @@ def test_deferred_top_level_legacy_permanent_approval_key_still_restores() -> No ) +def test_rebuild_approvals_ignores_malformed_approval_values() -> None: + context_wrapper = RunContextWrapper(context=None) + + context_wrapper._rebuild_approvals(["not", "a", "mapping"]) # noqa: SLF001 + assert context_wrapper._approvals == {} + + context_wrapper._rebuild_approvals( # noqa: SLF001 + { + "get_weather": { + "approved": {"not": "valid"}, + "rejected": ["call-denied", 123], + "rejection_messages": {"call-denied": "no"}, + }, + 123: {"approved": True}, + } + ) + + assert context_wrapper.is_tool_approved("get_weather", "any-call") is None + assert context_wrapper.is_tool_approved("get_weather", "call-denied") is False + assert context_wrapper.get_rejection_message("get_weather", "call-denied") == "no" + assert context_wrapper.is_tool_approved("123", "any-call") is None + + def test_deferred_top_level_approval_does_not_alias_to_visible_bare_sibling() -> None: agent = Agent(name="test-agent") context_wrapper = RunContextWrapper(context=None) diff --git a/tests/test_run_state.py b/tests/test_run_state.py index 56fc6abbc2..7b2de6b859 100644 --- a/tests/test_run_state.py +++ b/tests/test_run_state.py @@ -29,7 +29,7 @@ from openai.types.responses.tool_param import Mcp from pydantic import BaseModel -from agents import Agent, Model, ModelSettings, Runner, handoff, trace +from agents import Agent, Model, ModelSettings, RunConfig, Runner, handoff, trace from agents.computer import Computer from agents.exceptions import UserError from agents.guardrail import ( @@ -56,6 +56,7 @@ TResponseStreamEvent, ) from agents.run_context import RunContextWrapper +from agents.run_internal.agent_runner_helpers import resolve_trace_settings from agents.run_internal.items import run_items_to_input_items from agents.run_internal.run_loop import ( NextStepInterruption, @@ -723,6 +724,18 @@ async def test_trace_api_key_serialization_is_opt_in(self): == default_json["trace"]["tracing_api_key_hash"] ) + *_, restored_config = resolve_trace_settings( + run_state=restored_with_key, + run_config=RunConfig(), + ) + assert restored_config is None + + *_, explicit_config = resolve_trace_settings( + run_state=restored_with_key, + run_config=RunConfig(tracing={"api_key": "explicit-trace-key"}), + ) + assert explicit_config == {"api_key": "explicit-trace-key"} + async def test_throws_error_if_schema_version_is_missing_or_invalid(self): """Test that deserialization fails with missing or invalid schema version.""" context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) From 17f7caeaa33d97eee7152f9834af5e706b8f90e2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 19 May 2026 10:24:42 +0900 Subject: [PATCH 2/2] Release 0.17.3 (#3417) --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f64090f8aa..4d0122049f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-agents" -version = "0.17.2" +version = "0.17.3" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.10" diff --git a/uv.lock b/uv.lock index d9d25ce302..3e5cb31b70 100644 --- a/uv.lock +++ b/uv.lock @@ -2431,7 +2431,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.17.2" +version = "0.17.3" source = { editable = "." } dependencies = [ { name = "griffelib" },