From 8ab8b52b174de0156af6d7d430cae9528c11afed Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 16 Jun 2026 23:59:50 +0100 Subject: [PATCH] Add message_history to AgentOperator for multi-turn agent sessions AgentOperator and @task.agent ran a fresh single-turn conversation every time. Add an opt-in message_history parameter that seeds the run with prior turns and pushes the post-run transcript to XCom (key 'message_history') so the next run can resume. Default None keeps single-turn behavior unchanged. Storing the transcript under a session key stays the DAG's responsibility. --- providers/common/ai/docs/operators/agent.rst | 48 ++++++ .../common/ai/example_dags/example_agent.py | 56 ++++++- .../providers/common/ai/operators/agent.py | 69 +++++++- .../unit/common/ai/operators/test_agent.py | 149 +++++++++++++++++- 4 files changed, 318 insertions(+), 4 deletions(-) diff --git a/providers/common/ai/docs/operators/agent.rst b/providers/common/ai/docs/operators/agent.rst index a79e5110048b4..b3805aa34b765 100644 --- a/providers/common/ai/docs/operators/agent.rst +++ b/providers/common/ai/docs/operators/agent.rst @@ -156,6 +156,49 @@ tasks can consume it. :end-before: [END howto_agent_chain] +Multi-turn Sessions +------------------- + +By default each agent run is a cold, single-turn conversation. To carry a +conversation across runs -- a chat or iterative agent where "and the third one?" +must resolve against an earlier answer -- pass ``message_history``. + +When ``message_history`` is set, the operator seeds the run with those prior +turns and, after the run, pushes the full updated transcript +(``result.all_messages()``) to XCom under the key ``message_history``. The next +run reads it back to resume the conversation. ``None`` (the default) keeps the +single-turn behavior unchanged. + +The operator does **not** decide *where* a session is stored -- that keying is +deployment-specific. The pattern is three tasks: load the prior transcript for +the session, run the agent, store the updated transcript. The example keys a +JSON file in object storage by ``session_id`` (use ``s3://`` / ``gs://`` in a +deployment); the first run starts from an empty ``"[]"``. + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py + :language: python + :start-after: [START howto_agent_session] + :end-before: [END howto_agent_session] + +``message_history`` accepts a list of pydantic-ai ``ModelMessage`` objects or +their JSON form (``str`` / ``bytes``), so the value emitted to XCom feeds +straight back in on the next run. When pulling it via a template, pass +``default='[]'`` (as above) so the first run -- which has no XCom yet -- starts a +fresh session instead of trying to parse the string ``"None"``. + +The transcript is **cumulative**: each turn appends to it, so it grows for the +life of the session. For long sessions, configure an object-storage XCom backend +or trim older turns before the next run rather than feeding the whole history +back unbounded. + +.. note:: + + ``message_history`` cannot be combined with ``enable_hitl_review`` -- the + operator raises at construction. The post-review (human-approved) transcript + is not recoverable today, so emitting the pre-review transcript would + silently drop the reviewed turns. + + Durable Execution ----------------- @@ -406,6 +449,11 @@ Parameters - ``code_mode``: When ``True``, wraps the agent's tools in a single ``run_code`` tool that the model drives by writing Python, executed in the Monty sandbox. Requires the ``code-mode`` extra. Default ``False``. See :ref:`code-mode`. +- ``message_history``: Prior conversation to seed a multi-turn session, as a list + of pydantic-ai ``ModelMessage`` objects or their JSON form (``str`` / ``bytes``). + When set, the post-run transcript is pushed to XCom under the key + ``message_history`` for the next run to resume. Default ``None`` (single-turn). + See `Multi-turn Sessions`_. Logging diff --git a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py index 6fced224c92f9..787b0d6dce2e1 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py +++ b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py @@ -24,7 +24,7 @@ from airflow.providers.common.ai.operators.agent import AgentOperator from airflow.providers.common.ai.toolsets.hook import HookToolset -from airflow.providers.common.compat.sdk import dag, task +from airflow.providers.common.compat.sdk import ObjectStoragePath, dag, task try: from airflow.providers.common.ai.toolsets.sql import SQLToolset @@ -247,3 +247,57 @@ def example_agent_operator_code_mode(): # [END howto_operator_agent_code_mode] example_agent_operator_code_mode() + + +# --------------------------------------------------------------------------- +# 8. Multi-turn session — resume a conversation across DAG runs +# --------------------------------------------------------------------------- + + +# [START howto_agent_session] +@dag(tags=["example"], params={"session_id": "demo-session"}) +def example_agent_session(): + """Resume a conversation across runs via ``message_history``. + + The agent step seeds itself with the prior transcript and re-emits the + updated transcript to XCom (key ``message_history``). Loading and storing + that transcript under a session key is the DAG's job -- here, a JSON file in + object storage keyed by ``session_id``. Swap the path for ``s3://`` / + ``gs://`` in a deployment. + """ + sessions_root = ObjectStoragePath("file:///tmp/airflow_agent_sessions") + + @task + def load_history(session_id: str) -> str: + path = sessions_root / f"{session_id}.json" + # First turn: no file yet -> start a fresh session (empty transcript). + return path.read_text() if path.exists() else "[]" + + @task.agent( + llm_conn_id="pydanticai_default", + system_prompt="You are a helpful assistant. Use the earlier turns for context.", + # The XComArg both wires the dependency and resolves to the JSON transcript. + message_history=load_history("{{ params.session_id }}"), + ) + def ask(question: str) -> str: + return question + + @task + def save_history(session_id: str, transcript: str) -> None: + # Local/fsspec object storage does not auto-create parent dirs on write. + sessions_root.mkdir(parents=True, exist_ok=True) + (sessions_root / f"{session_id}.json").write_text(transcript) + + answer = ask("And what did I ask you a moment ago?") + saved = save_history( + "{{ params.session_id }}", + # The agent step pushes the post-run transcript under this XCom key. + "{{ ti.xcom_pull(task_ids='ask', key='message_history') }}", + ) + # save runs after the agent so the pulled transcript is the fresh one. + answer >> saved + + +# [END howto_agent_session] + +example_agent_session() diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py index bda06ea1f5648..56c9ec5bbb65a 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py @@ -48,6 +48,7 @@ if TYPE_CHECKING: from pydantic_ai import Agent + from pydantic_ai.messages import ModelMessage from pydantic_ai.toolsets.abstract import AbstractToolset from pydantic_ai.usage import UsageLimits @@ -166,6 +167,22 @@ class AgentOperator(BaseOperator, HITLReviewMixin): Cannot be combined with ``durable=True`` (durable replay assumes a stable per-step call order that code mode does not guarantee). Default ``False``. + :param message_history: Prior conversation to seed the run with, for + multi-turn sessions that span task runs. Accepts a ``list`` of + pydantic-ai ``ModelMessage`` objects, or their JSON form as ``str`` / + ``bytes`` -- e.g. + ``"{{ ti.xcom_pull(task_ids='ask', key='message_history', default='[]') }}"`` + (pass ``default='[]'`` so the first run, with no XCom yet, starts a fresh + session instead of failing to parse the string ``"None"``). ``None`` + (default) is a single-turn run -- no behavior change. When set (an empty + ``[]`` / ``""`` starts a fresh session), the full transcript after the run + -- ``result.all_messages()`` -- is pushed to XCom under the key + ``message_history`` so the next run can resume. Persisting that transcript + under a session key (e.g. in object storage) is the DAG's responsibility. + The transcript is cumulative and grows each turn; for long sessions use an + object-storage XCom backend or trim old turns. Not supported together with + ``enable_hitl_review`` (raises) -- the post-review transcript is not yet + recoverable. **HITL Review parameters** (requires the ``hitl_review`` plugin): @@ -199,6 +216,7 @@ class AgentOperator(BaseOperator, HITLReviewMixin): "model_id", "system_prompt", "agent_params", + "message_history", ) operator_extra_links = (HITLReviewLink(),) @@ -217,6 +235,7 @@ def __init__( usage_limits: UsageLimits | None = None, durable: bool = False, code_mode: bool = False, + message_history: list[ModelMessage] | str | bytes | None = None, # Agent feedback parameters enable_hitl_review: bool = False, max_hitl_iterations: int = 5, @@ -240,6 +259,7 @@ def __init__( self.enable_tool_logging = enable_tool_logging self.agent_params = agent_params or {} self.usage_limits = usage_limits + self.message_history = message_history self.durable = durable self.code_mode = code_mode @@ -256,6 +276,13 @@ def __init__( # replay. Reject the combination rather than silently mis-replaying. raise ValueError("durable=True and code_mode=True cannot be used together.") + if message_history is not None and enable_hitl_review: + # The post-review transcript is not recoverable today (run_hitl_review + # returns only the final string), so emitting the pre-review transcript + # would silently drop the human-approved turns. Block until HITL can + # surface the final message history. + raise ValueError("message_history and enable_hitl_review=True cannot be used together.") + self.enable_hitl_review = enable_hitl_review self.max_hitl_iterations = max_hitl_iterations self.hitl_timeout = hitl_timeout @@ -331,6 +358,11 @@ def execute(self, context: Context) -> Any: agent = self._build_agent() + run_kwargs: dict[str, Any] = {"usage_limits": self.usage_limits} + history = self._resolve_message_history() + if history is not None: + run_kwargs["message_history"] = history + storage = self._durable_storage counter = self._durable_counter if self.durable and storage is not None and counter is not None: @@ -343,9 +375,9 @@ def execute(self, context: Context) -> Any: resolved_model = infer_model(agent.model) caching_model = CachingModel(resolved_model, storage=storage, counter=counter) with agent.override(model=caching_model): - result = agent.run_sync(self.prompt, usage_limits=self.usage_limits) + result = agent.run_sync(self.prompt, **run_kwargs) else: - result = agent.run_sync(self.prompt, usage_limits=self.usage_limits) + result = agent.run_sync(self.prompt, **run_kwargs) log_run_summary(self.log, result) @@ -368,6 +400,9 @@ def execute(self, context: Context) -> Any: if self._durable_storage is not None: self._durable_storage.cleanup() + if self.message_history is not None: + self._emit_message_history(context, result) + output = result.output if self.enable_hitl_review: @@ -391,6 +426,36 @@ def execute(self, context: Context) -> Any: output = output.model_dump() return output + def _resolve_message_history(self) -> list[ModelMessage] | None: + """ + Deserialize :attr:`message_history` into a list of pydantic-ai messages. + + ``None`` means single-turn (no history passed to the run). A ``str`` / + ``bytes`` value is parsed as the JSON the operator emits to XCom; a list + (of ``ModelMessage`` objects or their dict form) is validated as-is. + """ + raw = self.message_history + if raw is None: + return None + if isinstance(raw, (str, bytes)) and not raw.strip(): + # A template that renders to empty (no prior XCom) starts a fresh session. + return [] + # pydantic-ai is imported lazily here to match this module's pattern of + # keeping pydantic-ai out of DAG-parse-time imports. + from pydantic_ai.messages import ModelMessagesTypeAdapter + + if isinstance(raw, (str, bytes)): + return ModelMessagesTypeAdapter.validate_json(raw) + return ModelMessagesTypeAdapter.validate_python(raw) + + def _emit_message_history(self, context: Context, result: Any) -> None: + """Push the full post-run transcript to XCom for the next turn to resume.""" + # Lazy import: see _resolve_message_history. + from pydantic_ai.messages import ModelMessagesTypeAdapter + + transcript = ModelMessagesTypeAdapter.dump_json(result.all_messages()).decode() + context["task_instance"].xcom_push(key="message_history", value=transcript) + def regenerate_with_feedback(self, *, feedback: str, message_history: Any) -> tuple[str, Any]: """Re-run the agent with *feedback* appended to the conversation history.""" agent = self._build_agent() diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py index a9f017b94ee4c..1288dbbe6525b 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py @@ -22,6 +22,13 @@ import pytest from pydantic import BaseModel +from pydantic_ai.messages import ( + ModelMessagesTypeAdapter, + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, +) from pydantic_ai.usage import UsageLimits from airflow.providers.common.ai.operators.agent import AgentOperator, HITLReviewLink, _build_code_mode @@ -90,7 +97,14 @@ def test_hitl_params_stored(self): class TestAgentOperatorTemplateFields: def test_template_fields(self): - expected = {"prompt", "llm_conn_id", "model_id", "system_prompt", "agent_params"} + expected = { + "prompt", + "llm_conn_id", + "model_id", + "system_prompt", + "agent_params", + "message_history", + } assert set(AgentOperator.template_fields) == expected @@ -617,3 +631,136 @@ def test_execute_rejects_sequence_prompt_with_hitl_review(self, mock_hook_cls): op.execute(context=MagicMock()) mock_agent.run_sync.assert_not_called() + + +def _sample_history(): + """A minimal two-message pydantic-ai conversation for round-trip tests.""" + return [ + ModelRequest(parts=[UserPromptPart(content="first question")]), + ModelResponse(parts=[TextPart(content="first answer")]), + ] + + +# The accepted input forms for ``message_history``, computed once at collection time. +_SAMPLE_HISTORY_JSON = ModelMessagesTypeAdapter.dump_json(_sample_history()).decode() +_SAMPLE_HISTORY_DICTS = ModelMessagesTypeAdapter.dump_python(_sample_history(), mode="json") + + +class TestAgentOperatorMessageHistory: + """Multi-turn session support: seed run_sync with prior history, emit the transcript.""" + + @pytest.mark.parametrize( + ("raw", "expected_len"), + [ + pytest.param([], 0, id="empty-list"), + pytest.param("", 0, id="empty-str"), + pytest.param(" ", 0, id="blank-str"), + pytest.param(_SAMPLE_HISTORY_JSON, 2, id="json-str"), + pytest.param(_SAMPLE_HISTORY_DICTS, 2, id="list-of-dicts"), + pytest.param(_sample_history(), 2, id="list-of-objects"), + ], + ) + @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", autospec=True) + def test_message_history_seeds_run_sync(self, mock_hook_cls, raw, expected_len): + """Every accepted input form is deserialized and passed to run_sync; blank/empty start fresh.""" + mock_agent = _make_mock_agent("ok") + mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent + + op = AgentOperator(task_id="t", prompt="run", llm_conn_id="c", message_history=raw) + op.execute(context=MagicMock()) + + passed = mock_agent.run_sync.call_args.kwargs["message_history"] + assert len(passed) == expected_len + assert all(isinstance(m, (ModelRequest, ModelResponse)) for m in passed) + + @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", autospec=True) + def test_none_is_single_turn_no_history_no_emit(self, mock_hook_cls): + """Default message_history=None passes no history and pushes no transcript XCom.""" + mock_agent = _make_mock_agent("ok") + mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent + + op = AgentOperator(task_id="t", prompt="run", llm_conn_id="c") + context = MagicMock() + op.execute(context=context) + + assert "message_history" not in mock_agent.run_sync.call_args.kwargs + context["task_instance"].xcom_push.assert_not_called() + + @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", autospec=True) + def test_transcript_emitted_to_xcom_when_history_set(self, mock_hook_cls): + """When message_history is set, the post-run transcript is pushed to XCom and round-trips.""" + mock_agent = _make_mock_agent("ok") + mock_agent.run_sync.return_value.all_messages.return_value = _sample_history() + mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent + + op = AgentOperator(task_id="t", prompt="run", llm_conn_id="c", message_history=[]) + context = MagicMock() + op.execute(context=context) + + ti = context["task_instance"] + ti.xcom_push.assert_called_once() + push_kwargs = ti.xcom_push.call_args.kwargs + assert push_kwargs["key"] == "message_history" + restored = ModelMessagesTypeAdapter.validate_json(push_kwargs["value"]) + assert len(restored) == 2 + + @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", autospec=True) + def test_usage_limits_still_forwarded_with_history(self, mock_hook_cls): + """Adding message_history does not drop usage_limits from the run_sync call.""" + mock_agent = _make_mock_agent("ok") + mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent + + limits = UsageLimits(request_limit=2) + op = AgentOperator( + task_id="t", prompt="run", llm_conn_id="c", usage_limits=limits, message_history=[] + ) + op.execute(context=MagicMock()) + + kwargs = mock_agent.run_sync.call_args.kwargs + assert kwargs["usage_limits"] is limits + assert kwargs["message_history"] == [] + + def test_message_history_with_hitl_review_raises(self): + """message_history cannot be combined with HITL review (post-review transcript is lost).""" + with pytest.raises(ValueError, match="message_history and enable_hitl_review"): + AgentOperator( + task_id="t", + prompt="run", + llm_conn_id="c", + message_history=[], + enable_hitl_review=True, + ) + + @patch("pydantic_ai.models.wrapper.infer_model", side_effect=lambda m: m) + @patch("pydantic_ai.models.infer_model", autospec=True) + @patch("airflow.providers.common.ai.durable.storage._get_base_path") + @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", autospec=True) + def test_durable_path_also_seeds_message_history( + self, mock_hook_cls, mock_base_path, mock_infer, _, tmp_path + ): + """The durable branch forwards message_history into the cached run too.""" + from airflow.sdk import ObjectStoragePath + + mock_base_path.return_value = ObjectStoragePath(f"file://{tmp_path.as_posix()}") + + mock_agent = MagicMock(spec=["run_sync", "model", "override"]) + mock_agent.run_sync.return_value = _make_mock_run_result("ok") + mock_agent.model = "test-model" + mock_agent.override.return_value.__enter__ = MagicMock(return_value=None) + mock_agent.override.return_value.__exit__ = MagicMock(return_value=False) + mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent + mock_infer.return_value = MagicMock() + + context = MagicMock() + context.__getitem__ = MagicMock( + return_value=MagicMock(dag_id="d", task_id="t", run_id="r", map_index=-1) + ) + + history_json = ModelMessagesTypeAdapter.dump_json(_sample_history()).decode() + op = AgentOperator( + task_id="test", prompt="test", llm_conn_id="my_llm", durable=True, message_history=history_json + ) + op.execute(context=context) + + passed = mock_agent.run_sync.call_args.kwargs["message_history"] + assert len(passed) == 2