From c7509c5cb92420dc0c2f1a806a0c308d061acf4b Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 13 Jun 2026 00:47:51 +0100 Subject: [PATCH] common.ai: Park approval reviews in awaiting_input on Airflow 3.3+ LLMApprovalMixin (require_approval=True on LLMOperator/AgentOperator) now raises TaskAwaitingInput on Airflow 3.3+ so the task parks in the first-class awaiting_input state -- no trigger or triggerer involved -- matching the standard provider's HITLOperator. On older cores it falls back to deferring to HITLTrigger as before. The response deadline is enforced by the scheduler's awaiting_input timeout sweep on 3.3+. Because nothing upstream schema-validates params_input on the awaiting_input path (HITLTrigger did on the legacy path), execute_complete now enforces the string contract for reviewer-modified output and raises HITLTriggerEventError for non-string values. The AIRFLOW_V_3_3_PLUS flag this uses was added in apache-airflow-providers-common-compat 1.15.0; the dependency line is marked "# use next version" so the release manager bumps the floor at release time. --- providers/common/ai/pyproject.toml | 2 +- .../providers/common/ai/mixins/approval.py | 39 ++++++++++-- .../unit/common/ai/mixins/test_approval.py | 62 ++++++++++++++++++- .../unit/common/ai/operators/test_llm.py | 55 ++++++++++++---- .../ai/operators/test_llm_file_analysis.py | 23 ++++--- .../unit/common/ai/operators/test_llm_sql.py | 18 ++++-- 6 files changed, 165 insertions(+), 34 deletions(-) diff --git a/providers/common/ai/pyproject.toml b/providers/common/ai/pyproject.toml index db08fab7374d3..e569ae88f913f 100644 --- a/providers/common/ai/pyproject.toml +++ b/providers/common/ai/pyproject.toml @@ -67,7 +67,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=3.0.0", - "apache-airflow-providers-common-compat>=1.14.1", + "apache-airflow-providers-common-compat>=1.14.1", # use next version "apache-airflow-providers-standard>=1.12.1", "pydantic-ai-slim>=1.99.0", ] diff --git a/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py b/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py index 07855340c4b33..5ebd679efcdcf 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py +++ b/providers/common/ai/src/airflow/providers/common/ai/mixins/approval.py @@ -23,6 +23,13 @@ from pydantic import BaseModel +from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_3_PLUS + +if AIRFLOW_V_3_3_PLUS: + # On Airflow 3.3+ the review parks the task in the first-class AWAITING_INPUT state instead + # of deferring to a trigger. On older cores this name is absent and defer() is used. + from airflow.sdk.exceptions import TaskAwaitingInput + log = logging.getLogger(__name__) if TYPE_CHECKING: @@ -45,7 +52,8 @@ class LLMApprovalMixin: When ``require_approval=True`` on the operator, the generated output is presented to a human reviewer via the Airflow Human-in-the-Loop (HITL) - interface. The task defers until the reviewer approves or rejects. + interface. The task waits (``awaiting_input`` on Airflow 3.3+, deferred on + older versions) until the reviewer approves or rejects. If ``allow_modifications=True``, the reviewer can also edit the output before approving. The (possibly modified) output is then returned as the @@ -71,7 +79,11 @@ def defer_for_approval( body: str | None = None, ) -> None: """ - Write HITL detail, then defer to HITLTrigger for human review. + Write HITL detail, then pause the task for human review. + + On Airflow 3.3+ the task parks in the ``awaiting_input`` state (no trigger or triggerer + involved); on older versions it defers to :class:`HITLTrigger`. Either way it resumes in + ``execute_complete`` once a response (or timeout default) arrives. :param context: Airflow task context. :param output: The generated output to present for review. @@ -100,7 +112,6 @@ def defer_for_approval( output = str(output) ti_id = context["task_instance"].id - timeout_datetime = utcnow() + self.approval_timeout if self.approval_timeout else None if subject is None: subject = f"Review output for task `{self.task_id}`" @@ -128,6 +139,16 @@ def defer_for_approval( params=hitl_params, ) + if AIRFLOW_V_3_3_PLUS: + # New core (3.3+): park the task in AWAITING_INPUT -- no trigger, no triggerer. The + # task is resumed by the Core API response handler or the scheduler timeout sweep. + raise TaskAwaitingInput( + method_name="execute_complete", + kwargs={"generated_output": output}, + timeout=self.approval_timeout, + ) + + # Fallback for cores < 3.3: defer the response check to HITLTrigger on the triggerer. self.defer( trigger=HITLTrigger( ti_id=ti_id, @@ -135,7 +156,7 @@ def defer_for_approval( defaults=None, params=hitl_params, multiple=False, - timeout_datetime=timeout_datetime, + timeout_datetime=utcnow() + self.approval_timeout if self.approval_timeout else None, ), method_name="execute_complete", kwargs={"generated_output": output}, @@ -182,6 +203,16 @@ def execute_complete(self, context: Context, generated_output: str, event: dict[ # when allow_modifications=False, bypassing the read-only approval flow. if getattr(self, "allow_modifications", False) and params_input: modified = params_input.get("output") + if modified is not None and not isinstance(modified, str): + # On the awaiting_input path nothing upstream schema-validates params_input + # (HITLTrigger did on the legacy path), so enforce the string contract here + # rather than returning a non-string as the task's output. + raise HITLTriggerEventError( + { + "error": f"Modified output must be a string, got {type(modified).__name__}.", + "error_type": "validation", + } + ) if modified is not None and modified != generated_output: log.info("output=%s modified by the reviewer=%s ", modified, responded_by_user) return modified diff --git a/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py b/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py index 464dfe38986a1..54b675da723a9 100644 --- a/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py +++ b/providers/common/ai/tests/unit/common/ai/mixins/test_approval.py @@ -18,7 +18,7 @@ import pytest -from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS if not AIRFLOW_V_3_1_PLUS: pytest.skip("Human in the loop is only compatible with Airflow >= 3.1.0", allow_module_level=True) @@ -34,9 +34,13 @@ ) from airflow.providers.standard.exceptions import HITLRejectException, HITLTriggerEventError +if AIRFLOW_V_3_3_PLUS: + from airflow.sdk.exceptions import TaskAwaitingInput + HITL_TRIGGER_PATH = "airflow.providers.standard.triggers.hitl.HITLTrigger" UPSERT_HITL_PATH = "airflow.sdk.execution_time.hitl.upsert_hitl_detail" UTCNOW_PATH = "airflow.sdk.timezone.utcnow" +AWAIT_INPUT_FLAG_PATH = "airflow.providers.common.ai.mixins.approval.AIRFLOW_V_3_3_PLUS" class FakeOperator(LLMApprovalMixin): @@ -76,6 +80,9 @@ def context(): return MagicMock(**{"__getitem__": lambda self, key: {"task_instance": ti}[key]}) +# The legacy trigger path is taken on cores < 3.3; pin the flag so these tests keep +# exercising the defer() fallback when run against newer cores. +@patch(AWAIT_INPUT_FLAG_PATH, False) class TestDeferForApproval: @patch(HITL_TRIGGER_PATH, autospec=True) @patch(UPSERT_HITL_PATH) @@ -253,6 +260,21 @@ def test_approved_with_modified_output(self, approval_op_with_modifications): assert result == "modified output" + def test_approved_with_non_string_modified_output_raises(self, approval_op_with_modifications): + # On the awaiting_input path nothing upstream schema-validates params_input + # (HITLTrigger did on the legacy path), so execute_complete must enforce the + # string contract instead of returning a dict as the task's output. + event = { + "chosen_options": ["Approve"], + "responded_by_user": "editor", + "params_input": {"output": {"sneaky": "dict"}}, + } + + with pytest.raises(HITLTriggerEventError, match="must be a string"): + approval_op_with_modifications.execute_complete( + {}, generated_output="original output", event=event + ) + def test_approved_with_unmodified_output(self, approval_op_with_modifications): event = { "chosen_options": ["Approve"], @@ -324,3 +346,41 @@ def test_rejection_message_includes_username(self, approval_op): with pytest.raises(HITLRejectException, match="alice"): approval_op.execute_complete({}, generated_output="output", event=event) + + +@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="awaiting_input path requires Airflow 3.3+") +class TestAwaitInputForApproval: + """On Airflow 3.3+ the review parks the task in AWAITING_INPUT instead of deferring.""" + + @patch(UPSERT_HITL_PATH) + def test_parks_task_in_awaiting_input(self, mock_upsert, approval_op, context): + with pytest.raises(TaskAwaitingInput) as exc_info: + approval_op.defer_for_approval(context, "some LLM output") + + assert exc_info.value.method_name == "execute_complete" + assert exc_info.value.kwargs == {"generated_output": "some LLM output"} + assert exc_info.value.timeout is None + mock_upsert.assert_called_once() + assert mock_upsert.call_args[1]["options"] == ["Approve", "Reject"] + approval_op.defer.assert_not_called() + + @patch(UPSERT_HITL_PATH) + def test_approval_timeout_carried_on_await(self, mock_upsert, context): + timeout = timedelta(hours=2) + op = FakeOperator(approval_timeout=timeout) + + with pytest.raises(TaskAwaitingInput) as exc_info: + op.defer_for_approval(context, "output") + + assert exc_info.value.timeout == timeout + + @patch(UPSERT_HITL_PATH) + def test_pydantic_output_stringified_on_await(self, mock_upsert, approval_op, context): + class Answer(BaseModel): + text: str + confidence: float + + with pytest.raises(TaskAwaitingInput) as exc_info: + approval_op.defer_for_approval(context, Answer(text="Paris", confidence=0.95)) + + assert exc_info.value.kwargs == {"generated_output": '{"text":"Paris","confidence":0.95}'} diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py index d5ef8228d3528..f9f3bf0909924 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py @@ -29,13 +29,25 @@ ) from airflow.providers.common.ai.operators.llm import LLMOperator -from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS try: from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as _CORE_WALKER except ImportError: _CORE_WALKER = False +from airflow.providers.common.compat.sdk import TaskDeferred + +if AIRFLOW_V_3_3_PLUS: + # On 3.3+ cores require_approval pauses the task in AWAITING_INPUT; older cores defer + # to HITLTrigger. Both exceptions carry method_name/kwargs/timeout, so the approval + # tests assert against whichever pause signal the running core uses. + from airflow.sdk.exceptions import TaskAwaitingInput as ApprovalPauseSignal +else: + ApprovalPauseSignal = TaskDeferred # type: ignore[assignment, misc] + +AWAIT_INPUT_FLAG_PATH = "airflow.providers.common.ai.mixins.approval.AIRFLOW_V_3_3_PLUS" + # Returning the Pydantic instance through XCom (rather than a dict) only happens # on cores that register declared ``output_type`` classes from the worker-side # DAG walk. On older cores the operator dumps to a dict, so these tests skip. @@ -187,8 +199,6 @@ def test_default_approval_flags(self): @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", autospec=True) def test_execute_with_approval_defers(self, mock_hook_cls, mock_upsert, mock_trigger_cls): """When require_approval=True, execute() defers instead of returning output.""" - from airflow.providers.common.compat.sdk import TaskDeferred - mock_agent = MagicMock(spec=["run_sync"]) mock_agent.run_sync.return_value = _make_mock_run_result("LLM response") mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent @@ -201,20 +211,43 @@ def test_execute_with_approval_defers(self, mock_hook_cls, mock_upsert, mock_tri ) ctx = _make_context() - with pytest.raises(TaskDeferred) as exc_info: + with pytest.raises(ApprovalPauseSignal) as exc_info: op.execute(context=ctx) assert exc_info.value.method_name == "execute_complete" assert exc_info.value.kwargs["generated_output"] == "LLM response" mock_upsert.assert_called_once() + @patch(AWAIT_INPUT_FLAG_PATH, False) + @patch("airflow.providers.standard.triggers.hitl.HITLTrigger", autospec=True) + @patch("airflow.sdk.execution_time.hitl.upsert_hitl_detail") + @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", autospec=True) + def test_execute_with_approval_defers_on_legacy_core(self, mock_hook_cls, mock_upsert, mock_trigger_cls): + """On cores < 3.3 (flag pinned), execute() falls back to deferring to HITLTrigger.""" + mock_agent = MagicMock(spec=["run_sync"]) + mock_agent.run_sync.return_value = _make_mock_run_result("LLM response") + mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent + + op = LLMOperator( + task_id="legacy_approval_test", + prompt="Summarize this", + llm_conn_id="my_llm", + require_approval=True, + ) + + with pytest.raises(TaskDeferred) as exc_info: + op.execute(context=_make_context()) + + assert exc_info.value.method_name == "execute_complete" + assert exc_info.value.kwargs["generated_output"] == "LLM response" + mock_trigger_cls.assert_called_once() + mock_upsert.assert_called_once() + @patch("airflow.providers.standard.triggers.hitl.HITLTrigger", autospec=True) @patch("airflow.sdk.execution_time.hitl.upsert_hitl_detail") @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", autospec=True) def test_execute_with_approval_and_modifications(self, mock_hook_cls, mock_upsert, mock_trigger_cls): """allow_modifications=True passes an editable 'output' param.""" - from airflow.providers.common.compat.sdk import TaskDeferred - mock_agent = MagicMock(spec=["run_sync"]) mock_agent.run_sync.return_value = _make_mock_run_result("draft output") mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent @@ -228,7 +261,7 @@ def test_execute_with_approval_and_modifications(self, mock_hook_cls, mock_upser ) ctx = _make_context() - with pytest.raises(TaskDeferred): + with pytest.raises(ApprovalPauseSignal): op.execute(context=ctx) upsert_kwargs = mock_upsert.call_args[1] @@ -239,8 +272,6 @@ def test_execute_with_approval_and_modifications(self, mock_hook_cls, mock_upser @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", autospec=True) def test_execute_with_approval_and_timeout(self, mock_hook_cls, mock_upsert, mock_trigger_cls): """approval_timeout is passed to the trigger.""" - from airflow.providers.common.compat.sdk import TaskDeferred - mock_agent = MagicMock(spec=["run_sync"]) mock_agent.run_sync.return_value = _make_mock_run_result("output") mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent @@ -255,7 +286,7 @@ def test_execute_with_approval_and_timeout(self, mock_hook_cls, mock_upsert, moc ) ctx = _make_context() - with pytest.raises(TaskDeferred) as exc_info: + with pytest.raises(ApprovalPauseSignal) as exc_info: op.execute(context=ctx) assert exc_info.value.timeout == timeout @@ -265,8 +296,6 @@ def test_execute_with_approval_and_timeout(self, mock_hook_cls, mock_upsert, moc @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", autospec=True) def test_execute_with_approval_structured_output(self, mock_hook_cls, mock_upsert, mock_trigger_cls): """Structured (BaseModel) output is serialized before deferring.""" - from airflow.providers.common.compat.sdk import TaskDeferred - mock_agent = MagicMock(spec=["run_sync"]) mock_agent.run_sync.return_value = _make_mock_run_result(Summary(text="hello")) mock_hook_cls.get_hook.return_value.create_agent.return_value = mock_agent @@ -280,7 +309,7 @@ def test_execute_with_approval_structured_output(self, mock_hook_cls, mock_upser ) ctx = _make_context() - with pytest.raises(TaskDeferred) as exc_info: + with pytest.raises(ApprovalPauseSignal) as exc_info: op.execute(context=ctx) assert exc_info.value.kwargs["generated_output"] == '{"text":"hello"}' diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py b/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py index 7c955a160b4f4..9e692b420f9ea 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py @@ -25,8 +25,17 @@ from airflow.providers.common.ai.operators.llm_file_analysis import LLMFileAnalysisOperator from airflow.providers.common.ai.utils.file_analysis import FileAnalysisRequest +from airflow.providers.common.compat.sdk import TaskDeferred -from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS + +if AIRFLOW_V_3_3_PLUS: + # On 3.3+ cores require_approval pauses the task in AWAITING_INPUT; older cores defer to + # HITLTrigger. Both signals carry method_name/kwargs/timeout, so the approval tests assert + # against whichever pause signal the running core uses. + from airflow.sdk.exceptions import TaskAwaitingInput as ApprovalPauseSignal +else: + ApprovalPauseSignal = TaskDeferred # type: ignore[assignment, misc] try: from airflow.sdk.serde import SUPPORTS_OPERATOR_DESERIALIZATION_WALKER as _CORE_WALKER @@ -208,8 +217,6 @@ class TestLLMFileAnalysisOperatorApproval: def test_execute_with_approval_defers( self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls ): - from airflow.providers.common.compat.sdk import TaskDeferred - mock_build_request.return_value = FileAnalysisRequest( user_content="prepared prompt", resolved_paths=["/tmp/app.log"], @@ -228,7 +235,7 @@ def test_execute_with_approval_defers( ) ctx = _make_context() - with pytest.raises(TaskDeferred) as exc_info: + with pytest.raises(ApprovalPauseSignal) as exc_info: op.execute(context=ctx) assert exc_info.value.method_name == "execute_complete" @@ -244,8 +251,6 @@ def test_execute_with_approval_defers( def test_execute_with_approval_defers_structured_output_as_json( self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls ): - from airflow.providers.common.compat.sdk import TaskDeferred - mock_build_request.return_value = FileAnalysisRequest( user_content="prepared prompt", resolved_paths=["/tmp/app.log"], @@ -264,7 +269,7 @@ def test_execute_with_approval_defers_structured_output_as_json( require_approval=True, ) - with pytest.raises(TaskDeferred) as exc_info: + with pytest.raises(ApprovalPauseSignal) as exc_info: op.execute(context=_make_context()) assert exc_info.value.kwargs["generated_output"] == '{"findings":["error spike"]}' @@ -318,8 +323,6 @@ def test_execute_complete_with_approval_restores_modified_structured_output(self def test_execute_with_approval_timeout( self, mock_build_request, mock_hook_cls, mock_upsert, mock_trigger_cls ): - from airflow.providers.common.compat.sdk import TaskDeferred - mock_build_request.return_value = FileAnalysisRequest( user_content="prepared prompt", resolved_paths=["/tmp/app.log"], @@ -339,7 +342,7 @@ def test_execute_with_approval_timeout( approval_timeout=timeout, ) - with pytest.raises(TaskDeferred) as exc_info: + with pytest.raises(ApprovalPauseSignal) as exc_info: op.execute(context=_make_context()) assert exc_info.value.timeout == timeout diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py index a994ae3d1cd2f..1862971c9539d 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py @@ -30,7 +30,15 @@ from airflow.providers.common.compat.sdk import TaskDeferred from airflow.providers.common.sql.config import DataSourceConfig -from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS + +if AIRFLOW_V_3_3_PLUS: + # On 3.3+ cores require_approval pauses the task in AWAITING_INPUT; older cores defer to + # HITLTrigger. Both signals carry method_name/kwargs/timeout, so the approval tests assert + # against whichever pause signal the running core uses. + from airflow.sdk.exceptions import TaskAwaitingInput as ApprovalPauseSignal +else: + ApprovalPauseSignal = TaskDeferred # type: ignore[assignment, misc] def _make_mock_run_result(output): @@ -475,7 +483,7 @@ def test_execute_with_approval_defers(self, mock_hook_cls, mock_upsert, mock_tri ) ctx = _make_context() - with pytest.raises(TaskDeferred) as exc_info: + with pytest.raises(ApprovalPauseSignal) as exc_info: op.execute(context=ctx) assert exc_info.value.method_name == "execute_complete" @@ -521,7 +529,7 @@ def test_execute_with_approval_and_modifications(self, mock_hook_cls, mock_upser ) ctx = _make_context() - with pytest.raises(TaskDeferred): + with pytest.raises(ApprovalPauseSignal): op.execute(context=ctx) upsert_kwargs = mock_upsert.call_args[1] @@ -545,7 +553,7 @@ def test_execute_with_approval_and_timeout(self, mock_hook_cls, mock_upsert, moc ) ctx = _make_context() - with pytest.raises(TaskDeferred) as exc_info: + with pytest.raises(ApprovalPauseSignal) as exc_info: op.execute(context=ctx) assert exc_info.value.timeout == timeout @@ -583,7 +591,7 @@ def test_execute_strips_code_fences_before_deferring(self, mock_hook_cls, mock_u ) ctx = _make_context() - with pytest.raises(TaskDeferred) as exc_info: + with pytest.raises(ApprovalPauseSignal) as exc_info: op.execute(context=ctx) assert exc_info.value.kwargs["generated_output"] == "SELECT 1"