diff --git a/.jules/bolt.md b/.jules/bolt.md new file mode 100644 index 0000000..5c09cb3 --- /dev/null +++ b/.jules/bolt.md @@ -0,0 +1,3 @@ +## 2023-10-25 - Concurrent tool call execution in Retrieval Pipelines +**Learning:** Sequential execution of multiple tool calls in LangChain AIMessage (e.g. querying Pinecone and Neo4j simultaneously) creates a significant latency bottleneck in `RetrievalPipeline` and `CodeRetrievalPipeline`. A similar sequential bottleneck exists when searching across multiple repositories. +**Action:** Use `asyncio.gather` to execute independent tool calls and multi-namespace searches concurrently, then process results sequentially to safely update shared state. diff --git a/src/pipelines/code_retrieval.py b/src/pipelines/code_retrieval.py index 69e4abc..6ff2329 100644 --- a/src/pipelines/code_retrieval.py +++ b/src/pipelines/code_retrieval.py @@ -375,17 +375,23 @@ async def run( turn_records: List[SourceRecord] = [] only_read_tools = True - for tc in ai_response.tool_calls: - tool_name = tc["name"] - tool_args = tc["args"] - tool_id = tc["id"] - + async def _process_tool_call(tc: Dict[str, Any]) -> tuple[str, str, str, List[SourceRecord], float]: + t_name = tc["name"] + t_args = tc["args"] + t_id = tc["id"] t1 = _time.perf_counter() - records = await self._execute_tool( - tool_name, tool_args, repo=repo, top_k=top_k, - user_id=user_id, + recs = await self._execute_tool( + t_name, t_args, repo=repo, top_k=top_k, user_id=user_id, ) - tool_ms = (_time.perf_counter() - t1) * 1000 + t_ms = (_time.perf_counter() - t1) * 1000 + return t_name, t_args, t_id, recs, t_ms + + import asyncio + results = await asyncio.gather(*( + _process_tool_call(tc) for tc in ai_response.tool_calls + )) + + for tool_name, tool_args, tool_id, records, tool_ms in results: logger.info(" Tool: %s(%s) → %d results (%.0fms)", tool_name, tool_args, len(records), tool_ms) turn_records.extend(records) sources.extend(records) @@ -471,17 +477,22 @@ async def run_stream( if ai_response.tool_calls: yield json.dumps({"type": "status", "content": f"Running {len(ai_response.tool_calls)} search tool(s)..."}) + "\n" - for tc in ai_response.tool_calls: - tool_name = tc["name"] - tool_args = tc["args"] - tool_id = tc["id"] + async def _process_stream_tool_call(tc: Dict[str, Any]) -> tuple[str, str, str, List[SourceRecord]]: + t_name = tc["name"] + t_args = tc["args"] + t_id = tc["id"] + logger.info(" Tool: %s(%s)", t_name, t_args) + recs = await self._execute_tool( + t_name, t_args, repo=repo, top_k=top_k, user_id=user_id, + ) + return t_name, t_args, t_id, recs - logger.info(" Tool: %s(%s)", tool_name, tool_args) + import asyncio + results = await asyncio.gather(*( + _process_stream_tool_call(tc) for tc in ai_response.tool_calls + )) - records = await self._execute_tool( - tool_name, tool_args, repo=repo, top_k=top_k, - user_id=user_id, - ) + for tool_name, tool_args, tool_id, records in results: sources.extend(records) tool_result_text = self._format_tool_results(records) @@ -589,14 +600,16 @@ async def _search_symbols( ) -> List[SourceRecord]: if not repo: logger.warning("search_symbols called without repo — searching all repos") - results = [] - for r in self.repos: - results.extend(await self._search_namespace( + import asyncio + repo_results = await asyncio.gather(*( + self._search_namespace( namespace=symbols_namespace(self.org_id, r), query=query, domain="symbol", top_k=top_k, - )) + ) for r in self.repos + )) + results = [item for sublist in repo_results for item in sublist] return results[:top_k] return await self._search_namespace( @@ -612,14 +625,16 @@ async def _search_files( self, query: str, repo: str, top_k: int = 10, ) -> List[SourceRecord]: if not repo: - results = [] - for r in self.repos: - results.extend(await self._search_namespace( + import asyncio + repo_results = await asyncio.gather(*( + self._search_namespace( namespace=files_namespace(self.org_id, r), query=query, domain="file", top_k=top_k, - )) + ) for r in self.repos + )) + results = [item for sublist in repo_results for item in sublist] return results[:top_k] return await self._search_namespace( diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index d54cc0d..e3be97c 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -177,16 +177,24 @@ async def run( if ai_response.tool_calls: called_tools = set() - for tc in ai_response.tool_calls: - tool_name = tc["name"] - tool_args = tc["args"] - tool_id = tc["id"] - logger.info(" Tool call: %s(%s)", tool_name, tool_args) + # Helper to execute tool calls concurrently + async def _process_tool_call(tc: Dict[str, Any]) -> tuple[str, str, str, List[SourceRecord]]: + t_name = tc["name"] + t_args = tc["args"] + t_id = tc["id"] + logger.info(" Tool call: %s(%s)", t_name, t_args) + recs = await self._execute_tool(t_name, t_args, user_id, top_k) + return t_name, t_args, t_id, recs + + import asyncio + # Execute all tool calls concurrently + results = await asyncio.gather(*( + _process_tool_call(tc) for tc in ai_response.tool_calls + )) - records = await self._execute_tool( - tool_name, tool_args, user_id, top_k, - ) + # Process results sequentially to safely extend shared state + for tool_name, tool_args, tool_id, records in results: sources.extend(records) # Build ToolMessage for the LLM