diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index ca0c5843a3f..92e83c5df4d 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -82,6 +82,7 @@ agent_framework/ - **`MCPStdioTool`** / **`MCPStreamableHTTPTool`** / **`MCPWebsocketTool`** - Transport-specific subclasses. - **Argument allowlist (`_prepare_call_kwargs`)** - Before each `tools/call`, kwargs are filtered to an **allowlist** built from the tool's declared parameters (`inputSchema.properties`) plus any user-configured extras. Framework runtime kwargs injected through the function-invocation pipeline (e.g. `thread`, `conversation_id`, `chat_options`, `options`, `response_format`) are stripped by default rather than forwarded. A tool that declares no usable `properties` (including schemas with `additionalProperties: true`) forwards only the configured extras. The `_MCP_FRAMEWORK_DENYLIST` is a safety net for framework-named params a server *declares* in its schema (those are dropped); names explicitly opted in via `additional_tool_argument_names` always win. The reserved `_meta` key is extracted as MCP request metadata, never forwarded as an argument. - **`additional_tool_argument_names`** (constructor arg on all `MCPTool` subclasses) - Opt extra argument names back into the allowlist. Accepts a `Sequence[str]` (applied to every tool) or a `Mapping[str, Sequence[str]]` keyed by **remote tool name**, where the reserved key `"*"` denotes global extras. It is configured only in user code at construction; there is **no per-call/runtime override**, so a model-issued tool call cannot change which names pass through. To use a server that accepts `additionalProperties: true`, list the extra names here and then either (1) manually extend that tool's `inputSchema` (via the `.functions` list after connecting) so the model is prompted to supply them, or (2) supply the values yourself via `function_invocation_kwargs`. If a name is supplied by both the model and `function_invocation_kwargs`, the model-supplied value wins. +- **Sampling guardrails** (`sampling_callback`) - Passing `client=` advertises `SamplingCapability` so the server can send `sampling/createMessage`. Because remote servers are untrusted (confused-deputy risk), the default `sampling_callback` is **deny-by-default** and applies, in order: a per-session rate limit (`sampling_max_requests`, default `_DEFAULT_SAMPLING_MAX_REQUESTS`), an approval gate (`sampling_approval_callback`), and a `maxTokens` cap (`sampling_max_tokens`, default `_DEFAULT_SAMPLING_MAX_TOKENS`). The approval callback (constructor arg on all subclasses; exported type alias `SamplingApprovalCallback`) receives the raw `CreateMessageRequestParams`, may be sync or async, and must return truthy to approve. When it is `None` (the default) every sampling request is denied; pass `lambda params: True` to restore legacy auto-approve as an explicit opt-in. Requests and denials are logged at WARNING (content is not logged). The per-session counter resets in `_reset_session_state`. - **`MCPTaskOptions`** (experimental, `MCP_LONG_RUNNING_TASKS` feature, **frozen**) - Per-tool-instance options controlling the SEP-2663 long-running task lifecycle. When the server advertises a tool with `execution.taskSupport == "required"`, `MCPTool.call_tool` transparently routes through `call_tool_as_task`, which sends an augmented `tools/call`, polls `tasks/get` until terminal, and reinterprets `tasks/result` as a normal `CallToolResult`. Instances are immutable; replace via `MCPTool.task_options = MCPTaskOptions(...)`. Fields: - `default_ttl: timedelta | None` — forwarded to the server as `params.task.ttl` (milliseconds). When `None`, the server's default applies. - `cancel_remote_task_on_local_cancellation: bool = True` — only gates the `CancelledError` path. Abandonment paths (see below) always cancel. diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 33d16cd9bbb..b287b4be579 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -124,7 +124,7 @@ TodoSessionStore, TodoStore, ) -from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool +from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPTaskOptions, MCPWebsocketTool, SamplingApprovalCallback from ._middleware import ( AgentContext, AgentMiddleware, @@ -472,6 +472,7 @@ "RunContext", "Runner", "RunnerContext", + "SamplingApprovalCallback", "SecretString", "SelectiveToolCallCompactionStrategy", "SessionContext", diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 784c618302d..ccb0be3b70f 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -16,6 +16,7 @@ from dataclasses import dataclass from datetime import timedelta from functools import partial +from inspect import isawaitable from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast from opentelemetry import propagate @@ -99,6 +100,22 @@ class MCPSpecificApproval(TypedDict, total=False): MCP_DEFAULT_TIMEOUT = 30 MCP_DEFAULT_SSE_READ_TIMEOUT = 60 * 5 +# Default safety limits applied to server-initiated MCP sampling requests +# (``sampling/createMessage``). MCP servers are untrusted third parties, so the +# default ``sampling_callback`` denies requests unless an approval callback is +# supplied, and bounds the cost of any approved request. +# - ``_DEFAULT_SAMPLING_MAX_TOKENS`` clamps the server-requested ``maxTokens``. +# - ``_DEFAULT_SAMPLING_MAX_REQUESTS`` caps the number of sampling requests per +# session connection (the counter resets on reconnect). +_DEFAULT_SAMPLING_MAX_TOKENS = 4096 +_DEFAULT_SAMPLING_MAX_REQUESTS = 25 + +# A user-supplied gate invoked before each server-initiated sampling request is +# forwarded to the chat client. It receives the raw ``CreateMessageRequestParams`` +# and returns (or awaits to) a truthy value to approve the request or a falsy +# value to deny it. Both synchronous and asynchronous callables are supported. +SamplingApprovalCallback = Callable[["types.CreateMessageRequestParams"], "bool | Coroutine[Any, Any, bool]"] + # region: Helpers LOG_LEVEL_MAPPING: dict[str, int] = { @@ -345,6 +362,9 @@ def __init__( session: ClientSession | None = None, request_timeout: int | None = None, client: SupportsChatGetResponse | None = None, + sampling_approval_callback: SamplingApprovalCallback | None = None, + sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS, + sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS, additional_properties: dict[str, Any] | None = None, task_options: MCPTaskOptions | None = None, additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, @@ -378,6 +398,20 @@ def __init__( session: An existing MCP client session to use. request_timeout: Timeout in seconds for MCP requests. client: A chat client for sampling callbacks. + sampling_approval_callback: Optional gate invoked before each server-initiated + ``sampling/createMessage`` request is forwarded to ``client``. It receives the + raw ``CreateMessageRequestParams`` and may be synchronous or asynchronous; + returning a truthy value approves the request and a falsy value denies it. When + ``None`` (the default), every sampling request is **denied** because MCP servers + are untrusted third parties (confused-deputy risk). To restore the legacy + auto-approve behavior, pass ``lambda params: True`` as an explicit, conscious + opt-in. + sampling_max_tokens: Upper bound applied to the server-requested ``maxTokens`` for an + approved sampling request. The effective value is ``min(requested, cap)``. Set to + ``None`` to disable the cap. Defaults to ``_DEFAULT_SAMPLING_MAX_TOKENS``. + sampling_max_requests: Maximum number of sampling requests allowed per session + connection; further requests are rejected. The counter resets on reconnect. Set + to ``None`` to disable the limit. Defaults to ``_DEFAULT_SAMPLING_MAX_REQUESTS``. additional_properties: Additional properties for the tool. task_options: Options controlling how long-running MCP tasks are driven for tools that advertise ``execution.taskSupport == "required"``. When ``None``, @@ -410,6 +444,10 @@ def __init__( self.session = session self.request_timeout = request_timeout self.client = client + self.sampling_approval_callback = sampling_approval_callback + self.sampling_max_tokens = sampling_max_tokens + self.sampling_max_requests = sampling_max_requests + self._sampling_request_count = 0 self._functions: list[FunctionTool] = [] self._tool_call_meta_by_name: dict[str, dict[str, Any]] = {} self._tool_task_support_by_name: dict[str, str] = {} @@ -840,6 +878,7 @@ def _reset_session_state(self) -> None: self._supports_prompts = True self._supports_logging = None self._ping_available = True + self._sampling_request_count = 0 def _set_server_capabilities(self, capabilities: types.ServerCapabilities | None) -> None: self._server_capabilities = capabilities @@ -994,6 +1033,49 @@ async def _connect_on_owner(self, *, reset: bool = False, load_configured: bool except Exception as exc: logger.warning("Failed to set log level to %s", logger.level, exc_info=exc) + async def _sampling_request_approved(self, params: types.CreateMessageRequestParams) -> bool: + """Run the configured sampling approval gate. + + Returns ``True`` only when an approval callback is configured and approves the request. + When no callback is set, the request is denied (safe default for untrusted servers). + """ + callback = self.sampling_approval_callback + if callback is None: + logger.warning( + "Denying MCP sampling request from '%s': no 'sampling_approval_callback' configured.", + self.name, + ) + return False + try: + outcome = callback(params) + if isawaitable(outcome): + outcome = await outcome + except Exception as ex: + logger.warning( + "Denying MCP sampling request from '%s': approval callback raised %s.", + self.name, + ex, + exc_info=True, + ) + return False + approved = bool(outcome) + if not approved: + logger.warning("MCP sampling request from '%s' was denied by the approval callback.", self.name) + return approved + + def _capped_sampling_max_tokens(self, requested: int) -> int: + """Clamp the server-requested ``maxTokens`` to ``sampling_max_tokens`` when configured.""" + cap = self.sampling_max_tokens + if cap is not None and requested > cap: + logger.warning( + "Capping MCP sampling maxTokens for '%s' from %d to %d.", + self.name, + requested, + cap, + ) + return cap + return requested + async def sampling_callback( self, context: RequestContext[ClientSession, Any], @@ -1001,20 +1083,32 @@ async def sampling_callback( ) -> types.CreateMessageResult | types.ErrorData: """Callback function for sampling. - This function is called when the MCP server needs to get a message completed. - It uses the configured chat client to generate responses. + This function is called when the MCP server sends a ``sampling/createMessage`` + request. It enforces safety guardrails and, if the request is approved, uses the + configured chat client to generate a response. + + Safety: + MCP servers are untrusted third parties, so forwarding server-controlled prompts + to the chat client without review is a confused-deputy risk. This callback + therefore applies, in order: a per-session rate limit + (``sampling_max_requests``), an approval gate (``sampling_approval_callback``, + which **denies by default** when not configured), and a ``maxTokens`` cap + (``sampling_max_tokens``). To allow sampling, pass a ``sampling_approval_callback`` + that returns a truthy value (use ``lambda params: True`` to auto-approve as an + explicit opt-in). Note: - This is a simple version of this function. It can be overridden to allow - more complex sampling. It gets added to the session at initialization time, - so overriding it is the best way to customize this behavior. + This is the default implementation. It can be overridden to allow more complex + sampling. It gets added to the session at initialization time, so overriding it is + the best way to customize this behavior. Args: context: The request context from the MCP server. params: The message creation request parameters. Returns: - Either a CreateMessageResult with the generated message or ErrorData if generation fails. + Either a CreateMessageResult with the generated message or ErrorData if the request + is denied, rate limited, or generation fails. """ from mcp import types @@ -1023,7 +1117,38 @@ async def sampling_callback( code=types.INTERNAL_ERROR, message="No chat client available. Please set a chat client.", ) - logger.debug("Sampling callback called with params: %s", params) + + logger.warning( + "MCP server '%s' sent a sampling/createMessage request (%d message(s), maxTokens=%s).", + self.name, + len(params.messages), + params.maxTokens, + ) + + if self.sampling_max_requests is not None: + if self._sampling_request_count >= self.sampling_max_requests: + logger.warning( + "Denying MCP sampling request from '%s': per-session limit of %d reached.", + self.name, + self.sampling_max_requests, + ) + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Sampling rate limit exceeded for this MCP session.", + ) + self._sampling_request_count += 1 + + if not await self._sampling_request_approved(params): + if self.sampling_approval_callback is None: + message = ( + "Sampling request denied. MCP sampling is disabled by default for untrusted " + "servers; provide a 'sampling_approval_callback' that approves the request to " + "enable it." + ) + else: + message = "Sampling request denied by the 'sampling_approval_callback'." + return types.ErrorData(code=types.INVALID_REQUEST, message=message) + messages: list[Message] = [] for msg in params.messages: messages.append(self._parse_message_from_mcp(msg)) @@ -1045,7 +1170,7 @@ async def sampling_callback( if params.temperature is not None: options["temperature"] = params.temperature - options["max_tokens"] = params.maxTokens + options["max_tokens"] = self._capped_sampling_max_tokens(params.maxTokens) if params.stopSequences is not None: options["stop"] = params.stopSequences @@ -2219,6 +2344,9 @@ def __init__( env: dict[str, str] | None = None, encoding: str | None = None, client: SupportsChatGetResponse | None = None, + sampling_approval_callback: SamplingApprovalCallback | None = None, + sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS, + sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS, additional_properties: dict[str, Any] | None = None, task_options: MCPTaskOptions | None = None, additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, @@ -2266,6 +2394,16 @@ def __init__( env: The environment variables to set for the command. encoding: The encoding to use for the command output. client: The chat client to use for sampling. + sampling_approval_callback: Optional gate run before each server-initiated + ``sampling/createMessage`` request reaches ``client``. Receives the raw + ``CreateMessageRequestParams`` (sync or async); a truthy return approves the + request, a falsy return denies it. When ``None`` (the default) every sampling + request is **denied**, since MCP servers are untrusted (confused-deputy risk). + Pass ``lambda params: True`` to auto-approve as an explicit opt-in. + sampling_max_tokens: Cap applied to an approved request's ``maxTokens`` + (``min(requested, cap)``); ``None`` disables it. + sampling_max_requests: Per-session cap on the number of sampling requests; further + requests are rejected. Resets on reconnect. ``None`` disables it. task_options: Options for tools that advertise ``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`. additional_tool_argument_names: Extra argument names to forward to the MCP server in @@ -2300,6 +2438,9 @@ def __init__( request_timeout=request_timeout, task_options=task_options, additional_tool_argument_names=additional_tool_argument_names, + sampling_approval_callback=sampling_approval_callback, + sampling_max_tokens=sampling_max_tokens, + sampling_max_requests=sampling_max_requests, ) self.command = command self.args = args or [] @@ -2375,6 +2516,9 @@ def __init__( allowed_tools: Collection[str] | None = None, terminate_on_close: bool | None = None, client: SupportsChatGetResponse | None = None, + sampling_approval_callback: SamplingApprovalCallback | None = None, + sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS, + sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS, additional_properties: dict[str, Any] | None = None, http_client: AsyncClient | None = None, header_provider: Callable[[dict[str, Any]], dict[str, str]] | None = None, @@ -2423,6 +2567,16 @@ def __init__( additional_properties: Additional properties. terminate_on_close: Close the transport when the MCP client is terminated. client: The chat client to use for sampling. + sampling_approval_callback: Optional gate run before each server-initiated + ``sampling/createMessage`` request reaches ``client``. Receives the raw + ``CreateMessageRequestParams`` (sync or async); a truthy return approves the + request, a falsy return denies it. When ``None`` (the default) every sampling + request is **denied**, since MCP servers are untrusted (confused-deputy risk). + Pass ``lambda params: True`` to auto-approve as an explicit opt-in. + sampling_max_tokens: Cap applied to an approved request's ``maxTokens`` + (``min(requested, cap)``); ``None`` disables it. + sampling_max_requests: Per-session cap on the number of sampling requests; further + requests are rejected. Resets on reconnect. ``None`` disables it. http_client: Optional asyncClient to use. If not provided, the ``streamable_http_client`` API will create and manage a default client. To configure headers, timeouts, or other HTTP client settings, create @@ -2466,6 +2620,9 @@ def __init__( request_timeout=request_timeout, task_options=task_options, additional_tool_argument_names=additional_tool_argument_names, + sampling_approval_callback=sampling_approval_callback, + sampling_max_tokens=sampling_max_tokens, + sampling_max_requests=sampling_max_requests, ) self.url = url self.terminate_on_close = terminate_on_close @@ -2590,6 +2747,9 @@ def __init__( approval_mode: (Literal["always_require", "never_require"] | MCPSpecificApproval | None) = None, allowed_tools: Collection[str] | None = None, client: SupportsChatGetResponse | None = None, + sampling_approval_callback: SamplingApprovalCallback | None = None, + sampling_max_tokens: int | None = _DEFAULT_SAMPLING_MAX_TOKENS, + sampling_max_requests: int | None = _DEFAULT_SAMPLING_MAX_REQUESTS, additional_properties: dict[str, Any] | None = None, task_options: MCPTaskOptions | None = None, additional_tool_argument_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, @@ -2635,6 +2795,16 @@ def __init__( allowed_tools: A list of tools that are allowed to use this tool. additional_properties: Additional properties. client: The chat client to use for sampling. + sampling_approval_callback: Optional gate run before each server-initiated + ``sampling/createMessage`` request reaches ``client``. Receives the raw + ``CreateMessageRequestParams`` (sync or async); a truthy return approves the + request, a falsy return denies it. When ``None`` (the default) every sampling + request is **denied**, since MCP servers are untrusted (confused-deputy risk). + Pass ``lambda params: True`` to auto-approve as an explicit opt-in. + sampling_max_tokens: Cap applied to an approved request's ``maxTokens`` + (``min(requested, cap)``); ``None`` disables it. + sampling_max_requests: Per-session cap on the number of sampling requests; further + requests are rejected. Resets on reconnect. ``None`` disables it. task_options: Options for tools that advertise ``execution.taskSupport == "required"``. See :class:`MCPTaskOptions`. additional_tool_argument_names: Extra argument names to forward to the MCP server in @@ -2669,6 +2839,9 @@ def __init__( request_timeout=request_timeout, task_options=task_options, additional_tool_argument_names=additional_tool_argument_names, + sampling_approval_callback=sampling_approval_callback, + sampling_max_tokens=sampling_max_tokens, + sampling_max_requests=sampling_max_requests, ) self.url = url self._client_kwargs = kwargs diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 7c45296cbb0..ce69e9766b2 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -1813,6 +1813,18 @@ async def blocking_load_tools(): assert len(tool._pending_reload_tasks) == 0 +def _approve(_params: object) -> bool: + """Approving sampling gate used by tests that exercise forwarding behavior.""" + return True + + +def _make_sampling_response(text: str = "response", model: str = "test-model") -> Mock: + mock_response = Mock() + mock_response.messages = [Message(role="assistant", contents=[Content.from_text(text)])] + mock_response.model = model + return mock_response + + async def test_mcp_tool_sampling_callback_no_client(): """Test sampling callback error path when no chat client is available.""" tool = MCPStdioTool(name="test_tool", command="python") @@ -1828,9 +1840,190 @@ async def test_mcp_tool_sampling_callback_no_client(): assert "No chat client available" in result.message +async def test_mcp_tool_sampling_callback_denies_by_default(): + """Sampling is denied when no approval callback is configured (safe default).""" + tool = MCPStdioTool(name="test_tool", command="python") + mock_chat_client = AsyncMock() + tool.client = mock_chat_client + + params = Mock() + params.messages = [] + params.maxTokens = 128 + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.ErrorData) + assert result.code == types.INVALID_REQUEST + assert "denied" in result.message + assert "sampling_approval_callback" in result.message + mock_chat_client.get_response.assert_not_called() + + +async def test_mcp_tool_sampling_callback_denied_by_callback(): + """Sampling is denied when the approval callback returns a falsy value.""" + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=lambda params: False) + mock_chat_client = AsyncMock() + tool.client = mock_chat_client + + params = Mock() + params.messages = [] + params.maxTokens = 128 + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.ErrorData) + assert result.code == types.INVALID_REQUEST + assert "denied by the 'sampling_approval_callback'" in result.message + mock_chat_client.get_response.assert_not_called() + + +async def test_mcp_tool_sampling_callback_callback_exception_denies(): + """An approval callback that raises results in denial, not an LLM call.""" + + def boom(_params: object) -> bool: + raise RuntimeError("approval error") + + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=boom) + mock_chat_client = AsyncMock() + tool.client = mock_chat_client + + params = Mock() + params.messages = [] + params.maxTokens = 128 + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.ErrorData) + assert result.code == types.INVALID_REQUEST + mock_chat_client.get_response.assert_not_called() + + +async def test_mcp_tool_sampling_callback_async_approval(): + """An async approval callback that approves allows the request through.""" + + async def approve(_params: object) -> bool: + return True + + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=approve) + mock_chat_client = AsyncMock() + mock_chat_client.get_response.return_value = _make_sampling_response("ok") + tool.client = mock_chat_client + + params = Mock() + params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))] + params.temperature = None + params.maxTokens = 100 + params.stopSequences = None + params.systemPrompt = None + params.tools = None + params.toolChoice = None + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.CreateMessageResult) + assert result.content.text == "ok" + mock_chat_client.get_response.assert_awaited_once() + + +async def test_mcp_tool_sampling_callback_clamps_max_tokens(): + """An approved request's maxTokens is clamped to sampling_max_tokens.""" + tool = MCPStdioTool( + name="test_tool", + command="python", + sampling_approval_callback=_approve, + sampling_max_tokens=512, + ) + mock_chat_client = AsyncMock() + mock_chat_client.get_response.return_value = _make_sampling_response() + tool.client = mock_chat_client + + params = Mock() + params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))] + params.temperature = None + params.maxTokens = 1_000_000 + params.stopSequences = None + params.systemPrompt = None + params.tools = None + params.toolChoice = None + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.CreateMessageResult) + options = mock_chat_client.get_response.call_args.kwargs.get("options") or {} + assert options["max_tokens"] == 512 + + +async def test_mcp_tool_sampling_callback_does_not_clamp_under_cap(): + """A request below the cap keeps its requested maxTokens.""" + tool = MCPStdioTool( + name="test_tool", + command="python", + sampling_approval_callback=_approve, + sampling_max_tokens=512, + ) + mock_chat_client = AsyncMock() + mock_chat_client.get_response.return_value = _make_sampling_response() + tool.client = mock_chat_client + + params = Mock() + params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))] + params.temperature = None + params.maxTokens = 100 + params.stopSequences = None + params.systemPrompt = None + params.tools = None + params.toolChoice = None + + result = await tool.sampling_callback(Mock(), params) + + assert isinstance(result, types.CreateMessageResult) + options = mock_chat_client.get_response.call_args.kwargs.get("options") or {} + assert options["max_tokens"] == 100 + + +async def test_mcp_tool_sampling_callback_rate_limited(): + """Sampling requests beyond sampling_max_requests are rejected per session.""" + tool = MCPStdioTool( + name="test_tool", + command="python", + sampling_approval_callback=_approve, + sampling_max_requests=2, + ) + mock_chat_client = AsyncMock() + mock_chat_client.get_response.return_value = _make_sampling_response() + tool.client = mock_chat_client + + def make_params() -> Mock: + params = Mock() + params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))] + params.temperature = None + params.maxTokens = 100 + params.stopSequences = None + params.systemPrompt = None + params.tools = None + params.toolChoice = None + return params + + first = await tool.sampling_callback(Mock(), make_params()) + second = await tool.sampling_callback(Mock(), make_params()) + third = await tool.sampling_callback(Mock(), make_params()) + + assert isinstance(first, types.CreateMessageResult) + assert isinstance(second, types.CreateMessageResult) + assert isinstance(third, types.ErrorData) + assert third.code == types.INVALID_REQUEST + assert "rate limit" in third.message.lower() + assert mock_chat_client.get_response.await_count == 2 + + # The counter resets on a session reset. + tool._reset_session_state() + fourth = await tool.sampling_callback(Mock(), make_params()) + assert isinstance(fourth, types.CreateMessageResult) + + async def test_mcp_tool_sampling_callback_chat_client_exception(): """Test sampling callback when chat client raises exception.""" - tool = MCPStdioTool(name="test_tool", command="python") + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve) # Mock chat client that raises exception mock_chat_client = AsyncMock() @@ -1846,7 +2039,7 @@ async def test_mcp_tool_sampling_callback_chat_client_exception(): mock_message.content.text = "Test question" params.messages = [mock_message] params.temperature = None - params.maxTokens = None + params.maxTokens = 100 params.stopSequences = None params.systemPrompt = None params.tools = None @@ -1863,7 +2056,7 @@ async def test_mcp_tool_sampling_callback_no_valid_content(): """Test sampling callback when response has no valid content types.""" from agent_framework import Message - tool = MCPStdioTool(name="test_tool", command="python") + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve) # Mock chat client with response containing only invalid content types mock_chat_client = AsyncMock() @@ -1892,7 +2085,7 @@ async def test_mcp_tool_sampling_callback_no_valid_content(): mock_message.content.text = "Test question" params.messages = [mock_message] params.temperature = None - params.maxTokens = None + params.maxTokens = 100 params.stopSequences = None params.systemPrompt = None params.tools = None @@ -1905,18 +2098,18 @@ async def test_mcp_tool_sampling_callback_no_valid_content(): assert "Failed to get right content types from the response." in result.message mock_chat_client.get_response.assert_awaited_once() _, kwargs = mock_chat_client.get_response.await_args - assert kwargs["options"] == {"max_tokens": None} + assert kwargs["options"] == {"max_tokens": 100} async def test_mcp_tool_sampling_callback_no_response_and_successful_message_creation(): """Test sampling callback when the chat client returns no response and then valid content.""" - tool = MCPStdioTool(name="test_tool", command="python") + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve) tool.client = AsyncMock() params = Mock() params.messages = [types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hi"))] params.temperature = None - params.maxTokens = None + params.maxTokens = 100 params.stopSequences = None params.systemPrompt = None params.tools = None @@ -1955,7 +2148,7 @@ async def test_mcp_tool_sampling_callback_forwards_system_prompt(): """Test sampling callback passes systemPrompt as instructions in options.""" from agent_framework import Message - tool = MCPStdioTool(name="test_tool", command="python") + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve) mock_chat_client = AsyncMock() mock_response = Mock() @@ -1972,7 +2165,7 @@ async def test_mcp_tool_sampling_callback_forwards_system_prompt(): mock_message.content.text = "Test question" params.messages = [mock_message] params.temperature = None - params.maxTokens = None + params.maxTokens = 100 params.stopSequences = None params.systemPrompt = "You are a helpful assistant" params.tools = None @@ -1990,7 +2183,7 @@ async def test_mcp_tool_sampling_callback_forwards_tools(): """Test sampling callback converts MCP tools to FunctionTools and passes them in options.""" from agent_framework import FunctionTool, Message - tool = MCPStdioTool(name="test_tool", command="python") + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve) mock_chat_client = AsyncMock() mock_response = Mock() @@ -2013,7 +2206,7 @@ async def test_mcp_tool_sampling_callback_forwards_tools(): mock_message.content.text = "Test question" params.messages = [mock_message] params.temperature = None - params.maxTokens = None + params.maxTokens = 100 params.stopSequences = None params.systemPrompt = None params.tools = [mcp_tool] @@ -2036,7 +2229,7 @@ async def test_mcp_tool_sampling_callback_forwards_tool_choice(): """Test sampling callback passes toolChoice mode in options.""" from agent_framework import Message - tool = MCPStdioTool(name="test_tool", command="python") + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve) mock_chat_client = AsyncMock() mock_response = Mock() @@ -2053,7 +2246,7 @@ async def test_mcp_tool_sampling_callback_forwards_tool_choice(): mock_message.content.text = "Test question" params.messages = [mock_message] params.temperature = None - params.maxTokens = None + params.maxTokens = 100 params.stopSequences = None params.systemPrompt = None params.tools = None @@ -2071,7 +2264,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_system_prompt(): """Test sampling callback forwards empty string systemPrompt as instructions.""" from agent_framework import Message - tool = MCPStdioTool(name="test_tool", command="python") + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve) mock_chat_client = AsyncMock() mock_response = Mock() @@ -2088,7 +2281,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_system_prompt(): mock_message.content.text = "Test question" params.messages = [mock_message] params.temperature = None - params.maxTokens = None + params.maxTokens = 100 params.stopSequences = None params.systemPrompt = "" params.tools = None @@ -2106,7 +2299,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_tools_list(): """Test sampling callback forwards empty tools list in options.""" from agent_framework import Message - tool = MCPStdioTool(name="test_tool", command="python") + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve) mock_chat_client = AsyncMock() mock_response = Mock() @@ -2123,7 +2316,7 @@ async def test_mcp_tool_sampling_callback_forwards_empty_tools_list(): mock_message.content.text = "Test question" params.messages = [mock_message] params.temperature = None - params.maxTokens = None + params.maxTokens = 100 params.stopSequences = None params.systemPrompt = None params.tools = [] @@ -2141,7 +2334,7 @@ async def test_mcp_tool_sampling_callback_forwards_generation_params_in_options( """Test sampling callback passes temperature, max_tokens, and stop in options.""" from agent_framework import Message - tool = MCPStdioTool(name="test_tool", command="python") + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve) mock_chat_client = AsyncMock() mock_response = Mock() @@ -2182,7 +2375,7 @@ async def test_mcp_tool_sampling_callback_omits_temperature_when_none(): """Test sampling callback does not set temperature in options when it is None.""" from agent_framework import Message - tool = MCPStdioTool(name="test_tool", command="python") + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve) mock_chat_client = AsyncMock() mock_response = Mock() @@ -2219,7 +2412,7 @@ async def test_mcp_tool_sampling_callback_always_passes_max_tokens(): """Test sampling callback always sets max_tokens in options since maxTokens is a required int field.""" from agent_framework import Message - tool = MCPStdioTool(name="test_tool", command="python") + tool = MCPStdioTool(name="test_tool", command="python", sampling_approval_callback=_approve) mock_chat_client = AsyncMock() mock_response = Mock() diff --git a/python/samples/02-agents/mcp/README.md b/python/samples/02-agents/mcp/README.md index de57286320e..53af7d31a83 100644 --- a/python/samples/02-agents/mcp/README.md +++ b/python/samples/02-agents/mcp/README.md @@ -14,6 +14,7 @@ The Model Context Protocol (MCP) is an open standard for connecting AI agents to | **API Key Authentication** | [`mcp_api_key_auth.py`](mcp_api_key_auth.py) | Demonstrates API key authentication with MCP servers using `header_provider`, runtime invocation kwargs, and a command-line API key argument | | **GitHub Integration with PAT** | [`mcp_github_pat.py`](mcp_github_pat.py) | Demonstrates connecting to GitHub's MCP server using Personal Access Token (PAT) authentication | | **Long-Running Task** | [`mcp_long_running_task.py`](mcp_long_running_task.py) | Demonstrates transparent SEP-2663 long-running task handling for MCP tools that advertise `taskSupport=required`. Self-spawns a stdio MCP child server | +| **Sampling Approval** | [`mcp_sampling_approval.py`](mcp_sampling_approval.py) | Demonstrates gating server-initiated `sampling/createMessage` requests with a `sampling_approval_callback`, plus the `sampling_max_tokens` and `sampling_max_requests` guardrails. MCP sampling is denied by default | ## Prerequisites diff --git a/python/samples/02-agents/mcp/mcp_sampling_approval.py b/python/samples/02-agents/mcp/mcp_sampling_approval.py new file mode 100644 index 00000000000..0d359b7aecb --- /dev/null +++ b/python/samples/02-agents/mcp/mcp_sampling_approval.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio + +from agent_framework import Agent, MCPStreamableHTTPTool +from agent_framework.openai import OpenAIChatClient +from dotenv import load_dotenv +from mcp import types + +# Load environment variables from .env file +load_dotenv() + +""" +MCP Sampling Approval Example + +MCP servers can send the client a ``sampling/createMessage`` request, asking the +client to run an LLM completion on the server's behalf. Because remote MCP +servers are untrusted third parties, forwarding these server-controlled prompts +to your chat client without review is a confused-deputy risk: a malicious server +could exfiltrate context, force tool calls, or burn through your token budget. + +For that reason Agent Framework **denies MCP sampling by default**. To allow it, +pass a ``sampling_approval_callback`` to the MCP tool. The callback receives the +raw ``CreateMessageRequestParams`` and returns ``True`` to approve or ``False`` +to deny. It may be synchronous or asynchronous, so you can implement a +human-in-the-loop prompt, a policy check, or an audit log. + +Two further guardrails apply to approved requests: +- ``sampling_max_tokens`` caps the server-requested ``maxTokens``. +- ``sampling_max_requests`` limits how many sampling requests a single session + may make. + +To restore the legacy "always approve" behavior (only do this for servers you +trust), pass ``sampling_approval_callback=lambda params: True``. +""" + + +async def approve_sampling(params: types.CreateMessageRequestParams) -> bool: + """Human-in-the-loop approval gate for server-initiated sampling. + + Shows the server-supplied system prompt and messages, then asks the user to + approve or deny. Returning ``False`` rejects the request. + """ + print("\n--- MCP server requested a sampling/createMessage ---") + if params.systemPrompt: + print(f"System prompt: {params.systemPrompt}") + for message in params.messages: + text = getattr(message.content, "text", message.content) + print(f"{message.role}: {text}") + answer = await asyncio.to_thread(input, "Approve this sampling request? [y/N]: ") + return answer.strip().lower() in {"y", "yes"} + + +async def main() -> None: + """Run an agent against an MCP server with a sampling approval gate.""" + async with Agent( + client=OpenAIChatClient(), + name="Agent", + instructions="You are a helpful assistant. Use your MCP tool when answering the user's question.", + tools=MCPStreamableHTTPTool( + name="MCP tool", + description="MCP tool description.", + url="", + # Passing ``client`` enables sampling; the approval callback gates it. + client=OpenAIChatClient(), + sampling_approval_callback=approve_sampling, + sampling_max_tokens=2048, + sampling_max_requests=5, + ), + ) as agent: + query = "Use your MCP tool to help answer this question." + print(f"User: {query}") + result = await agent.run(query) + print(f"Agent: {result.text}") + + +if __name__ == "__main__": + asyncio.run(main())