Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/common/ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ requires-python = ">=3.10"
# After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
dependencies = [
"apache-airflow>=3.0.0",
"apache-airflow-providers-common-compat>=1.14.1",
"apache-airflow-providers-common-compat>=1.14.1", # use next version
"apache-airflow-providers-standard>=1.12.1",
"pydantic-ai-slim>=1.99.0",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@

from pydantic import BaseModel

from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_3_PLUS

if AIRFLOW_V_3_3_PLUS:
# On Airflow 3.3+ the review parks the task in the first-class AWAITING_INPUT state instead
# of deferring to a trigger. On older cores this name is absent and defer() is used.
from airflow.sdk.exceptions import TaskAwaitingInput

log = logging.getLogger(__name__)

if TYPE_CHECKING:
Expand All @@ -45,7 +52,8 @@ class LLMApprovalMixin:

When ``require_approval=True`` on the operator, the generated output is
presented to a human reviewer via the Airflow Human-in-the-Loop (HITL)
interface. The task defers until the reviewer approves or rejects.
interface. The task waits (``awaiting_input`` on Airflow 3.3+, deferred on
older versions) until the reviewer approves or rejects.

If ``allow_modifications=True``, the reviewer can also edit the output
before approving. The (possibly modified) output is then returned as the
Expand All @@ -71,7 +79,11 @@ def defer_for_approval(
body: str | None = None,
) -> None:
"""
Write HITL detail, then defer to HITLTrigger for human review.
Write HITL detail, then pause the task for human review.

On Airflow 3.3+ the task parks in the ``awaiting_input`` state (no trigger or triggerer
involved); on older versions it defers to :class:`HITLTrigger`. Either way it resumes in
``execute_complete`` once a response (or timeout default) arrives.

:param context: Airflow task context.
:param output: The generated output to present for review.
Expand Down Expand Up @@ -100,7 +112,6 @@ def defer_for_approval(
output = str(output)

ti_id = context["task_instance"].id
timeout_datetime = utcnow() + self.approval_timeout if self.approval_timeout else None

if subject is None:
subject = f"Review output for task `{self.task_id}`"
Expand Down Expand Up @@ -128,14 +139,24 @@ def defer_for_approval(
params=hitl_params,
)

if AIRFLOW_V_3_3_PLUS:
# New core (3.3+): park the task in AWAITING_INPUT -- no trigger, no triggerer. The
# task is resumed by the Core API response handler or the scheduler timeout sweep.
raise TaskAwaitingInput(
method_name="execute_complete",
kwargs={"generated_output": output},
timeout=self.approval_timeout,
)

# Fallback for cores < 3.3: defer the response check to HITLTrigger on the triggerer.
self.defer(
trigger=HITLTrigger(
ti_id=ti_id,
options=[LLMApprovalMixin.APPROVE, LLMApprovalMixin.REJECT],
defaults=None,
params=hitl_params,
multiple=False,
timeout_datetime=timeout_datetime,
timeout_datetime=utcnow() + self.approval_timeout if self.approval_timeout else None,
),
method_name="execute_complete",
kwargs={"generated_output": output},
Expand Down Expand Up @@ -182,6 +203,16 @@ def execute_complete(self, context: Context, generated_output: str, event: dict[
# when allow_modifications=False, bypassing the read-only approval flow.
if getattr(self, "allow_modifications", False) and params_input:
modified = params_input.get("output")
if modified is not None and not isinstance(modified, str):
# On the awaiting_input path nothing upstream schema-validates params_input
# (HITLTrigger did on the legacy path), so enforce the string contract here
# rather than returning a non-string as the task's output.
raise HITLTriggerEventError(
{
"error": f"Modified output must be a string, got {type(modified).__name__}.",
"error_type": "validation",
}
)
if modified is not None and modified != generated_output:
log.info("output=%s modified by the reviewer=%s ", modified, responded_by_user)
return modified
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pytest

from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS

if not AIRFLOW_V_3_1_PLUS:
pytest.skip("Human in the loop is only compatible with Airflow >= 3.1.0", allow_module_level=True)
Expand All @@ -34,9 +34,13 @@
)
from airflow.providers.standard.exceptions import HITLRejectException, HITLTriggerEventError

if AIRFLOW_V_3_3_PLUS:
from airflow.sdk.exceptions import TaskAwaitingInput

HITL_TRIGGER_PATH = "airflow.providers.standard.triggers.hitl.HITLTrigger"
UPSERT_HITL_PATH = "airflow.sdk.execution_time.hitl.upsert_hitl_detail"
UTCNOW_PATH = "airflow.sdk.timezone.utcnow"
AWAIT_INPUT_FLAG_PATH = "airflow.providers.common.ai.mixins.approval.AIRFLOW_V_3_3_PLUS"


class FakeOperator(LLMApprovalMixin):
Expand Down Expand Up @@ -76,6 +80,9 @@ def context():
return MagicMock(**{"__getitem__": lambda self, key: {"task_instance": ti}[key]})


# The legacy trigger path is taken on cores < 3.3; pin the flag so these tests keep
# exercising the defer() fallback when run against newer cores.
@patch(AWAIT_INPUT_FLAG_PATH, False)
class TestDeferForApproval:
@patch(HITL_TRIGGER_PATH, autospec=True)
@patch(UPSERT_HITL_PATH)
Expand Down Expand Up @@ -253,6 +260,21 @@ def test_approved_with_modified_output(self, approval_op_with_modifications):

assert result == "modified output"

def test_approved_with_non_string_modified_output_raises(self, approval_op_with_modifications):
# On the awaiting_input path nothing upstream schema-validates params_input
# (HITLTrigger did on the legacy path), so execute_complete must enforce the
# string contract instead of returning a dict as the task's output.
event = {
"chosen_options": ["Approve"],
"responded_by_user": "editor",
"params_input": {"output": {"sneaky": "dict"}},
}

with pytest.raises(HITLTriggerEventError, match="must be a string"):
approval_op_with_modifications.execute_complete(
{}, generated_output="original output", event=event
)

def test_approved_with_unmodified_output(self, approval_op_with_modifications):
event = {
"chosen_options": ["Approve"],
Expand Down Expand Up @@ -324,3 +346,41 @@ def test_rejection_message_includes_username(self, approval_op):

with pytest.raises(HITLRejectException, match="alice"):
approval_op.execute_complete({}, generated_output="output", event=event)


@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="awaiting_input path requires Airflow 3.3+")
class TestAwaitInputForApproval:
"""On Airflow 3.3+ the review parks the task in AWAITING_INPUT instead of deferring."""

@patch(UPSERT_HITL_PATH)
def test_parks_task_in_awaiting_input(self, mock_upsert, approval_op, context):
with pytest.raises(TaskAwaitingInput) as exc_info:
approval_op.defer_for_approval(context, "some LLM output")

assert exc_info.value.method_name == "execute_complete"
assert exc_info.value.kwargs == {"generated_output": "some LLM output"}
assert exc_info.value.timeout is None
mock_upsert.assert_called_once()
assert mock_upsert.call_args[1]["options"] == ["Approve", "Reject"]
approval_op.defer.assert_not_called()

@patch(UPSERT_HITL_PATH)
def test_approval_timeout_carried_on_await(self, mock_upsert, context):
timeout = timedelta(hours=2)
op = FakeOperator(approval_timeout=timeout)

with pytest.raises(TaskAwaitingInput) as exc_info:
op.defer_for_approval(context, "output")

assert exc_info.value.timeout == timeout

@patch(UPSERT_HITL_PATH)
def test_pydantic_output_stringified_on_await(self, mock_upsert, approval_op, context):
class Answer(BaseModel):
text: str
confidence: float

with pytest.raises(TaskAwaitingInput) as exc_info:
approval_op.defer_for_approval(context, Answer(text="Paris", confidence=0.95))

assert exc_info.value.kwargs == {"generated_output": '{"text":"Paris","confidence":0.95}'}
55 changes: 42 additions & 13 deletions providers/common/ai/tests/unit/common/ai/operators/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,25 @@
)
from airflow.providers.common.ai.operators.llm import LLMOperator

from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS

try:
from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as _CORE_WALKER
except ImportError:
_CORE_WALKER = False

from airflow.providers.common.compat.sdk import TaskDeferred

if AIRFLOW_V_3_3_PLUS:
# On 3.3+ cores require_approval pauses the task in AWAITING_INPUT; older cores defer
# to HITLTrigger. Both exceptions carry method_name/kwargs/timeout, so the approval
# tests assert against whichever pause signal the running core uses.
from airflow.sdk.exceptions import TaskAwaitingInput as ApprovalPauseSignal
else:
ApprovalPauseSignal = TaskDeferred # type: ignore[assignment, misc]

AWAIT_INPUT_FLAG_PATH = "airflow.providers.common.ai.mixins.approval.AIRFLOW_V_3_3_PLUS"

# Returning the Pydantic instance through XCom (rather than a dict) only happens
# on cores that register declared ``output_type`` classes from the worker-side
# DAG walk. On older cores the operator dumps to a dict, so these tests skip.
Expand Down Expand Up @@ -187,8 +199,6 @@ def test_default_approval_flags(self):
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", autospec=True)
def test_execute_with_approval_defers(self, mock_hook_cls, mock_upsert, mock_trigger_cls):
"""When require_approval=True, execute() defers instead of returning output."""
from airflow.providers.common.compat.sdk import TaskDeferred

mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("LLM response")
mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent
Expand All @@ -201,20 +211,43 @@ def test_execute_with_approval_defers(self, mock_hook_cls, mock_upsert, mock_tri
)
ctx = _make_context()

with pytest.raises(TaskDeferred) as exc_info:
with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=ctx)

assert exc_info.value.method_name == "execute_complete"
assert exc_info.value.kwargs["generated_output"] == "LLM response"
mock_upsert.assert_called_once()

@patch(AWAIT_INPUT_FLAG_PATH, False)
@patch("airflow.providers.standard.triggers.hitl.HITLTrigger", autospec=True)
@patch("airflow.sdk.execution_time.hitl.upsert_hitl_detail")
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", autospec=True)
def test_execute_with_approval_defers_on_legacy_core(self, mock_hook_cls, mock_upsert, mock_trigger_cls):
"""On cores < 3.3 (flag pinned), execute() falls back to deferring to HITLTrigger."""
mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("LLM response")
mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent

op = LLMOperator(
task_id="legacy_approval_test",
prompt="Summarize this",
llm_conn_id="my_llm",
require_approval=True,
)

with pytest.raises(TaskDeferred) as exc_info:
op.execute(context=_make_context())

assert exc_info.value.method_name == "execute_complete"
assert exc_info.value.kwargs["generated_output"] == "LLM response"
mock_trigger_cls.assert_called_once()
mock_upsert.assert_called_once()

@patch("airflow.providers.standard.triggers.hitl.HITLTrigger", autospec=True)
@patch("airflow.sdk.execution_time.hitl.upsert_hitl_detail")
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", autospec=True)
def test_execute_with_approval_and_modifications(self, mock_hook_cls, mock_upsert, mock_trigger_cls):
"""allow_modifications=True passes an editable 'output' param."""
from airflow.providers.common.compat.sdk import TaskDeferred

mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("draft output")
mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent
Expand All @@ -228,7 +261,7 @@ def test_execute_with_approval_and_modifications(self, mock_hook_cls, mock_upser
)
ctx = _make_context()

with pytest.raises(TaskDeferred):
with pytest.raises(ApprovalPauseSignal):
op.execute(context=ctx)

upsert_kwargs = mock_upsert.call_args[1]
Expand All @@ -239,8 +272,6 @@ def test_execute_with_approval_and_modifications(self, mock_hook_cls, mock_upser
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", autospec=True)
def test_execute_with_approval_and_timeout(self, mock_hook_cls, mock_upsert, mock_trigger_cls):
"""approval_timeout is passed to the trigger."""
from airflow.providers.common.compat.sdk import TaskDeferred

mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result("output")
mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent
Expand All @@ -255,7 +286,7 @@ def test_execute_with_approval_and_timeout(self, mock_hook_cls, mock_upsert, moc
)
ctx = _make_context()

with pytest.raises(TaskDeferred) as exc_info:
with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=ctx)

assert exc_info.value.timeout == timeout
Expand All @@ -265,8 +296,6 @@ def test_execute_with_approval_and_timeout(self, mock_hook_cls, mock_upsert, moc
@patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", autospec=True)
def test_execute_with_approval_structured_output(self, mock_hook_cls, mock_upsert, mock_trigger_cls):
"""Structured (BaseModel) output is serialized before deferring."""
from airflow.providers.common.compat.sdk import TaskDeferred

mock_agent = MagicMock(spec=["run_sync"])
mock_agent.run_sync.return_value = _make_mock_run_result(Summary(text="hello"))
mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent
Expand All @@ -280,7 +309,7 @@ def test_execute_with_approval_structured_output(self, mock_hook_cls, mock_upser
)
ctx = _make_context()

with pytest.raises(TaskDeferred) as exc_info:
with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=ctx)

assert exc_info.value.kwargs["generated_output"] == '{"text":"hello"}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,17 @@

from airflow.providers.common.ai.operators.llm_file_analysis import LLMFileAnalysisOperator
from airflow.providers.common.ai.utils.file_analysis import FileAnalysisRequest
from airflow.providers.common.compat.sdk import TaskDeferred

from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS

if AIRFLOW_V_3_3_PLUS:
# On 3.3+ cores require_approval pauses the task in AWAITING_INPUT; older cores defer to
# HITLTrigger. Both signals carry method_name/kwargs/timeout, so the approval tests assert
# against whichever pause signal the running core uses.
from airflow.sdk.exceptions import TaskAwaitingInput as ApprovalPauseSignal
else:
ApprovalPauseSignal = TaskDeferred # type: ignore[assignment, misc]

try:
from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as _CORE_WALKER
Expand Down Expand Up @@ -208,8 +217,6 @@ class TestLLMFileAnalysisOperatorApproval:
def test_execute_with_approval_defers(
self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls
):
from airflow.providers.common.compat.sdk import TaskDeferred

mock_build_request.return_value = FileAnalysisRequest(
user_content="prepared prompt",
resolved_paths=["/tmp/app.log"],
Expand All @@ -228,7 +235,7 @@ def test_execute_with_approval_defers(
)
ctx = _make_context()

with pytest.raises(TaskDeferred) as exc_info:
with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=ctx)

assert exc_info.value.method_name == "execute_complete"
Expand All @@ -244,8 +251,6 @@ def test_execute_with_approval_defers(
def test_execute_with_approval_defers_structured_output_as_json(
self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls
):
from airflow.providers.common.compat.sdk import TaskDeferred

mock_build_request.return_value = FileAnalysisRequest(
user_content="prepared prompt",
resolved_paths=["/tmp/app.log"],
Expand All @@ -264,7 +269,7 @@ def test_execute_with_approval_defers_structured_output_as_json(
require_approval=True,
)

with pytest.raises(TaskDeferred) as exc_info:
with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=_make_context())

assert exc_info.value.kwargs["generated_output"] == '{"findings":["error spike"]}'
Expand Down Expand Up @@ -318,8 +323,6 @@ def test_execute_complete_with_approval_restores_modified_structured_output(self
def test_execute_with_approval_timeout(
self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls
):
from airflow.providers.common.compat.sdk import TaskDeferred

mock_build_request.return_value = FileAnalysisRequest(
user_content="prepared prompt",
resolved_paths=["/tmp/app.log"],
Expand All @@ -339,7 +342,7 @@ def test_execute_with_approval_timeout(
approval_timeout=timeout,
)

with pytest.raises(TaskDeferred) as exc_info:
with pytest.raises(ApprovalPauseSignal) as exc_info:
op.execute(context=_make_context())

assert exc_info.value.timeout == timeout
Loading
Loading