diff --git a/airflow-core/docs/tutorial/hitl.rst b/airflow-core/docs/tutorial/hitl.rst index 8e89b49d88de4..aea843486f404 100644 --- a/airflow-core/docs/tutorial/hitl.rst +++ b/airflow-core/docs/tutorial/hitl.rst @@ -196,6 +196,37 @@ When the operator creates an HITL request that is waiting for a human response, :end-before: [END howto_hitl_entry_operator] +Testing HITL Dags locally +------------------------- + +``airflow dags test`` (and the underlying ``dag.test()``) supports HITL tasks. A task that reaches +the ``awaiting_input`` state stays parked -- the test run never resolves it itself -- and the run +waits, logging which tasks await input, until a response is recorded from outside. The response +goes through the same channels as on a real deployment: the Required Actions page or the HITL REST +API (``PATCH .../hitlDetails``) of an api-server sharing the metadata database (for example +``airflow standalone``, or a separately started ``airflow api-server``). Once the response lands, +the test run resumes the task and continues with downstream tasks. + +This also lets AI agents drive a HITL pipeline end-to-end locally: run ``airflow dags test``, watch +for the waiting log line, ask the human, and submit their answer through the HITL REST API. The two +calls involved (``~`` works as a wildcard for ``dag_id`` and ``dag_run_id``): + +.. code-block:: text + + # Discover pending requests (subject, options, params, run/task identifiers) + GET /api/v2/dags/~/dagRuns/~/hitlDetails?response_received=false + + # Submit the response; the test run resumes the task on its next poll. + # map_index is -1 for non-mapped tasks. + PATCH /api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/hitlDetails + {"chosen_options": ["Approve"], "params_input": {}} + +.. note:: + + ``response_timeout`` and timeout defaults are enforced by the scheduler, which does not run + under ``airflow dags test``. A parked task therefore waits indefinitely for a response; supply + one through the UI or REST API to let the run finish. + Benefits and Common Use Cases ----------------------------- diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index ad4ed2c134d13..81198d8e41810 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -22,6 +22,7 @@ import os import pickle import re +import time from contextlib import nullcontext from datetime import timedelta from pathlib import Path @@ -62,8 +63,10 @@ from airflow.models.dagbundle import DagBundleModel from airflow.models.dagrun import DagRun from airflow.models.deadline_alert import DeadlineAlert as DeadlineAlertModel +from airflow.models.hitl import HITLDetail from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance as TI +from airflow.models.trigger import handle_event_submit from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator @@ -83,6 +86,8 @@ from airflow.sdk.definitions.callback import AsyncCallback from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference from airflow.sdk.definitions.param import Param +from airflow.sdk.exceptions import TaskAwaitingInput +from airflow.sdk.execution_time.hitl import upsert_hitl_detail from airflow.serialization.definitions.dag import SerializedDAG from airflow.serialization.encoders import coerce_to_core_timetable from airflow.serialization.serialized_objects import LazyDeserializedDAG @@ -93,6 +98,7 @@ NullTimetable, OnceTimetable, ) +from airflow.triggers.base import TriggerEvent from airflow.utils.file import list_py_file_paths from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState @@ -1826,6 +1832,101 @@ def check_task_2(my_input): mock_task_object_1.assert_called() mock_task_object_2.assert_not_called() + @staticmethod + def _make_awaiting_input_dag(dag_id, resume_calls): + """Build a Dag whose single task parks in AWAITING_INPUT (Human-in-the-loop).""" + + class AskOperator(BaseOperator): + def execute(self, context): + upsert_hitl_detail( + ti_id=context["task_instance"].id, + options=["Approve", "Reject"], + subject="Deploy?", + multiple=False, + params={}, + ) + raise TaskAwaitingInput(method_name="execute_complete") + + def execute_complete(self, context, event): + resume_calls.append((event["chosen_options"], event["params_input"])) + return event["chosen_options"] + + dag = DAG(dag_id=dag_id, schedule=None, start_date=DEFAULT_DATE) + with dag: + AskOperator(task_id="ask") + sync_dag_to_db(dag) + return dag + + @pytest.mark.execution_timeout(60) + def test_dag_test_hitl_task_stays_parked_until_external_response( + self, testing_dag_bundle, monkeypatch, caplog + ): + """ + The dag.test() contract for Human-in-the-loop: a task that parks in AWAITING_INPUT is + never resolved by dag.test() itself. The run waits until a response recorded from + outside (here through an independent session, the way the API response handler does) + flips it back to SCHEDULED, at which point the loop resumes it. + + The loop's time.sleep is the synchronization point: patching it to deliver the + external response keeps the test deterministic, with no real-time waits. + """ + resume_calls: list = [] + dag = self._make_awaiting_input_dag("test_dag_test_hitl_external", resume_calls) + + parked_states_seen = [] + spins = 0 + + def deliver_external_response(): + """Record an Approve through an independent session, as the API handler would.""" + with create_session(scoped=False) as external_session: + parked_ti = external_session.scalar( + select(TI).where( + TI.dag_id == "test_dag_test_hitl_external", + TI.task_id == "ask", + TI.state == TaskInstanceState.AWAITING_INPUT, + ) + ) + if parked_ti is None: + return + parked_states_seen.append(parked_ti.state) + detail = external_session.get(HITLDetail, parked_ti.id) + detail.chosen_options = ["Approve"] + detail.params_input = {} + detail.responded_at = timezone.utcnow() + detail.responded_by = {"id": "external", "name": "external"} + external_session.add(detail) + handle_event_submit( + TriggerEvent(detail.as_resume_event_payload()), + task_instance=parked_ti, + session=external_session, + ) + + def respond_once_waiting(seconds): + # Replaces the loop's real sleep to keep the test fast, and delivers the response + # only once dag.test() has logged that it is parked on the HITL task -- so the + # assertions below prove the task resumed through the new awaiting_input branch + # rather than any other path. + nonlocal spins + spins += 1 + assert spins < 50, "dag.test() never logged that it was waiting on the parked task" + if "Waiting for Human-in-the-loop input" in caplog.text: + deliver_external_response() + + monkeypatch.setattr(time, "sleep", respond_once_waiting) + + with caplog.at_level(logging.INFO, logger="airflow.sdk.definitions.dag"): + dr = dag.test() + + # The task must take the new "waiting for input" branch, not the old "unrunnable" one. + assert "Waiting for Human-in-the-loop input" in caplog.text + assert "No tasks to run" not in caplog.text + + ti = dr.get_task_instance("ask") + assert ti is not None + assert ti.state == TaskInstanceState.SUCCESS + assert parked_states_seen == [TaskInstanceState.AWAITING_INPUT] + assert resume_calls == [(["Approve"], {})] + def test_dag_connection_file(self, tmp_path, testing_dag_bundle): test_connections_string = """ --- diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 46794edbe4a68..42924fb6aaec8 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1404,8 +1404,23 @@ def test( # triggerer may mark tasks scheduled so we read from DB all_tis = set(dr.get_task_instances(session=session)) scheduled_tis = {x for x in all_tis if x.state == TaskInstanceState.SCHEDULED} - ids_unrunnable = {x for x in all_tis if x.state not in FINISHED_STATES} - scheduled_tis - if not scheduled_tis and ids_unrunnable: + awaiting_input_tis = {x for x in all_tis if x.state == TaskInstanceState.AWAITING_INPUT} + ids_unrunnable = ( + {x for x in all_tis if x.state not in FINISHED_STATES} + - scheduled_tis + - awaiting_input_tis + ) + if not scheduled_tis and awaiting_input_tis: + # Human-in-the-loop tasks stay parked in AWAITING_INPUT: dag.test() never + # resolves them itself. Keep the run alive until a response recorded from + # outside -- the Required Actions UI or the HITL REST API of an api-server + # sharing this metadata DB -- flips them back to SCHEDULED. + log.info( + "Waiting for Human-in-the-loop input for tasks: %s", + sorted(x.task_id for x in awaiting_input_tis), + ) + time.sleep(1) + elif not scheduled_tis and ids_unrunnable: log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable) time.sleep(1)