diff --git a/dreadnode/agent/agent.py b/dreadnode/agent/agent.py index a1a7dd1d..5ba40642 100644 --- a/dreadnode/agent/agent.py +++ b/dreadnode/agent/agent.py @@ -5,7 +5,8 @@ import rigging as rg from pydantic import ConfigDict, Field, PrivateAttr, SkipValidation, field_validator -from rigging.message import inject_system_content # can't access via rg +from rigging.message import inject_system_content +from ulid import ULID # can't access via rg from dreadnode.agent.error import MaxStepsError from dreadnode.agent.events import ( @@ -275,6 +276,7 @@ async def _stream( # noqa: PLR0912, PLR0915 ) -> t.AsyncGenerator[AgentEvent, None]: events: list[AgentEvent] = [] stop_conditions = self.stop_conditions + session_id = ULID() # Event dispatcher @@ -368,6 +370,7 @@ async def _dispatch(event: AgentEvent) -> t.AsyncIterator[AgentEvent]: "unknown", ) reacted_event = Reacted( + session_id=session_id, agent=self, thread=thread, messages=messages, @@ -395,6 +398,7 @@ async def _process_tool_call( ) -> t.AsyncGenerator[AgentEvent, None]: async for event in _dispatch( ToolStart( + session_id=session_id, agent=self, thread=thread, messages=messages, @@ -416,6 +420,7 @@ async def _process_tool_call( except Exception as e: async for event in _dispatch( AgentError( + session_id=session_id, agent=self, thread=thread, messages=messages, @@ -432,6 +437,7 @@ async def _process_tool_call( async for event in _dispatch( ToolEnd( + session_id=session_id, agent=self, thread=thread, messages=messages, @@ -447,6 +453,7 @@ async def _process_tool_call( async for event in _dispatch( AgentStart( + session_id=session_id, agent=self, thread=thread, messages=messages, @@ -464,6 +471,7 @@ async def _process_tool_call( try: async for event in _dispatch( StepStart( + session_id=session_id, agent=self, thread=thread, messages=messages, @@ -479,6 +487,7 @@ async def _process_tool_call( if step_chat.failed and step_chat.error: async for event in _dispatch( AgentError( + session_id=session_id, agent=self, thread=thread, messages=messages, @@ -493,6 +502,7 @@ async def _process_tool_call( async for event in _dispatch( GenerationEnd( + session_id=session_id, agent=self, thread=thread, messages=messages, @@ -516,6 +526,7 @@ async def _process_tool_call( async for event in _dispatch( AgentStalled( + session_id=session_id, agent=self, thread=thread, messages=messages, @@ -579,6 +590,7 @@ async def _process_tool_call( thread.events.extend(events) yield AgentEnd( + session_id=session_id, agent=self, thread=thread, messages=messages, diff --git a/dreadnode/agent/events.py b/dreadnode/agent/events.py index 9e1a85d8..d7e4a301 100644 --- a/dreadnode/agent/events.py +++ b/dreadnode/agent/events.py @@ -9,6 +9,7 @@ from rich.rule import Rule from rich.table import Table from rich.text import Text +from ulid import ULID from dreadnode.agent.format import format_message from dreadnode.agent.reactions import ( @@ -39,6 +40,8 @@ class AgentEvent: ) """The timestamp of when the event occurred (UTC).""" + session_id: ULID = field(repr=False) + """The unique identifier for the agent run session.""" agent: "Agent" = field(repr=False) """The agent associated with this event.""" thread: "Thread" = field(repr=False) diff --git a/dreadnode/agent/hooks/__init__.py b/dreadnode/agent/hooks/__init__.py index 0a582e2b..65fcbb28 100644 --- a/dreadnode/agent/hooks/__init__.py +++ b/dreadnode/agent/hooks/__init__.py @@ -1,3 +1,4 @@ +from dreadnode.agent.hooks.backoff import backoff_on_error, backoff_on_ratelimit from dreadnode.agent.hooks.base import ( Hook, retry_with_feedback, @@ -6,6 +7,8 @@ __all__ = [ "Hook", + "backoff_on_error", + "backoff_on_ratelimit", "retry_with_feedback", "summarize_when_long", ] diff --git a/dreadnode/agent/hooks/backoff.py b/dreadnode/agent/hooks/backoff.py new file mode 100644 index 00000000..cd3880f9 --- /dev/null +++ b/dreadnode/agent/hooks/backoff.py @@ -0,0 +1,134 @@ +import asyncio +import random +import time +import typing as t +from dataclasses import dataclass + +from loguru import logger + +from dreadnode.agent.events import AgentError, AgentEvent, StepStart +from dreadnode.agent.reactions import Reaction, Retry + +if t.TYPE_CHECKING: + from ulid import ULID + + from dreadnode.agent.hooks.base import Hook + + +@dataclass +class BackoffState: + tries: int = 0 + start_time: float | None = None + last_step_seen: int = -1 + + def reset(self, step: int = -1) -> None: + self.tries = 0 + self.start_time = None + self.last_step_seen = step + + +def backoff_on_error( + exception_types: type[Exception] | t.Iterable[type[Exception]], + *, + max_tries: int = 8, + max_time: float = 300.0, + base_factor: float = 1.0, + jitter: bool = True, +) -> "Hook": + """ + Creates a hook that retries with exponential backoff when specific errors occur. + + It listens for `AgentError` events and, if the error matches, waits for an + exponentially increasing duration before issuing a `Retry` reaction. + + Args: + exception_types: An exception type or iterable of types to catch. + max_tries: The maximum number of retries before giving up. + max_time: The maximum total time in seconds to wait before giving up. + base_factor: The base duration (in seconds) for the backoff calculation. + jitter: If True, adds a random jitter to the wait time to prevent synchronized retries. + + Returns: + An agent hook that implements the backoff logic. + """ + exceptions = ( + tuple(exception_types) if isinstance(exception_types, t.Iterable) else (exception_types,) + ) + + session_states: dict[ULID, BackoffState] = {} + + async def backoff_hook(event: "AgentEvent") -> "Reaction | None": + state = session_states.setdefault(event.session_id, BackoffState()) + + if isinstance(event, StepStart): + if event.step > state.last_step_seen: + state.reset(event.step) + return None + + if not isinstance(event, AgentError) or not isinstance(event.error, exceptions): + return None + + if state.start_time is None: + state.start_time = time.monotonic() + + if state.tries >= max_tries: + logger.warning( + f"Backoff aborted for session {event.session_id}: maximum tries ({max_tries}) exceeded." + ) + return None + + if (time.monotonic() - state.start_time) >= max_time: + logger.warning( + f"Backoff aborted for session {event.session_id}: maximum time ({max_time:.2f}s) exceeded." + ) + return None + + state.tries += 1 + + seconds = base_factor * (2 ** (state.tries - 1)) + if jitter: + seconds += random.uniform(0, base_factor) # noqa: S311 # nosec + + logger.warning( + f"Backing off for {seconds:.2f}s (try {state.tries}/{max_tries}) on session {event.session_id} due to error: {event.error}" + ) + + await asyncio.sleep(seconds) + return Retry() + + return backoff_hook + + +def backoff_on_ratelimit( + *, + max_tries: int = 8, + max_time: float = 300.0, + base_factor: float = 1.0, + jitter: bool = True, +) -> "Hook": + """ + A convenient default backoff hook for common, ephemeral LLM errors. + + This hook retries on `litellm.exceptions.RateLimitError` and `litellm.exceptions.APIError` + with an exponential backoff strategy for up to 5 minutes. + + See `backoff_on_error` for more details. + + Args: + max_tries: The maximum number of retries before giving up. + max_time: The maximum total time in seconds to wait before giving up. + base_factor: The base duration (in seconds) for the backoff calculation. + jitter: If True, adds a random jitter to the wait time to prevent synchronized retries. + + Returns: + An agent hook that implements the backoff logic. + """ + import litellm.exceptions + + return backoff_on_error( + (litellm.exceptions.RateLimitError, litellm.exceptions.APIError), + max_time=max_time, + max_tries=max_tries, + base_factor=base_factor, + jitter=jitter, + ) diff --git a/dreadnode/agent/reactions.py b/dreadnode/agent/reactions.py index c47fc45a..02378d38 100644 --- a/dreadnode/agent/reactions.py +++ b/dreadnode/agent/reactions.py @@ -19,7 +19,7 @@ class Continue(Reaction): @dataclass class Retry(Reaction): - messages: list[rg.Message] | None = Field(None, repr=False) + messages: list[rg.Message] | None = Field(default=None, repr=False) @dataclass