diff --git a/.jules/bolt.md b/.jules/bolt.md new file mode 100644 index 0000000..62e023b --- /dev/null +++ b/.jules/bolt.md @@ -0,0 +1,3 @@ +## 2024-04-11 - Parallelize Tool Executions in Retrieval Pipelines +**Learning:** Sequential execution of LLM tool calls in `RetrievalPipeline` and `CodeRetrievalPipeline` caused unnecessary blocking during query processing. Refactoring the loop to use `asyncio.gather` reduces latency. +**Action:** When working with LLM responses that request multiple tool calls, evaluate if the calls are independent. If so, process them concurrently with `asyncio.gather` and then update shared state sequentially after awaiting the results to ensure thread safety and optimal performance. diff --git a/src/pipelines/code_retrieval.py b/src/pipelines/code_retrieval.py index 69e4abc..2d911e9 100644 --- a/src/pipelines/code_retrieval.py +++ b/src/pipelines/code_retrieval.py @@ -25,6 +25,7 @@ from __future__ import annotations +import asyncio import logging from typing import Any, Callable, Dict, List, Optional @@ -37,7 +38,6 @@ from src.scanner.code_store import CodeStore from src.schemas.code import ( annotations_namespace, - directories_namespace, files_namespace, snippets_namespace, symbols_namespace, @@ -375,11 +375,11 @@ async def run( turn_records: List[SourceRecord] = [] only_read_tools = True - for tc in ai_response.tool_calls: + + + async def _process_tool_call(tc): tool_name = tc["name"] tool_args = tc["args"] - tool_id = tc["id"] - t1 = _time.perf_counter() records = await self._execute_tool( tool_name, tool_args, repo=repo, top_k=top_k, @@ -387,6 +387,14 @@ async def run( ) tool_ms = (_time.perf_counter() - t1) * 1000 logger.info(" Tool: %s(%s) → %d results (%.0fms)", tool_name, tool_args, len(records), tool_ms) + return tc, records + + tool_results = await asyncio.gather(*[_process_tool_call(tc) for tc in ai_response.tool_calls]) + + for tc, records in tool_results: + tool_name = tc["name"] + tool_id = tc["id"] + turn_records.extend(records) sources.extend(records) @@ -471,17 +479,23 @@ async def run_stream( if ai_response.tool_calls: yield json.dumps({"type": "status", "content": f"Running {len(ai_response.tool_calls)} search tool(s)..."}) + "\n" - for tc in ai_response.tool_calls: + + + async def _process_stream_tool_call(tc): tool_name = tc["name"] tool_args = tc["args"] - tool_id = tc["id"] - logger.info(" Tool: %s(%s)", tool_name, tool_args) records = await self._execute_tool( tool_name, tool_args, repo=repo, top_k=top_k, user_id=user_id, ) + return tc, records + + tool_results = await asyncio.gather(*[_process_stream_tool_call(tc) for tc in ai_response.tool_calls]) + + for tc, records in tool_results: + tool_id = tc["id"] sources.extend(records) tool_result_text = self._format_tool_results(records) diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index d54cc0d..2844a44 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -20,8 +20,8 @@ from __future__ import annotations +import asyncio import logging -import os from typing import Any, Callable, Dict, List, Optional from dotenv import load_dotenv @@ -176,17 +176,26 @@ async def run( tool_messages: List[ToolMessage] = [] if ai_response.tool_calls: + + called_tools = set() - for tc in ai_response.tool_calls: + + async def _process_tool_call(tc): tool_name = tc["name"] tool_args = tc["args"] - tool_id = tc["id"] - logger.info(" Tool call: %s(%s)", tool_name, tool_args) records = await self._execute_tool( tool_name, tool_args, user_id, top_k, ) + return tc, records + + tool_results = await asyncio.gather(*[_process_tool_call(tc) for tc in ai_response.tool_calls]) + + for tc, records in tool_results: + tool_name = tc["name"] + tool_id = tc["id"] + sources.extend(records) # Build ToolMessage for the LLM @@ -351,7 +360,7 @@ async def _search_temporal( top_k: int = 3, ) -> List[SourceRecord]: """Semantic search over temporal events in Neo4j.""" - import asyncio + from functools import partial loop = asyncio.get_running_loop()