diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 5caef0b263..4aac78bc90 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -322,11 +322,21 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if tools: if tool_list := tools.get_func_desc_anthropic_style(): payloads["tools"] = tool_list - payloads["tool_choice"] = { - "type": "any" - if payloads.get("tool_choice") == "required" - else "auto" - } + # 转换为 Anthropic API 要求的 tool_choice 格式 + # 参考: https://platform.claude.com/docs/en/agents-and-tools/tool-use/define-tools#providing-tool-use-examples + if "tool_choice" in payloads: + tool_choice = payloads["tool_choice"] + if isinstance(tool_choice, dict): + payloads["tool_choice"] = tool_choice + elif tool_choice == "required": + # 兼容 OpenAI 命名 + payloads["tool_choice"] = {"type": "any"} + elif tool_choice in ("auto", "any", "none", "tool"): + payloads["tool_choice"] = {"type": tool_choice} + else: + payloads["tool_choice"] = {"type": "auto"} + else: + payloads["tool_choice"] = {"type": "auto"} extra_body = self.provider_config.get("custom_extra_body", {}) @@ -409,11 +419,20 @@ async def _query_stream( if tools: if tool_list := tools.get_func_desc_anthropic_style(): payloads["tools"] = tool_list - payloads["tool_choice"] = { - "type": "any" - if payloads.get("tool_choice") == "required" - else "auto" - } + # 转换为 Anthropic API 要求的 tool_choice 格式 + if "tool_choice" in payloads: + tool_choice = payloads["tool_choice"] + if isinstance(tool_choice, dict): + payloads["tool_choice"] = tool_choice + elif tool_choice == "required": + # 兼容 OpenAI 命名 + payloads["tool_choice"] = {"type": "any"} + elif tool_choice in ("auto", "any", "none", "tool"): + payloads["tool_choice"] = {"type": tool_choice} + else: + payloads["tool_choice"] = {"type": "auto"} + else: + payloads["tool_choice"] = {"type": "auto"} # 用于累积工具调用信息 tool_use_buffer = {} @@ -569,7 +588,7 @@ async def text_chat( tool_calls_result=None, model=None, extra_user_content_parts=None, - tool_choice: Literal["auto", "required"] = "auto", + tool_choice: Literal["auto", "any", "tool", "none"] | dict[str, str] = "auto", **kwargs, ) -> LLMResponse: if contexts is None: @@ -598,8 +617,8 @@ async def text_chat( if not isinstance(tool_calls_result, list): context_query.extend(tool_calls_result.to_openai_messages()) else: - for tcr in tool_calls_result: - context_query.extend(tcr.to_openai_messages()) + for tool_choicer in tool_calls_result: + context_query.extend(tool_choicer.to_openai_messages()) system_prompt, new_messages = self._prepare_payload(context_query) @@ -637,7 +656,7 @@ async def text_chat_stream( tool_calls_result=None, model=None, extra_user_content_parts=None, - tool_choice: Literal["auto", "required"] = "auto", + tool_choice: Literal["auto", "any", "tool", "none"] | dict[str, str] = "auto", **kwargs, ): if contexts is None: @@ -665,8 +684,8 @@ async def text_chat_stream( if not isinstance(tool_calls_result, list): context_query.extend(tool_calls_result.to_openai_messages()) else: - for tcr in tool_calls_result: - context_query.extend(tcr.to_openai_messages()) + for tool_choicer in tool_calls_result: + context_query.extend(tool_choicer.to_openai_messages()) system_prompt, new_messages = self._prepare_payload(context_query) diff --git a/tests/test_anthropic_kimi_code_provider.py b/tests/test_anthropic_kimi_code_provider.py index b9d84d1a93..05bd58aa8b 100644 --- a/tests/test_anthropic_kimi_code_provider.py +++ b/tests/test_anthropic_kimi_code_provider.py @@ -416,3 +416,172 @@ def test_prepare_payload_does_not_merge_non_consecutive_tool_results(): ], }, ] + + +# ---- tool_choice 转换测试 ---- + + +class _FakeToolSet: + """模拟包含工具的 ToolSet""" + + def get_func_desc_anthropic_style(self): + return [{"name": "get_weather", "description": "Get weather"}] + + def empty(self): + return False + + +class _FakeMessages: + """模拟 AsyncAnthropic.messages 命名空间""" + + +async def _capture_payloads_create(**kwargs): + """捕获 payloads 并返回一个真实的 Message 实例""" + from anthropic.types import Message, TextBlock, Usage + + _capture_payloads_create.last_kwargs = kwargs + return Message( + id="msg_fake", + content=[TextBlock(type="text", text="Hello")], + model="claude-test", + role="assistant", + stop_reason=None, + stop_sequence=None, + type="message", + usage=Usage(input_tokens=10, output_tokens=5), + ) + + +def _setup_provider_with_mock_client(monkeypatch) -> anthropic_source.ProviderAnthropic: + """创建 provider 并 mock 底层 API 调用""" + monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic) + + provider = anthropic_source.ProviderAnthropic( + provider_config={ + "id": "anthropic-test", + "type": "anthropic_chat_completion", + "model": "claude-test", + "key": ["test-key"], + }, + provider_settings={}, + ) + + fakeMessages = _FakeMessages() + fakeMessages.create = _capture_payloads_create + provider.client.messages = fakeMessages + + return provider + + +@pytest.mark.asyncio +async def test_tool_choice_auto_converts_to_dict(monkeypatch): + """tool_choice='auto' 应转换为 {'type': 'auto'}""" + provider = _setup_provider_with_mock_client(monkeypatch) + + await provider.text_chat( + prompt="hello", + func_tool=_FakeToolSet(), + tool_choice="auto", + ) + + assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "auto"} + + +@pytest.mark.asyncio +async def test_tool_choice_any_converts_to_dict(monkeypatch): + """tool_choice='any' 应转换为 {'type': 'any'}""" + provider = _setup_provider_with_mock_client(monkeypatch) + + await provider.text_chat( + prompt="hello", + func_tool=_FakeToolSet(), + tool_choice="any", + ) + + assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "any"} + + +@pytest.mark.asyncio +async def test_tool_choice_none_converts_to_dict(monkeypatch): + """tool_choice='none' 应转换为 {'type': 'none'}""" + provider = _setup_provider_with_mock_client(monkeypatch) + + await provider.text_chat( + prompt="hello", + func_tool=_FakeToolSet(), + tool_choice="none", + ) + + assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "none"} + + +@pytest.mark.asyncio +async def test_tool_choice_required_legacy_compat(monkeypatch): + """tool_choice='required'(OpenAI 命名) 应兼容转换为 {'type': 'any'}""" + provider = _setup_provider_with_mock_client(monkeypatch) + + await provider.text_chat( + prompt="hello", + func_tool=_FakeToolSet(), + tool_choice="required", + ) + + assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "any"} + + +@pytest.mark.asyncio +async def test_tool_choice_dict_passthrough(monkeypatch): + """tool_choice 为 dict 时应直接透传""" + provider = _setup_provider_with_mock_client(monkeypatch) + + await provider.text_chat( + prompt="hello", + func_tool=_FakeToolSet(), + tool_choice={"type": "tool", "name": "get_weather"}, + ) + + assert _capture_payloads_create.last_kwargs["tool_choice"] == { + "type": "tool", + "name": "get_weather", + } + + +@pytest.mark.asyncio +async def test_tool_choice_default_when_not_set(monkeypatch): + """未传 tool_choice 时,默认应为 {'type': 'auto'}""" + provider = _setup_provider_with_mock_client(monkeypatch) + + await provider.text_chat( + prompt="hello", + func_tool=_FakeToolSet(), + ) + + assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "auto"} + + +@pytest.mark.asyncio +async def test_tool_choice_invalid_string_falls_back_to_auto(monkeypatch): + """无效的 tool_choice 字符串应回退为 {'type': 'auto'}""" + provider = _setup_provider_with_mock_client(monkeypatch) + + await provider.text_chat( + prompt="hello", + func_tool=_FakeToolSet(), + tool_choice="invalid_value", + ) + + assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "auto"} + + +@pytest.mark.asyncio +async def test_tool_choice_no_tools_skips_tool_choice(monkeypatch): + """无工具时不应设置 tool_choice""" + provider = _setup_provider_with_mock_client(monkeypatch) + + await provider.text_chat( + prompt="hello", + func_tool=None, + tool_choice="any", + ) + + assert "tool_choice" not in _capture_payloads_create.last_kwargs