diff --git a/src/pipelines/code_retrieval.py b/src/pipelines/code_retrieval.py index 69e4abc..df39698 100644 --- a/src/pipelines/code_retrieval.py +++ b/src/pipelines/code_retrieval.py @@ -25,6 +25,7 @@ from __future__ import annotations +import asyncio import logging from typing import Any, Callable, Dict, List, Optional @@ -37,7 +38,6 @@ from src.scanner.code_store import CodeStore from src.schemas.code import ( annotations_namespace, - directories_namespace, files_namespace, snippets_namespace, symbols_namespace, @@ -375,18 +375,26 @@ async def run( turn_records: List[SourceRecord] = [] only_read_tools = True - for tc in ai_response.tool_calls: + async def _process_tool_call(tc: dict) -> tuple[dict, list[SourceRecord]]: + t1 = _time.perf_counter() tool_name = tc["name"] tool_args = tc["args"] - tool_id = tc["id"] - - t1 = _time.perf_counter() - records = await self._execute_tool( + recs = await self._execute_tool( tool_name, tool_args, repo=repo, top_k=top_k, user_id=user_id, ) tool_ms = (_time.perf_counter() - t1) * 1000 - logger.info(" Tool: %s(%s) → %d results (%.0fms)", tool_name, tool_args, len(records), tool_ms) + logger.info(" Tool: %s(%s) → %d results (%.0fms)", tool_name, tool_args, len(recs), tool_ms) + return tc, recs + + tool_results = await asyncio.gather(*( + _process_tool_call(tc) for tc in ai_response.tool_calls + )) + + for tc, records in tool_results: + tool_name = tc["name"] + tool_id = tc["id"] + turn_records.extend(records) sources.extend(records) @@ -471,17 +479,22 @@ async def run_stream( if ai_response.tool_calls: yield json.dumps({"type": "status", "content": f"Running {len(ai_response.tool_calls)} search tool(s)..."}) + "\n" - for tc in ai_response.tool_calls: + async def _process_tool_call_stream(tc: dict) -> tuple[dict, list[SourceRecord]]: tool_name = tc["name"] tool_args = tc["args"] - tool_id = tc["id"] - logger.info(" Tool: %s(%s)", tool_name, tool_args) - - records = await self._execute_tool( + recs = await self._execute_tool( tool_name, tool_args, repo=repo, top_k=top_k, user_id=user_id, ) + return tc, recs + + tool_results = await asyncio.gather(*( + _process_tool_call_stream(tc) for tc in ai_response.tool_calls + )) + + for tc, records in tool_results: + tool_id = tc["id"] sources.extend(records) tool_result_text = self._format_tool_results(records) diff --git a/src/pipelines/ingest.py b/src/pipelines/ingest.py index c1cb3a4..7eee427 100644 --- a/src/pipelines/ingest.py +++ b/src/pipelines/ingest.py @@ -82,7 +82,7 @@ ) from src.schemas.events import EventResult from src.schemas.image import ImageResult -from src.schemas.judge import JudgeDomain, JudgeResult, OperationType +from src.schemas.judge import JudgeDomain, JudgeResult from src.schemas.profile import ProfileResult from src.schemas.summary import SummaryResult from src.schemas.weaver import WeaverResult diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index d54cc0d..203fad1 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -20,8 +20,8 @@ from __future__ import annotations +import asyncio import logging -import os from typing import Any, Callable, Dict, List, Optional from dotenv import load_dotenv @@ -177,16 +177,24 @@ async def run( if ai_response.tool_calls: called_tools = set() - for tc in ai_response.tool_calls: + + async def _process_tool_call(tc: dict) -> tuple[dict, list[SourceRecord]]: tool_name = tc["name"] tool_args = tc["args"] - tool_id = tc["id"] - logger.info(" Tool call: %s(%s)", tool_name, tool_args) - - records = await self._execute_tool( + recs = await self._execute_tool( tool_name, tool_args, user_id, top_k, ) + return tc, recs + + tool_results = await asyncio.gather(*( + _process_tool_call(tc) for tc in ai_response.tool_calls + )) + + for tc, records in tool_results: + tool_name = tc["name"] + tool_id = tc["id"] + sources.extend(records) # Build ToolMessage for the LLM