diff --git a/flocks/hooks/pipeline.py b/flocks/hooks/pipeline.py index 14bfeb09f..4faf4294a 100644 --- a/flocks/hooks/pipeline.py +++ b/flocks/hooks/pipeline.py @@ -4,6 +4,8 @@ Provides a lightweight hook registry and execution pipeline that mirrors oh-my-opencode's lifecycle stages: - chat.message +- llm.call.before +- llm.call.after - tool.execute.before - tool.execute.after - event @@ -20,6 +22,8 @@ class HookStage: CHAT_MESSAGE = "chat.message" + LLM_BEFORE = "llm.call.before" + LLM_AFTER = "llm.call.after" TOOL_BEFORE = "tool.execute.before" TOOL_AFTER = "tool.execute.after" EVENT = "event" @@ -39,6 +43,12 @@ class HookBase: async def chat_message(self, ctx: HookContext) -> None: # pragma: no cover - default no-op return None + async def llm_before(self, ctx: HookContext) -> None: # pragma: no cover - default no-op + return None + + async def llm_after(self, ctx: HookContext) -> None: # pragma: no cover - default no-op + return None + async def tool_before(self, ctx: HookContext) -> None: # pragma: no cover - default no-op return None @@ -98,6 +108,22 @@ async def run_chat_message( ) -> HookContext: return await cls._run_stage(HookStage.CHAT_MESSAGE, input_data, output_data) + @classmethod + async def run_llm_before( + cls, + input_data: Dict[str, Any], + output_data: Optional[Dict[str, Any]] = None, + ) -> HookContext: + return await cls._run_stage(HookStage.LLM_BEFORE, input_data, output_data) + + @classmethod + async def run_llm_after( + cls, + input_data: Dict[str, Any], + output_data: Optional[Dict[str, Any]] = None, + ) -> HookContext: + return await cls._run_stage(HookStage.LLM_AFTER, input_data, output_data) + @classmethod async def run_tool_before( cls, @@ -174,6 +200,10 @@ async def _run_stage( def _resolve_handler(hook: HookBase, stage: str) -> Optional[Callable[[HookContext], Awaitable[None]]]: if stage == HookStage.CHAT_MESSAGE: return getattr(hook, "chat_message", None) + if stage == HookStage.LLM_BEFORE: + return getattr(hook, "llm_before", None) + if stage == HookStage.LLM_AFTER: + return getattr(hook, "llm_after", None) if stage == HookStage.TOOL_BEFORE: return getattr(hook, "tool_before", None) if stage == HookStage.TOOL_AFTER: diff --git a/flocks/session/runner.py b/flocks/session/runner.py index f5731ab00..9f22b9a3f 100644 --- a/flocks/session/runner.py +++ b/flocks/session/runner.py @@ -13,6 +13,7 @@ import json import os import sys +import time from datetime import datetime from typing import Optional, Dict, Any, List, Callable, Awaitable, Set, Tuple from dataclasses import dataclass, field @@ -41,6 +42,7 @@ from flocks.agent.agent import AgentInfo from flocks.agent.toolset import agent_declares_tool from flocks.provider.provider import Provider, ChatMessage +from flocks.hooks.pipeline import HookPipeline from flocks.tool.catalog import get_tool_catalog_metadata, list_tool_catalog_infos from flocks.tool.registry import ToolRegistry, ToolResult from flocks.utils.langfuse import generation_scope, trace_scope @@ -1796,6 +1798,64 @@ async def _call_llm( Uses StreamProcessor to handle events and execute tools synchronously. Ported from Flocks' SessionProcessor.process() behavior. """ + def _summarize_content(content: Any) -> Dict[str, Any]: + if isinstance(content, str): + return { + "type": "text", + "length": len(content), + "preview": content[:500], + } + if isinstance(content, list): + part_summaries = [] + for part in content[:5]: + if isinstance(part, dict): + part_summary = {"type": part.get("type", "object")} + text_value = part.get("text") + if isinstance(text_value, str): + part_summary["textLength"] = len(text_value) + part_summary["textPreview"] = text_value[:160] + mime_type = part.get("mimeType") + if mime_type: + part_summary["mimeType"] = mime_type + part_summaries.append(part_summary) + else: + part_summaries.append({"type": type(part).__name__}) + return { + "type": "parts", + "partCount": len(content), + "parts": part_summaries, + } + return { + "type": type(content).__name__, + "preview": str(content)[:500], + } + + def _summarize_message(message: ChatMessage) -> Dict[str, Any]: + tool_calls = getattr(message, "tool_calls", None) or [] + tool_call_names = [] + for tool_call in tool_calls[:20]: + if isinstance(tool_call, dict): + tool_call_names.append(tool_call.get("function", {}).get("name", "")) + reasoning = getattr(message, "reasoning", None) or "" + return { + "role": message.role, + "content": _summarize_content(message.content), + "toolCallCount": len(tool_calls), + "toolCallNames": tool_call_names, + "reasoningLength": len(reasoning), + } + + def _summarize_tools(tool_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + summaries = [] + for tool in tool_list[:50]: + function_meta = tool.get("function", {}) if isinstance(tool, dict) else {} + summaries.append({ + "name": function_meta.get("name", ""), + "description": (function_meta.get("description", "") or "")[:240], + "hasParameters": bool(function_meta.get("parameters")), + }) + return summaries + # Create stream processor main_session_key = self.session.id config_data: Dict[str, Any] = {} @@ -1932,101 +1992,144 @@ async def _call_llm( "tool_count": len(tools), }) - async for chunk in _iter_with_chunk_timeout( - provider.chat_stream( - model_id=self.model_id, - messages=messages, - tools=provider_tools, - # session_id is forwarded via kwargs so providers that need - # to look up persisted session data (e.g. Gemini's DB-backed - # reasoning replay) can do so. Providers that don't care - # simply ignore unknown kwargs. - session_id=self.session.id, - **provider_options, - ), - first_chunk_timeout_s=LLM_STREAM_FIRST_CHUNK_TIMEOUT_S, - ongoing_chunk_timeout_s=LLM_STREAM_ONGOING_CHUNK_TIMEOUT_S, - ): - chunk_counts["total"] += 1 - - chunk_finish = getattr(chunk, 'finish_reason', None) - if chunk_finish: - stream_finish_reason = chunk_finish - - # Capture usage from chunk (providers may include it in the final chunk) - if hasattr(chunk, 'usage') and chunk.usage: - stream_usage = chunk.usage - - # Check for abort - if self.is_aborted: - break - - # Determine event type from chunk. A single chunk may carry any - # combination of reasoning / text / tool_calls (e.g. Gemini bundles - # them). We must not drop non-reasoning content when reasoning is - # present, and we must not double-emit `delta` as text when the - # provider used `event_type == 'reasoning'` to overload `delta` for - # reasoning text. - event_type = getattr(chunk, 'event_type', None) - - chunk_reasoning = getattr(chunk, 'reasoning', None) or None - if not chunk_reasoning and event_type == 'reasoning': - # Older providers signal reasoning via event_type and put the - # reasoning text in `delta` (no separate `reasoning` field). - chunk_reasoning = getattr(chunk, 'delta', '') or None - - # Treat `delta` as text only when it isn't already consumed as - # reasoning above. This preserves backward compatibility with - # providers that emit reasoning-only chunks via `event_type`. - chunk_text = '' - if event_type != 'reasoning' or getattr(chunk, 'reasoning', None): - chunk_text = getattr(chunk, 'delta', '') or '' - - chunk_tool_calls = getattr(chunk, 'tool_calls', None) - - # 1) Process reasoning delta (start reasoning block on first sight). - if chunk_reasoning: - chunk_counts["reasoning"] += 1 - log.debug("runner.reasoning.received", { - "length": len(chunk_reasoning), - "text_preview": chunk_reasoning[:50], - }) - if not hasattr(self, '_current_reasoning_id'): - reasoning_id_counter += 1 - self._current_reasoning_id = f"reasoning-{reasoning_id_counter}" - await processor.process_event(ReasoningStartEvent( + llm_hook_input = { + "sessionID": self.session.id, + "messageID": assistant_msg.id, + "workspace": self.session.directory, + "agent": agent.name, + "step": self._step, + "model": { + "providerID": self.provider_id, + "modelID": self.model_id, + }, + "request": { + "messageCount": len(messages), + "messages": [_summarize_message(message) for message in messages], + "toolCount": len(tools), + "tools": _summarize_tools(tools), + "providerOptions": dict(provider_options), + "providerToolsEnabled": provider_tools is not None, + }, + } + try: + await HookPipeline.run_llm_before(llm_hook_input) + except Exception as exc: + log.debug("runner.hook.llm_before.error", {"error": str(exc)}) + + llm_call_started_at = time.perf_counter() + try: + async for chunk in _iter_with_chunk_timeout( + provider.chat_stream( + model_id=self.model_id, + messages=messages, + tools=provider_tools, + # session_id is forwarded via kwargs so providers that need + # to look up persisted session data (e.g. Gemini's DB-backed + # reasoning replay) can do so. Providers that don't care + # simply ignore unknown kwargs. + session_id=self.session.id, + **provider_options, + ), + first_chunk_timeout_s=LLM_STREAM_FIRST_CHUNK_TIMEOUT_S, + ongoing_chunk_timeout_s=LLM_STREAM_ONGOING_CHUNK_TIMEOUT_S, + ): + chunk_counts["total"] += 1 + + chunk_finish = getattr(chunk, 'finish_reason', None) + if chunk_finish: + stream_finish_reason = chunk_finish + + # Capture usage from chunk (providers may include it in the final chunk) + if hasattr(chunk, 'usage') and chunk.usage: + stream_usage = chunk.usage + + # Check for abort + if self.is_aborted: + break + + # Determine event type from chunk. A single chunk may carry any + # combination of reasoning / text / tool_calls (e.g. Gemini bundles + # them). We must not drop non-reasoning content when reasoning is + # present, and we must not double-emit `delta` as text when the + # provider used `event_type == 'reasoning'` to overload `delta` for + # reasoning text. + event_type = getattr(chunk, 'event_type', None) + + chunk_reasoning = getattr(chunk, 'reasoning', None) or None + if not chunk_reasoning and event_type == 'reasoning': + # Older providers signal reasoning via event_type and put the + # reasoning text in `delta` (no separate `reasoning` field). + chunk_reasoning = getattr(chunk, 'delta', '') or None + + # Treat `delta` as text only when it isn't already consumed as + # reasoning above. This preserves backward compatibility with + # providers that emit reasoning-only chunks via `event_type`. + chunk_text = '' + if event_type != 'reasoning' or getattr(chunk, 'reasoning', None): + chunk_text = getattr(chunk, 'delta', '') or '' + + chunk_tool_calls = getattr(chunk, 'tool_calls', None) + + # 1) Process reasoning delta (start reasoning block on first sight). + if chunk_reasoning: + chunk_counts["reasoning"] += 1 + log.debug("runner.reasoning.received", { + "length": len(chunk_reasoning), + "text_preview": chunk_reasoning[:50], + }) + if not hasattr(self, '_current_reasoning_id'): + reasoning_id_counter += 1 + self._current_reasoning_id = f"reasoning-{reasoning_id_counter}" + await processor.process_event(ReasoningStartEvent( + id=self._current_reasoning_id + )) + + await processor.process_event(ReasoningDeltaEvent( + id=self._current_reasoning_id, + text=chunk_reasoning, + )) + + # 2) End reasoning block when this chunk also carries non-reasoning + # content (or once the stream moves away from reasoning). + if (chunk_text or chunk_tool_calls) and hasattr(self, '_current_reasoning_id'): + await processor.process_event(ReasoningEndEvent( id=self._current_reasoning_id )) + delattr(self, '_current_reasoning_id') - await processor.process_event(ReasoningDeltaEvent( - id=self._current_reasoning_id, - text=chunk_reasoning, - )) - - # 2) End reasoning block when this chunk also carries non-reasoning - # content (or once the stream moves away from reasoning). - if (chunk_text or chunk_tool_calls) and hasattr(self, '_current_reasoning_id'): - await processor.process_event(ReasoningEndEvent( - id=self._current_reasoning_id - )) - delattr(self, '_current_reasoning_id') - - # 3) Process text delta. - if chunk_text: - chunk_counts["text"] += 1 - if not text_started: - await processor.process_event(TextStartEvent()) - text_started = True - - await processor.process_event(TextDeltaEvent( - text=chunk_text, - )) - - # 4) Process tool calls. - if chunk_tool_calls: - chunk_counts["tool"] += 1 - for tc in chunk_tool_calls: - await tool_accumulator.feed_chunk(tc) + # 3) Process text delta. + if chunk_text: + chunk_counts["text"] += 1 + if not text_started: + await processor.process_event(TextStartEvent()) + text_started = True + + await processor.process_event(TextDeltaEvent( + text=chunk_text, + )) + + # 4) Process tool calls. + if chunk_tool_calls: + chunk_counts["tool"] += 1 + for tc in chunk_tool_calls: + await tool_accumulator.feed_chunk(tc) + except Exception as exc: + try: + await HookPipeline.run_llm_after( + llm_hook_input, + { + "durationMs": int((time.perf_counter() - llm_call_started_at) * 1000), + "error": { + "type": type(exc).__name__, + "message": str(exc), + }, + "usage": stream_usage, + "chunkCounts": dict(chunk_counts), + }, + ) + except Exception as hook_exc: + log.debug("runner.hook.llm_after.error", {"error": str(hook_exc)}) + raise log.info("runner.stream.summary", { "total_chunks": chunk_counts["total"], @@ -2103,6 +2206,27 @@ async def _call_llm( ) for tc_state in processor.tool_calls.values() ] + result_action = "continue" if tool_calls_for_result else "stop" + try: + await HookPipeline.run_llm_after( + llm_hook_input, + { + "durationMs": int((time.perf_counter() - llm_call_started_at) * 1000), + "finishReason": processor.get_finish_reason(), + "contentLength": len(content), + "reasoningLength": len(reasoning), + "toolCallCount": len(tool_calls_for_result), + "toolCalls": [ + {"id": tool_call.id, "name": tool_call.name} + for tool_call in tool_calls_for_result[:30] + ], + "usage": stream_usage, + "chunkCounts": dict(chunk_counts), + "action": result_action, + }, + ) + except Exception as exc: + log.debug("runner.hook.llm_after.error", {"error": str(exc)}) if tool_calls_for_result: self._end_observability( @@ -2127,7 +2251,7 @@ async def _call_llm( }, ) return StepResult( - action="continue", + action=result_action, content=content, tool_calls=tool_calls_for_result, usage=stream_usage, @@ -2152,7 +2276,7 @@ async def _call_llm( "finish_reason": processor.get_finish_reason(), }, ) - return StepResult(action="stop", content=content, usage=stream_usage) + return StepResult(action=result_action, content=content, usage=stream_usage) @staticmethod def _end_observability( diff --git a/flocks/tool/catalog.py b/flocks/tool/catalog.py index 98ed38b26..50388bf20 100644 --- a/flocks/tool/catalog.py +++ b/flocks/tool/catalog.py @@ -7,6 +7,7 @@ from __future__ import annotations +import re from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from pydantic import BaseModel, Field @@ -112,11 +113,49 @@ def list_tool_catalog_infos(tool_names: Optional[Iterable[str]] = None) -> List[ return result +def canonical_tool_token(value: str) -> str: + """Normalize tool names and aliases for exact selection/search matching.""" + canonical = "".join(ch.lower() for ch in value if ch.isalnum()) + if canonical.endswith("tool"): + canonical = canonical[:-4] + return canonical + + +def normalize_tool_search_query(query: str) -> str: + query = query.strip() + if query.lower().startswith("select:"): + query = query[len("select:"):] + return " ".join( + canonical_tool_token(term) + for term in re.split(r"[\s,]+", query) + if term + ) + + +def _format_tool_catalog_match(tool_info: Any, matched_tags: List[str], score: int) -> Dict[str, Any]: + metadata = get_tool_catalog_metadata(tool_info.name, tool_info) + return { + "name": tool_info.name, + "description": tool_info.description, + "category": getattr(tool_info.category, "value", str(tool_info.category)), + "requires_confirmation": getattr(tool_info, "requires_confirmation", False), + "source": getattr(tool_info, "source", None), + "native": getattr(tool_info, "native", False), + "always_load": metadata.always_load, + "tags": metadata.tags, + "matchedTags": matched_tags, + "score": score, + } + + def _score_tool_catalog_match(query: str, category: Optional[str], tool_info: Any) -> Tuple[int, List[str]]: q = (query or "").strip().lower() - tokens = [token for token in q.split() if token] + tokens = [token for token in re.split(r"[\s,]+", q) if token] + canonical_tokens = [canonical_tool_token(token) for token in tokens] name = tool_info.name.lower() + canonical_name = canonical_tool_token(tool_info.name) desc = (tool_info.description or "").lower() + normalized_desc = normalize_tool_search_query(tool_info.description or "") source = (getattr(tool_info, "source", None) or "").lower() tool_category = getattr(tool_info.category, "value", str(tool_info.category)).lower() metadata = get_tool_catalog_metadata(tool_info.name, tool_info) @@ -128,12 +167,18 @@ def _score_tool_catalog_match(query: str, category: Optional[str], tool_info: An score += 10 if q and q in name: score += 120 + if canonical_tokens and any(token == canonical_name for token in canonical_tokens): + score += 140 if q and any(token in name for token in tokens): score += 55 + if canonical_tokens and any(token and token in canonical_name for token in canonical_tokens): + score += 65 if q and q in desc: score += 40 if q and any(token in desc for token in tokens): score += 20 + if canonical_tokens and any(token and token in normalized_desc for token in canonical_tokens): + score += 20 if q and q in source: score += 10 if q and any(token in tag for token in tokens for tag in tags): @@ -152,6 +197,45 @@ def _score_tool_catalog_match(query: str, category: Optional[str], tool_info: An return score, matched_tags +def _select_tool_catalog( + query: str, + *, + category: Optional[str], + limit: int, +) -> Optional[Tuple[List[Dict[str, Any]], List[str]]]: + lowered = (query or "").strip().lower() + if not lowered.startswith("select:"): + return None + + wanted = [ + canonical_tool_token(part) + for part in lowered[len("select:"):].split(",") + if part.strip() + ] + if not wanted: + return [], [] + + tools_by_canonical = { + canonical_tool_token(tool_info.name): tool_info + for tool_info in list_tool_catalog_infos() + if not category + or getattr(tool_info.category, "value", str(tool_info.category)).lower() == category.lower() + } + + matches: List[Dict[str, Any]] = [] + seen: Set[str] = set() + for canonical_name in wanted: + tool_info = tools_by_canonical.get(canonical_name) + if tool_info is None or tool_info.name in seen: + continue + seen.add(tool_info.name) + matches.append(_format_tool_catalog_match(tool_info, [], 10_000)) + if len(matches) >= limit: + break + + return matches, [] + + def search_tool_catalog( query: Optional[str] = None, *, @@ -159,6 +243,10 @@ def search_tool_catalog( limit: int = 8, ) -> Tuple[List[Dict[str, Any]], List[str]]: limit = max(1, min(limit or 8, 20)) + selected = _select_tool_catalog(query or "", category=category, limit=limit) + if selected is not None: + return selected + ranked: List[Tuple[int, Any, List[str]]] = [] for tool_info in list_tool_catalog_infos(): @@ -176,19 +264,7 @@ def search_tool_catalog( matched_tag_set: Set[str] = set() for score, tool_info, matched_tags in ranked[:limit]: - metadata = get_tool_catalog_metadata(tool_info.name, tool_info) matched_tag_set.update(matched_tags) - matches.append({ - "name": tool_info.name, - "description": tool_info.description, - "category": getattr(tool_info.category, "value", str(tool_info.category)), - "requires_confirmation": getattr(tool_info, "requires_confirmation", False), - "source": getattr(tool_info, "source", None), - "native": getattr(tool_info, "native", False), - "always_load": metadata.always_load, - "tags": metadata.tags, - "matchedTags": matched_tags, - "score": score, - }) + matches.append(_format_tool_catalog_match(tool_info, matched_tags, score)) return matches, sorted(matched_tag_set) diff --git a/flocks/tool/system/tool_search.py b/flocks/tool/system/tool_search.py index 17a8e3257..0deda0fd3 100644 --- a/flocks/tool/system/tool_search.py +++ b/flocks/tool/system/tool_search.py @@ -9,7 +9,7 @@ from typing import Optional -from flocks.tool.catalog import search_tool_catalog +from flocks.tool.catalog import normalize_tool_search_query, search_tool_catalog from flocks.tool.registry import ( ParameterType, ToolCategory, @@ -21,11 +21,13 @@ from flocks.session.callable_state import add_session_callable_tools -DESCRIPTION = """Search available tools by task intent, keyword, or category. +DESCRIPTION = """Search available tools by task intent, keyword, category, or exact names. Use this tool when you need to discover a tool that is not already exposed in the current turn. Search by user goal, capability, or keyword. Matching tools -returned here are added to the current session callable tool set immediately.""" +returned here are added to the current session callable tool set immediately. +If you already know the needed tool names, prefer one exact batch query such as +`select:websearch,webfetch,skill` instead of multiple separate searches.""" @ToolRegistry.register_function( name="tool_search", @@ -35,7 +37,10 @@ ToolParameter( name="query", type=ParameterType.STRING, - description="Search query describing the capability or task intent", + description=( + "Search query describing the capability or task intent. " + "Use select:tool_a,tool_b to expose multiple known tools in one call." + ), required=False, ), ToolParameter( @@ -61,12 +66,14 @@ async def tool_search( ) -> ToolResult: limit = max(1, min(limit or 8, 20)) matches, matched_tags = search_tool_catalog(query, category=category, limit=limit) + normalized_query = normalize_tool_search_query(query or "") callable_candidates = [match["name"] for match in matches] callable_tools = await add_session_callable_tools(ctx.session_id, callable_candidates) if ctx.event_publish_callback: await ctx.event_publish_callback("runtime.tool_discovery", { "sessionID": ctx.session_id, "query": query or "", + "normalizedQuery": normalized_query, "category": category, "returnedToolCount": len(matches), "callableToolCount": len(callable_tools), @@ -78,6 +85,7 @@ async def tool_search( success=True, output={ "query": query or "", + "normalizedQuery": normalized_query, "category": category, "count": len(matches), "matchedTags": matched_tags, diff --git a/tests/session/test_runner_llm_hooks.py b/tests/session/test_runner_llm_hooks.py new file mode 100644 index 000000000..18ac49a94 --- /dev/null +++ b/tests/session/test_runner_llm_hooks.py @@ -0,0 +1,272 @@ +"""Tests for LLM lifecycle hooks in SessionRunner and HookPipeline.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +import flocks.session.runner as runner_mod +from flocks.hooks.pipeline import HookBase, HookPipeline +from flocks.provider.provider import ChatMessage +from flocks.session.runner import SessionRunner +from flocks.session.session import SessionInfo + + +def _make_session(session_id: str = "ses_runner_llm_hooks") -> SessionInfo: + return SessionInfo.model_construct( + id=session_id, + slug="test", + project_id="proj_runner", + directory="/tmp", + title="Runner Hook Test", + ) + + +def _make_runner(session_id: str = "ses_runner_llm_hooks") -> SessionRunner: + return SessionRunner( + session=_make_session(session_id), + provider_id="anthropic", + model_id="claude-sonnet", + ) + + +class _FakeProcessor: + def __init__(self, **_: object): + self._text_parts: list[str] = [] + self._reasoning_parts: list[str] = [] + self.finish_reason = "stop" + self.tool_calls = {} + self._langfuse_generation = None + + async def process_event(self, event) -> None: + event_name = type(event).__name__ + if event_name == "TextDeltaEvent": + self._text_parts.append(event.text) + elif event_name == "ReasoningDeltaEvent": + self._reasoning_parts.append(event.text) + elif event_name == "FinishEvent": + self.finish_reason = event.finish_reason + + def get_text_content(self) -> str: + return "".join(self._text_parts) + + def get_reasoning_content(self) -> str: + return "".join(self._reasoning_parts) + + def get_finish_reason(self): + return self.finish_reason + + +class _FakeToolAccumulator: + def __init__(self, processor): + self.processor = processor + + async def feed_chunk(self, tool_call) -> None: + return None + + async def flush_remaining(self, finish_reason) -> None: + return None + + +@pytest.mark.asyncio +async def test_hook_pipeline_runs_llm_stages(): + seen: list[tuple[str, str]] = [] + + class _RecordingHook(HookBase): + async def llm_before(self, ctx) -> None: + seen.append((ctx.stage, ctx.input["request_id"])) + + async def llm_after(self, ctx) -> None: + seen.append((ctx.stage, ctx.output["status"])) + + HookPipeline.register("test-llm-stage-hook", _RecordingHook()) + try: + await HookPipeline.run_llm_before({"request_id": "req-1"}) + await HookPipeline.run_llm_after({"request_id": "req-1"}, {"status": "ok"}) + finally: + HookPipeline.unregister("test-llm-stage-hook") + + assert seen == [ + ("llm.call.before", "req-1"), + ("llm.call.after", "ok"), + ] + + +@pytest.mark.asyncio +async def test_call_llm_emits_hooks_on_success(monkeypatch: pytest.MonkeyPatch): + runner = _make_runner("ses_runner_llm_hooks_success") + assistant_msg = SimpleNamespace(id="msg_assistant_success") + agent = SimpleNamespace(name="rex") + usage = {"prompt_tokens": 7, "completion_tokens": 11, "total_tokens": 18} + order: list[str] = [] + + async def _before(payload): + order.append("before") + assert payload["request"]["toolCount"] == 1 + assert payload["request"]["providerToolsEnabled"] is True + + async def _after(payload, result): + order.append("after") + assert payload["sessionID"] == runner.session.id + assert result["action"] == "stop" + assert result["finishReason"] == "stop" + assert result["contentLength"] == len("hello") + assert result["reasoningLength"] == len("thinking") + assert result["toolCallCount"] == 0 + assert result["usage"] == usage + assert result["chunkCounts"] == {"total": 1, "reasoning": 1, "text": 1, "tool": 0} + + monkeypatch.setattr(runner_mod, "StreamProcessor", _FakeProcessor) + monkeypatch.setattr( + runner_mod.HookPipeline, + "run_llm_before", + AsyncMock(side_effect=_before), + ) + monkeypatch.setattr( + runner_mod.HookPipeline, + "run_llm_after", + AsyncMock(side_effect=_after), + ) + monkeypatch.setattr( + runner_mod.SessionRunner, + "_end_observability", + staticmethod(lambda *args, **kwargs: None), + ) + monkeypatch.setattr( + "flocks.provider.options.build_provider_options", + lambda provider_id, model_id: {"temperature": 0.2}, + ) + monkeypatch.setattr( + "flocks.session.streaming.tool_accumulator.ToolCallAccumulator", + _FakeToolAccumulator, + ) + monkeypatch.setattr(runner_mod.Message, "update", AsyncMock(return_value=None)) + monkeypatch.setattr( + runner_mod, + "trace_scope", + lambda **kwargs: SimpleNamespace(observation=None), + ) + monkeypatch.setattr( + runner_mod, + "generation_scope", + lambda **kwargs: SimpleNamespace(observation=None), + ) + + class _Provider: + def chat_stream(self, **kwargs): + assert kwargs["model_id"] == runner.model_id + assert kwargs["session_id"] == runner.session.id + + async def _gen(): + order.append("provider") + yield SimpleNamespace( + delta="hello", + reasoning="thinking", + tool_calls=None, + event_type=None, + finish_reason="stop", + usage=usage, + ) + + return _gen() + + result = await runner._call_llm( + provider=_Provider(), + messages=[ChatMessage(role="user", content="hello from user")], + tools=[ + { + "type": "function", + "function": { + "name": "search_docs", + "description": "Search docs", + "parameters": {"type": "object"}, + }, + } + ], + agent=agent, + assistant_msg=assistant_msg, + ) + + assert result.action == "stop" + assert result.content == "hello" + assert result.usage == usage + assert order == ["before", "provider", "after"] + + +@pytest.mark.asyncio +async def test_call_llm_emits_after_hook_on_error(monkeypatch: pytest.MonkeyPatch): + runner = _make_runner("ses_runner_llm_hooks_error") + assistant_msg = SimpleNamespace(id="msg_assistant_error") + agent = SimpleNamespace(name="rex") + order: list[str] = [] + + async def _before(payload): + order.append("before") + assert payload["request"]["messageCount"] == 1 + + async def _after(payload, result): + order.append("after") + assert payload["messageID"] == assistant_msg.id + assert result["chunkCounts"] == {"total": 0, "reasoning": 0, "text": 0, "tool": 0} + assert result["error"]["type"] == "RuntimeError" + assert "provider boom" in result["error"]["message"] + + monkeypatch.setattr(runner_mod, "StreamProcessor", _FakeProcessor) + monkeypatch.setattr( + runner_mod.HookPipeline, + "run_llm_before", + AsyncMock(side_effect=_before), + ) + monkeypatch.setattr( + runner_mod.HookPipeline, + "run_llm_after", + AsyncMock(side_effect=_after), + ) + monkeypatch.setattr( + runner_mod.SessionRunner, + "_end_observability", + staticmethod(lambda *args, **kwargs: None), + ) + monkeypatch.setattr( + "flocks.provider.options.build_provider_options", + lambda provider_id, model_id: {}, + ) + monkeypatch.setattr( + "flocks.session.streaming.tool_accumulator.ToolCallAccumulator", + _FakeToolAccumulator, + ) + monkeypatch.setattr(runner_mod.Message, "update", AsyncMock(return_value=None)) + monkeypatch.setattr( + runner_mod, + "trace_scope", + lambda **kwargs: SimpleNamespace(observation=None), + ) + monkeypatch.setattr( + runner_mod, + "generation_scope", + lambda **kwargs: SimpleNamespace(observation=None), + ) + + class _Provider: + def chat_stream(self, **kwargs): + assert kwargs["model_id"] == runner.model_id + + async def _gen(): + order.append("provider") + raise RuntimeError("provider boom") + yield # pragma: no cover + + return _gen() + + with pytest.raises(RuntimeError, match="provider boom"): + await runner._call_llm( + provider=_Provider(), + messages=[ChatMessage(role="user", content="hello from user")], + tools=[], + agent=agent, + assistant_msg=assistant_msg, + ) + + assert order == ["before", "provider", "after"] diff --git a/tests/tool/test_tool_catalog.py b/tests/tool/test_tool_catalog.py index d0840510d..f7ec413b3 100644 --- a/tests/tool/test_tool_catalog.py +++ b/tests/tool/test_tool_catalog.py @@ -1,7 +1,10 @@ from flocks.tool.catalog import ( apply_tool_catalog_defaults, + canonical_tool_token, get_tool_catalog_metadata, list_tool_catalog_infos, + normalize_tool_search_query, + search_tool_catalog, ) from flocks.tool.registry import ToolCategory, ToolInfo, ToolRegistry @@ -61,6 +64,46 @@ def test_explicit_tags_are_merged_with_defaults() -> None: assert "web" in enriched.tags +def test_tool_search_query_normalization_supports_aliases() -> None: + assert canonical_tool_token("WebSearchTool") == "websearch" + assert normalize_tool_search_query("WebSearchTool, web-fetch") == "websearch webfetch" + + +def test_search_tool_catalog_selects_exact_tools_in_order(monkeypatch) -> None: + tools = [ + ToolInfo( + name="websearch", + description="Search the web", + category=ToolCategory.BROWSER, + native=True, + enabled=True, + ), + ToolInfo( + name="webfetch", + description="Fetch a web page", + category=ToolCategory.BROWSER, + native=True, + enabled=True, + ), + ToolInfo( + name="read", + description="Read file contents", + category=ToolCategory.FILE, + native=True, + enabled=True, + ), + ] + monkeypatch.setattr( + "flocks.tool.registry.ToolRegistry.list_tools", + lambda: tools, + ) + + matches, matched_tags = search_tool_catalog("select:WebFetchTool,websearch", limit=5) + + assert [match["name"] for match in matches] == ["webfetch", "websearch"] + assert matched_tags == [] + + def test_list_tool_catalog_infos_excludes_disabled_tools(monkeypatch) -> None: enabled_tool = ToolInfo( name="read", diff --git a/tests/tool/test_tool_search_discovery.py b/tests/tool/test_tool_search_discovery.py index 512eb2335..dd2a28b78 100644 --- a/tests/tool/test_tool_search_discovery.py +++ b/tests/tool/test_tool_search_discovery.py @@ -82,6 +82,30 @@ async def test_tool_search_supports_category_and_tag_matching( assert result.output["matchedTags"] == ["research"] +@pytest.mark.asyncio +async def test_tool_search_supports_exact_batch_select_and_aliases( + monkeypatch: pytest.MonkeyPatch, +) -> None: + tools = [ + _tool("websearch", ToolCategory.BROWSER), + _tool("webfetch", ToolCategory.BROWSER), + _tool("read", ToolCategory.FILE), + ] + add_callable = AsyncMock(return_value={"websearch", "webfetch"}) + + monkeypatch.setattr("flocks.tool.system.tool_search.ToolRegistry.list_tools", lambda: tools) + monkeypatch.setattr("flocks.tool.system.tool_search.add_session_callable_tools", add_callable) + + ctx = SimpleNamespace(session_id="session-select", event_publish_callback=AsyncMock()) + result = await tool_search(ctx, query="select:WebSearchTool,webfetch", limit=5) + + assert result.success is True + assert result.output["normalizedQuery"] == "websearch webfetch" + assert result.output["callableToolNames"] == ["webfetch", "websearch"] + assert [match["name"] for match in result.output["matches"]] == ["websearch", "webfetch"] + add_callable.assert_awaited_once_with("session-select", ["websearch", "webfetch"]) + + @pytest.mark.asyncio async def test_tool_search_returns_user_plugin_tools( monkeypatch: pytest.MonkeyPatch,