Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .jules/bolt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## 2023-10-25 - Concurrent tool call execution in Retrieval Pipelines
**Learning:** Sequential execution of multiple tool calls in LangChain AIMessage (e.g. querying Pinecone and Neo4j simultaneously) creates a significant latency bottleneck in `RetrievalPipeline` and `CodeRetrievalPipeline`. A similar sequential bottleneck exists when searching across multiple repositories.
**Action:** Use `asyncio.gather` to execute independent tool calls and multi-namespace searches concurrently, then process results sequentially to safely update shared state.
67 changes: 41 additions & 26 deletions src/pipelines/code_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,17 +375,23 @@ async def run(
turn_records: List[SourceRecord] = []
only_read_tools = True

for tc in ai_response.tool_calls:
tool_name = tc["name"]
tool_args = tc["args"]
tool_id = tc["id"]

async def _process_tool_call(tc: Dict[str, Any]) -> tuple[str, str, str, List[SourceRecord], float]:
t_name = tc["name"]
t_args = tc["args"]
t_id = tc["id"]
t1 = _time.perf_counter()
records = await self._execute_tool(
tool_name, tool_args, repo=repo, top_k=top_k,
user_id=user_id,
recs = await self._execute_tool(
t_name, t_args, repo=repo, top_k=top_k, user_id=user_id,
)
tool_ms = (_time.perf_counter() - t1) * 1000
t_ms = (_time.perf_counter() - t1) * 1000
return t_name, t_args, t_id, recs, t_ms

import asyncio
results = await asyncio.gather(*(
_process_tool_call(tc) for tc in ai_response.tool_calls
))

for tool_name, tool_args, tool_id, records, tool_ms in results:
logger.info(" Tool: %s(%s) → %d results (%.0fms)", tool_name, tool_args, len(records), tool_ms)
turn_records.extend(records)
sources.extend(records)
Expand Down Expand Up @@ -471,17 +477,22 @@ async def run_stream(
if ai_response.tool_calls:
yield json.dumps({"type": "status", "content": f"Running {len(ai_response.tool_calls)} search tool(s)..."}) + "\n"

for tc in ai_response.tool_calls:
tool_name = tc["name"]
tool_args = tc["args"]
tool_id = tc["id"]
async def _process_stream_tool_call(tc: Dict[str, Any]) -> tuple[str, str, str, List[SourceRecord]]:
t_name = tc["name"]
t_args = tc["args"]
t_id = tc["id"]
logger.info(" Tool: %s(%s)", t_name, t_args)
recs = await self._execute_tool(
t_name, t_args, repo=repo, top_k=top_k, user_id=user_id,
)
return t_name, t_args, t_id, recs

logger.info(" Tool: %s(%s)", tool_name, tool_args)
import asyncio
results = await asyncio.gather(*(
_process_stream_tool_call(tc) for tc in ai_response.tool_calls
))

records = await self._execute_tool(
tool_name, tool_args, repo=repo, top_k=top_k,
user_id=user_id,
)
for tool_name, tool_args, tool_id, records in results:
sources.extend(records)

tool_result_text = self._format_tool_results(records)
Expand Down Expand Up @@ -589,14 +600,16 @@ async def _search_symbols(
) -> List[SourceRecord]:
if not repo:
logger.warning("search_symbols called without repo — searching all repos")
results = []
for r in self.repos:
results.extend(await self._search_namespace(
import asyncio
repo_results = await asyncio.gather(*(
self._search_namespace(
namespace=symbols_namespace(self.org_id, r),
query=query,
domain="symbol",
top_k=top_k,
))
) for r in self.repos
))
results = [item for sublist in repo_results for item in sublist]
return results[:top_k]

return await self._search_namespace(
Expand All @@ -612,14 +625,16 @@ async def _search_files(
self, query: str, repo: str, top_k: int = 10,
) -> List[SourceRecord]:
if not repo:
results = []
for r in self.repos:
results.extend(await self._search_namespace(
import asyncio
repo_results = await asyncio.gather(*(
self._search_namespace(
namespace=files_namespace(self.org_id, r),
query=query,
domain="file",
top_k=top_k,
))
) for r in self.repos
))
results = [item for sublist in repo_results for item in sublist]
return results[:top_k]

return await self._search_namespace(
Expand Down
24 changes: 16 additions & 8 deletions src/pipelines/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,24 @@ async def run(

if ai_response.tool_calls:
called_tools = set()
for tc in ai_response.tool_calls:
tool_name = tc["name"]
tool_args = tc["args"]
tool_id = tc["id"]

logger.info(" Tool call: %s(%s)", tool_name, tool_args)
# Helper to execute tool calls concurrently
async def _process_tool_call(tc: Dict[str, Any]) -> tuple[str, str, str, List[SourceRecord]]:
t_name = tc["name"]
t_args = tc["args"]
t_id = tc["id"]
logger.info(" Tool call: %s(%s)", t_name, t_args)
recs = await self._execute_tool(t_name, t_args, user_id, top_k)
return t_name, t_args, t_id, recs

import asyncio
# Execute all tool calls concurrently
results = await asyncio.gather(*(
_process_tool_call(tc) for tc in ai_response.tool_calls
))

records = await self._execute_tool(
tool_name, tool_args, user_id, top_k,
)
# Process results sequentially to safely extend shared state
for tool_name, tool_args, tool_id, records in results:
sources.extend(records)

# Build ToolMessage for the LLM
Expand Down