-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
fix(anthropic): 修复 Anthropic API tool_choice 格式转换及参数支持 #8328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -322,11 +322,21 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: | |||||||||
| if tools: | ||||||||||
| if tool_list := tools.get_func_desc_anthropic_style(): | ||||||||||
| payloads["tools"] = tool_list | ||||||||||
| payloads["tool_choice"] = { | ||||||||||
| "type": "any" | ||||||||||
| if payloads.get("tool_choice") == "required" | ||||||||||
| else "auto" | ||||||||||
| } | ||||||||||
| # 转换为 Anthropic API 要求的 tool_choice 格式 | ||||||||||
| # 参考: https://platform.claude.com/docs/en/agents-and-tools/tool-use/define-tools#providing-tool-use-examples | ||||||||||
| if "tool_choice" in payloads: | ||||||||||
| tool_choice = payloads["tool_choice"] | ||||||||||
| if isinstance(tool_choice, dict): | ||||||||||
| payloads["tool_choice"] = tool_choice | ||||||||||
| elif tool_choice == "required": | ||||||||||
| # 兼容 OpenAI 命名 | ||||||||||
| payloads["tool_choice"] = {"type": "any"} | ||||||||||
| elif tool_choice in ("auto", "any", "none", "tool"): | ||||||||||
| payloads["tool_choice"] = {"type": tool_choice} | ||||||||||
|
Comment on lines
+334
to
+335
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Using {"type": "tool"} without a name is likely invalid per Anthropic’s tool_choice schema. Per Anthropic’s schema,
Comment on lines
+334
to
+335
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的逻辑存在两个问题:
Suggested change
|
||||||||||
| else: | ||||||||||
| payloads["tool_choice"] = {"type": "auto"} | ||||||||||
| else: | ||||||||||
| payloads["tool_choice"] = {"type": "auto"} | ||||||||||
|
Comment on lines
+325
to
+339
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的 此外,根据项目规则,新功能(如工具选择处理)应伴随相应的单元测试。 References
|
||||||||||
|
|
||||||||||
| extra_body = self.provider_config.get("custom_extra_body", {}) | ||||||||||
|
|
||||||||||
|
|
@@ -409,11 +419,20 @@ async def _query_stream( | |||||||||
| if tools: | ||||||||||
| if tool_list := tools.get_func_desc_anthropic_style(): | ||||||||||
| payloads["tools"] = tool_list | ||||||||||
| payloads["tool_choice"] = { | ||||||||||
| "type": "any" | ||||||||||
| if payloads.get("tool_choice") == "required" | ||||||||||
| else "auto" | ||||||||||
| } | ||||||||||
| # 转换为 Anthropic API 要求的 tool_choice 格式 | ||||||||||
| if "tool_choice" in payloads: | ||||||||||
| tool_choice = payloads["tool_choice"] | ||||||||||
| if isinstance(tool_choice, dict): | ||||||||||
| payloads["tool_choice"] = tool_choice | ||||||||||
| elif tool_choice == "required": | ||||||||||
| # 兼容 OpenAI 命名 | ||||||||||
| payloads["tool_choice"] = {"type": "any"} | ||||||||||
| elif tool_choice in ("auto", "any", "none", "tool"): | ||||||||||
| payloads["tool_choice"] = {"type": tool_choice} | ||||||||||
|
Comment on lines
+430
to
+431
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
| else: | ||||||||||
| payloads["tool_choice"] = {"type": "auto"} | ||||||||||
| else: | ||||||||||
| payloads["tool_choice"] = {"type": "auto"} | ||||||||||
|
|
||||||||||
| # 用于累积工具调用信息 | ||||||||||
| tool_use_buffer = {} | ||||||||||
|
|
@@ -569,7 +588,7 @@ async def text_chat( | |||||||||
| tool_calls_result=None, | ||||||||||
| model=None, | ||||||||||
| extra_user_content_parts=None, | ||||||||||
| tool_choice: Literal["auto", "required"] = "auto", | ||||||||||
| tool_choice: Literal["auto", "any", "tool", "none"] | dict[str, str] = "auto", | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议从
Suggested change
|
||||||||||
| **kwargs, | ||||||||||
| ) -> LLMResponse: | ||||||||||
| if contexts is None: | ||||||||||
|
|
@@ -598,8 +617,8 @@ async def text_chat( | |||||||||
| if not isinstance(tool_calls_result, list): | ||||||||||
| context_query.extend(tool_calls_result.to_openai_messages()) | ||||||||||
| else: | ||||||||||
| for tcr in tool_calls_result: | ||||||||||
| context_query.extend(tcr.to_openai_messages()) | ||||||||||
| for tool_choicer in tool_calls_result: | ||||||||||
| context_query.extend(tool_choicer.to_openai_messages()) | ||||||||||
|
Comment on lines
+620
to
+621
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 变量名
Suggested change
|
||||||||||
|
|
||||||||||
| system_prompt, new_messages = self._prepare_payload(context_query) | ||||||||||
|
|
||||||||||
|
|
@@ -637,7 +656,7 @@ async def text_chat_stream( | |||||||||
| tool_calls_result=None, | ||||||||||
| model=None, | ||||||||||
| extra_user_content_parts=None, | ||||||||||
| tool_choice: Literal["auto", "required"] = "auto", | ||||||||||
| tool_choice: Literal["auto", "any", "tool", "none"] | dict[str, str] = "auto", | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
| **kwargs, | ||||||||||
| ): | ||||||||||
| if contexts is None: | ||||||||||
|
|
@@ -665,8 +684,8 @@ async def text_chat_stream( | |||||||||
| if not isinstance(tool_calls_result, list): | ||||||||||
| context_query.extend(tool_calls_result.to_openai_messages()) | ||||||||||
| else: | ||||||||||
| for tcr in tool_calls_result: | ||||||||||
| context_query.extend(tcr.to_openai_messages()) | ||||||||||
| for tool_choicer in tool_calls_result: | ||||||||||
| context_query.extend(tool_choicer.to_openai_messages()) | ||||||||||
|
Comment on lines
+687
to
+688
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
|
|
||||||||||
| system_prompt, new_messages = self._prepare_payload(context_query) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -416,3 +416,172 @@ def test_prepare_payload_does_not_merge_non_consecutive_tool_results(): | |
| ], | ||
| }, | ||
| ] | ||
|
|
||
|
|
||
| # ---- tool_choice 转换测试 ---- | ||
|
|
||
|
|
||
| class _FakeToolSet: | ||
| """模拟包含工具的 ToolSet""" | ||
|
|
||
| def get_func_desc_anthropic_style(self): | ||
| return [{"name": "get_weather", "description": "Get weather"}] | ||
|
|
||
| def empty(self): | ||
|
Comment on lines
+424
to
+430
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Consider adding a test for the case where tools are present but Right now Suggested implementation: # ---- tool_choice 转换测试 ----
# ---- tool_choice 转换测试 ----
class _FakeToolSet:
"""模拟包含工具的 ToolSet"""
def get_func_desc_anthropic_style(self):
return [{"name": "get_weather", "description": "Get weather"}]
def empty(self) -> bool:
"""指示该 ToolSet 非空"""
return False
class _EmptyToolSet:
"""模拟包含空工具列表的 ToolSet,用于测试 tool_choice 在无工具时不应被设置"""
def get_func_desc_anthropic_style(self):
# 显式返回空列表,触发代码中对 falsy tool_list 分支的处理
return []
def empty(self) -> bool:
"""对于一些实现会检查 tools.empty(),这里应返回 True。"""
return True要完整实现你的建议,还需要在本文件中新增或扩展测试用例,大致步骤如下:
|
||
| return False | ||
|
|
||
|
|
||
| class _FakeMessages: | ||
| """模拟 AsyncAnthropic.messages 命名空间""" | ||
|
|
||
|
|
||
| async def _capture_payloads_create(**kwargs): | ||
| """捕获 payloads 并返回一个真实的 Message 实例""" | ||
| from anthropic.types import Message, TextBlock, Usage | ||
|
|
||
| _capture_payloads_create.last_kwargs = kwargs | ||
| return Message( | ||
| id="msg_fake", | ||
| content=[TextBlock(type="text", text="Hello")], | ||
| model="claude-test", | ||
| role="assistant", | ||
| stop_reason=None, | ||
| stop_sequence=None, | ||
| type="message", | ||
| usage=Usage(input_tokens=10, output_tokens=5), | ||
| ) | ||
|
|
||
|
|
||
| def _setup_provider_with_mock_client(monkeypatch) -> anthropic_source.ProviderAnthropic: | ||
| """创建 provider 并 mock 底层 API 调用""" | ||
| monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic) | ||
|
|
||
| provider = anthropic_source.ProviderAnthropic( | ||
| provider_config={ | ||
| "id": "anthropic-test", | ||
| "type": "anthropic_chat_completion", | ||
| "model": "claude-test", | ||
| "key": ["test-key"], | ||
| }, | ||
| provider_settings={}, | ||
| ) | ||
|
|
||
| fakeMessages = _FakeMessages() | ||
| fakeMessages.create = _capture_payloads_create | ||
| provider.client.messages = fakeMessages | ||
|
|
||
| return provider | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tool_choice_auto_converts_to_dict(monkeypatch): | ||
| """tool_choice='auto' 应转换为 {'type': 'auto'}""" | ||
| provider = _setup_provider_with_mock_client(monkeypatch) | ||
|
|
||
| await provider.text_chat( | ||
| prompt="hello", | ||
| func_tool=_FakeToolSet(), | ||
| tool_choice="auto", | ||
| ) | ||
|
|
||
| assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "auto"} | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tool_choice_any_converts_to_dict(monkeypatch): | ||
| """tool_choice='any' 应转换为 {'type': 'any'}""" | ||
| provider = _setup_provider_with_mock_client(monkeypatch) | ||
|
|
||
| await provider.text_chat( | ||
| prompt="hello", | ||
| func_tool=_FakeToolSet(), | ||
| tool_choice="any", | ||
| ) | ||
|
|
||
| assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "any"} | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tool_choice_none_converts_to_dict(monkeypatch): | ||
| """tool_choice='none' 应转换为 {'type': 'none'}""" | ||
| provider = _setup_provider_with_mock_client(monkeypatch) | ||
|
|
||
| await provider.text_chat( | ||
| prompt="hello", | ||
| func_tool=_FakeToolSet(), | ||
| tool_choice="none", | ||
| ) | ||
|
|
||
| assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "none"} | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tool_choice_required_legacy_compat(monkeypatch): | ||
| """tool_choice='required'(OpenAI 命名) 应兼容转换为 {'type': 'any'}""" | ||
| provider = _setup_provider_with_mock_client(monkeypatch) | ||
|
|
||
| await provider.text_chat( | ||
| prompt="hello", | ||
| func_tool=_FakeToolSet(), | ||
| tool_choice="required", | ||
| ) | ||
|
|
||
| assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "any"} | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tool_choice_dict_passthrough(monkeypatch): | ||
| """tool_choice 为 dict 时应直接透传""" | ||
| provider = _setup_provider_with_mock_client(monkeypatch) | ||
|
|
||
| await provider.text_chat( | ||
| prompt="hello", | ||
| func_tool=_FakeToolSet(), | ||
| tool_choice={"type": "tool", "name": "get_weather"}, | ||
| ) | ||
|
|
||
| assert _capture_payloads_create.last_kwargs["tool_choice"] == { | ||
| "type": "tool", | ||
| "name": "get_weather", | ||
| } | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tool_choice_default_when_not_set(monkeypatch): | ||
| """未传 tool_choice 时,默认应为 {'type': 'auto'}""" | ||
| provider = _setup_provider_with_mock_client(monkeypatch) | ||
|
|
||
| await provider.text_chat( | ||
| prompt="hello", | ||
| func_tool=_FakeToolSet(), | ||
| ) | ||
|
|
||
| assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "auto"} | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tool_choice_invalid_string_falls_back_to_auto(monkeypatch): | ||
| """无效的 tool_choice 字符串应回退为 {'type': 'auto'}""" | ||
| provider = _setup_provider_with_mock_client(monkeypatch) | ||
|
|
||
| await provider.text_chat( | ||
| prompt="hello", | ||
| func_tool=_FakeToolSet(), | ||
| tool_choice="invalid_value", | ||
| ) | ||
|
|
||
| assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "auto"} | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tool_choice_no_tools_skips_tool_choice(monkeypatch): | ||
| """无工具时不应设置 tool_choice""" | ||
| provider = _setup_provider_with_mock_client(monkeypatch) | ||
|
|
||
| await provider.text_chat( | ||
| prompt="hello", | ||
| func_tool=None, | ||
| tool_choice="any", | ||
| ) | ||
|
|
||
| assert "tool_choice" not in _capture_payloads_create.last_kwargs | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: The tool_choice normalization logic is duplicated between _query and _query_stream and could be extracted to a shared helper.
A helper like
_normalize_tool_choice(payloads)would centralize this Anthropic-specific mapping and reduce the risk that sync and streaming diverge if the API or supported values change.