diff --git a/flocks/agent/agent_factory.py b/flocks/agent/agent_factory.py index a8b5bc455..a27649731 100644 --- a/flocks/agent/agent_factory.py +++ b/flocks/agent/agent_factory.py @@ -110,68 +110,84 @@ def load_agent(agent_dir: Path, native: bool = False) -> Optional[AgentInfo]: }) return None + if not isinstance(raw, dict): + log.warn("agent.factory.yaml_invalid", { + "path": str(yaml_path), + "hint": "Expected a YAML mapping", + }) + return None + name = raw.get("name") or agent_dir.name if not name: log.warn("agent.factory.missing_name", {"path": str(yaml_path)}) return None - # ── Prompt resolution ────────────────────────────────────────────────── - prompt: Optional[str] = None - prompt_builder: Optional[str] = None - - prompt_md = agent_dir / "prompt.md" - prompt_builder_py = agent_dir / "prompt_builder.py" - - if prompt_md.is_file(): - prompt = prompt_md.read_text(encoding="utf-8").strip() - elif prompt_builder_py.is_file(): - # Derive Python module path from file location, relative to flocks package root - try: - rel = prompt_builder_py.relative_to(Path(__file__).parent.parent.parent) - module_path = str(rel.with_suffix("")).replace("/", ".").replace("\\", ".") - except ValueError: - # Fallback: use absolute path notation - module_path = str(prompt_builder_py) - prompt_builder = f"{module_path}:inject" - - # ── Tools / legacy permission compatibility ───────────────────────────── - tools_list_raw: Optional[List[str]] = raw.get("tools") - perm_raw = raw.get("permission") - tools_list, legacy_permission = resolve_agent_initial_tools( - tools_list_raw, - perm_raw, - agent_name=name, - ) - - # ── Model ──────────────────────────────────────────────────────────────── - model_raw = raw.get("model") - model = AgentModel(**model_raw) if isinstance(model_raw, dict) else None - - desc_cn = raw.get("description_cn") - if desc_cn is None and isinstance(raw.get("descriptionCn"), str): - desc_cn = raw.get("descriptionCn") - - return AgentInfo( - name=name, - description=raw.get("description"), - description_cn=desc_cn, - mode=raw.get("mode", "subagent"), - native=native, - hidden=raw.get("hidden", False), - color=raw.get("color"), - permission=legacy_permission, - model=model, - prompt=prompt, - prompt_builder=prompt_builder, - tools=tools_list, - options=raw.get("options", {}), - steps=raw.get("steps"), - delegatable=raw.get("delegatable"), - temperature=raw.get("temperature"), - top_p=raw.get("top_p"), - prompt_metadata=_parse_prompt_metadata(raw), - tags=raw.get("tags", []), - ) + try: + # ── Prompt resolution ────────────────────────────────────────────── + prompt: Optional[str] = None + prompt_builder: Optional[str] = None + + prompt_md = agent_dir / "prompt.md" + prompt_builder_py = agent_dir / "prompt_builder.py" + + if prompt_md.is_file(): + prompt = prompt_md.read_text(encoding="utf-8").strip() + elif prompt_builder_py.is_file(): + # Derive Python module path from file location, relative to flocks package root + try: + rel = prompt_builder_py.relative_to(Path(__file__).parent.parent.parent) + module_path = str(rel.with_suffix("")).replace("/", ".").replace("\\", ".") + except ValueError: + # Fallback: use absolute path notation + module_path = str(prompt_builder_py) + prompt_builder = f"{module_path}:inject" + + # ── Tools / legacy permission compatibility ───────────────────────── + tools_list_raw: Optional[List[str]] = raw.get("tools") + perm_raw = raw.get("permission") + tools_list, legacy_permission = resolve_agent_initial_tools( + tools_list_raw, + perm_raw, + agent_name=name, + ) + + # ── Model ─────────────────────────────────────────────────────────── + model_raw = raw.get("model") + model = AgentModel(**model_raw) if isinstance(model_raw, dict) else None + + desc_cn = raw.get("description_cn") + if desc_cn is None and isinstance(raw.get("descriptionCn"), str): + desc_cn = raw.get("descriptionCn") + + return AgentInfo( + name=name, + description=raw.get("description"), + description_cn=desc_cn, + mode=raw.get("mode", "subagent"), + native=native, + hidden=raw.get("hidden", False), + color=raw.get("color"), + permission=legacy_permission, + model=model, + prompt=prompt, + prompt_builder=prompt_builder, + tools=tools_list, + options=raw.get("options", {}), + steps=raw.get("steps"), + delegatable=raw.get("delegatable"), + temperature=raw.get("temperature"), + top_p=raw.get("top_p"), + prompt_metadata=_parse_prompt_metadata(raw), + tags=raw.get("tags", []), + ) + except Exception as e: + log.error("agent.factory.load_failed", { + "name": name, + "path": str(yaml_path), + "error": str(e), + "type": type(e).__name__, + }) + return None # --------------------------------------------------------------------------- diff --git a/flocks/provider/catalog.json b/flocks/provider/catalog.json index 5d72b814a..168bd71d1 100644 --- a/flocks/provider/catalog.json +++ b/flocks/provider/catalog.json @@ -57,6 +57,12 @@ "family": "minimax", "capabilities": { "supports_tools": true, + "supports_reasoning": true, + "interleaved": { + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -74,6 +80,12 @@ "family": "minimax", "capabilities": { "supports_tools": true, + "supports_reasoning": true, + "interleaved": { + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -92,6 +104,11 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -110,6 +127,11 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -128,6 +150,11 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -146,6 +173,12 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "placeholder": " ", + "cross_provider_policy": "placeholder" + }, "supports_streaming": true }, "limits": { @@ -190,6 +223,12 @@ "family": "minimax", "capabilities": { "supports_tools": true, + "supports_reasoning": true, + "interleaved": { + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -207,6 +246,12 @@ "family": "minimax", "capabilities": { "supports_tools": true, + "supports_reasoning": true, + "interleaved": { + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -225,6 +270,11 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -243,6 +293,11 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -261,6 +316,11 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -893,6 +953,12 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "placeholder": " ", + "cross_provider_policy": "placeholder" + }, "supports_streaming": true }, "limits": { @@ -1027,6 +1093,12 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "placeholder": " ", + "cross_provider_policy": "placeholder" + }, "supports_streaming": true }, "limits": { @@ -1047,6 +1119,12 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "placeholder": " ", + "cross_provider_policy": "placeholder" + }, "supports_streaming": true }, "limits": { @@ -1114,6 +1192,11 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -1149,6 +1232,11 @@ "capabilities": { "supports_tools": true, "supports_reasoning": true, + "interleaved": { + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -1198,6 +1286,12 @@ "family": "minimax-m2", "capabilities": { "supports_tools": true, + "supports_reasoning": true, + "interleaved": { + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { @@ -1215,6 +1309,12 @@ "family": "minimax-m2", "capabilities": { "supports_tools": true, + "supports_reasoning": true, + "interleaved": { + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, "supports_streaming": true }, "limits": { diff --git a/flocks/provider/interleaved.py b/flocks/provider/interleaved.py new file mode 100644 index 000000000..702d815ff --- /dev/null +++ b/flocks/provider/interleaved.py @@ -0,0 +1,86 @@ +"""Runtime inference for interleaved thinking replay.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + + +_STRICT_REASONING_CONTENT = { + "field": "reasoning_content", + "echo": "tool_calls", + "placeholder": " ", + "cross_provider_policy": "placeholder", +} + +_PROMOTE_REASONING_CONTENT = { + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote", +} + +_PROMOTE_REASONING_DETAILS = { + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote", +} + + +def _lower(value: Optional[str]) -> str: + return value.lower() if isinstance(value, str) else "" + + +def infer_interleaved_capability( + *, + provider_id: str, + model_id: str, + base_url: Optional[str] = None, +) -> Optional[Dict[str, Any]]: + """Infer interleaved replay policy for known reasoning model families. + + Explicit config and catalog metadata should take precedence. This helper is + only a fallback for runtime-discovered or user-added models so the feature + works without user-visible toggles. + """ + pid = _lower(provider_id) + mid = _lower(model_id) + burl = _lower(base_url) + + if "minimax" in mid or pid == "minimax": + return dict(_PROMOTE_REASONING_DETAILS) + + if any(token in mid for token in ("qwen3", "qwq", "qwen-max")) or pid == "alibaba": + return dict(_PROMOTE_REASONING_CONTENT) + + if any(token in mid for token in ("glm-5", "glm5")) or pid == "zhipu": + return dict(_PROMOTE_REASONING_CONTENT) + + if any(token in mid for token in ("deepseek-reasoner", "deepseek-r1", "reasoner")): + return dict(_STRICT_REASONING_CONTENT) + if "deepseek.com" in burl and any(token in mid for token in ("r1", "reasoner", "thinking")): + return dict(_STRICT_REASONING_CONTENT) + + if any(token in mid for token in ("kimi-k2.5", "kimi-k2.6", "kimi-k2-thinking", "mimo")): + return dict(_STRICT_REASONING_CONTENT) + + return None + + +def apply_interleaved_capability_defaults( + model: Any, + *, + provider_id: str, + base_url: Optional[str] = None, +) -> Any: + """Populate model.capabilities.interleaved when it is implicitly known.""" + capabilities = getattr(model, "capabilities", None) + if capabilities is None or getattr(capabilities, "interleaved", None): + return model + + inferred = infer_interleaved_capability( + provider_id=provider_id, + model_id=getattr(model, "id", ""), + base_url=base_url, + ) + if inferred: + capabilities.interleaved = inferred + return model diff --git a/flocks/provider/model_catalog.py b/flocks/provider/model_catalog.py index 5b79a580f..d87cf18ac 100644 --- a/flocks/provider/model_catalog.py +++ b/flocks/provider/model_catalog.py @@ -140,6 +140,7 @@ def _parse_model_definitions( supports_tools=caps_raw.get("supports_tools", False), supports_vision=caps_raw.get("supports_vision", False), supports_reasoning=caps_raw.get("supports_reasoning", False), + interleaved=caps_raw.get("interleaved"), supports_streaming=caps_raw.get("supports_streaming", True), ) diff --git a/flocks/provider/provider.py b/flocks/provider/provider.py index 4537ca1a5..8ab196b0b 100644 --- a/flocks/provider/provider.py +++ b/flocks/provider/provider.py @@ -11,6 +11,7 @@ from flocks.utils.log import Log from flocks.config.config import Config +from flocks.provider.interleaved import apply_interleaved_capability_defaults log = Log.create(service="provider") @@ -57,6 +58,7 @@ class ModelCapabilities(BaseModel): supports_tools: bool = True supports_vision: bool = False supports_reasoning: bool = False + interleaved: Optional[Dict[str, Any]] = None max_tokens: Optional[int] = None context_window: Optional[int] = None @@ -87,6 +89,18 @@ class ChatMessage(BaseModel): content: Union[str, List[Dict[str, Any]]] = Field("", description="Message content") # Reasoning/Thinking content (optional) reasoning: Optional[str] = Field(None, description="Reasoning or thinking content") + reasoning_content: Optional[str] = Field( + None, + description="Provider-facing reasoning content for replay", + ) + reasoning_details: Optional[List[Dict[str, Any]]] = Field( + None, + description="Structured provider-facing reasoning details for replay", + ) + reasoning_source: Optional[str] = Field( + None, + description="Diagnostic source label for replayed reasoning", + ) # OpenAI function-calling fields (optional) tool_calls: Optional[List[Dict[str, Any]]] = Field(None, description="Tool calls for assistant messages") tool_call_id: Optional[str] = Field(None, description="Tool call ID for tool-result messages") @@ -464,6 +478,79 @@ def remove_model_from_runtime(cls, provider_id: str, model_id: str) -> None: def get_model(cls, model_id: str) -> Optional[ModelInfo]: """Get model info by ID""" return cls._models.get(model_id) + + @classmethod + def resolve_model(cls, provider_id: str, model_id: str) -> Optional[Any]: + """Resolve model metadata for a specific provider/model pair. + + Hermes-style lookup prefers the active provider's runtime/config state + over the global ``model_id -> ModelInfo`` registry so replay decisions + (e.g. interleaved reasoning rules) are derived from the current + provider, not whichever provider last wrote the shared model ID. + + Lookup order: + 1. provider.get_model_definitions() (config-aware, catalog-enriched) + 2. provider._config_models + 3. provider.get_models() + 4. Provider._models global registry + """ + cls._ensure_initialized() + + provider = cls._providers.get(provider_id) + provider_base_url = None + if provider is not None: + provider_config = getattr(provider, "_config", None) + provider_base_url = ( + getattr(provider_config, "base_url", None) + or getattr(provider, "_base_url", None) + ) + if provider is not None: + try: + for model in provider.get_model_definitions(): + if getattr(model, "id", None) == model_id: + return apply_interleaved_capability_defaults( + model, + provider_id=provider_id, + base_url=provider_base_url, + ) + except Exception as exc: + log.debug("provider.resolve_model.definitions_failed", { + "provider_id": provider_id, + "model_id": model_id, + "error": str(exc), + }) + + for model in getattr(provider, "_config_models", []): + if getattr(model, "id", None) == model_id: + return apply_interleaved_capability_defaults( + model, + provider_id=provider_id, + base_url=provider_base_url, + ) + + try: + for model in provider.get_models(): + if getattr(model, "id", None) == model_id: + return apply_interleaved_capability_defaults( + model, + provider_id=provider_id, + base_url=provider_base_url, + ) + except Exception as exc: + log.debug("provider.resolve_model.runtime_failed", { + "provider_id": provider_id, + "model_id": model_id, + "error": str(exc), + }) + + model = cls._models.get(model_id) + if model is None: + return None + return apply_interleaved_capability_defaults( + model, + provider_id=provider_id, + base_url=provider_base_url, + ) @classmethod def resolve_model_info(cls, provider_id: str, model_id: str) -> tuple: @@ -486,19 +573,7 @@ def resolve_model_info(cls, provider_id: str, model_id: str) -> tuple: max_input = None try: - model_info = None - - # 1. Check provider._config_models first (flocks.json models) - provider = cls.get(provider_id) - if provider: - for m in getattr(provider, "_config_models", []): - if m.id == model_id: - model_info = m - break - - # 2. Fallback to global model registry - if model_info is None: - model_info = cls.get_model(model_id) + model_info = cls.resolve_model(provider_id, model_id) if model_info and hasattr(model_info, 'capabilities') and model_info.capabilities: context_window = getattr(model_info.capabilities, 'context_window', 0) or 0 @@ -674,6 +749,7 @@ async def apply_config(cls, config: Optional[Any] = None, provider_id: Optional[ supports_tools=model_dict.get("supports_tools", True), supports_vision=model_dict.get("supports_vision", False), supports_reasoning=model_dict.get("supports_reasoning", False), + interleaved=model_dict.get("interleaved"), max_tokens=model_dict.get("max_output_tokens") or model_dict.get("max_tokens"), context_window=model_dict.get("context_window"), ), @@ -1034,6 +1110,7 @@ def _build_model_definition(self, model: "ModelInfo") -> "ModelDefinition": supports_tools=model.capabilities.supports_tools, supports_vision=model.capabilities.supports_vision, supports_reasoning=getattr(model.capabilities, "supports_reasoning", False), + interleaved=getattr(model.capabilities, "interleaved", None), ), limits=ModelLimits( context_window=model.capabilities.context_window or 128000, @@ -1086,6 +1163,10 @@ def _apply_config_overrides(self, catalog_def: "ModelDefinition", model: "ModelI overridden.capabilities.supports_reasoning = getattr( model.capabilities, "supports_reasoning", False ) + if "interleaved" in keys: + overridden.capabilities.interleaved = getattr( + model.capabilities, "interleaved", None + ) # Limits — only override when explicitly stored if "context_window" in keys and model.capabilities.context_window is not None: diff --git a/flocks/provider/reasoning_replay.py b/flocks/provider/reasoning_replay.py new file mode 100644 index 000000000..609b7dc63 --- /dev/null +++ b/flocks/provider/reasoning_replay.py @@ -0,0 +1,94 @@ +"""Reasoning replay helpers for thinking/interleaved providers.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from flocks.provider.provider import ChatMessage + + +def _message_requires_echo(message: ChatMessage, interleaved: Dict[str, Any]) -> bool: + """Return True when the current provider requires a replay field.""" + echo_mode = interleaved.get("echo", "when_present") + if echo_mode == "all_assistant": + return True + if echo_mode == "tool_calls": + return bool(message.tool_calls) + if echo_mode == "when_present": + return bool(message.reasoning_content or message.reasoning_details) + return False + + +def _promotable_reasoning_text(message: ChatMessage) -> tuple[Optional[str], Optional[str]]: + """Return the best-effort text form that can be promoted cross-provider.""" + if isinstance(message.reasoning_content, str) and message.reasoning_content: + return message.reasoning_content, "promoted_reasoning_content" + if isinstance(message.reasoning, str) and message.reasoning: + return message.reasoning, "promoted_reasoning" + return None, None + + +def _summary_reasoning_details(text: str) -> List[Dict[str, Any]]: + """Wrap plain reasoning text into a generic details payload.""" + return [{"type": "reasoning.summary", "text": text}] + + +def prepare_reasoning_for_replay( + *, + provider_id: str, + model_id: str, + message: ChatMessage, + interleaved: Optional[Dict[str, Any]], +) -> ChatMessage: + """Prepare provider-facing reasoning fields for API replay. + + Hermes-style rules: + - Preserve explicit provider-facing fields. + - Upgrade stale empty-string placeholders when the model requires echo. + - Promote internal ``reasoning`` only for providers configured to do so. + - Use a placeholder instead of leaking another provider's CoT to strict + echo providers such as DeepSeek/Kimi. + """ + if message.role != "assistant" or not interleaved: + return message + + prepared = message.model_copy(deep=True) + field = interleaved.get("field", "reasoning_content") + placeholder = interleaved.get("placeholder", " ") + cross_provider_policy = interleaved.get("cross_provider_policy", "promote") + requires_echo = _message_requires_echo(prepared, interleaved) + promoted_text, promoted_source = _promotable_reasoning_text(prepared) + + if field == "reasoning_details": + prepared.reasoning_content = None + if prepared.reasoning_details: + prepared.reasoning_source = prepared.reasoning_source or "native_reasoning_details" + return prepared + if cross_provider_policy == "promote" and promoted_text: + prepared.reasoning_details = _summary_reasoning_details(promoted_text) + prepared.reasoning_source = promoted_source or "promoted_reasoning" + return prepared + if requires_echo: + prepared.reasoning_details = _summary_reasoning_details(placeholder) + prepared.reasoning_source = "placeholder" + return prepared + return prepared + + prepared.reasoning_details = None + if isinstance(prepared.reasoning_content, str): + if prepared.reasoning_content == "" and requires_echo: + prepared.reasoning_content = placeholder + prepared.reasoning_source = "placeholder" + return prepared + + if cross_provider_policy == "promote" and promoted_text: + prepared.reasoning_content = promoted_text + prepared.reasoning_source = promoted_source or "promoted_reasoning" + return prepared + + if requires_echo: + prepared.reasoning_content = placeholder + prepared.reasoning_source = "placeholder" + return prepared + + return prepared diff --git a/flocks/provider/sdk/anthropic.py b/flocks/provider/sdk/anthropic.py index ff8443071..4a63b000b 100644 --- a/flocks/provider/sdk/anthropic.py +++ b/flocks/provider/sdk/anthropic.py @@ -15,6 +15,7 @@ ChatResponse, StreamChunk, ) +from flocks.provider.sdk.openai_base import build_reasoning_metadata from flocks.utils.log import Log log = Log.create(service="provider.anthropic") @@ -24,6 +25,8 @@ class AnthropicProvider(BaseProvider): """Anthropic (Claude) provider with tool support.""" CATALOG_ID = "anthropic" + _INTERLEAVED_THINKING_BETA = "interleaved-thinking-2025-05-14" + _FINE_GRAINED_TOOL_STREAMING_BETA = "fine-grained-tool-streaming-2025-05-14" def __init__(self): super().__init__(provider_id="anthropic", name="Anthropic") @@ -129,6 +132,13 @@ def _format_messages_anthropic(messages: List[ChatMessage]) -> list: if msg.role == "assistant": content_blocks: list = [] + anthropic_thinking_blocks = None + if isinstance(msg.custom_settings, dict): + anthropic_thinking_blocks = msg.custom_settings.get("anthropic_thinking_blocks") + if isinstance(anthropic_thinking_blocks, list): + for block in anthropic_thinking_blocks: + if isinstance(block, dict) and block.get("type") in {"thinking", "redacted_thinking"}: + content_blocks.append(block) if msg.content: content_blocks.append({"type": "text", "text": msg.content}) if msg.tool_calls: @@ -174,6 +184,46 @@ def _format_messages_anthropic(messages: List[ChatMessage]) -> list: }) return formatted + @classmethod + def _beta_flags_for_request(cls, *, thinking_enabled: bool, has_tools: bool) -> Optional[List[str]]: + """Return beta feature flags needed for interleaved thinking.""" + if not thinking_enabled: + return None + betas = [cls._INTERLEAVED_THINKING_BETA] + if has_tools: + betas.append(cls._FINE_GRAINED_TOOL_STREAMING_BETA) + return betas + + @staticmethod + def _reasoning_metadata( + *, + provider_id: str, + model_id: str, + reasoning_source: str, + reasoning_field: str = "thinking", + reasoning_content: Optional[str] = None, + thinking_signature: Optional[str] = None, + redacted_thinking_data: Optional[str] = None, + ) -> Dict[str, Any]: + """Build Anthropic reasoning metadata for thinking/redacted blocks.""" + metadata = build_reasoning_metadata( + provider_id=provider_id, + model_id=model_id, + reasoning_content=reasoning_content, + reasoning_source=reasoning_source, + reasoning_field=reasoning_field, + ) or { + "providerID": provider_id, + "modelID": model_id, + "reasoningField": reasoning_field, + "reasoningSource": reasoning_source, + } + if thinking_signature: + metadata["thinkingSignature"] = thinking_signature + if redacted_thinking_data: + metadata["redactedThinkingData"] = redacted_thinking_data + return metadata + async def chat( self, model_id: str, @@ -222,8 +272,15 @@ async def chat( request_params["system"] = system_message if tools: request_params["tools"] = tools - - response = await client.messages.create(**request_params) + + betas = self._beta_flags_for_request( + thinking_enabled=bool(kwargs.get("thinking")), + has_tools=bool(tools), + ) + if betas and hasattr(client, "beta") and hasattr(client.beta, "messages"): + response = await client.beta.messages.create(**request_params, betas=betas) + else: + response = await client.messages.create(**request_params) # Parse response content content_parts = [] @@ -312,12 +369,19 @@ async def chat_stream( request_params["system"] = system_message if tools: request_params["tools"] = tools - + + betas = self._beta_flags_for_request( + thinking_enabled=bool(kwargs.get("thinking")), + has_tools=bool(tools), + ) + # Track tool calls during streaming - current_tool_calls: List[Dict[str, Any]] = [] current_tool_id: Optional[str] = None current_tool_name: Optional[str] = None current_tool_input: str = "" + current_reasoning_open = False + current_reasoning_signature: Optional[str] = None + current_redacted_thinking_data: Optional[str] = None # Track token usage from streaming events input_tokens: int = 0 output_tokens: int = 0 @@ -325,7 +389,13 @@ async def chat_stream( cache_write_tokens: int = 0 try: - async with client.messages.stream(**request_params) as stream: + stream_target = client.messages + stream_kwargs = dict(request_params) + if betas and hasattr(client, "beta") and hasattr(client.beta, "messages"): + stream_target = client.beta.messages + stream_kwargs["betas"] = betas + + async with stream_target.stream(**stream_kwargs) as stream: async for event in stream: # Handle different event types if event.type == "message_start": @@ -347,8 +417,30 @@ async def chat_stream( current_tool_name = block.name current_tool_input = "" elif block.type == "thinking": - # Start thinking block (reasoning content will stream in deltas) - pass + current_reasoning_open = True + current_reasoning_signature = None + current_redacted_thinking_data = None + yield StreamChunk( + event_type="reasoning-start", + metadata=self._reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_source="anthropic_thinking", + ), + ) + elif block.type == "redacted_thinking": + current_reasoning_open = True + current_reasoning_signature = None + current_redacted_thinking_data = getattr(block, "data", None) + yield StreamChunk( + event_type="reasoning-start", + metadata=self._reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_source="anthropic_redacted_thinking", + redacted_thinking_data=current_redacted_thinking_data, + ), + ) elif event.type == "content_block_delta": delta = event.delta @@ -361,24 +453,54 @@ async def chat_stream( event_type="reasoning", reasoning=delta.thinking, finish_reason=None, + metadata=self._reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_content=delta.thinking, + reasoning_source="anthropic_thinking", + ), ) + elif delta.type == "signature_delta": + current_reasoning_signature = getattr(delta, "signature", None) elif delta.type == "input_json_delta": current_tool_input += delta.partial_json elif event.type == "content_block_stop": # Finalize tool call if we were building one if current_tool_id and current_tool_name: - current_tool_calls.append({ - "id": current_tool_id, - "type": "function", - "function": { - "name": current_tool_name, - "arguments": current_tool_input or "{}", - }, - }) + yield StreamChunk( + delta="", + finish_reason=None, + tool_calls=[{ + "id": current_tool_id, + "type": "function", + "function": { + "name": current_tool_name, + "arguments": current_tool_input or "{}", + }, + }], + ) current_tool_id = None current_tool_name = None current_tool_input = "" + elif current_reasoning_open: + yield StreamChunk( + event_type="reasoning-end", + metadata=self._reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_source=( + "anthropic_redacted_thinking" + if current_redacted_thinking_data + else "anthropic_thinking" + ), + thinking_signature=current_reasoning_signature, + redacted_thinking_data=current_redacted_thinking_data, + ), + ) + current_reasoning_open = False + current_reasoning_signature = None + current_redacted_thinking_data = None elif event.type == "message_delta": # Capture output token count (accumulates during streaming) @@ -406,22 +528,11 @@ async def chat_stream( "cache_read": cache_read_tokens, "cache_write": cache_write_tokens, }) - # Yield final chunk with tool calls if any - if current_tool_calls: - yield StreamChunk( - delta="", - finish_reason="tool_calls", - tool_calls=current_tool_calls, - usage=usage_meta if usage_meta else None, - ) - # Clear tool calls after yielding to prevent duplicate sends - current_tool_calls = [] - else: - yield StreamChunk( - delta="", - finish_reason="stop", - usage=usage_meta if usage_meta else None, - ) + yield StreamChunk( + delta="", + finish_reason="stop", + usage=usage_meta if usage_meta else None, + ) except Exception as e: # Catch and log stream errors, but don't propagate harmless connection close errors @@ -430,14 +541,6 @@ async def chat_stream( # This is a known Anthropic SDK issue when stream ends after tool calls # The stream has actually completed successfully, so we can safely ignore this log.debug("anthropic.stream.harmless_close", {"error": str(e)}) - # Make sure we yielded a final chunk (only if not already sent) - if current_tool_calls: - log.debug("anthropic.stream.yielding_tools_after_error", {"count": len(current_tool_calls)}) - yield StreamChunk( - delta="", - finish_reason="tool_calls", - tool_calls=current_tool_calls, - ) elif "list index out of range" in error_msg: # Fallback to non-streaming request if streaming fails unexpectedly log.warn("anthropic.stream.fallback_to_chat", {"error": str(e)}) @@ -449,6 +552,56 @@ async def chat_stream( for block in response.content: if block.type == "text": content_parts.append(block.text) + elif block.type == "thinking": + thinking_text = getattr(block, "thinking", None) or "" + yield StreamChunk( + event_type="reasoning-start", + metadata=self._reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_source="anthropic_thinking", + ), + ) + if thinking_text: + yield StreamChunk( + event_type="reasoning", + reasoning=thinking_text, + metadata=self._reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_content=thinking_text, + reasoning_source="anthropic_thinking", + ), + ) + yield StreamChunk( + event_type="reasoning-end", + metadata=self._reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_source="anthropic_thinking", + thinking_signature=getattr(block, "signature", None), + ), + ) + elif block.type == "redacted_thinking": + redacted_data = getattr(block, "data", None) + yield StreamChunk( + event_type="reasoning-start", + metadata=self._reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_source="anthropic_redacted_thinking", + redacted_thinking_data=redacted_data, + ), + ) + yield StreamChunk( + event_type="reasoning-end", + metadata=self._reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_source="anthropic_redacted_thinking", + redacted_thinking_data=redacted_data, + ), + ) elif block.type == "tool_use": tool_calls.append({ "id": block.id, diff --git a/flocks/provider/sdk/openai.py b/flocks/provider/sdk/openai.py index a6d9a2ecc..ad33a18f3 100644 --- a/flocks/provider/sdk/openai.py +++ b/flocks/provider/sdk/openai.py @@ -18,8 +18,10 @@ ) from flocks.provider.sdk.openai_base import ( DEFAULT_HTTP_TIMEOUT, + build_reasoning_metadata, _coerce_bool, - extract_reasoning_content, + extract_reasoning_content_with_source, + extract_reasoning_details, format_openai_content, format_openai_messages, resolve_verify_ssl, @@ -252,12 +254,21 @@ async def chat_stream( continue # Handle reasoning/thinking content (for o1/o3/gpt-5 models) - reasoning_content = extract_reasoning_content(delta) - if reasoning_content: + reasoning_content, reasoning_source = extract_reasoning_content_with_source(delta) + reasoning_details = extract_reasoning_details(delta) + if reasoning_content is not None or reasoning_details: + reasoning_metadata = build_reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_content=reasoning_content, + reasoning_source=reasoning_source, + reasoning_details=reasoning_details, + ) yield StreamChunk( event_type="reasoning", - reasoning=reasoning_content, + reasoning=reasoning_content or "", finish_reason=None, + metadata=reasoning_metadata, ) # Handle text content diff --git a/flocks/provider/sdk/openai_base.py b/flocks/provider/sdk/openai_base.py index 905d5779c..90a7663fe 100644 --- a/flocks/provider/sdk/openai_base.py +++ b/flocks/provider/sdk/openai_base.py @@ -254,8 +254,13 @@ def format_openai_messages( d["content"] = content if m.tool_calls: d["tool_calls"] = m.tool_calls - if include_reasoning and role == "assistant" and m.reasoning: - d[reasoning_field] = m.reasoning + if role == "assistant": + if m.reasoning_details: + d["reasoning_details"] = m.reasoning_details + elif m.reasoning_content is not None: + d["reasoning_content"] = m.reasoning_content + elif include_reasoning and m.reasoning: + d[reasoning_field] = m.reasoning if m.tool_call_id: d["tool_call_id"] = m.tool_call_id if m.name: @@ -558,34 +563,122 @@ def _partial_suffix(text: str, tag: str) -> int: "thinking", "reasoning", ) +_REASONING_DETAILS_FIELDS = ( + "reasoning_details", +) -def extract_reasoning_content(delta) -> Optional[str]: - """Extract reasoning/thinking content from a streaming delta object. +def _normalize_reasoning_details(value: Any) -> Optional[List[Dict[str, Any]]]: + """Normalize provider reasoning_details payloads to a list of dicts.""" + if value is None: + return None - Supports multiple provider/proxy formats: - - Direct attribute: OpenAI o-series, DeepSeek R1 (reasoning_content) - - Anthropic-compatible proxies (thinking, thinking_content) - - model_extra dict: GLM, other OpenAI-compatible APIs + if isinstance(value, tuple): + value = list(value) + elif hasattr(value, "model_dump"): + dumped = value.model_dump() + if isinstance(dumped, list): + value = dumped + else: + value = [dumped] - This is a shared utility used by OpenAIBaseProvider, OpenAIProvider, - and OpenAICompatibleProvider. - """ - if delta is None: + if not isinstance(value, list): return None + + normalized: List[Dict[str, Any]] = [] + for item in value: + if isinstance(item, dict): + normalized.append(item) + elif hasattr(item, "model_dump"): + dumped = item.model_dump() + if isinstance(dumped, dict): + normalized.append(dumped) + elif hasattr(item, "__dict__"): + normalized.append(dict(item.__dict__)) + return normalized or None + + +def extract_reasoning_content_with_source(delta) -> tuple[Optional[str], Optional[str]]: + """Extract reasoning text plus the field/source it came from.""" + if delta is None: + return None, None for field in _REASONING_FIELDS: value = getattr(delta, field, None) if value is not None: - return value + return value, field extra = getattr(delta, "model_extra", None) if extra and isinstance(extra, dict): for field in _REASONING_FIELDS: value = extra.get(field) if value is not None: - return value + return value, field + return None, None + + +def extract_reasoning_details(delta) -> Optional[List[Dict[str, Any]]]: + """Extract structured reasoning details from streaming delta objects.""" + if delta is None: + return None + for field in _REASONING_DETAILS_FIELDS: + value = getattr(delta, field, None) + normalized = _normalize_reasoning_details(value) + if normalized: + return normalized + extra = getattr(delta, "model_extra", None) + if extra and isinstance(extra, dict): + for field in _REASONING_DETAILS_FIELDS: + normalized = _normalize_reasoning_details(extra.get(field)) + if normalized: + return normalized return None +def build_reasoning_metadata( + *, + provider_id: str, + model_id: str, + reasoning_content: Optional[str] = None, + reasoning_source: Optional[str] = None, + reasoning_details: Optional[List[Dict[str, Any]]] = None, + reasoning_field: Optional[str] = None, +) -> Optional[Dict[str, Any]]: + """Build normalized metadata for reasoning chunks.""" + if not any((reasoning_content is not None, reasoning_source, reasoning_details, reasoning_field)): + return None + + field = reasoning_field + if field is None: + field = "reasoning_details" if reasoning_details else "reasoning_content" + + metadata: Dict[str, Any] = { + "providerID": provider_id, + "modelID": model_id, + "reasoningField": field, + } + if reasoning_source: + metadata["reasoningSource"] = reasoning_source + if reasoning_content is not None: + metadata["reasoningContent"] = reasoning_content + if reasoning_details: + metadata["reasoningDetails"] = reasoning_details + return metadata + + +def extract_reasoning_content(delta) -> Optional[str]: + """Extract reasoning/thinking content from a streaming delta object. + + Supports multiple provider/proxy formats: + - Direct attribute: OpenAI o-series, DeepSeek R1 (reasoning_content) + - Anthropic-compatible proxies (thinking, thinking_content) + - model_extra dict: GLM, other OpenAI-compatible APIs + + This is a shared utility used by OpenAIBaseProvider, OpenAIProvider, + and OpenAICompatibleProvider. + """ + reasoning, _source = extract_reasoning_content_with_source(delta) + return reasoning + + class OpenAIBaseProvider(BaseProvider): """Base class for providers using OpenAI-compatible API. @@ -887,13 +980,22 @@ async def chat_stream( }) # 1) Native reasoning_content field (OpenAI o-series, DeepSeek R1, etc.) - reasoning = extract_reasoning_content(delta) - if reasoning: + reasoning, reasoning_source = extract_reasoning_content_with_source(delta) + reasoning_details = extract_reasoning_details(delta) + if reasoning is not None or reasoning_details: emitted_substantive_chunk = True + reasoning_metadata = build_reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_content=reasoning, + reasoning_source=reasoning_source, + reasoning_details=reasoning_details, + ) yield StreamChunk( event_type="reasoning", - reasoning=reasoning, + reasoning=reasoning or "", finish_reason=None, + metadata=reasoning_metadata, ) # 2) Regular content – extract inline tags if present @@ -904,10 +1006,17 @@ async def chat_stream( if seg_type == "reasoning": if seg_text: emitted_substantive_chunk = True + reasoning_metadata = build_reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_content=seg_text, + reasoning_source="think_tag", + ) yield StreamChunk( event_type="reasoning", reasoning=seg_text, finish_reason=None, + metadata=reasoning_metadata, ) else: if seg_text: @@ -941,10 +1050,17 @@ async def chat_stream( if seg_type == "reasoning": if seg_text: emitted_substantive_chunk = True + reasoning_metadata = build_reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_content=seg_text, + reasoning_source="think_tag", + ) yield StreamChunk( event_type="reasoning", reasoning=seg_text, finish_reason=None, + metadata=reasoning_metadata, ) else: if seg_text: diff --git a/flocks/provider/sdk/openai_compatible.py b/flocks/provider/sdk/openai_compatible.py index 3214163cf..391582be8 100644 --- a/flocks/provider/sdk/openai_compatible.py +++ b/flocks/provider/sdk/openai_compatible.py @@ -24,10 +24,12 @@ from flocks.provider.sdk.openai_base import ( DEFAULT_HTTP_TIMEOUT, ThinkTagExtractor, + build_reasoning_metadata, _coerce_bool, _normalize_stream_usage, _supports_include_usage_fallback, - extract_reasoning_content, + extract_reasoning_content_with_source, + extract_reasoning_details, format_openai_content, format_openai_messages, resolve_verify_ssl, @@ -326,13 +328,22 @@ async def chat_stream( }) # Handle reasoning/thinking content (DeepSeek R1, GLM, Claude proxies, etc.) - reasoning = extract_reasoning_content(delta) - if reasoning: + reasoning, reasoning_source = extract_reasoning_content_with_source(delta) + reasoning_details = extract_reasoning_details(delta) + if reasoning is not None or reasoning_details: emitted_substantive_chunk = True + reasoning_metadata = build_reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_content=reasoning, + reasoning_source=reasoning_source, + reasoning_details=reasoning_details, + ) yield StreamChunk( event_type="reasoning", - reasoning=reasoning, + reasoning=reasoning or "", finish_reason=None, + metadata=reasoning_metadata, ) # Handle text content – extract inline tags if present @@ -343,10 +354,17 @@ async def chat_stream( if seg_type == "reasoning": if seg_text: emitted_substantive_chunk = True + reasoning_metadata = build_reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_content=seg_text, + reasoning_source="think_tag", + ) yield StreamChunk( event_type="reasoning", reasoning=seg_text, finish_reason=None, + metadata=reasoning_metadata, ) else: if seg_text: @@ -362,10 +380,17 @@ async def chat_stream( if seg_type == "reasoning": if seg_text: emitted_substantive_chunk = True + reasoning_metadata = build_reasoning_metadata( + provider_id=self.id, + model_id=model_id, + reasoning_content=seg_text, + reasoning_source="think_tag", + ) yield StreamChunk( event_type="reasoning", reasoning=seg_text, finish_reason=None, + metadata=reasoning_metadata, ) else: if seg_text: diff --git a/flocks/provider/types.py b/flocks/provider/types.py index 68410e76a..6754be636 100644 --- a/flocks/provider/types.py +++ b/flocks/provider/types.py @@ -245,6 +245,7 @@ class ModelCapabilitiesV2(BaseModel): supports_tools: bool = True supports_vision: bool = False supports_reasoning: bool = False + interleaved: Optional[Dict[str, Any]] = None supports_temperature: bool = True supports_json_mode: bool = False supports_structured_output: bool = False diff --git a/flocks/server/routes/session.py b/flocks/server/routes/session.py index e2d04c205..b2de4a4d9 100644 --- a/flocks/server/routes/session.py +++ b/flocks/server/routes/session.py @@ -1396,17 +1396,15 @@ async def _prepare_replay_runtime( agent_name = getattr(user_message, "agent", None) or await Agent.default_agent() agent = await Agent.get(agent_name) or await Agent.get(DEFAULT_AGENT) - - model_info = getattr(user_message, "model", None) - provider_id = model_info.get("providerID") if isinstance(model_info, dict) else None - model_id = model_info.get("modelID") if isinstance(model_info, dict) else None - if not provider_id or not model_id: - dummy_request = type( - "_MessageReplayRequest", - (), - {"model": None, "agent": agent_name}, - )() - provider_id, model_id, _ = await _resolve_model(dummy_request, agent, session_id) + # Replay should follow the model that is active *now* for this session + # (current session pin / current default / current agent override), not the + # historical model stored on the original user message being replayed. + dummy_request = type( + "_MessageReplayRequest", + (), + {"model": None, "agent": agent_name}, + )() + provider_id, model_id, _ = await _resolve_model(dummy_request, agent, session_id) Provider._ensure_initialized() config = await Config.get() diff --git a/flocks/session/core/defaults.py b/flocks/session/core/defaults.py index a9e493a90..a36cb9fce 100644 --- a/flocks/session/core/defaults.py +++ b/flocks/session/core/defaults.py @@ -24,3 +24,13 @@ def fallback_model_id() -> str: # Doom-loop detection: if the last N tool calls in a single assistant # message are identical (same tool + same input), stop processing. DOOM_LOOP_THRESHOLD = 3 + +# Default assistant-step budget when an agent does not declare an explicit +# ``steps`` limit. Keeps tool loops finite without being too aggressive for +# longer coding/research tasks. +DEFAULT_MAX_TOOL_STEPS = 1000 + +# Cross-step loop guard thresholds. These complement the per-message doom-loop +# detection in ``stream_processor.py`` by stopping repeated tool-only turns. +REPEATED_EXACT_TOOL_CALL_HALT_THRESHOLD = 3 +SAME_TOOL_STREAK_HALT_THRESHOLD = 8 diff --git a/flocks/session/runner.py b/flocks/session/runner.py index 5b5ec791f..4888d21a4 100644 --- a/flocks/session/runner.py +++ b/flocks/session/runner.py @@ -26,7 +26,12 @@ from flocks.session.message import Message, MessageInfo, MessageRole from flocks.session.prompt import SessionPrompt from flocks.session.core.status import SessionStatus, SessionStatusRetry, SessionStatusBusy -from flocks.session.core.defaults import DOOM_LOOP_THRESHOLD +from flocks.session.core.defaults import ( + DEFAULT_MAX_TOOL_STEPS, + DOOM_LOOP_THRESHOLD, + REPEATED_EXACT_TOOL_CALL_HALT_THRESHOLD, + SAME_TOOL_STREAK_HALT_THRESHOLD, +) from flocks.session.lifecycle.retry import SessionRetry from flocks.session.lifecycle.compaction import SessionCompaction, CompactionPolicy from flocks.session.streaming.stream_processor import StreamProcessor @@ -45,6 +50,7 @@ from flocks.agent.agent import AgentInfo from flocks.agent.toolset import agent_declares_tool from flocks.provider.provider import Provider, ChatMessage +from flocks.provider.reasoning_replay import prepare_reasoning_for_replay from flocks.hooks.pipeline import HookPipeline, HookStage from flocks.tool.catalog import ( get_always_load_tool_names, @@ -245,6 +251,123 @@ def __init__( self._memory_bootstrap_data: Optional[Dict[str, Any]] = memory_bootstrap_data self._static_cache = static_cache if static_cache is not None else {} + @staticmethod + def _canonical_tool_signature(tool_name: str, arguments: Dict[str, Any]) -> str: + args_json = json.dumps( + arguments or {}, + ensure_ascii=False, + sort_keys=True, + separators=(",", ":"), + default=str, + ) + return f"{tool_name}:{args_json}" + + def _reset_tool_loop_guard(self, *, last_user_id: Optional[str] = None) -> Dict[str, Any]: + state = { + "last_user_id": last_user_id or "", + "last_signature": "", + "last_tool_name": "", + "exact_count": 0, + "same_tool_count": 0, + } + self._static_cache["tool_loop_guard"] = state + return state + + def _get_tool_loop_guard_state(self, *, last_user_id: Optional[str] = None) -> Dict[str, Any]: + state = self._static_cache.get("tool_loop_guard") + if not isinstance(state, dict): + state = self._reset_tool_loop_guard(last_user_id=last_user_id) + elif last_user_id is not None: + cached_user_id = str(state.get("last_user_id") or "") + if cached_user_id and cached_user_id != last_user_id: + state = self._reset_tool_loop_guard(last_user_id=last_user_id) + else: + state["last_user_id"] = last_user_id + return state + + def _should_warn_about_tool_loop(self, *, last_user_id: str) -> bool: + state = self._get_tool_loop_guard_state(last_user_id=last_user_id) + return ( + int(state.get("exact_count", 0)) >= max(2, REPEATED_EXACT_TOOL_CALL_HALT_THRESHOLD - 1) + or int(state.get("same_tool_count", 0)) >= max(4, SAME_TOOL_STREAK_HALT_THRESHOLD // 2) + ) + + def _build_tool_loop_halt_message( + self, + *, + tool_name: str, + reason: str, + count: int, + ) -> str: + if reason == "repeated_exact_tool_call": + return ( + f"Stopped the loop because `{tool_name}` was called {count} times in a row " + "with the same arguments and kept producing a tool-only turn. Change strategy, " + "summarize the blocker, or answer directly instead of repeating the exact same call." + ) + return ( + f"Stopped the loop because `{tool_name}` was the only tool used for {count} consecutive " + "tool-only turns. Change strategy, summarize the blocker, or answer directly instead of " + "continuing the same tool pattern." + ) + + def _update_tool_loop_guard( + self, + result: StepResult, + *, + last_user_id: str, + ) -> Dict[str, Any]: + visible_text = bool((result.content or "").strip()) + if visible_text or len(result.tool_calls) != 1: + self._reset_tool_loop_guard(last_user_id=last_user_id) + return {"action": "allow"} + + tool_call = result.tool_calls[0] + state = self._get_tool_loop_guard_state(last_user_id=last_user_id) + signature = self._canonical_tool_signature(tool_call.name, tool_call.arguments) + + exact_count = 1 + if state.get("last_signature") == signature: + exact_count = int(state.get("exact_count", 0)) + 1 + + same_tool_count = 1 + if state.get("last_tool_name") == tool_call.name: + same_tool_count = int(state.get("same_tool_count", 0)) + 1 + + state.update({ + "last_user_id": last_user_id, + "last_signature": signature, + "last_tool_name": tool_call.name, + "exact_count": exact_count, + "same_tool_count": same_tool_count, + }) + + if exact_count >= REPEATED_EXACT_TOOL_CALL_HALT_THRESHOLD: + return { + "action": "halt", + "reason": "repeated_exact_tool_call", + "tool_name": tool_call.name, + "count": exact_count, + } + if same_tool_count >= SAME_TOOL_STREAK_HALT_THRESHOLD: + return { + "action": "halt", + "reason": "same_tool_streak", + "tool_name": tool_call.name, + "count": same_tool_count, + } + if ( + exact_count >= max(2, REPEATED_EXACT_TOOL_CALL_HALT_THRESHOLD - 1) + or same_tool_count >= max(4, SAME_TOOL_STREAK_HALT_THRESHOLD // 2) + ): + return { + "action": "warn", + "tool_name": tool_call.name, + "exact_count": exact_count, + "same_tool_count": same_tool_count, + } + return {"action": "allow"} + async def _list_callable_tool_infos_for_turn( self, agent: AgentInfo, @@ -874,7 +997,7 @@ async def _process_step( log.debug("runner.session_agent.error", {"error": str(e)}) # Check if we've reached max steps (matching Flocks logic) - max_steps = agent.steps if hasattr(agent, 'steps') and agent.steps is not None else float('inf') + max_steps = agent.steps if hasattr(agent, 'steps') and agent.steps is not None else DEFAULT_MAX_TOOL_STEPS is_last_step = self._step >= max_steps # Get provider @@ -964,37 +1087,16 @@ async def channel_context_prompt_factory() -> Optional[str]: from flocks.session.prompt_strings import PROMPT_TOOL_RESULTS_AVAILABLE system_prompts.append(PROMPT_TOOL_RESULTS_AVAILABLE) - # 检查最近几条消息中是否有重复的工具调用(轻量级警告) - if has_tool_result and self._step > 2: - # 收集最近的工具调用签名 - recent_tool_sigs = [] - for msg in reversed(messages[-3:]): # 检查最近3条消息 - if msg.role == MessageRole.ASSISTANT: - msg_parts = await Message.parts(msg.id, self.session.id) - for p in msg_parts: - if (getattr(p, "type", None) == "tool" and - hasattr(p, 'state') and - hasattr(p.state, 'status') and - p.state.status == "completed"): - tool_name = getattr(p, 'tool', '') - tool_input = getattr(p.state, 'input', {}) - sig = f"{tool_name}:{json.dumps(tool_input, sort_keys=True)}" - recent_tool_sigs.append(sig) - - # 如果有重复的工具调用签名,添加提示(不禁用工具) - if recent_tool_sigs: - sig_counts = {} - for sig in recent_tool_sigs: - sig_counts[sig] = sig_counts.get(sig, 0) + 1 - - repeated_sigs = [sig for sig, count in sig_counts.items() if count >= 2] - if repeated_sigs: - log.warn("runner.repeated_tool_calls_detected", { - "repeated_sigs": repeated_sigs, - "step": self._step, - }) - from flocks.session.prompt_strings import PROMPT_REPEATED_TOOL_CALLS - system_prompts.append(PROMPT_REPEATED_TOOL_CALLS) + if has_tool_result and self._should_warn_about_tool_loop(last_user_id=last_user.id): + state = self._get_tool_loop_guard_state(last_user_id=last_user.id) + log.warn("runner.repeated_tool_calls_detected", { + "tool_name": state.get("last_tool_name"), + "exact_count": state.get("exact_count", 0), + "same_tool_count": state.get("same_tool_count", 0), + "step": self._step, + }) + from flocks.session.prompt_strings import PROMPT_REPEATED_TOOL_CALLS + system_prompts.append(PROMPT_REPEATED_TOOL_CALLS) # Hook pipeline: chat.message stage try: @@ -1256,6 +1358,33 @@ async def channel_context_prompt_factory() -> Optional[str]: ) return StepResult(action="stop", error=empty_error_msg) + tool_loop_guard = self._update_tool_loop_guard( + result, + last_user_id=last_user.id, + ) + if tool_loop_guard.get("action") == "halt": + halt_message = self._build_tool_loop_halt_message( + tool_name=str(tool_loop_guard.get("tool_name") or "tool"), + reason=str(tool_loop_guard.get("reason") or "same_tool_streak"), + count=int(tool_loop_guard.get("count", 0) or 0), + ) + log.warn("runner.tool_loop_guard_halt", { + "tool_name": tool_loop_guard.get("tool_name"), + "reason": tool_loop_guard.get("reason"), + "count": tool_loop_guard.get("count"), + "step": self._step, + }) + await Message.update( + self.session.id, + assistant_msg.id, + content=halt_message, + ) + result = StepResult( + action="stop", + content=halt_message, + usage=result.usage, + ) + # Success! Update finish reason finish = "tool-calls" if result.tool_calls else "stop" await Message.update(self.session.id, assistant_msg.id, finish=finish) @@ -1872,6 +2001,12 @@ async def _to_chat_messages( ctx_window_tokens = self._get_context_window_tokens() tool_result_refs: List[Dict[str, Any]] = [] turn_index = 0 + active_model = Provider.resolve_model(self.provider_id, self.model_id) + active_interleaved = ( + getattr(active_model.capabilities, "interleaved", None) + if active_model and getattr(active_model, "capabilities", None) + else None + ) # Identify the last USER message — only that one keeps real image # bytes in its content blocks. Earlier turns get a short text @@ -2050,6 +2185,10 @@ async def _to_chat_messages( assistant_content_parts = [] assistant_reasoning_parts = [] + assistant_reasoning_content_parts = [] + assistant_reasoning_details: List[Dict[str, Any]] = [] + assistant_reasoning_sources: set[str] = set() + assistant_custom_settings: Dict[str, Any] = {} # Structured tool calls for the assistant message (OpenAI format) structured_tool_calls: List[Dict[str, Any]] = [] # Corresponding tool-result messages (role="tool") @@ -2064,6 +2203,43 @@ async def _to_chat_messages( assistant_content_parts.append(part.text) elif part.type == "reasoning" and hasattr(part, 'text'): assistant_reasoning_parts.append(part.text) + part_metadata = getattr(part, "metadata", None) or {} + reasoning_meta = part_metadata.get("reasoning") if isinstance(part_metadata, dict) else None + reasoning_content = None + reasoning_source = None + reasoning_details = None + if isinstance(reasoning_meta, dict): + reasoning_content = reasoning_meta.get("content") + reasoning_source = reasoning_meta.get("source") + reasoning_details = reasoning_meta.get("details") + if reasoning_content is None and isinstance(part_metadata, dict): + reasoning_content = part_metadata.get("reasoningContent") + if not reasoning_source and isinstance(part_metadata, dict): + reasoning_source = part_metadata.get("reasoningSource") + if reasoning_details is None and isinstance(part_metadata, dict): + reasoning_details = part_metadata.get("reasoningDetails") + + if reasoning_content is not None: + assistant_reasoning_content_parts.append(reasoning_content) + if reasoning_source: + assistant_reasoning_sources.add(reasoning_source) + if isinstance(reasoning_details, list): + for item in reasoning_details: + if isinstance(item, dict): + assistant_reasoning_details.append(item) + thinking_signature = part_metadata.get("thinkingSignature") if isinstance(part_metadata, dict) else None + redacted_thinking = part_metadata.get("redactedThinkingData") if isinstance(part_metadata, dict) else None + if redacted_thinking: + assistant_custom_settings.setdefault("anthropic_thinking_blocks", []).append({ + "type": "redacted_thinking", + "data": redacted_thinking, + }) + elif thinking_signature: + assistant_custom_settings.setdefault("anthropic_thinking_blocks", []).append({ + "type": "thinking", + "thinking": part.text, + "signature": thinking_signature, + }) # Tool parts - use structured OpenAI function-calling format elif part.type == "tool" and hasattr(part, 'state'): @@ -2166,12 +2342,23 @@ async def _to_chat_messages( # Add assistant message if assistant_content_parts or structured_tool_calls: - chat_messages.append(ChatMessage( + assistant_message = ChatMessage( role="assistant", content="\n\n".join(assistant_content_parts) if assistant_content_parts else "", reasoning="".join(assistant_reasoning_parts) if assistant_reasoning_parts else None, + reasoning_content="".join(assistant_reasoning_content_parts) if assistant_reasoning_content_parts else None, + reasoning_details=assistant_reasoning_details if assistant_reasoning_details else None, + reasoning_source=sorted(assistant_reasoning_sources)[0] if assistant_reasoning_sources else None, tool_calls=structured_tool_calls if structured_tool_calls else None, - )) + custom_settings=assistant_custom_settings, + ) + assistant_message = prepare_reasoning_for_replay( + provider_id=self.provider_id, + model_id=self.model_id, + message=assistant_message, + interleaved=active_interleaved, + ) + chat_messages.append(assistant_message) # Append tool-result messages immediately after the assistant message chat_messages.extend(pending_tool_results) else: @@ -2482,40 +2669,66 @@ def _build_llm_response_payload( # provider used `event_type == 'reasoning'` to overload `delta` for # reasoning text. event_type = getattr(chunk, 'event_type', None) + chunk_metadata = getattr(chunk, 'metadata', None) or {} + reasoning_event_types = {"reasoning", "reasoning-start", "reasoning-end"} + + if event_type == "reasoning-start" and not hasattr(self, '_current_reasoning_id'): + reasoning_id_counter += 1 + self._current_reasoning_id = f"reasoning-{reasoning_id_counter}" + await processor.process_event(ReasoningStartEvent( + id=self._current_reasoning_id, + metadata=chunk_metadata, + )) + + if event_type == "reasoning-end" and hasattr(self, '_current_reasoning_id'): + await processor.process_event(ReasoningEndEvent( + id=self._current_reasoning_id, + metadata=chunk_metadata, + )) + delattr(self, '_current_reasoning_id') chunk_reasoning = getattr(chunk, 'reasoning', None) or None if not chunk_reasoning and event_type == 'reasoning': # Older providers signal reasoning via event_type and put the # reasoning text in `delta` (no separate `reasoning` field). chunk_reasoning = getattr(chunk, 'delta', '') or None + has_reasoning_metadata = bool( + chunk_metadata.get("reasoningDetails") + or chunk_metadata.get("reasoningContent") is not None + or chunk_metadata.get("reasoningField") + ) # Treat `delta` as text only when it isn't already consumed as # reasoning above. This preserves backward compatibility with # providers that emit reasoning-only chunks via `event_type`. chunk_text = '' - if event_type != 'reasoning' or getattr(chunk, 'reasoning', None): + if event_type not in reasoning_event_types or getattr(chunk, 'reasoning', None): chunk_text = getattr(chunk, 'delta', '') or '' chunk_tool_calls = getattr(chunk, 'tool_calls', None) # 1) Process reasoning delta (start reasoning block on first sight). - if chunk_reasoning: + if chunk_reasoning or (event_type == 'reasoning' and has_reasoning_metadata): + reasoning_text = chunk_reasoning or "" chunk_counts["reasoning"] += 1 log.debug("runner.reasoning.received", { - "length": len(chunk_reasoning), - "text_preview": chunk_reasoning[:50], + "length": len(reasoning_text), + "text_preview": reasoning_text[:50], }) if not hasattr(self, '_current_reasoning_id'): reasoning_id_counter += 1 self._current_reasoning_id = f"reasoning-{reasoning_id_counter}" await processor.process_event(ReasoningStartEvent( - id=self._current_reasoning_id + id=self._current_reasoning_id, + metadata=chunk_metadata, )) - await processor.process_event(ReasoningDeltaEvent( - id=self._current_reasoning_id, - text=chunk_reasoning, - )) + if chunk_reasoning: + await processor.process_event(ReasoningDeltaEvent( + id=self._current_reasoning_id, + text=chunk_reasoning, + metadata=chunk_metadata, + )) # 2) End reasoning block when this chunk also carries non-reasoning # content (or once the stream moves away from reasoning). diff --git a/tests/agent/test_agent_factory.py b/tests/agent/test_agent_factory.py index 296e6c7e2..1c6a1bd40 100644 --- a/tests/agent/test_agent_factory.py +++ b/tests/agent/test_agent_factory.py @@ -219,6 +219,15 @@ def test_loads_model(self, tmp_path): assert agent.model.model_id == "gpt-4" assert agent.model.provider_id == "openai" + def test_returns_none_on_invalid_model_config(self, tmp_path): + agent_dir = _write_agent_dir(tmp_path, """ + name: bad_model_agent + model: + temperature: 0.3 + """) + + assert load_agent(agent_dir) is None + def test_loads_optional_fields(self, tmp_path): agent_dir = _write_agent_dir(tmp_path, """ name: full_agent @@ -341,6 +350,28 @@ def test_scans_extra_dirs(self, tmp_path): result = scan_and_load(dirs=[extra_dir]) assert "extra_agent" in result + def test_skips_invalid_agent_and_continues_scan(self, tmp_path): + """A malformed agent config should not prevent other agents from loading.""" + extra_dir = tmp_path / "extra" + bad_dir = extra_dir / "bad_agent" + good_dir = extra_dir / "good_agent" + bad_dir.mkdir(parents=True) + good_dir.mkdir(parents=True) + (bad_dir / "agent.yaml").write_text( + textwrap.dedent(""" + name: bad_agent + model: + temperature: 0.3 + """), + encoding="utf-8", + ) + (good_dir / "agent.yaml").write_text("name: good_agent\n", encoding="utf-8") + + result = scan_and_load(dirs=[extra_dir]) + + assert "bad_agent" not in result + assert "good_agent" in result + def test_name_conflict_skips_duplicate(self, tmp_path): """When two dirs have the same agent name, first wins.""" first = tmp_path / "first" diff --git a/tests/provider/test_anthropic_interleaved.py b/tests/provider/test_anthropic_interleaved.py new file mode 100644 index 000000000..bec647e4f --- /dev/null +++ b/tests/provider/test_anthropic_interleaved.py @@ -0,0 +1,224 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from flocks.provider.provider import ChatMessage +from flocks.provider.sdk.anthropic import AnthropicProvider + + +class _FakeAsyncStream: + def __init__(self, events): + self._events = list(events) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def __aiter__(self): + self._iter = iter(self._events) + return self + + async def __anext__(self): + try: + return next(self._iter) + except StopIteration as exc: + raise StopAsyncIteration from exc + + +def test_anthropic_formatter_includes_preserved_thinking_blocks(): + message = ChatMessage( + role="assistant", + content="Done", + tool_calls=[ + { + "id": "tool_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q":"weather"}'}, + } + ], + custom_settings={ + "anthropic_thinking_blocks": [ + {"type": "thinking", "thinking": "plan", "signature": "sig123"} + ] + }, + ) + + formatted = AnthropicProvider._format_messages_anthropic([message]) + + assert formatted[0]["content"][0] == { + "type": "thinking", + "thinking": "plan", + "signature": "sig123", + } + assert formatted[0]["content"][1] == {"type": "text", "text": "Done"} + assert formatted[0]["content"][2]["type"] == "tool_use" + + +def test_anthropic_formatter_includes_redacted_thinking_blocks(): + message = ChatMessage( + role="assistant", + content="Done", + custom_settings={ + "anthropic_thinking_blocks": [ + {"type": "redacted_thinking", "data": "opaque_blob"} + ] + }, + ) + + formatted = AnthropicProvider._format_messages_anthropic([message]) + + assert formatted[0]["content"][0] == { + "type": "redacted_thinking", + "data": "opaque_blob", + } + assert formatted[0]["content"][1] == {"type": "text", "text": "Done"} + + +@pytest.mark.asyncio +async def test_anthropic_stream_uses_beta_interleaved_and_yields_tools_on_block_stop(): + provider = AnthropicProvider() + beta_stream = MagicMock() + beta_stream.stream.return_value = _FakeAsyncStream( + [ + SimpleNamespace( + type="message_start", + message=SimpleNamespace( + usage=SimpleNamespace( + input_tokens=11, + output_tokens=0, + cache_read_input_tokens=2, + cache_creation_input_tokens=0, + ) + ), + ), + SimpleNamespace( + type="content_block_start", + content_block=SimpleNamespace(type="thinking"), + ), + SimpleNamespace( + type="content_block_delta", + delta=SimpleNamespace(type="thinking_delta", thinking="plan"), + ), + SimpleNamespace( + type="content_block_delta", + delta=SimpleNamespace(type="signature_delta", signature="sig123"), + ), + SimpleNamespace(type="content_block_stop"), + SimpleNamespace( + type="content_block_start", + content_block=SimpleNamespace(type="tool_use", id="tool_1", name="search"), + ), + SimpleNamespace( + type="content_block_delta", + delta=SimpleNamespace(type="input_json_delta", partial_json='{"q":"weather"}'), + ), + SimpleNamespace(type="content_block_stop"), + SimpleNamespace( + type="message_delta", + usage=SimpleNamespace(output_tokens=7), + ), + SimpleNamespace(type="message_stop"), + ] + ) + provider._client = SimpleNamespace( + beta=SimpleNamespace(messages=beta_stream), + messages=SimpleNamespace(stream=AsyncMock()), + ) + + chunks = [ + chunk + async for chunk in provider.chat_stream( + "claude-sonnet-4-6", + [ChatMessage(role="user", content="hello")], + tools=[ + { + "type": "function", + "function": { + "name": "search", + "description": "search", + "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, + }, + } + ], + thinking={"type": "enabled", "budget_tokens": 2048}, + ) + ] + + kwargs = beta_stream.stream.call_args.kwargs + assert kwargs["betas"] == [ + "interleaved-thinking-2025-05-14", + "fine-grained-tool-streaming-2025-05-14", + ] + + assert chunks[0].event_type == "reasoning-start" + assert chunks[1].event_type == "reasoning" + assert chunks[1].reasoning == "plan" + assert chunks[2].event_type == "reasoning-end" + assert chunks[2].metadata["thinkingSignature"] == "sig123" + assert chunks[3].tool_calls[0]["id"] == "tool_1" + assert chunks[3].finish_reason is None + assert chunks[-1].finish_reason == "stop" + assert chunks[-1].usage == { + "prompt_tokens": 11, + "completion_tokens": 7, + "total_tokens": 18, + "cache_read_input_tokens": 2, + "cache_creation_input_tokens": 0, + } + + +@pytest.mark.asyncio +async def test_anthropic_stream_preserves_redacted_thinking_metadata(): + provider = AnthropicProvider() + beta_stream = MagicMock() + beta_stream.stream.return_value = _FakeAsyncStream( + [ + SimpleNamespace( + type="content_block_start", + content_block=SimpleNamespace(type="redacted_thinking", data="opaque_blob"), + ), + SimpleNamespace(type="content_block_stop"), + SimpleNamespace( + type="content_block_start", + content_block=SimpleNamespace(type="tool_use", id="tool_1", name="search"), + ), + SimpleNamespace( + type="content_block_delta", + delta=SimpleNamespace(type="input_json_delta", partial_json='{"q":"weather"}'), + ), + SimpleNamespace(type="content_block_stop"), + SimpleNamespace(type="message_stop"), + ] + ) + provider._client = SimpleNamespace( + beta=SimpleNamespace(messages=beta_stream), + messages=SimpleNamespace(stream=AsyncMock()), + ) + + chunks = [ + chunk + async for chunk in provider.chat_stream( + "claude-sonnet-4-6", + [ChatMessage(role="user", content="hello")], + tools=[ + { + "type": "function", + "function": { + "name": "search", + "description": "search", + "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, + }, + } + ], + thinking={"type": "enabled", "budget_tokens": 2048}, + ) + ] + + assert chunks[0].event_type == "reasoning-start" + assert chunks[0].metadata["redactedThinkingData"] == "opaque_blob" + assert chunks[1].event_type == "reasoning-end" + assert chunks[1].metadata["redactedThinkingData"] == "opaque_blob" + assert chunks[2].tool_calls[0]["id"] == "tool_1" diff --git a/tests/provider/test_chinese_providers.py b/tests/provider/test_chinese_providers.py index a4c7930a9..7ecfd1291 100644 --- a/tests/provider/test_chinese_providers.py +++ b/tests/provider/test_chinese_providers.py @@ -147,6 +147,8 @@ def test_deepseek_catalog(self): r1 = next(m for m in models if m.id == "deepseek-reasoner") assert r1.capabilities.supports_reasoning is True + assert r1.capabilities.interleaved["field"] == "reasoning_content" + assert r1.capabilities.interleaved["placeholder"] == " " assert r1.pricing.currency == "CNY" assert r1.pricing.output == 16.0 @@ -172,6 +174,8 @@ def test_moonshot_catalog(self): k26 = next(m for m in models if m.id == "kimi-k2.6") assert k26.capabilities.supports_reasoning is True + assert k26.capabilities.interleaved["field"] == "reasoning_content" + assert k26.capabilities.interleaved["placeholder"] == " " assert k26.pricing.currency == "CNY" assert k26.pricing.cache_read == 1.3 assert k26.limits.context_window == 256000 @@ -188,6 +192,7 @@ def test_zhipu_catalog(self): } turbo = next(m for m in models if m.id == "glm-5-turbo") + assert turbo.capabilities.interleaved["field"] == "reasoning_content" assert turbo.pricing.output == 26.0 assert turbo.limits.context_window == 202752 @@ -197,6 +202,9 @@ def test_minimax_catalog(self): "minimax-m2.7", "minimax-m2.5", } + m27 = next(m for m in models if m.id == "minimax-m2.7") + assert m27.capabilities.supports_reasoning is True + assert m27.capabilities.interleaved["field"] == "reasoning_details" def test_stepfun_catalog(self): models = get_provider_model_definitions("stepfun") @@ -223,6 +231,7 @@ def test_threatbook_cn_llm_catalog(self): kimi = next(m for m in models if m.id == "kimi-k2.6") assert kimi.capabilities.supports_reasoning is True + assert kimi.capabilities.interleaved["field"] == "reasoning_content" assert kimi.pricing.currency == "CNY" assert kimi.pricing.cache_read == 1.3 assert kimi.pricing.input == 6.5 @@ -246,6 +255,7 @@ def test_threatbook_io_llm_catalog(self): } m27 = next(m for m in models if m.id == "minimax-m2.7") + assert m27.capabilities.interleaved["field"] == "reasoning_details" assert m27.pricing.currency == "CNY" assert m27.pricing.input == 2.1 assert m27.limits.context_window == 196608 diff --git a/tests/provider/test_openai_base_provider.py b/tests/provider/test_openai_base_provider.py index 19384ec2a..6f14efafe 100644 --- a/tests/provider/test_openai_base_provider.py +++ b/tests/provider/test_openai_base_provider.py @@ -14,7 +14,13 @@ import pytest import flocks.provider.sdk.openai_base as openai_base_module -from flocks.provider.sdk.openai_base import OpenAIBaseProvider, extract_reasoning_content +from flocks.provider.sdk.openai_base import ( + OpenAIBaseProvider, + build_reasoning_metadata, + extract_reasoning_content, + extract_reasoning_content_with_source, + extract_reasoning_details, +) from flocks.provider.provider import ModelInfo, ModelCapabilities, ProviderConfig @@ -512,6 +518,26 @@ class TestExtractReasoningContent: def test_extract_reasoning_content_none_delta(self): assert extract_reasoning_content(None) is None + def test_extract_reasoning_content_with_source_uses_model_extra(self): + delta = SimpleNamespace(model_extra={"reasoning_content": "deep thought"}) + reasoning, source = extract_reasoning_content_with_source(delta) + assert reasoning == "deep thought" + assert source == "reasoning_content" + + def test_extract_reasoning_details_preserves_dicts(self): + details = [{"type": "reasoning.summary", "text": "step", "signature": "sig"}] + delta = SimpleNamespace(model_extra={"reasoning_details": details}) + assert extract_reasoning_details(delta) == details + + def test_build_reasoning_metadata_prefers_reasoning_details_field(self): + metadata = build_reasoning_metadata( + provider_id="openrouter", + model_id="minimax-m2.7", + reasoning_details=[{"type": "reasoning.summary", "text": "step"}], + ) + assert metadata["reasoningField"] == "reasoning_details" + assert metadata["providerID"] == "openrouter" + async def _stream_from_chunks(*chunks): for chunk in chunks: diff --git a/tests/provider/test_provider.py b/tests/provider/test_provider.py index 40ff7bfd2..3859f2d17 100644 --- a/tests/provider/test_provider.py +++ b/tests/provider/test_provider.py @@ -3,6 +3,7 @@ """ import pytest +from types import SimpleNamespace from flocks.provider.provider import ( Provider, ChatMessage, @@ -90,6 +91,79 @@ async def test_get_model(): assert unknown is None +def test_resolve_model_prefers_provider_specific_runtime_model(monkeypatch): + provider_model = SimpleNamespace( + id="shared-model", + capabilities=SimpleNamespace(interleaved={"field": "reasoning_content"}), + ) + wrong_global_model = SimpleNamespace( + id="shared-model", + capabilities=SimpleNamespace(interleaved={"field": "reasoning_details"}), + ) + fake_provider = SimpleNamespace( + get_model_definitions=lambda: [provider_model], + get_models=lambda: [], + _config_models=[], + ) + + monkeypatch.setattr(Provider, "_initialized", True) + monkeypatch.setattr(Provider, "_providers", {"deepseek": fake_provider}) + monkeypatch.setattr(Provider, "_models", {"shared-model": wrong_global_model}) + + resolved = Provider.resolve_model("deepseek", "shared-model") + + assert resolved is provider_model + assert resolved.capabilities.interleaved["field"] == "reasoning_content" + + +def test_resolve_model_infers_interleaved_for_runtime_discovered_reasoning_model(monkeypatch): + provider_model = SimpleNamespace( + id="qwen3-max", + capabilities=SimpleNamespace(interleaved=None), + ) + fake_provider = SimpleNamespace( + get_model_definitions=lambda: [provider_model], + get_models=lambda: [], + _config_models=[], + _config=SimpleNamespace(base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"), + ) + + monkeypatch.setattr(Provider, "_initialized", True) + monkeypatch.setattr(Provider, "_providers", {"openai-compatible": fake_provider}) + monkeypatch.setattr(Provider, "_models", {}) + + resolved = Provider.resolve_model("openai-compatible", "qwen3-max") + + assert resolved is provider_model + assert resolved.capabilities.interleaved == { + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote", + } + + +def test_resolve_model_does_not_infer_interleaved_for_non_reasoning_model(monkeypatch): + provider_model = SimpleNamespace( + id="deepseek-chat", + capabilities=SimpleNamespace(interleaved=None), + ) + fake_provider = SimpleNamespace( + get_model_definitions=lambda: [provider_model], + get_models=lambda: [], + _config_models=[], + _config=SimpleNamespace(base_url="https://api.deepseek.com/v1"), + ) + + monkeypatch.setattr(Provider, "_initialized", True) + monkeypatch.setattr(Provider, "_providers", {"custom-demo": fake_provider}) + monkeypatch.setattr(Provider, "_models", {}) + + resolved = Provider.resolve_model("custom-demo", "deepseek-chat") + + assert resolved is provider_model + assert resolved.capabilities.interleaved is None + + @pytest.mark.asyncio async def test_provider_models(): """Test provider model listing""" diff --git a/tests/provider/test_reasoning_replay.py b/tests/provider/test_reasoning_replay.py new file mode 100644 index 000000000..65269b046 --- /dev/null +++ b/tests/provider/test_reasoning_replay.py @@ -0,0 +1,180 @@ +from flocks.provider.provider import ChatMessage +from flocks.provider.reasoning_replay import prepare_reasoning_for_replay +from flocks.provider.sdk.openai_base import format_openai_messages + + +def test_prepare_reasoning_promotes_internal_reasoning_for_promote_policy(): + message = ChatMessage( + role="assistant", + content="", + reasoning="Need to inspect the tool result.", + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + ) + + prepared = prepare_reasoning_for_replay( + provider_id="alibaba", + model_id="qwen3-max", + message=message, + interleaved={ + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote", + }, + ) + + assert prepared.reasoning_content == "Need to inspect the tool result." + assert prepared.reasoning_source == "promoted_reasoning" + + +def test_prepare_reasoning_uses_placeholder_for_strict_echo_provider(): + message = ChatMessage( + role="assistant", + content="", + reasoning="Prior provider chain of thought", + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + ) + + prepared = prepare_reasoning_for_replay( + provider_id="deepseek", + model_id="deepseek-reasoner", + message=message, + interleaved={ + "field": "reasoning_content", + "echo": "tool_calls", + "placeholder": " ", + "cross_provider_policy": "placeholder", + }, + ) + + assert prepared.reasoning_content == " " + assert prepared.reasoning_source == "placeholder" + + +def test_prepare_reasoning_preserves_reasoning_details(): + details = [{"type": "reasoning.summary", "text": "step", "signature": "sig"}] + message = ChatMessage( + role="assistant", + content="answer", + reasoning_details=details, + ) + + prepared = prepare_reasoning_for_replay( + provider_id="minimax", + model_id="minimax-m2.7", + message=message, + interleaved={ + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote", + }, + ) + + assert prepared.reasoning_details == details + assert prepared.reasoning_source == "native_reasoning_details" + + +def test_prepare_reasoning_drops_details_when_target_uses_reasoning_content(): + message = ChatMessage( + role="assistant", + content="", + reasoning="Short internal summary", + reasoning_details=[{"type": "thinking", "thinking": "opaque provider scratchpad"}], + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + ) + + prepared = prepare_reasoning_for_replay( + provider_id="deepseek", + model_id="deepseek-reasoner", + message=message, + interleaved={ + "field": "reasoning_content", + "echo": "tool_calls", + "placeholder": " ", + "cross_provider_policy": "placeholder", + }, + ) + + assert prepared.reasoning_content == " " + assert prepared.reasoning_details is None + assert prepared.reasoning_source == "placeholder" + + +def test_prepare_reasoning_promotes_reasoning_content_into_reasoning_details(): + message = ChatMessage( + role="assistant", + content="", + reasoning_content="Native scratchpad", + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + ) + + prepared = prepare_reasoning_for_replay( + provider_id="minimax", + model_id="minimax-m2.7", + message=message, + interleaved={ + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote", + }, + ) + + assert prepared.reasoning_content is None + assert prepared.reasoning_details == [ + {"type": "reasoning.summary", "text": "Native scratchpad"} + ] + assert prepared.reasoning_source == "promoted_reasoning_content" + + +def test_format_openai_messages_prefers_reasoning_details_over_reasoning_content(): + formatted = format_openai_messages( + [ + ChatMessage( + role="assistant", + content="", + reasoning="summary", + reasoning_content="native scratchpad", + reasoning_details=[{"type": "reasoning.summary", "text": "step"}], + ) + ] + ) + + assert formatted[0]["reasoning_details"] == [{"type": "reasoning.summary", "text": "step"}] + assert "reasoning_content" not in formatted[0] + + +def test_format_openai_messages_serializes_reasoning_content_without_include_reasoning(): + formatted = format_openai_messages( + [ + ChatMessage( + role="assistant", + content="", + reasoning_content="native scratchpad", + ) + ] + ) + + assert formatted[0]["reasoning_content"] == "native scratchpad" diff --git a/tests/server/routes/test_session_routes.py b/tests/server/routes/test_session_routes.py index 7ab8593dd..0865fabd8 100644 --- a/tests/server/routes/test_session_routes.py +++ b/tests/server/routes/test_session_routes.py @@ -12,6 +12,9 @@ from __future__ import annotations +from types import SimpleNamespace +from unittest.mock import AsyncMock + import pytest from fastapi import HTTPException, status from httpx import AsyncClient @@ -586,6 +589,195 @@ async def _fake_run_existing_user_message( assert len(scheduled_coroutines) == 1 await scheduled_coroutines.pop(0) + @pytest.mark.asyncio + async def test_prepare_replay_runtime_uses_current_model_resolution( + self, + monkeypatch: pytest.MonkeyPatch, + ): + """Replay runtime should ignore the historical user-message model.""" + from flocks.server.routes import session as session_routes + + user_message = SimpleNamespace( + agent="rex", + model={"providerID": "anthropic", "modelID": "old-message-model"}, + ) + + monkeypatch.setattr( + "flocks.agent.registry.Agent.get", + AsyncMock(return_value=SimpleNamespace(name="rex", model=None)), + ) + monkeypatch.setattr( + session_routes, + "_resolve_model", + AsyncMock(return_value=("openai", "gpt-4.1", "session")), + ) + monkeypatch.setattr("flocks.provider.provider.Provider._ensure_initialized", lambda: None) + monkeypatch.setattr("flocks.config.config.Config.get", AsyncMock(return_value=SimpleNamespace())) + monkeypatch.setattr("flocks.provider.provider.Provider.apply_config", AsyncMock()) + monkeypatch.setattr("flocks.provider.provider.Provider.get", lambda _provider_id: object()) + + runtime = await session_routes._prepare_replay_runtime("ses_test", user_message) + + assert runtime == { + "agent_name": "rex", + "provider_id": "openai", + "model_id": "gpt-4.1", + } + + @pytest.mark.asyncio + async def test_resend_uses_current_model_for_replay( + self, + client: AsyncClient, + session_id: str, + monkeypatch: pytest.MonkeyPatch, + ): + """Resend should replay with the session's current model, not the original one.""" + from flocks.server.routes import session as session_routes + + create_resp = await client.post( + f"/api/session/{session_id}/message", + json={ + "parts": [{"type": "text", "text": "Original user text"}], + "noReply": True, + "mockReply": "Initial reply", + }, + ) + assert create_resp.status_code == status.HTTP_200_OK + + list_resp = await client.get(f"/api/session/{session_id}/message") + messages = list_resp.json() + user_message = next(msg for msg in messages if msg["info"]["role"] == "user") + user_part = next(part for part in user_message["parts"] if part["type"] == "text") + + scheduled_coroutines = [] + captured_runtime = {} + + async def _fake_instance_provide(*, directory, init, fn): + return await fn() + + async def _fake_run_existing_user_message( + session_id: str, + session, + user_message, + working_directory: str, + runtime=None, + ): + captured_runtime.update(runtime or {}) + return {"status": "completed", "sessionID": session_id, "messageID": user_message.id} + + monkeypatch.setattr("flocks.project.instance.Instance.provide", _fake_instance_provide) + monkeypatch.setattr( + session_routes, + "_resolve_model", + AsyncMock(return_value=("openai", "gpt-4.1", "session")), + ) + monkeypatch.setattr( + "flocks.agent.registry.Agent.get", + AsyncMock(return_value=SimpleNamespace(name="rex", model=None)), + ) + monkeypatch.setattr("flocks.provider.provider.Provider._ensure_initialized", lambda: None) + monkeypatch.setattr("flocks.config.config.Config.get", AsyncMock(return_value=SimpleNamespace())) + monkeypatch.setattr("flocks.provider.provider.Provider.apply_config", AsyncMock()) + monkeypatch.setattr("flocks.provider.provider.Provider.get", lambda _provider_id: object()) + monkeypatch.setattr( + "flocks.session.lifecycle.revert.SessionRevert.revert", + AsyncMock(return_value=None), + ) + monkeypatch.setattr(session_routes, "_run_existing_user_message", _fake_run_existing_user_message) + monkeypatch.setattr( + session_routes, + "_schedule_background_coro", + lambda coro, **kwargs: scheduled_coroutines.append(coro), + ) + + resend_resp = await client.post( + f"/api/session/{session_id}/message/{user_message['info']['id']}/resend", + json={"text": "Updated user text", "partID": user_part["id"]}, + ) + assert resend_resp.status_code == status.HTTP_202_ACCEPTED + assert len(scheduled_coroutines) == 1 + + await scheduled_coroutines.pop(0) + + assert captured_runtime["provider_id"] == "openai" + assert captured_runtime["model_id"] == "gpt-4.1" + + @pytest.mark.asyncio + async def test_regenerate_uses_current_model_for_replay( + self, + client: AsyncClient, + session_id: str, + monkeypatch: pytest.MonkeyPatch, + ): + """Regenerate should replay with the session's current model, not the original one.""" + from flocks.server.routes import session as session_routes + + create_resp = await client.post( + f"/api/session/{session_id}/message", + json={ + "parts": [{"type": "text", "text": "Question"}], + "noReply": True, + "mockReply": "Assistant answer", + }, + ) + assert create_resp.status_code == status.HTTP_200_OK + + list_resp = await client.get(f"/api/session/{session_id}/message") + messages = list_resp.json() + assistant_message = next(msg for msg in messages if msg["info"]["role"] == "assistant") + + scheduled_coroutines = [] + captured_runtime = {} + + async def _fake_instance_provide(*, directory, init, fn): + return await fn() + + async def _fake_run_existing_user_message( + session_id: str, + session, + user_message, + working_directory: str, + runtime=None, + ): + captured_runtime.update(runtime or {}) + return {"status": "completed", "sessionID": session_id, "messageID": user_message.id} + + monkeypatch.setattr("flocks.project.instance.Instance.provide", _fake_instance_provide) + monkeypatch.setattr( + session_routes, + "_resolve_model", + AsyncMock(return_value=("openai", "gpt-4.1", "session")), + ) + monkeypatch.setattr( + "flocks.agent.registry.Agent.get", + AsyncMock(return_value=SimpleNamespace(name="rex", model=None)), + ) + monkeypatch.setattr("flocks.provider.provider.Provider._ensure_initialized", lambda: None) + monkeypatch.setattr("flocks.config.config.Config.get", AsyncMock(return_value=SimpleNamespace())) + monkeypatch.setattr("flocks.provider.provider.Provider.apply_config", AsyncMock()) + monkeypatch.setattr("flocks.provider.provider.Provider.get", lambda _provider_id: object()) + monkeypatch.setattr( + "flocks.session.lifecycle.revert.SessionRevert.revert", + AsyncMock(return_value=None), + ) + monkeypatch.setattr(session_routes, "_run_existing_user_message", _fake_run_existing_user_message) + monkeypatch.setattr( + session_routes, + "_schedule_background_coro", + lambda coro, **kwargs: scheduled_coroutines.append(coro), + ) + + regenerate_resp = await client.post( + f"/api/session/{session_id}/message/{assistant_message['info']['id']}/regenerate", + ) + assert regenerate_resp.status_code == status.HTTP_202_ACCEPTED + assert len(scheduled_coroutines) == 1 + + await scheduled_coroutines.pop(0) + + assert captured_runtime["provider_id"] == "openai" + assert captured_runtime["model_id"] == "gpt-4.1" + # =========================================================================== # Utility endpoints diff --git a/tests/session/test_runner_chunk_handling.py b/tests/session/test_runner_chunk_handling.py index f6a793ea8..36083a8e2 100644 --- a/tests/session/test_runner_chunk_handling.py +++ b/tests/session/test_runner_chunk_handling.py @@ -32,6 +32,7 @@ class FakeStreamChunk: reasoning: Optional[str] = None tool_calls: Optional[List[Dict[str, Any]]] = None event_type: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None finish_reason: Optional[str] = None usage: Optional[Dict[str, int]] = None @@ -52,9 +53,9 @@ def __init__(self): async def process_event(self, ev): cls = type(ev).__name__ if cls == "ReasoningStartEvent": - self.events.append(_Event("reasoning_start", {"id": ev["id"]})) + self.events.append(_Event("reasoning_start", {"id": ev["id"], "metadata": ev.get("metadata")})) elif cls == "ReasoningDeltaEvent": - self.events.append(_Event("reasoning_delta", {"id": ev["id"], "text": ev["text"]})) + self.events.append(_Event("reasoning_delta", {"id": ev["id"], "text": ev["text"], "metadata": ev.get("metadata")})) elif cls == "ReasoningEndEvent": self.events.append(_Event("reasoning_end", {"id": ev["id"]})) elif cls == "TextStartEvent": @@ -72,14 +73,14 @@ def __getattr__(self, item): raise AttributeError(item) from exc -def ReasoningStartEvent(*, id): # noqa: N802 – mimic real event class name - e = _EventDict(id=id) +def ReasoningStartEvent(*, id, metadata=None): # noqa: N802 – mimic real event class name + e = _EventDict(id=id, metadata=metadata) e.__class__.__name__ = "ReasoningStartEvent" return e -def ReasoningDeltaEvent(*, id, text): # noqa: N802 - e = _EventDict(id=id, text=text) +def ReasoningDeltaEvent(*, id, text, metadata=None): # noqa: N802 + e = _EventDict(id=id, text=text, metadata=metadata) e.__class__.__name__ = "ReasoningDeltaEvent" return e @@ -125,28 +126,53 @@ async def consume_chunks(chunks, processor, tool_accumulator) -> Dict[str, int]: for chunk in chunks: event_type = getattr(chunk, "event_type", None) + chunk_metadata = getattr(chunk, "metadata", None) or {} + reasoning_event_types = {"reasoning", "reasoning-start", "reasoning-end"} + + if event_type == "reasoning-start" and state["reasoning_id"] is None: + reasoning_id_counter += 1 + state["reasoning_id"] = f"reasoning-{reasoning_id_counter}" + await processor.process_event( + ReasoningStartEvent(id=state["reasoning_id"], metadata=chunk_metadata) + ) + + if event_type == "reasoning-end" and state["reasoning_id"] is not None: + await processor.process_event( + ReasoningEndEvent(id=state["reasoning_id"]) + ) + state["reasoning_id"] = None chunk_reasoning = getattr(chunk, "reasoning", None) or None if not chunk_reasoning and event_type == "reasoning": chunk_reasoning = getattr(chunk, "delta", "") or None + has_reasoning_metadata = bool( + chunk_metadata.get("reasoningDetails") + or chunk_metadata.get("reasoningContent") is not None + or chunk_metadata.get("reasoningField") + ) chunk_text = "" - if event_type != "reasoning" or getattr(chunk, "reasoning", None): + if event_type not in reasoning_event_types or getattr(chunk, "reasoning", None): chunk_text = getattr(chunk, "delta", "") or "" chunk_tool_calls = getattr(chunk, "tool_calls", None) - if chunk_reasoning: + if chunk_reasoning or (event_type == "reasoning" and has_reasoning_metadata): chunk_counts["reasoning"] += 1 if state["reasoning_id"] is None: reasoning_id_counter += 1 state["reasoning_id"] = f"reasoning-{reasoning_id_counter}" await processor.process_event( - ReasoningStartEvent(id=state["reasoning_id"]) + ReasoningStartEvent(id=state["reasoning_id"], metadata=chunk_metadata) + ) + if chunk_reasoning: + await processor.process_event( + ReasoningDeltaEvent( + id=state["reasoning_id"], + text=chunk_reasoning, + metadata=chunk_metadata, + ) ) - await processor.process_event( - ReasoningDeltaEvent(id=state["reasoning_id"], text=chunk_reasoning) - ) if (chunk_text or chunk_tool_calls) and state["reasoning_id"] is not None: await processor.process_event( @@ -262,6 +288,54 @@ async def test_separate_chunks_close_block_then_open_text(self): # Then the text block opens cleanly. assert kinds[-2:] == ["text_start", "text_delta"] + @pytest.mark.asyncio + async def test_metadata_only_reasoning_chunk_still_opens_reasoning_block(self): + proc = _RecordingProcessor() + acc = _ToolAccumulator() + + chunks = [ + FakeStreamChunk( + event_type="reasoning", + metadata={ + "reasoningField": "reasoning_details", + "reasoningDetails": [{"type": "reasoning.summary", "text": "opaque"}], + }, + ), + FakeStreamChunk( + tool_calls=[{"id": "c1", "function": {"name": "search", "arguments": "{}"}}], + ), + ] + + counts = await consume_chunks(chunks, proc, acc) + + assert counts == {"reasoning": 1, "text": 0, "tool": 1} + assert proc.events[0].kind == "reasoning_start" + assert proc.events[0].payload["metadata"]["reasoningField"] == "reasoning_details" + assert proc.events[1].kind == "reasoning_end" + assert acc.fed[0]["id"] == "c1" + + @pytest.mark.asyncio + async def test_explicit_reasoning_start_and_end_events_are_respected(self): + proc = _RecordingProcessor() + acc = _ToolAccumulator() + + chunks = [ + FakeStreamChunk(event_type="reasoning-start", metadata={"reasoningField": "thinking"}), + FakeStreamChunk(event_type="reasoning", reasoning="step 1"), + FakeStreamChunk(event_type="reasoning-end", metadata={"thinkingSignature": "sig123"}), + FakeStreamChunk(tool_calls=[{"id": "c1", "function": {"name": "search", "arguments": "{}"}}]), + ] + + counts = await consume_chunks(chunks, proc, acc) + + assert counts == {"reasoning": 1, "text": 0, "tool": 1} + assert [e.kind for e in proc.events[:3]] == [ + "reasoning_start", + "reasoning_delta", + "reasoning_end", + ] + assert acc.fed[0]["id"] == "c1" + @pytest.mark.asyncio async def test_usage_only_chunk_does_not_close_reasoning(self): proc = _RecordingProcessor() diff --git a/tests/session/test_runner_step.py b/tests/session/test_runner_step.py index 28fa9146b..e7a72bf5b 100644 --- a/tests/session/test_runner_step.py +++ b/tests/session/test_runner_step.py @@ -31,6 +31,7 @@ ToolCall, ) from flocks.session.prompt import SessionPrompt +from flocks.session.core.defaults import DEFAULT_MAX_TOOL_STEPS from flocks.session.session import Session, SessionInfo from flocks.tool.registry import ToolCategory, ToolInfo @@ -106,6 +107,63 @@ def test_error_action(self): assert result.error == "LLM failed" +class TestToolLoopGuard: + def test_halts_after_three_exact_tool_only_steps(self): + runner = _make_runner("ses_runner_tool_loop_exact") + result = StepResult( + action="continue", + tool_calls=[ToolCall(id="c1", name="echo_tool", arguments={"text": "loop"})], + ) + + first = runner._update_tool_loop_guard(result, last_user_id="user-1") + second = runner._update_tool_loop_guard(result, last_user_id="user-1") + third = runner._update_tool_loop_guard(result, last_user_id="user-1") + + assert first["action"] == "allow" + assert second["action"] == "warn" + assert third["action"] == "halt" + assert third["reason"] == "repeated_exact_tool_call" + assert third["count"] == 3 + + def test_halts_after_same_tool_streak_with_varying_args(self): + runner = _make_runner("ses_runner_tool_loop_same_tool") + decision = None + + for idx in range(1, 9): + decision = runner._update_tool_loop_guard( + StepResult( + action="continue", + tool_calls=[ToolCall(id=f"c{idx}", name="echo_tool", arguments={"text": f"loop-{idx}"})], + ), + last_user_id="user-1", + ) + + assert decision is not None + assert decision["action"] == "halt" + assert decision["reason"] == "same_tool_streak" + assert decision["count"] == 8 + + def test_resets_after_text_response(self): + runner = _make_runner("ses_runner_tool_loop_reset") + tool_only = StepResult( + action="continue", + tool_calls=[ToolCall(id="c1", name="echo_tool", arguments={"text": "loop"})], + ) + + runner._update_tool_loop_guard(tool_only, last_user_id="user-1") + warned = runner._update_tool_loop_guard(tool_only, last_user_id="user-1") + reset = runner._update_tool_loop_guard( + StepResult(action="stop", content="done"), + last_user_id="user-1", + ) + restarted = runner._update_tool_loop_guard(tool_only, last_user_id="user-1") + + assert warned["action"] == "warn" + assert reset["action"] == "allow" + assert restarted["action"] == "allow" + assert runner._get_tool_loop_guard_state(last_user_id="user-1")["exact_count"] == 1 + + # --------------------------------------------------------------------------- # RunnerCallbacks dataclass # --------------------------------------------------------------------------- @@ -476,7 +534,7 @@ async def test_build_tools_uses_selector_results_and_emits_event(self): assert event_callback.await_args.args[1]["enabledToolCount"] == 3 @pytest.mark.asyncio - async def test_build_tools_keeps_registered_skill_description(self): + async def test_build_tools_refreshes_skill_description_from_enabled_skills(self): runner = _make_runner() agent = _make_agent(name="rex") skill_tool = ToolInfo( @@ -491,11 +549,17 @@ async def test_build_tools_keeps_registered_skill_description(self): SessionRunner, "_list_callable_tool_infos_for_turn", AsyncMock(return_value=([skill_tool], {"enabledToolCount": 3})), + ), patch( + "flocks.skill.skill.Skill.list_enabled", + AsyncMock(return_value=[SimpleNamespace(name="agent-builder")]), + ), patch( + "flocks.tool.system.skill.build_description", + return_value="Refreshed skill description", ): tools = await runner._build_callable_tool_schema(agent, []) assert tools[0]["function"]["name"] == "skill" - assert tools[0]["function"]["description"] == "Original skill description" + assert tools[0]["function"]["description"] == "Refreshed skill description" class TestBuildSystemPrompts: @@ -885,7 +949,7 @@ def test_build_tool_catalog_prompt_for_rex(self): "flocks.session.runner.get_always_load_tool_names", return_value={"question", "tool_search"}, ), patch( - "flocks.tool.system.slash_command.format_tools_catalog_summary", + "flocks.command.direct.format_tools_catalog_summary", return_value="Available Tools (grouped by category):\n\n**custom**\n- plugin_memory: Access project memory", ): prompt = runner._build_tool_catalog_prompt(agent) @@ -926,7 +990,7 @@ def test_build_tool_catalog_prompt_for_rex_excludes_builtin_and_always_load_tool "flocks.session.runner.get_always_load_tool_names", return_value={"question", "tool_search"}, ), patch( - "flocks.tool.system.slash_command.format_tools_catalog_summary", + "flocks.command.direct.format_tools_catalog_summary", side_effect=lambda tools, **_: "\n".join(tool.name for tool in tools), ) as formatter_mock: prompt = runner._build_tool_catalog_prompt(agent) @@ -1271,6 +1335,215 @@ async def test_to_chat_messages_preserves_assistant_reasoning_for_replay(): assert chat_messages[1].tool_call_id == "call_reasoning_replay" +@pytest.mark.asyncio +async def test_to_chat_messages_restores_provider_reasoning_fields_from_metadata(monkeypatch): + session = await Session.create( + project_id="test_runner_reasoning_metadata_replay", + directory="/tmp/runner-reasoning-metadata", + ) + assistant_message = await Message.create( + session_id=session.id, + role=MessageRole.ASSISTANT, + content="", + ) + runner = SessionRunner(session=session, static_cache={}) + runner.provider_id = "alibaba" + runner.model_id = "qwen3-max" + + monkeypatch.setattr( + runner_mod.Provider, + "get_model", + lambda _model_id: SimpleNamespace( + capabilities=SimpleNamespace( + interleaved={ + "field": "reasoning_content", + "echo": "tool_calls", + "cross_provider_policy": "promote", + } + ) + ), + ) + + await Message.add_part( + session.id, + assistant_message.id, + ReasoningPart( + sessionID=session.id, + messageID=assistant_message.id, + text="Need to call the tool first.", + metadata={ + "reasoningContent": "Need to call the tool first.", + "reasoningSource": "native_reasoning_content", + }, + time=PartTime(start=1), + ), + ) + await Message.add_part( + session.id, + assistant_message.id, + ToolPart( + sessionID=session.id, + messageID=assistant_message.id, + callID="call_reasoning_metadata", + tool="task", + state=ToolStateRunning( + input={"prompt": "continue"}, + time={"start": 1}, + ), + ), + ) + + chat_messages = await runner._to_chat_messages([assistant_message], []) + + assert len(chat_messages) == 2 + assert chat_messages[0].reasoning == "Need to call the tool first." + assert chat_messages[0].reasoning_content == "Need to call the tool first." + assert chat_messages[0].reasoning_source == "native_reasoning_content" + assert chat_messages[0].tool_calls[0]["function"]["name"] == "task" + + +@pytest.mark.asyncio +async def test_to_chat_messages_restores_redacted_anthropic_thinking_blocks(monkeypatch): + session = await Session.create( + project_id="test_runner_redacted_thinking", + directory="/tmp/runner-redacted-thinking", + ) + assistant_message = await Message.create( + session_id=session.id, + role=MessageRole.ASSISTANT, + content="", + ) + runner = SessionRunner(session=session, static_cache={}) + runner.provider_id = "anthropic" + runner.model_id = "claude-sonnet-4-6" + + monkeypatch.setattr( + runner_mod.Provider, + "get_model", + lambda _model_id: SimpleNamespace( + capabilities=SimpleNamespace(interleaved=None) + ), + ) + + await Message.add_part( + session.id, + assistant_message.id, + ReasoningPart( + sessionID=session.id, + messageID=assistant_message.id, + text="", + metadata={ + "reasoningField": "thinking", + "reasoningSource": "anthropic_redacted_thinking", + "redactedThinkingData": "opaque_blob", + }, + time=PartTime(start=1), + ), + ) + await Message.add_part( + session.id, + assistant_message.id, + ToolPart( + sessionID=session.id, + messageID=assistant_message.id, + callID="call_redacted_reasoning", + tool="task", + state=ToolStateRunning( + input={"prompt": "continue"}, + time={"start": 1}, + ), + ), + ) + + chat_messages = await runner._to_chat_messages([assistant_message], []) + + assert len(chat_messages) == 2 + assert chat_messages[0].custom_settings["anthropic_thinking_blocks"] == [ + {"type": "redacted_thinking", "data": "opaque_blob"} + ] + assert chat_messages[0].tool_calls[0]["function"]["name"] == "task" + + +@pytest.mark.asyncio +async def test_to_chat_messages_prefers_provider_specific_interleaved_resolution(monkeypatch): + session = await Session.create( + project_id="test_runner_provider_specific_interleaved", + directory="/tmp/runner-provider-interleaved", + ) + assistant_message = await Message.create( + session_id=session.id, + role=MessageRole.ASSISTANT, + content="", + ) + runner = SessionRunner(session=session, static_cache={}) + runner.provider_id = "deepseek" + runner.model_id = "shared-model" + + monkeypatch.setattr( + runner_mod.Provider, + "resolve_model", + lambda provider_id, model_id: ( + SimpleNamespace( + capabilities=SimpleNamespace( + interleaved={ + "field": "reasoning_content", + "echo": "tool_calls", + "placeholder": " ", + "cross_provider_policy": "placeholder", + } + ) + ) + if provider_id == "deepseek" and model_id == "shared-model" + else None + ), + ) + monkeypatch.setattr( + runner_mod.Provider, + "get_model", + lambda _model_id: SimpleNamespace( + capabilities=SimpleNamespace( + interleaved={ + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote", + } + ) + ), + ) + + await Message.add_part( + session.id, + assistant_message.id, + ReasoningPart( + sessionID=session.id, + messageID=assistant_message.id, + text="Prior provider chain of thought", + time=PartTime(start=1), + ), + ) + await Message.add_part( + session.id, + assistant_message.id, + ToolPart( + sessionID=session.id, + messageID=assistant_message.id, + callID="call_provider_specific_interleaved", + tool="task", + state=ToolStateRunning( + input={"prompt": "continue"}, + time={"start": 1}, + ), + ), + ) + + chat_messages = await runner._to_chat_messages([assistant_message], []) + + assert len(chat_messages) == 2 + assert chat_messages[0].reasoning_content == " " + assert chat_messages[0].reasoning_details is None + assert chat_messages[0].reasoning_source == "placeholder" + + @pytest.mark.asyncio async def test_process_step_creates_assistant_message_with_provider_and_model(monkeypatch): runner = _make_runner("ses_runner_provider_model") @@ -1550,6 +1823,172 @@ async def test_process_step_empty_retry_records_usage_per_attempt(monkeypatch): record_mock.assert_any_await(second_usage, message_id=assistant_msg.id) +@pytest.mark.asyncio +async def test_process_step_uses_default_max_steps_when_agent_steps_missing(monkeypatch): + runner = _make_runner("ses_runner_default_max_steps") + runner.callbacks = RunnerCallbacks(on_error=AsyncMock()) + runner._step = DEFAULT_MAX_TOOL_STEPS + + last_user = UserMessageInfo( + id="msg_user_default_max_steps", + sessionID=runner.session.id, + role="user", + time={"created": 1_000}, + agent="rex", + model={"providerID": "anthropic", "modelID": "claude-sonnet"}, + ) + + agent = SimpleNamespace(name="rex", steps=None, mode="primary", prompt="", tools=["read"]) + provider = MagicMock() + provider.is_configured.return_value = True + assistant_msg = SimpleNamespace(id="msg_assistant_default_max_steps") + sentinel_tools = [{"type": "function", "function": {"name": "read", "description": "", "parameters": {}}}] + captured = {} + + monkeypatch.setattr(runner_mod.Agent, "get", AsyncMock(return_value=agent)) + monkeypatch.setattr(runner_mod.Provider, "get", lambda provider_id: provider) + monkeypatch.setattr(runner_mod.Provider, "apply_config", AsyncMock(return_value=None)) + monkeypatch.setattr(runner_mod.SessionPrompt, "build_system_prompts", AsyncMock(return_value=[])) + monkeypatch.setattr(runner, "_build_callable_tool_schema", AsyncMock(return_value=sentinel_tools)) + monkeypatch.setattr( + runner, + "_to_chat_messages", + AsyncMock(return_value=[SimpleNamespace(role="user", content="hi")]), + ) + monkeypatch.setattr(runner_mod.Message, "get_text_content", AsyncMock(return_value="hi")) + monkeypatch.setattr(runner_mod.Message, "parts", AsyncMock(return_value=[])) + monkeypatch.setattr(runner_mod.Message, "create", AsyncMock(return_value=assistant_msg)) + monkeypatch.setattr(runner_mod.Message, "update", AsyncMock(return_value=None)) + + async def fake_call_llm(self, provider, messages, tools, agent, assistant_msg): # noqa: ANN001 + captured["tools"] = tools + return StepResult(action="stop", content="done") + + monkeypatch.setattr(SessionRunner, "_call_llm", fake_call_llm) + + result = await runner._process_step([last_user], last_user) + + assert result.action == "stop" + assert captured["tools"] == [] + + +@pytest.mark.asyncio +async def test_process_step_respects_explicit_agent_steps_over_default(monkeypatch): + runner = _make_runner("ses_runner_explicit_max_steps") + runner.callbacks = RunnerCallbacks(on_error=AsyncMock()) + runner._step = DEFAULT_MAX_TOOL_STEPS + + last_user = UserMessageInfo( + id="msg_user_explicit_max_steps", + sessionID=runner.session.id, + role="user", + time={"created": 1_000}, + agent="rex", + model={"providerID": "anthropic", "modelID": "claude-sonnet"}, + ) + + agent = SimpleNamespace(name="rex", steps=DEFAULT_MAX_TOOL_STEPS + 1, mode="primary", prompt="", tools=["read"]) + provider = MagicMock() + provider.is_configured.return_value = True + assistant_msg = SimpleNamespace(id="msg_assistant_explicit_max_steps") + sentinel_tools = [{"type": "function", "function": {"name": "read", "description": "", "parameters": {}}}] + captured = {} + + monkeypatch.setattr(runner_mod.Agent, "get", AsyncMock(return_value=agent)) + monkeypatch.setattr(runner_mod.Provider, "get", lambda provider_id: provider) + monkeypatch.setattr(runner_mod.Provider, "apply_config", AsyncMock(return_value=None)) + monkeypatch.setattr(runner_mod.SessionPrompt, "build_system_prompts", AsyncMock(return_value=[])) + monkeypatch.setattr(runner, "_build_callable_tool_schema", AsyncMock(return_value=sentinel_tools)) + monkeypatch.setattr( + runner, + "_to_chat_messages", + AsyncMock(return_value=[SimpleNamespace(role="user", content="hi")]), + ) + monkeypatch.setattr(runner_mod.Message, "get_text_content", AsyncMock(return_value="hi")) + monkeypatch.setattr(runner_mod.Message, "parts", AsyncMock(return_value=[])) + monkeypatch.setattr(runner_mod.Message, "create", AsyncMock(return_value=assistant_msg)) + monkeypatch.setattr(runner_mod.Message, "update", AsyncMock(return_value=None)) + + async def fake_call_llm(self, provider, messages, tools, agent, assistant_msg): # noqa: ANN001 + captured["tools"] = tools + return StepResult(action="stop", content="done") + + monkeypatch.setattr(SessionRunner, "_call_llm", fake_call_llm) + + result = await runner._process_step([last_user], last_user) + + assert result.action == "stop" + assert captured["tools"] == sentinel_tools + + +@pytest.mark.asyncio +async def test_process_step_halts_after_third_exact_tool_only_turn(monkeypatch): + shared_cache = {} + provider = MagicMock() + provider.is_configured.return_value = True + update_mock = AsyncMock(return_value=None) + create_mock = AsyncMock( + side_effect=[ + SimpleNamespace(id="msg_assistant_tool_loop_1"), + SimpleNamespace(id="msg_assistant_tool_loop_2"), + SimpleNamespace(id="msg_assistant_tool_loop_3"), + ] + ) + + async def fake_call_llm(self, provider, messages, tools, agent, assistant_msg): # noqa: ANN001 + del provider, messages, tools, agent, assistant_msg + return StepResult( + action="continue", + tool_calls=[ToolCall(id="c-loop", name="echo_tool", arguments={"text": "loop"})], + ) + + monkeypatch.setattr(runner_mod.Agent, "get", AsyncMock(return_value=SimpleNamespace( + name="rex", + steps=None, + mode="primary", + prompt="", + tools=["echo_tool"], + ))) + monkeypatch.setattr(runner_mod.Provider, "get", lambda provider_id: provider) + monkeypatch.setattr(runner_mod.Provider, "apply_config", AsyncMock(return_value=None)) + monkeypatch.setattr(runner_mod.SessionPrompt, "build_system_prompts", AsyncMock(return_value=[])) + monkeypatch.setattr(runner_mod.Message, "get_text_content", AsyncMock(return_value="hi")) + monkeypatch.setattr(runner_mod.Message, "parts", AsyncMock(return_value=[])) + monkeypatch.setattr(runner_mod.Message, "create", create_mock) + monkeypatch.setattr(runner_mod.Message, "update", update_mock) + monkeypatch.setattr(SessionRunner, "_call_llm", fake_call_llm) + + last_user = UserMessageInfo( + id="msg_user_tool_loop_guard", + sessionID="ses_runner_tool_loop_guard", + role="user", + time={"created": 1_000}, + agent="rex", + model={"providerID": "anthropic", "modelID": "claude-sonnet"}, + ) + + for idx in range(1, 4): + runner = SessionRunner(session=_make_session("ses_runner_tool_loop_guard"), static_cache=shared_cache) + runner.callbacks = RunnerCallbacks(on_error=AsyncMock()) + monkeypatch.setattr(runner, "_build_callable_tool_schema", AsyncMock(return_value=[ + {"type": "function", "function": {"name": "echo_tool", "description": "", "parameters": {}}} + ])) + monkeypatch.setattr( + runner, + "_to_chat_messages", + AsyncMock(return_value=[SimpleNamespace(role="user", content="hi")]), + ) + result = await runner._process_step([last_user], last_user) + if idx < 3: + assert result.action == "continue" + else: + assert result.action == "stop" + assert "Stopped the loop because `echo_tool` was called 3 times in a row" in result.content + + assert update_mock.await_args_list[-2].kwargs["content"].startswith("Stopped the loop because `echo_tool`") + assert update_mock.await_args_list[-1].kwargs["finish"] == "stop" + + @pytest.mark.asyncio async def test_record_usage_if_available_swallows_import_error(): """ImportError from the usage service import must not propagate out of