Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions providers/common/ai/docs/operators/agent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,49 @@ tasks can consume it.
:end-before: [END howto_agent_chain]


Multi-turn Sessions
-------------------

By default each agent run is a cold, single-turn conversation. To carry a
conversation across runs -- a chat or iterative agent where "and the third one?"
must resolve against an earlier answer -- pass ``message_history``.

When ``message_history`` is set, the operator seeds the run with those prior
turns and, after the run, pushes the full updated transcript
(``result.all_messages()``) to XCom under the key ``message_history``. The next
run reads it back to resume the conversation. ``None`` (the default) keeps the
single-turn behavior unchanged.

The operator does **not** decide *where* a session is stored -- that keying is
deployment-specific. The pattern is three tasks: load the prior transcript for
the session, run the agent, store the updated transcript. The example keys a
JSON file in object storage by ``session_id`` (use ``s3://`` / ``gs://`` in a
deployment); the first run starts from an empty ``"[]"``.

.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
:language: python
:start-after: [START howto_agent_session]
:end-before: [END howto_agent_session]

``message_history`` accepts a list of pydantic-ai ``ModelMessage`` objects or
their JSON form (``str`` / ``bytes``), so the value emitted to XCom feeds
straight back in on the next run. When pulling it via a template, pass
``default='[]'`` (as above) so the first run -- which has no XCom yet -- starts a
fresh session instead of trying to parse the string ``"None"``.

The transcript is **cumulative**: each turn appends to it, so it grows for the
life of the session. For long sessions, configure an object-storage XCom backend
or trim older turns before the next run rather than feeding the whole history
back unbounded.

.. note::

``message_history`` cannot be combined with ``enable_hitl_review`` -- the
operator raises at construction. The post-review (human-approved) transcript
is not recoverable today, so emitting the pre-review transcript would
silently drop the reviewed turns.


Durable Execution
-----------------

Expand Down Expand Up @@ -406,6 +449,11 @@ Parameters
- ``code_mode``: When ``True``, wraps the agent's tools in a single ``run_code``
tool that the model drives by writing Python, executed in the Monty sandbox.
Requires the ``code-mode`` extra. Default ``False``. See :ref:`code-mode`.
- ``message_history``: Prior conversation to seed a multi-turn session, as a list
of pydantic-ai ``ModelMessage`` objects or their JSON form (``str`` / ``bytes``).
When set, the post-run transcript is pushed to XCom under the key
``message_history`` for the next run to resume. Default ``None`` (single-turn).
See `Multi-turn Sessions`_.


Logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from airflow.providers.common.ai.operators.agent import AgentOperator
from airflow.providers.common.ai.toolsets.hook import HookToolset
from airflow.providers.common.compat.sdk import dag, task
from airflow.providers.common.compat.sdk import ObjectStoragePath, dag, task

try:
from airflow.providers.common.ai.toolsets.sql import SQLToolset
Expand Down Expand Up @@ -247,3 +247,57 @@ def example_agent_operator_code_mode():
# [END howto_operator_agent_code_mode]

example_agent_operator_code_mode()


# ---------------------------------------------------------------------------
# 8. Multi-turn session — resume a conversation across DAG runs
# ---------------------------------------------------------------------------


# [START howto_agent_session]
@dag(tags=["example"], params={"session_id": "demo-session"})
def example_agent_session():
"""Resume a conversation across runs via ``message_history``.

The agent step seeds itself with the prior transcript and re-emits the
updated transcript to XCom (key ``message_history``). Loading and storing
that transcript under a session key is the DAG's job -- here, a JSON file in
object storage keyed by ``session_id``. Swap the path for ``s3://`` /
``gs://`` in a deployment.
"""
sessions_root = ObjectStoragePath("file:///tmp/airflow_agent_sessions")

@task
def load_history(session_id: str) -> str:
path = sessions_root / f"{session_id}.json"
# First turn: no file yet -> start a fresh session (empty transcript).
return path.read_text() if path.exists() else "[]"

@task.agent(
llm_conn_id="pydanticai_default",
system_prompt="You are a helpful assistant. Use the earlier turns for context.",
# The XComArg both wires the dependency and resolves to the JSON transcript.
message_history=load_history("{{ params.session_id }}"),
)
def ask(question: str) -> str:
return question

@task
def save_history(session_id: str, transcript: str) -> None:
# Local/fsspec object storage does not auto-create parent dirs on write.
sessions_root.mkdir(parents=True, exist_ok=True)
(sessions_root / f"{session_id}.json").write_text(transcript)

answer = ask("And what did I ask you a moment ago?")
saved = save_history(
"{{ params.session_id }}",
# The agent step pushes the post-run transcript under this XCom key.
"{{ ti.xcom_pull(task_ids='ask', key='message_history') }}",
)
# save runs after the agent so the pulled transcript is the fresh one.
answer >> saved


# [END howto_agent_session]

example_agent_session()
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

if TYPE_CHECKING:
from pydantic_ai import Agent
from pydantic_ai.messages import ModelMessage
from pydantic_ai.toolsets.abstract import AbstractToolset
from pydantic_ai.usage import UsageLimits

Expand Down Expand Up @@ -166,6 +167,22 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
Cannot be combined with ``durable=True`` (durable replay assumes a
stable per-step call order that code mode does not guarantee).
Default ``False``.
:param message_history: Prior conversation to seed the run with, for
multi-turn sessions that span task runs. Accepts a ``list`` of
pydantic-ai ``ModelMessage`` objects, or their JSON form as ``str`` /
``bytes`` -- e.g.
``"{{ ti.xcom_pull(task_ids='ask', key='message_history', default='[]') }}"``
(pass ``default='[]'`` so the first run, with no XCom yet, starts a fresh
session instead of failing to parse the string ``"None"``). ``None``
(default) is a single-turn run -- no behavior change. When set (an empty
``[]`` / ``""`` starts a fresh session), the full transcript after the run
-- ``result.all_messages()`` -- is pushed to XCom under the key
``message_history`` so the next run can resume. Persisting that transcript
under a session key (e.g. in object storage) is the DAG's responsibility.
The transcript is cumulative and grows each turn; for long sessions use an
object-storage XCom backend or trim old turns. Not supported together with
``enable_hitl_review`` (raises) -- the post-review transcript is not yet
recoverable.

**HITL Review parameters** (requires the ``hitl_review`` plugin):

Expand Down Expand Up @@ -199,6 +216,7 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
"model_id",
"system_prompt",
"agent_params",
"message_history",
)

operator_extra_links = (HITLReviewLink(),)
Expand All @@ -217,6 +235,7 @@ def __init__(
usage_limits: UsageLimits | None = None,
durable: bool = False,
code_mode: bool = False,
message_history: list[ModelMessage] | str | bytes | None = None,
# Agent feedback parameters
enable_hitl_review: bool = False,
max_hitl_iterations: int = 5,
Expand All @@ -240,6 +259,7 @@ def __init__(
self.enable_tool_logging = enable_tool_logging
self.agent_params = agent_params or {}
self.usage_limits = usage_limits
self.message_history = message_history

self.durable = durable
self.code_mode = code_mode
Expand All @@ -256,6 +276,13 @@ def __init__(
# replay. Reject the combination rather than silently mis-replaying.
raise ValueError("durable=True and code_mode=True cannot be used together.")

if message_history is not None and enable_hitl_review:
# The post-review transcript is not recoverable today (run_hitl_review
# returns only the final string), so emitting the pre-review transcript
# would silently drop the human-approved turns. Block until HITL can
# surface the final message history.
raise ValueError("message_history and enable_hitl_review=True cannot be used together.")

self.enable_hitl_review = enable_hitl_review
self.max_hitl_iterations = max_hitl_iterations
self.hitl_timeout = hitl_timeout
Expand Down Expand Up @@ -331,6 +358,11 @@ def execute(self, context: Context) -> Any:

agent = self._build_agent()

run_kwargs: dict[str, Any] = {"usage_limits": self.usage_limits}
history = self._resolve_message_history()
if history is not None:
run_kwargs["message_history"] = history

storage = self._durable_storage
counter = self._durable_counter
if self.durable and storage is not None and counter is not None:
Expand All @@ -343,9 +375,9 @@ def execute(self, context: Context) -> Any:
resolved_model = infer_model(agent.model)
caching_model = CachingModel(resolved_model, storage=storage, counter=counter)
with agent.override(model=caching_model):
result = agent.run_sync(self.prompt, usage_limits=self.usage_limits)
result = agent.run_sync(self.prompt, **run_kwargs)
else:
result = agent.run_sync(self.prompt, usage_limits=self.usage_limits)
result = agent.run_sync(self.prompt, **run_kwargs)

log_run_summary(self.log, result)

Expand All @@ -368,6 +400,9 @@ def execute(self, context: Context) -> Any:
if self._durable_storage is not None:
self._durable_storage.cleanup()

if self.message_history is not None:
self._emit_message_history(context, result)

output = result.output

if self.enable_hitl_review:
Expand All @@ -391,6 +426,36 @@ def execute(self, context: Context) -> Any:
output = output.model_dump()
return output

def _resolve_message_history(self) -> list[ModelMessage] | None:
"""
Deserialize :attr:`message_history` into a list of pydantic-ai messages.

``None`` means single-turn (no history passed to the run). A ``str`` /
``bytes`` value is parsed as the JSON the operator emits to XCom; a list
(of ``ModelMessage`` objects or their dict form) is validated as-is.
"""
raw = self.message_history
if raw is None:
return None
if isinstance(raw, (str, bytes)) and not raw.strip():
# A template that renders to empty (no prior XCom) starts a fresh session.
return []
# pydantic-ai is imported lazily here to match this module's pattern of
# keeping pydantic-ai out of DAG-parse-time imports.
from pydantic_ai.messages import ModelMessagesTypeAdapter

if isinstance(raw, (str, bytes)):
return ModelMessagesTypeAdapter.validate_json(raw)
return ModelMessagesTypeAdapter.validate_python(raw)

def _emit_message_history(self, context: Context, result: Any) -> None:
"""Push the full post-run transcript to XCom for the next turn to resume."""
# Lazy import: see _resolve_message_history.
from pydantic_ai.messages import ModelMessagesTypeAdapter

transcript = ModelMessagesTypeAdapter.dump_json(result.all_messages()).decode()
context["task_instance"].xcom_push(key="message_history", value=transcript)

def regenerate_with_feedback(self, *, feedback: str, message_history: Any) -> tuple[str, Any]:
"""Re-run the agent with *feedback* appended to the conversation history."""
agent = self._build_agent()
Expand Down
Loading
Loading