Skip to content

Commit 46d7fcb

Browse files
fix: address critical Tool Search issues
- Fix unconditional bridge tool hijacking that broke user tools when disabled - Resolve duplicate ToolSearchConfig classes by consolidating to tools module - Add missing resolve_tool_search resolver function - Fix hot path performance issues in BM25 scoring - Add input validation for tool_call bridge args - Fix cache metadata mismatches in chat_mixin - Remove unnecessary f-string prefixes - Strip __praisonai_deferrable__ from provider payloads - Preserve tool_search config in clone_for_channel - Improve type validation in agent constructor - Fix overly broad MCP name heuristic Co-authored-by: Mervin Praison <MervinPraison@users.noreply.github.com>
1 parent ab44847 commit 46d7fcb

6 files changed

Lines changed: 86 additions & 65 deletions

File tree

src/praisonai-agents/praisonaiagents/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def _get_lazy_cache():
384384
'MultiAgentExecutionConfig': ('praisonaiagents.config.feature_configs', 'MultiAgentExecutionConfig'),
385385
'MultiAgentPlanningConfig': ('praisonaiagents.config.feature_configs', 'MultiAgentPlanningConfig'),
386386
'MultiAgentMemoryConfig': ('praisonaiagents.config.feature_configs', 'MultiAgentMemoryConfig'),
387+
'ToolSearchConfig': ('praisonaiagents.config.feature_configs', 'ToolSearchConfig'),
387388

388389
# Parameter resolver
389390
'resolve': ('praisonaiagents.config.param_resolver', 'resolve'),

src/praisonai-agents/praisonaiagents/agent/agent.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,8 +1458,14 @@ def __init__(
14581458
from ..tools.tool_search import ToolSearchConfig as _ToolSearchConfig
14591459
self._tool_search_config = _ToolSearchConfig(**tool_search)
14601460
else:
1461-
# Assume ToolSearchConfig instance or similar
1462-
self._tool_search_config = tool_search
1461+
from ..tools.tool_search import ToolSearchConfig as _ToolSearchConfig
1462+
if isinstance(tool_search, _ToolSearchConfig):
1463+
self._tool_search_config = tool_search
1464+
else:
1465+
raise TypeError(
1466+
"tool_search must be False/None, True, a mode string, "
1467+
"a dict of ToolSearchConfig fields, or ToolSearchConfig"
1468+
)
14631469

14641470
# ============================================================
14651471
# END CONSOLIDATED PARAMS EXTRACTION
@@ -2057,6 +2063,7 @@ def clone_for_channel(self) -> "Agent":
20572063
'skills': getattr(self, '_skills_config', None),
20582064
'approval': getattr(self, '_approval_config', None),
20592065
'learn': getattr(self, '_learn_config', None),
2066+
'tool_search': getattr(self, '_tool_search_config', None),
20602067

20612068
# Tool configuration
20622069
'tool_timeout': getattr(self, '_tool_timeout', None),

src/praisonai-agents/praisonaiagents/agent/chat_mixin.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,10 @@ def _format_tools_for_completion(self, tools=None):
349349
tools_key = self._get_tools_cache_key(tools)
350350
tool_search_enabled = getattr(self, '_tool_search_config', None) is not None
351351
cache_key = f"{tools_key}:tool_search={tool_search_enabled}"
352-
cached_tools = self._cache_get(self._formatted_tools_cache, cache_key)
353-
if cached_tools is not None:
352+
cached_entry = self._cache_get(self._formatted_tools_cache, cache_key)
353+
if cached_entry is not None:
354+
cached_tools, cached_metadata = cached_entry
355+
self._tool_search_metadata = cached_metadata
354356
return cached_tools
355357

356358
formatted_tools = []
@@ -424,9 +426,24 @@ def _format_tools_for_completion(self, tools=None):
424426
# Tool search module not available, continue with original tools
425427
logging.warning("Tool search requested but tool_search module not available")
426428

427-
# Cache the formatted tools with LRU eviction
428-
self._cache_put(self._formatted_tools_cache, cache_key, formatted_tools)
429-
return formatted_tools
429+
# Strip __praisonai_deferrable__ from provider-facing tool payloads
430+
# Keep the marker only for internal tool classification
431+
cleaned_tools = []
432+
for tool in formatted_tools:
433+
if isinstance(tool, dict) and "__praisonai_deferrable__" in tool:
434+
tool_copy = tool.copy()
435+
tool_copy.pop("__praisonai_deferrable__", None)
436+
cleaned_tools.append(tool_copy)
437+
else:
438+
cleaned_tools.append(tool)
439+
440+
# Cache the formatted tools with LRU eviction, including tool search metadata
441+
self._cache_put(
442+
self._formatted_tools_cache,
443+
cache_key,
444+
(cleaned_tools, getattr(self, "_tool_search_metadata", None)),
445+
)
446+
return cleaned_tools
430447

431448
def _build_multimodal_prompt(
432449
self,

src/praisonai-agents/praisonaiagents/agent/tool_execution.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def execute_tool(self, function_name: str, arguments: Dict[str, Any], tool_call_
136136
logging.debug(f"{self.name} executing tool {function_name} with arguments: {arguments}")
137137

138138
# Handle bridge tool unwrapping BEFORE trace/stream/hooks (design invariant #6)
139-
if function_name in ("tool_search", "tool_describe", "tool_call"):
139+
# Only intercept when tool_search is active; otherwise fall through to real tool execution
140+
if (getattr(self, '_tool_search_config', None) is not None and
141+
function_name in ("tool_search", "tool_describe", "tool_call")):
140142
return self._handle_bridge_tool_call(function_name, arguments, tool_call_id)
141143

142144
# NOTE: tool_call callback is triggered by display_tool_call in openai_client.py
@@ -976,13 +978,13 @@ def _handle_bridge_tool_call(self, function_name: str, arguments: Dict[str, Any]
976978
"""
977979
# Ensure tool search metadata is available
978980
if not hasattr(self, '_tool_search_metadata') or self._tool_search_metadata is None:
979-
return f"Tool search not available or not in bridge mode"
981+
return "Tool search not available or not in bridge mode"
980982

981983
metadata = self._tool_search_metadata
982984

983985
# Check if we're in bridge mode
984986
if not metadata.get("bridge_mode", False):
985-
return f"Tool search not in bridge mode"
987+
return "Tool search not in bridge mode"
986988

987989
# Get deferrable tools from metadata
988990
deferrable_tools = metadata.get("deferrable_tools", [])

src/praisonai-agents/praisonaiagents/config/feature_configs.py

Lines changed: 33 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,51 +1074,23 @@ def to_dict(self) -> Dict[str, Any]:
10741074
}
10751075

10761076

1077-
@dataclass
1078-
class ToolSearchConfig:
1079-
"""
1080-
Configuration for Tool Search feature - progressive MCP/plugin tool disclosure.
1081-
1082-
When deferrable tool schemas would consume a large share of the model context window,
1083-
replace them with bridge tools (tool_search, tool_describe, tool_call) and load
1084-
individual schemas on demand.
1085-
1086-
Consolidates: tool_search parameter
1087-
1088-
Usage:
1089-
# Simple enable (auto mode)
1090-
Agent(tool_search=True)
1091-
1092-
# Auto mode with custom threshold
1093-
Agent(tool_search=ToolSearchConfig(enabled="auto", threshold_pct=15))
1094-
1095-
# Always on mode
1096-
Agent(tool_search=ToolSearchConfig(enabled="on"))
1097-
"""
1098-
# Control mode: "auto" | "on" | "off"
1099-
enabled: Union[bool, str] = "auto"
1100-
1101-
# Percentage of context window for deferral threshold (auto mode)
1102-
threshold_pct: float = 10.0
1103-
1104-
# Default number of search results
1105-
search_default_limit: int = 5
1106-
1107-
# Maximum search results allowed
1108-
max_search_limit: int = 20
1109-
1110-
# Override core tools set (advanced usage)
1111-
core_tools: Optional[FrozenSet[str]] = None
1112-
1113-
def to_dict(self) -> Dict[str, Any]:
1114-
"""Convert to dictionary."""
1115-
return {
1116-
"enabled": self.enabled,
1117-
"threshold_pct": self.threshold_pct,
1118-
"search_default_limit": self.search_default_limit,
1119-
"max_search_limit": self.max_search_limit,
1120-
"core_tools": list(self.core_tools) if self.core_tools else None,
1121-
}
1077+
# Import ToolSearchConfig from tools module to avoid duplication
1078+
def __get_tool_search_config():
1079+
try:
1080+
from ..tools.tool_search import ToolSearchConfig as _ToolSearchConfig
1081+
return _ToolSearchConfig
1082+
except ImportError:
1083+
# Fallback minimal config if tools module not available
1084+
@dataclass
1085+
class FallbackToolSearchConfig:
1086+
enabled: Union[bool, str] = "auto"
1087+
threshold_pct: float = 10.0
1088+
search_default_limit: int = 5
1089+
max_search_limit: int = 20
1090+
core_tools: Optional[FrozenSet[str]] = None
1091+
return FallbackToolSearchConfig
1092+
1093+
ToolSearchConfig = __get_tool_search_config()
11221094

11231095

11241096
class AutonomyLevel(str, Enum):
@@ -1393,6 +1365,21 @@ def resolve_autonomy(value: AutonomyParam) -> Optional[AutonomyConfig]:
13931365
return value
13941366

13951367

1368+
def resolve_tool_search(value: ToolSearchParam) -> Optional[ToolSearchConfig]:
1369+
"""Resolve tool_search= parameter following precedence ladder."""
1370+
if value is None or value is False:
1371+
return None
1372+
if value is True:
1373+
return ToolSearchConfig()
1374+
if isinstance(value, str):
1375+
return ToolSearchConfig(enabled=value)
1376+
if isinstance(value, dict):
1377+
return ToolSearchConfig(**value)
1378+
if isinstance(value, ToolSearchConfig):
1379+
return value
1380+
return value
1381+
1382+
13961383
__all__ = [
13971384
# Enums
13981385
"MemoryBackend",
@@ -1451,4 +1438,5 @@ def resolve_autonomy(value: AutonomyParam) -> Optional[AutonomyConfig]:
14511438
"resolve_execution",
14521439
"resolve_caching",
14531440
"resolve_autonomy",
1441+
"resolve_tool_search",
14541442
]

src/praisonai-agents/praisonaiagents/tools/tool_search.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ def _is_tool_deferrable(tool_def: ToolDef) -> bool:
139139
if function_def.get("deferrable", False):
140140
return True
141141

142-
# Check tool name patterns for MCP tools
142+
# Check tool name patterns for MCP tools (prefix only, to avoid false-positives)
143143
tool_name = function_def.get("name", "")
144-
if tool_name.startswith("mcp_") or "mcp" in tool_name.lower():
144+
if tool_name.startswith("mcp_"):
145145
return True
146146

147147
return False
@@ -234,6 +234,7 @@ def _tokenize(self, text: str) -> List[str]:
234234

235235
def _build_index(self):
236236
"""Build BM25 index from catalog."""
237+
total_length = 0
237238
# Calculate term frequencies for each document
238239
for item in self.catalog:
239240
doc_text = f"{item['name']} {item['description']}"
@@ -245,11 +246,15 @@ def _build_index(self):
245246
tf[token] += 1
246247

247248
self.term_frequencies.append(dict(tf))
249+
total_length += len(tokens)
248250

249251
# Document frequency (how many docs contain each term)
250252
unique_tokens = set(tokens)
251253
for token in unique_tokens:
252254
self.doc_frequencies[token] += 1
255+
256+
# Cache average document length to avoid O(n) recomputation in hot path
257+
self._cached_avg_doc_length = total_length / len(self.catalog) if self.catalog else 0.0
253258

254259
def search(self, query: str, limit: int = 5) -> List[Dict[str, str]]:
255260
"""
@@ -269,7 +274,7 @@ def search(self, query: str, limit: int = 5) -> List[Dict[str, str]]:
269274
scores = []
270275
k1, b = 1.5, 0.75 # BM25 parameters
271276

272-
for i, (item, tf) in enumerate(zip(self.catalog, self.term_frequencies)):
277+
for item, tf in zip(self.catalog, self.term_frequencies, strict=True):
273278
score = 0.0
274279
doc_length = sum(tf.values())
275280

@@ -280,7 +285,7 @@ def search(self, query: str, limit: int = 5) -> List[Dict[str, str]]:
280285
(self.doc_frequencies[token] + 0.5))
281286
term_freq = tf[token]
282287
score += idf * (term_freq * (k1 + 1)) / (
283-
term_freq + k1 * (1 - b + b * (doc_length / self._avg_doc_length()))
288+
term_freq + k1 * (1 - b + b * (doc_length / self._cached_avg_doc_length))
284289
)
285290

286291
if score > 0:
@@ -290,12 +295,6 @@ def search(self, query: str, limit: int = 5) -> List[Dict[str, str]]:
290295
scores.sort(key=lambda x: x[0], reverse=True)
291296
return [item for _, item in scores[:limit]]
292297

293-
def _avg_doc_length(self) -> float:
294-
"""Calculate average document length."""
295-
if not self.term_frequencies:
296-
return 0.0
297-
total_length = sum(sum(tf.values()) for tf in self.term_frequencies)
298-
return total_length / len(self.term_frequencies)
299298

300299
def search_catalog(
301300
deferrable_tools: ToolDefList,
@@ -545,6 +544,13 @@ def resolve_underlying_call(tool_name: str, tool_args: Dict[str, Any]) -> Tuple[
545544
# Not a bridge call, return as-is
546545
return tool_name, tool_args
547546

547+
# Validate tool_args is a dict
548+
if not isinstance(tool_args, dict):
549+
raise TypeError(
550+
f"tool_call expects a dictionary for tool_args, got {type(tool_args).__name__}. "
551+
"Ensure the LLM output is properly formatted."
552+
)
553+
548554
# Extract real tool call from bridge args
549555
real_tool_name = tool_args.get("tool_name", "")
550556
real_args = tool_args.get("tool_args", {})

0 commit comments

Comments
 (0)