diff --git a/src/pipelines/code_retrieval.py b/src/pipelines/code_retrieval.py index 69e4abc..c58e4c7 100644 --- a/src/pipelines/code_retrieval.py +++ b/src/pipelines/code_retrieval.py @@ -25,6 +25,7 @@ from __future__ import annotations +import asyncio import logging from typing import Any, Callable, Dict, List, Optional @@ -37,7 +38,6 @@ from src.scanner.code_store import CodeStore from src.schemas.code import ( annotations_namespace, - directories_namespace, files_namespace, snippets_namespace, symbols_namespace, @@ -375,17 +375,24 @@ async def run( turn_records: List[SourceRecord] = [] only_read_tools = True - for tc in ai_response.tool_calls: - tool_name = tc["name"] - tool_args = tc["args"] - tool_id = tc["id"] - + async def _process_tool_call(tc: Dict[str, Any]): t1 = _time.perf_counter() records = await self._execute_tool( - tool_name, tool_args, repo=repo, top_k=top_k, + tc["name"], tc["args"], repo=repo, top_k=top_k, user_id=user_id, ) tool_ms = (_time.perf_counter() - t1) * 1000 + return tc, records, tool_ms + + results = await asyncio.gather( + *[_process_tool_call(tc) for tc in ai_response.tool_calls] + ) + + for tc, records, tool_ms in results: + tool_name = tc["name"] + tool_args = tc["args"] + tool_id = tc["id"] + logger.info(" Tool: %s(%s) → %d results (%.0fms)", tool_name, tool_args, len(records), tool_ms) turn_records.extend(records) sources.extend(records) @@ -471,17 +478,24 @@ async def run_stream( if ai_response.tool_calls: yield json.dumps({"type": "status", "content": f"Running {len(ai_response.tool_calls)} search tool(s)..."}) + "\n" - for tc in ai_response.tool_calls: + async def _process_tool_call_stream(tc: Dict[str, Any]): + records = await self._execute_tool( + tc["name"], tc["args"], repo=repo, top_k=top_k, + user_id=user_id, + ) + return tc, records + + results = await asyncio.gather( + *[_process_tool_call_stream(tc) for tc in ai_response.tool_calls] + ) + + for tc, records in results: tool_name = tc["name"] tool_args = tc["args"] tool_id = tc["id"] logger.info(" Tool: %s(%s)", tool_name, tool_args) - records = await self._execute_tool( - tool_name, tool_args, repo=repo, top_k=top_k, - user_id=user_id, - ) sources.extend(records) tool_result_text = self._format_tool_results(records) diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index d54cc0d..24fa612 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -20,8 +20,8 @@ from __future__ import annotations +import asyncio import logging -import os from typing import Any, Callable, Dict, List, Optional from dotenv import load_dotenv @@ -177,16 +177,24 @@ async def run( if ai_response.tool_calls: called_tools = set() - for tc in ai_response.tool_calls: + + async def _process_tool_call(tc: Dict[str, Any]): + records = await self._execute_tool( + tc["name"], tc["args"], user_id, top_k, + ) + return tc, records + + results = await asyncio.gather( + *[_process_tool_call(tc) for tc in ai_response.tool_calls] + ) + + for tc, records in results: tool_name = tc["name"] tool_args = tc["args"] tool_id = tc["id"] logger.info(" Tool call: %s(%s)", tool_name, tool_args) - records = await self._execute_tool( - tool_name, tool_args, user_id, top_k, - ) sources.extend(records) # Build ToolMessage for the LLM