Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 6 additions & 22 deletions livekit-agents/livekit/agents/beta/toolsets/tool_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import json
from typing import Any

from pydantic import ValidationError
from typing_extensions import Self

from ...llm.tool_context import (
Expand All @@ -16,7 +15,6 @@
function_tool,
)
from ...llm.utils import function_arguments_to_pydantic_model, prepare_function_arguments
from ...log import logger
from ...types import NOT_GIVEN, NotGivenOr
from ...voice.events import RunContext
from .tool_search import SearchStrategy, ToolSearchToolset
Expand Down Expand Up @@ -129,26 +127,12 @@ async def _handle_call(self, ctx: RunContext[Any], raw_arguments: dict[str, obje
if fnc_tool is None:
raise ToolError(f"unknown tool '{name}', use search_tools to discover available tools")

try:
json_args = json.dumps(parameters) if isinstance(parameters, dict) else str(parameters)
fnc_args, fnc_kwargs = prepare_function_arguments(
fnc=fnc_tool,
json_arguments=json_args,
call_ctx=ctx,
)
except ValidationError as e:
raise ToolError(
f"invalid parameters for tool '{name}': {e.json(include_url=False)}"
) from e
except ToolError:
raise
except Exception as e:
logger.exception(
f"error parsing arguments for tool '{name}'",
extra={"tool": name, "arguments": parameters},
)
raise ToolError(f"error calling '{name}': {e}") from e

json_args = json.dumps(parameters) if isinstance(parameters, dict) else str(parameters)
fnc_args, fnc_kwargs = prepare_function_arguments(
fnc=fnc_tool,
json_arguments=json_args,
call_ctx=ctx,
)
return await fnc_tool(*fnc_args, **fnc_kwargs)


Expand Down
59 changes: 35 additions & 24 deletions livekit-agents/livekit/agents/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,39 @@ def prepare_function_arguments(
json_arguments: str | dict[str, Any],
call_ctx: RunContext[Any] | None = None,
) -> tuple[tuple[Any, ...], dict[str, Any]]: # returns args, kwargs
"""
Create the positional and keyword arguments to call a function tool from
"""Create the positional and keyword arguments to call a function tool from
the raw function output from the LLM.

Argument-validation failures (bad JSON, pydantic ValidationError, missing
required params) are surfaced as :class:`ToolError` so the LLM gets a
concrete error message and can self-correct on its next turn.
"""
try:
return _prepare_function_arguments(
fnc=fnc, json_arguments=json_arguments, call_ctx=call_ctx
)
except ToolError:
raise
except (pydantic.ValidationError, ValueError, TypeError) as e:
logger.error(
f"error parsing arguments for `{fnc.info.name}`",
extra={"function": fnc.info.name, "arguments": json_arguments},
)
raise ToolError(f"Error parsing arguments for `{fnc.info.name}`: {e}") from e
except Exception:
logger.exception(
f"error parsing arguments for `{fnc.info.name}`",
extra={"function": fnc.info.name, "arguments": json_arguments},
)
raise


def _prepare_function_arguments(
*,
fnc: FunctionTool | RawFunctionTool,
json_arguments: str | dict[str, Any],
call_ctx: RunContext[Any] | None,
) -> tuple[tuple[Any, ...], dict[str, Any]]:
signature = inspect.signature(fnc)
type_hints = get_type_hints(fnc, include_extras=True)

Expand Down Expand Up @@ -636,33 +664,16 @@ async def execute_function_call(
json_arguments=tool_call.arguments or "{}",
call_ctx=call_ctx,
)
except (pydantic.ValidationError, ValueError) as e:
# Surface argument validation errors to the LLM so it can self-correct.
# Without this, the LLM only sees "An internal error occurred" and has
# no signal about what was wrong with its arguments.
logger.warning(
f"invalid arguments for AI function `{tool_call.name}`: {e}",
extra={"call_id": tool_call.call_id, "arguments": tool_call.arguments},
)
tool_error = ToolError(f"Error parsing arguments for `{tool_call.name}`: {e}")
return make_function_call_output(fnc_call=fnc_call, output=None, exception=tool_error)
except Exception as e:
logger.exception(
f"exception preparing arguments for AI function `{tool_call.name}`",
extra={"call_id": tool_call.call_id, "arguments": tool_call.arguments},
)
return make_function_call_output(fnc_call=fnc_call, output=None, exception=e)

try:
result = function_tool(*fnc_args, **fnc_kwargs)
if asyncio.iscoroutine(result):
result = await result

return make_function_call_output(fnc_call=fnc_call, output=result, exception=None)

except Exception as e:
logger.exception(
f"exception executing AI function `{tool_call.name}`",
extra={"call_id": tool_call.call_id, "arguments": tool_call.arguments},
)
if not isinstance(e, ToolError):
logger.exception(
f"exception executing AI function `{tool_call.name}`",
extra={"call_id": tool_call.call_id, "arguments": tool_call.arguments},
)
return make_function_call_output(fnc_call=fnc_call, output=None, exception=e)
143 changes: 64 additions & 79 deletions livekit-agents/livekit/agents/voice/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable

from opentelemetry import trace
from pydantic import ValidationError

from livekit import rtc

Expand Down Expand Up @@ -507,6 +506,39 @@ def _tool_completed(out: ToolExecutionOutput) -> None:
tool_execution_completed_cb(out)
tool_output.output.append(out)

async def _run_mock(mock: Callable, *fnc_args: Any, **fnc_kwargs: Any) -> Any:
sig = inspect.signature(mock)

pos_param_names = [
name
for name, param in sig.parameters.items()
if param.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
max_positional = len(pos_param_names)
trimmed_args = fnc_args[:max_positional]
kw_param_names = [
name
for name, param in sig.parameters.items()
if param.kind
in (
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
trimmed_kwargs = {k: v for k, v in fnc_kwargs.items() if k in kw_param_names}

bound = sig.bind_partial(*trimmed_args, **trimmed_kwargs)
bound.apply_defaults()

if inspect.iscoroutinefunction(mock):
return await mock(*bound.args, **bound.kwargs)
else:
return mock(*bound.args, **bound.kwargs)

tasks: list[asyncio.Task[Any]] = []
try:
async for fnc_call in function_stream:
Expand Down Expand Up @@ -556,30 +588,6 @@ def _tool_completed(out: ToolExecutionOutput) -> None:
)
continue

try:
json_args = fnc_call.arguments or "{}"
fnc_args, fnc_kwargs = llm_utils.prepare_function_arguments(
fnc=function_tool,
json_arguments=json_args,
call_ctx=RunContext(
session=session,
speech_handle=speech_handle,
function_call=fnc_call,
),
)

except (ValidationError, ValueError) as e:
logger.exception(
f"tried to call AI function `{fnc_call.name}` with invalid arguments",
extra={
"function": fnc_call.name,
"arguments": fnc_call.arguments,
"speech_id": speech_handle.id,
},
)
_tool_completed(make_tool_output(fnc_call=fnc_call, output=None, exception=e))
continue

if not tool_output.first_tool_started_fut.done():
tool_output.first_tool_started_fut.set_result(None)

Expand All @@ -590,63 +598,40 @@ def _tool_completed(out: ToolExecutionOutput) -> None:
mock_tools: dict[str, Callable] = _MockToolsContextVar.get({}).get(
type(session.current_agent), {}
)
mock = mock_tools.get(fnc_call.name)
mocked = mock is not None

if mock := mock_tools.get(fnc_call.name):
logger.debug(
"executing mock tool",
extra={
"function": fnc_call.name,
"arguments": fnc_call.arguments,
"speech_id": speech_handle.id,
},
run_ctx = RunContext(
session=session, speech_handle=speech_handle, function_call=fnc_call
)
json_args = fnc_call.arguments or "{}"
_bound_tool: llm.FunctionTool | llm.RawFunctionTool = function_tool

# unified body: arg prep + invocation in one closure so a ToolError
# from prepare_function_arguments is routed through the existing
# per-tool handler below
async def _execute(
ctx: RunContext[Any],
fnc: llm.FunctionTool | llm.RawFunctionTool = _bound_tool,
raw_args: str = json_args,
bound_mock: Callable | None = mock,
) -> Any:
fnc_args, fnc_kwargs = llm_utils.prepare_function_arguments(
fnc=fnc, json_arguments=raw_args, call_ctx=ctx
)
if bound_mock is not None:
return await _run_mock(bound_mock, *fnc_args, **fnc_kwargs)
return await fnc(*fnc_args, **fnc_kwargs)

async def _run_mock(mock: Callable, *fnc_args: Any, **fnc_kwargs: Any) -> Any:
sig = inspect.signature(mock)

pos_param_names = [
name
for name, param in sig.parameters.items()
if param.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
max_positional = len(pos_param_names)
trimmed_args = fnc_args[:max_positional]
kw_param_names = [
name
for name, param in sig.parameters.items()
if param.kind
in (
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
trimmed_kwargs = {
k: v for k, v in fnc_kwargs.items() if k in kw_param_names
}

bound = sig.bind_partial(*trimmed_args, **trimmed_kwargs)
bound.apply_defaults()

if inspect.iscoroutinefunction(mock):
return await mock(*bound.args, **bound.kwargs)
else:
return mock(*bound.args, **bound.kwargs)

function_callable = functools.partial(_run_mock, mock, *fnc_args, **fnc_kwargs)
else:
logger.debug(
"executing tool",
extra={
"function": fnc_call.name,
"arguments": fnc_call.arguments,
"speech_id": speech_handle.id,
},
)
function_callable = functools.partial(function_tool, *fnc_args, **fnc_kwargs)
logger.debug(
"executing mock tool" if mocked else "executing tool",
extra={
"function": fnc_call.name,
"arguments": fnc_call.arguments,
"speech_id": speech_handle.id,
},
)
function_callable = functools.partial(_execute, run_ctx)

@tracer.start_as_current_span("function_tool")
async def _traceable_fnc_tool(
Expand Down
16 changes: 6 additions & 10 deletions tests/test_tool_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,15 @@ async def test_call_missing_required_arg_raises_tool_error(self, capsys):
)
await ts.setup()

with pytest.raises(ToolError, match="invalid parameters") as exc_info:
with pytest.raises(ToolError, match="Error parsing arguments") as exc_info:
await ts._handle_call(_mock_ctx(), {"name": "weather_tool", "parameters": {}})

error_msg = exc_info.value.message
print(f"Missing arg error: {error_msg}")

# Error message should contain the missing field name and indicate it's required
error_data = json.loads(error_msg.split(":", 1)[1].strip())
field_names = [err["loc"][0] for err in error_data]
assert "city" in field_names
assert any(err["type"] == "missing" for err in error_data)
assert "city" in error_msg
assert "type=missing" in error_msg

async def test_call_wrong_type_arg_raises_tool_error(self, capsys):
"""Wrong argument type produces a detailed ToolError for the LLM."""
Expand All @@ -235,7 +233,7 @@ async def test_call_wrong_type_arg_raises_tool_error(self, capsys):
)
await ts.setup()

with pytest.raises(ToolError, match="invalid parameters") as exc_info:
with pytest.raises(ToolError, match="Error parsing arguments") as exc_info:
await ts._handle_call(
_mock_ctx(),
{"name": "forecast_tool", "parameters": {"city": "Tokyo", "days": "not_a_number"}},
Expand All @@ -245,10 +243,8 @@ async def test_call_wrong_type_arg_raises_tool_error(self, capsys):
print(f"Wrong type error: {error_msg}")

# Error message should contain the field name and indicate a type parsing issue
error_data = json.loads(error_msg.split(":", 1)[1].strip())
field_names = [err["loc"][0] for err in error_data]
assert "days" in field_names
assert any("int" in err["type"] for err in error_data)
assert "days" in error_msg
assert "int_parsing" in error_msg

async def test_call_tool_propagates_tool_error(self):
"""ToolError raised inside a tool is re-raised as-is."""
Expand Down
10 changes: 5 additions & 5 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel, Field

from livekit.agents import Agent
from livekit.agents.llm import ProviderTool, Tool, ToolContext, Toolset, function_tool
from livekit.agents.llm import ProviderTool, Tool, ToolContext, ToolError, Toolset, function_tool
from livekit.agents.llm._strict import to_strict_json_schema
from livekit.agents.llm.utils import (
build_legacy_openai_schema,
Expand Down Expand Up @@ -390,17 +390,17 @@ async def test_tool_execution_with_default_value(self):
assert output == {"arg1": "test", "opt_arg2": None}

def test_unexpected_arguments(self):
with pytest.raises(ValueError, match="validation error"):
with pytest.raises(ToolError, match="validation error"):
prepare_function_arguments(fnc=mock_tool_1, json_arguments='{"opt_arg2": "test2"}')

with pytest.raises(ValueError, match="Received no value for required parameter"):
with pytest.raises(ToolError, match="Received no value for required parameter"):
prepare_function_arguments(fnc=mock_tool_2, json_arguments='{"arg1": null}')

with pytest.raises(ValueError, match="validation error"):
with pytest.raises(ToolError, match="validation error"):
prepare_function_arguments(fnc=mock_tool_2, json_arguments='{"arg1": "d"}')

agent = DummyAgent()
with pytest.raises(ValueError, match="validation error"):
with pytest.raises(ToolError, match="validation error"):
prepare_function_arguments(
fnc=agent.mock_tool_in_agent, json_arguments='{"opt_arg2": "test2"}'
)
Expand Down
Loading