From 9070360cebf7e7adacbcb2f4442fed9b7efd5b41 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 13 May 2024 18:50:54 +0800 Subject: [PATCH 01/19] fix(baseoperator): change start_trigger into start_trigger_cls and start_trigger_kwargs --- airflow/decorators/base.py | 2 +- airflow/models/abstractoperator.py | 3 +- airflow/models/baseoperator.py | 6 +- airflow/models/dagrun.py | 11 +-- airflow/models/mappedoperator.py | 10 +-- airflow/models/taskinstance.py | 41 +++++++++++ airflow/serialization/serialized_objects.py | 23 ++---- tests/models/test_dagrun.py | 12 ++-- tests/serialization/test_dag_serialization.py | 70 ++++++++----------- 9 files changed, 105 insertions(+), 73 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 2ae85a9c435b2..79efae67bfb6d 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -509,7 +509,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: # task's expand() contribute to the op_kwargs operator argument, not # the operator arguments themselves, and should expand against it. expand_input_attr="op_kwargs_expand_input", - start_trigger=self.operator_class.start_trigger, + start_trigger_cls=self.operator_class.start_trigger_cls, next_method=self.operator_class.next_method, ) return XComArg(operator=operator) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index b7160430e066a..c0024f44130ce 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -122,7 +122,8 @@ class AbstractOperator(Templater, DAGNode): "node_id", # Duplicates task_id "task_group", # Doesn't have a useful repr, no point showing in UI "inherits_from_empty_operator", # impl detail - "start_trigger", + "start_trigger_cls", + "start_trigger_kwargs", "next_method", # For compatibility with TG, for operators these are just the current task, no point showing "roots", diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 98532d90b0256..3fd15120d6cd5 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -819,7 +819,8 @@ def say_hello_world(**context): # Set to True for an operator instantiated by a mapped operator. __from_mapped = False - start_trigger: BaseTrigger | None = None + start_trigger_cls: str | None = None + start_trigger_kwargs: dict[str, Any] | None = None next_method: str | None = None def __init__( @@ -1679,7 +1680,8 @@ def get_serialized_fields(cls): "is_teardown", "on_failure_fail_dagrun", "map_index_template", - "start_trigger", + "start_trigger_cls", + "start_trigger_kwargs", "next_method", "_needs_expansion", } diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 008608b96d225..0be447d678139 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -51,7 +51,7 @@ from airflow.api_internal.internal_api_call import internal_api_call from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.configuration import conf as airflow_conf -from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskDeferred, TaskNotFound +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskNotFound from airflow.listeners.listener import get_listener_manager from airflow.models import Log from airflow.models.abstractoperator import NotMapped @@ -1539,7 +1539,8 @@ def schedule_tis( ): dummy_ti_ids.append((ti.task_id, ti.map_index)) elif ( - ti.task.start_trigger is not None + ti.task.start_trigger_cls is not None + and ti.task.start_trigger_kwargs is not None and ti.task.next_method is not None and not ti.task.on_execute_callback and not ti.task.on_success_callback @@ -1547,9 +1548,11 @@ def schedule_tis( ): if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: ti.try_number += 1 - ti.defer_task( - exception=TaskDeferred(trigger=ti.task.start_trigger, method_name=ti.task.next_method), + ti.defer_task_from_start( session=session, + trigger_cls=ti.task.start_trigger_cls, + trigger_kwargs=ti.task.start_trigger_kwargs, + next_method=ti.task.next_method, ) else: schedulable_ti_ids.append((ti.task_id, ti.map_index)) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 27d0510c307c0..ecbf6d1799f92 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -81,7 +81,6 @@ from airflow.models.param import ParamsDict from airflow.models.xcom_arg import XComArg from airflow.ti_deps.deps.base_ti_dep import BaseTIDep - from airflow.triggers.base import BaseTrigger from airflow.utils.context import Context from airflow.utils.operator_resources import Resources from airflow.utils.task_group import TaskGroup @@ -237,7 +236,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: # For classic operators, this points to expand_input because kwargs # to BaseOperator.expand() contribute to operator arguments. expand_input_attr="expand_input", - start_trigger=self.operator_class.start_trigger, + start_trigger_cls=self.operator_class.start_trigger_cls, + start_trigger_kwargs=getattr(partial_kwargs, "start_trigger_kwargs", None), next_method=self.operator_class.next_method, ) return op @@ -281,7 +281,8 @@ class MappedOperator(AbstractOperator): _task_module: str _task_type: str _operator_name: str - start_trigger: BaseTrigger | None + start_trigger_cls: str | None + start_trigger_kwargs: dict[str, Any] | None next_method: str | None _needs_expansion: bool = True @@ -312,7 +313,8 @@ class MappedOperator(AbstractOperator): ( "parse_time_mapped_ti_count", "operator_class", - "start_trigger", + "start_trigger_cls", + "start_trigger_kwargs", "next_method", ) ) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index c1ace17cd5158..60cac8f270ac8 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -3007,6 +3007,47 @@ def defer_task(self, exception: TaskDeferred, session: Session) -> None: """ _defer_task(ti=self, exception=exception, session=session) + @provide_session + def defer_task_from_start( + self, session: Session, trigger_cls: str, trigger_kwargs: dict[str, Any], next_method: str + ) -> None: + """Mark the task as deferred and sets up the trigger that is needed to resume it. + + :meta: private + """ + from airflow.models.trigger import Trigger + + if TYPE_CHECKING: + assert self.task + + # First, make the trigger entry + trigger_row = Trigger(classpath=trigger_cls, kwargs=trigger_kwargs) + session.add(trigger_row) + session.flush() + + # Then, update ourselves so it matches the deferral request + # Keep an eye on the logic in `check_and_change_state_before_execution()` + # depending on self.next_method semantics + self.state = TaskInstanceState.DEFERRED + self.trigger_id = trigger_row.id + self.next_method = next_method + self.next_kwargs = trigger_kwargs or {} + + # Calculate timeout too if it was passed + # if defer.timeout is not None: + # self.trigger_timeout = timezone.utcnow() + defer.timeout + # else: + self.trigger_timeout = None + + # If an execution_timeout is set, set the timeout to the minimum of + # it and the trigger timeout + execution_timeout = self.task.execution_timeout + if execution_timeout: + if self.trigger_timeout: + self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) + else: + self.trigger_timeout = self.start_date + execution_timeout + def _run_execute_callback(self, context: Context, task: BaseOperator) -> None: """Functions that need to be run before a Task is executed.""" if not (callbacks := task.on_execute_callback): diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 6e7f50a87c73d..da6f590e230d4 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -70,7 +70,6 @@ from airflow.utils.code_utils import get_python_source from airflow.utils.context import Context, OutletEventAccessor, OutletEventAccessors from airflow.utils.docs import get_docs_url -from airflow.utils.helpers import exactly_one from airflow.utils.module_loading import import_string, qualname from airflow.utils.operator_resources import Resources from airflow.utils.task_group import MappedTaskGroup, TaskGroup @@ -1018,10 +1017,8 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) # Used to determine if an Operator is inherited from EmptyOperator serialize_op["_is_empty"] = op.inherits_from_empty_operator - if exactly_one(op.start_trigger is not None, op.next_method is not None): - raise AirflowException("start_trigger and next_method should both be set.") - - serialize_op["start_trigger"] = op.start_trigger.serialize() if op.start_trigger else None + serialize_op["start_trigger_cls"] = op.start_trigger_cls + serialize_op["start_trigger_kwargs"] = op.start_trigger_kwargs serialize_op["next_method"] = op.next_method if op.operator_extra_links: @@ -1206,15 +1203,8 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: # Used to determine if an Operator is inherited from EmptyOperator setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False))) - # Deserialize start_trigger - serialized_start_trigger = encoded_op.get("start_trigger") - if serialized_start_trigger: - trigger_cls_name, trigger_kwargs = serialized_start_trigger - trigger_cls = import_string(trigger_cls_name) - start_trigger = trigger_cls(**trigger_kwargs) - setattr(op, "start_trigger", start_trigger) - else: - setattr(op, "start_trigger", None) + setattr(op, "start_trigger_cls", encoded_op.get("start_trigger_cls", None)) + setattr(op, "start_trigger_kwargs", encoded_op.get("start_trigger_kwargs", None)) setattr(op, "next_method", encoded_op.get("next_method", None)) @staticmethod @@ -1278,8 +1268,9 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: end_date=None, disallow_kwargs_override=encoded_op["_disallow_kwargs_override"], expand_input_attr=encoded_op["_expand_input_attr"], - start_trigger=None, - next_method=None, + start_trigger_cls=encoded_op.get("start_trigger_cls", None), + start_trigger_kwargs=encoded_op.get("start_trigger_kwargs", None), + next_method=encoded_op.get("next_method", None), ) else: op = SerializedBaseOperator(task_id=encoded_op["task_id"]) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index fe1c2d58a29e3..09900ffdfff91 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -40,7 +40,6 @@ from airflow.operators.python import ShortCircuitOperator from airflow.serialization.serialized_objects import SerializedDAG from airflow.stats import Stats -from airflow.triggers.testing import SuccessTrigger from airflow.utils import timezone from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.trigger_rule import TriggerRule @@ -1989,16 +1988,17 @@ def test_schedule_tis_map_index(dag_maker, session): def test_schedule_tis_start_trigger(dag_maker, session): """ - Test that an operator with _start_trigger and _next_method set can be directly - deferred during scheduling. + Test that an operator with start_trigger_cls, start_trigger_kwargs and next_method set can be + directly deferred during scheduling. """ - trigger = SuccessTrigger() class TestOperator(BaseOperator): + start_trigger_cls = "airflow.triggers.testing.SuccessTrigger" + next_method = "execute_complete" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.start_trigger = trigger - self.next_method = "execute_complete" + self.start_trigger_kwargs = {} def execute_complete(self): pass diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 093b7fba7615e..ceb2a5566727d 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -72,7 +72,6 @@ from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.simple import NullTimetable, OnceTimetable -from airflow.triggers.testing import SuccessTrigger from airflow.utils import timezone from airflow.utils.operator_resources import Resources from airflow.utils.task_group import TaskGroup @@ -228,7 +227,8 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i "_needs_expansion": False, "weight_rule": "downstream", "next_method": None, - "start_trigger": None, + "start_trigger_cls": None, + "start_trigger_kwargs": None, }, }, ], @@ -2167,25 +2167,29 @@ def execute(self, context: Context): SerializedDAG.to_dict(dag) @pytest.mark.db_test - def test_start_trigger_and_next_method_in_serialized_dag(self): + def test_start_trigger_cls_kwargs_and_next_method_in_serialized_dag(self): """ - Test that when we provide start_trigger and next_method, the DAG can be correctly serialized. + Test that when we provide start_trigger_cls, start_trigger_kwargs and next_method, + the DAG can be correctly serialized. """ - trigger = SuccessTrigger() class TestOperator(BaseOperator): + start_trigger_cls = "airflow.triggers.testing.SuccessTrigger" + next_method = "execute_complete" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.start_trigger = trigger - self.next_method = "execute_complete" + self.start_trigger_kwargs = {} def execute_complete(self): pass class Test2Operator(BaseOperator): + start_trigger_cls = "airflow.triggers.testing.SuccessTrigger" + start_trigger_kwargs = {} + next_method = "execute_complete" + def __init__(self, *args, **kwargs): - self.start_trigger = trigger - self.next_method = "execute_complete" super().__init__(*args, **kwargs) def execute_complete(self): @@ -2200,30 +2204,10 @@ def execute_complete(self): serialized_obj = SerializedDAG.to_dict(dag) for task in serialized_obj["dag"]["tasks"]: - assert task["__var"]["start_trigger"] == trigger.serialize() + assert task["__var"]["start_trigger_cls"] == "airflow.triggers.testing.SuccessTrigger" + assert task["__var"]["start_trigger_kwargs"] == {} assert task["__var"]["next_method"] == "execute_complete" - @pytest.mark.db_test - def test_start_trigger_in_serialized_dag_but_no_next_method(self): - """ - Test that when we provide start_trigger without next_method, an AriflowException should be raised. - """ - - trigger = SuccessTrigger() - - class TestOperator(BaseOperator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.start_trigger = trigger - - dag = DAG(dag_id="test_dag", start_date=datetime(2023, 11, 9)) - - with dag: - TestOperator(task_id="test_task") - - with pytest.raises(AirflowException, match="start_trigger and next_method should both be set."): - SerializedDAG.to_dict(dag) - def test_kubernetes_optional(): """Serialisation / deserialisation continues to work without kubernetes installed""" @@ -2274,7 +2258,8 @@ def test_operator_expand_serde(): "_needs_expansion": True, "_task_module": "airflow.operators.bash", "_task_type": "BashOperator", - "start_trigger": None, + "start_trigger_cls": None, + "start_trigger_kwargs": None, "next_method": None, "downstream_task_ids": [], "expand_input": { @@ -2308,7 +2293,8 @@ def test_operator_expand_serde(): assert op.operator_class == { "_task_type": "BashOperator", "_needs_expansion": True, - "start_trigger": None, + "start_trigger_cls": None, + "start_trigger_kwargs": None, "next_method": None, "downstream_task_ids": [], "task_id": "a", @@ -2356,7 +2342,8 @@ def test_operator_expand_xcomarg_serde(): "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", "next_method": None, - "start_trigger": None, + "start_trigger_cls": None, + "start_trigger_kwargs": None, } op = BaseSerialization.deserialize(serialized) @@ -2414,7 +2401,8 @@ def test_operator_expand_kwargs_literal_serde(strict): "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", "next_method": None, - "start_trigger": None, + "start_trigger_cls": None, + "start_trigger_kwargs": None, } op = BaseSerialization.deserialize(serialized) @@ -2463,7 +2451,8 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", "next_method": None, - "start_trigger": None, + "start_trigger_cls": None, + "start_trigger_kwargs": None, } op = BaseSerialization.deserialize(serialized) @@ -2582,7 +2571,8 @@ def x(arg1, arg2, arg3): "_disallow_kwargs_override": False, "_expand_input_attr": "op_kwargs_expand_input", "next_method": None, - "start_trigger": None, + "start_trigger_cls": None, + "start_trigger_kwargs": None, } deserialized = BaseSerialization.deserialize(serialized) @@ -2649,7 +2639,8 @@ def x(arg1, arg2, arg3): "_task_type": "_PythonDecoratedOperator", "_operator_name": "@task", "next_method": None, - "start_trigger": None, + "start_trigger_cls": None, + "start_trigger_kwargs": None, "downstream_task_ids": [], "partial_kwargs": { "is_setup": False, @@ -2802,7 +2793,8 @@ def operator_extra_links(self): "_is_mapped": True, "_needs_expansion": True, "next_method": None, - "start_trigger": None, + "start_trigger_cls": None, + "start_trigger_kwargs": None, } deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR]) assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()] From 12905abf8343b6e3fec3cacbf2734d7e73d72446 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 14 May 2024 18:03:06 +0800 Subject: [PATCH 02/19] refactor(taskinstance): extract common logic in defer_task --- airflow/cli/commands/task_command.py | 2 +- airflow/models/dagrun.py | 2 +- airflow/models/taskinstance.py | 79 ++++++++++--------- .../serialization/pydantic/taskinstance.py | 3 +- 4 files changed, 43 insertions(+), 43 deletions(-) diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index ac9b211c21798..d2b0e90052743 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -673,7 +673,7 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N else: ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True, raise_on_defer=True) except TaskDeferred as defer: - ti.defer_task(exception=defer, session=session) + ti.defer_task_from_task_deferred(defer=defer, session=session) log.info("[TASK TEST] running trigger in line") event = _run_inline_trigger(defer.trigger) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 0be447d678139..817c3d2cc195e 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1548,7 +1548,7 @@ def schedule_tis( ): if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: ti.try_number += 1 - ti.defer_task_from_start( + ti.defer_task_from_start_trigger( session=session, trigger_cls=ti.task.start_trigger_cls, trigger_kwargs=ti.task.start_trigger_kwargs, diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 60cac8f270ac8..e84b04dceb1a4 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -159,6 +159,7 @@ from airflow.models.dagrun import DagRun from airflow.models.dataset import DatasetEvent from airflow.models.operator import Operator + from airflow.models.trigger import Trigger from airflow.serialization.pydantic.dag import DagModelPydantic from airflow.serialization.pydantic.dataset import DatasetEventPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic @@ -282,7 +283,7 @@ def _run_raw_task( # a trigger. if raise_on_defer: raise - ti.defer_task(exception=defer, session=session) + ti.defer_task_from_task_deferred(exception=defer, session=session) ti.log.info( "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, execution_date=%s, start_date=%s", ti.dag_id, @@ -1575,12 +1576,15 @@ def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session: Ses @internal_api_call @provide_session def _defer_task( - ti: TaskInstance | TaskInstancePydantic, exception: TaskDeferred, session: Session = NEW_SESSION + ti: TaskInstance | TaskInstancePydantic, + *, + trigger_row: Trigger, + trigger_kwargs: dict[str, Any] | None, + next_method: str, + timeout: timedelta | None = None, + session: Session = NEW_SESSION, ) -> TaskInstancePydantic | TaskInstance: - from airflow.models.trigger import Trigger - # First, make the trigger entry - trigger_row = Trigger.from_object(exception.trigger) session.add(trigger_row) session.flush() @@ -1594,12 +1598,12 @@ def _defer_task( # depending on self.next_method semantics ti.state = TaskInstanceState.DEFERRED ti.trigger_id = trigger_row.id - ti.next_method = exception.method_name - ti.next_kwargs = exception.kwargs or {} + ti.next_method = next_method + ti.next_kwargs = trigger_kwargs or {} # Calculate timeout too if it was passed - if exception.timeout is not None: - ti.trigger_timeout = timezone.utcnow() + exception.timeout + if timeout is not None: + ti.trigger_timeout = timezone.utcnow() + timeout else: ti.trigger_timeout = None @@ -3000,18 +3004,32 @@ def _execute_task(self, context: Context, task_orig: Operator): return _execute_task(self, context, task_orig) @provide_session - def defer_task(self, exception: TaskDeferred, session: Session) -> None: - """Mark the task as deferred and sets up the trigger that is needed to resume it. + def defer_task_from_task_deferred(self, session: Session, exception: TaskDeferred) -> None: + """Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised. :meta: private """ - _defer_task(ti=self, exception=exception, session=session) + from airflow.models.trigger import Trigger + + if TYPE_CHECKING: + assert self.task + + # First, make the trigger entry + trigger_row = Trigger.from_object(exception.trigger) + _defer_task( + ti=self, + session=session, + trigger_row=trigger_row, + trigger_kwargs=exception.kwargs, + next_method=exception.method_name, + timeout=exception.timeout, + ) @provide_session - def defer_task_from_start( + def defer_task_from_start_trigger( self, session: Session, trigger_cls: str, trigger_kwargs: dict[str, Any], next_method: str ) -> None: - """Mark the task as deferred and sets up the trigger that is needed to resume it. + """Mark the task as deferred and sets up the trigger that is needed to resume it when start_trigger arguments passed. :meta: private """ @@ -3022,31 +3040,14 @@ def defer_task_from_start( # First, make the trigger entry trigger_row = Trigger(classpath=trigger_cls, kwargs=trigger_kwargs) - session.add(trigger_row) - session.flush() - - # Then, update ourselves so it matches the deferral request - # Keep an eye on the logic in `check_and_change_state_before_execution()` - # depending on self.next_method semantics - self.state = TaskInstanceState.DEFERRED - self.trigger_id = trigger_row.id - self.next_method = next_method - self.next_kwargs = trigger_kwargs or {} - - # Calculate timeout too if it was passed - # if defer.timeout is not None: - # self.trigger_timeout = timezone.utcnow() + defer.timeout - # else: - self.trigger_timeout = None - - # If an execution_timeout is set, set the timeout to the minimum of - # it and the trigger timeout - execution_timeout = self.task.execution_timeout - if execution_timeout: - if self.trigger_timeout: - self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) - else: - self.trigger_timeout = self.start_date + execution_timeout + _defer_task( + ti=self, + session=session, + trigger_row=trigger_row, + trigger_kwargs=trigger_kwargs, + next_method=next_method, + timeout=None, + ) def _run_execute_callback(self, context: Context, task: BaseOperator) -> None: """Functions that need to be run before a Task is executed.""" diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index e499a98691940..4a25989c1b126 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -27,7 +27,6 @@ from airflow.models.taskinstance import ( TaskInstance, TaskReturnCode, - _defer_task, _handle_reschedule, _run_raw_task, _set_ti_attrs, @@ -499,7 +498,7 @@ def _register_dataset_changes(self, *, events, session: Session | None = None) - def defer_task(self, exception: TaskDeferred, session: Session | None = None): """Defer task.""" - updated_ti = _defer_task(ti=self, exception=exception, session=session) + updated_ti = self.defer_task_from_task_deferred(exception=exception, session=session) _set_ti_attrs(self, updated_ti) def _handle_reschedule( From 22e67608293f768993e55a077a1747654afc0e8c Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 14 May 2024 19:09:41 +0800 Subject: [PATCH 03/19] refactor(baseoperator): refactor start trigger arguments as a dataclass StartTriggerArgs --- airflow/decorators/base.py | 3 +- airflow/models/abstractoperator.py | 24 +++++- airflow/models/baseoperator.py | 9 +-- airflow/models/dagrun.py | 10 +-- airflow/models/mappedoperator.py | 13 +--- airflow/models/taskinstance.py | 25 ++++--- airflow/serialization/serialized_objects.py | 19 +++-- tests/models/test_dagrun.py | 14 ++-- tests/serialization/test_dag_serialization.py | 73 ++++++++----------- 9 files changed, 95 insertions(+), 95 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 79efae67bfb6d..bd2df7494e893 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -509,8 +509,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: # task's expand() contribute to the op_kwargs operator argument, not # the operator arguments themselves, and should expand against it. expand_input_attr="op_kwargs_expand_input", - start_trigger_cls=self.operator_class.start_trigger_cls, - next_method=self.operator_class.next_method, + start_trigger_args=self.operator_class.start_trigger_args, ) return XComArg(operator=operator) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index c0024f44130ce..7aa9848f1c536 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -20,6 +20,8 @@ import datetime import inspect from abc import abstractproperty +from dataclasses import dataclass +from datetime import timedelta from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence @@ -85,6 +87,24 @@ class NotMapped(Exception): """Raise if a task is neither mapped nor has any parent mapped groups.""" +@dataclass +class StartTriggerArgs: + """Arguments required for start task execution from triggerer.""" + + trigger_cls: str + next_method: str + trigger_kwargs: dict[str, Any] | None = None + timeout: timedelta | None = None + + def serialize(self): + return { + "trigger_cls": self.trigger_cls, + "trigger_kwargs": self.trigger_kwargs, + "next_method": self.next_method, + "timeout": self.timeout, + } + + class AbstractOperator(Templater, DAGNode): """Common implementation for operators, including unmapped and mapped. @@ -122,9 +142,7 @@ class AbstractOperator(Templater, DAGNode): "node_id", # Duplicates task_id "task_group", # Doesn't have a useful repr, no point showing in UI "inherits_from_empty_operator", # impl detail - "start_trigger_cls", - "start_trigger_kwargs", - "next_method", + "start_trigger_args", # For compatibility with TG, for operators these are just the current task, no point showing "roots", "leaves", diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 3fd15120d6cd5..1e835f868c521 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -77,6 +77,7 @@ DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, DEFAULT_WEIGHT_RULE, AbstractOperator, + StartTriggerArgs, ) from airflow.models.base import _sentinel from airflow.models.mappedoperator import OperatorPartial, validate_mapping_kwargs @@ -819,9 +820,7 @@ def say_hello_world(**context): # Set to True for an operator instantiated by a mapped operator. __from_mapped = False - start_trigger_cls: str | None = None - start_trigger_kwargs: dict[str, Any] | None = None - next_method: str | None = None + start_trigger_args: StartTriggerArgs | None = None def __init__( self, @@ -1680,9 +1679,7 @@ def get_serialized_fields(cls): "is_teardown", "on_failure_fail_dagrun", "map_index_template", - "start_trigger_cls", - "start_trigger_kwargs", - "next_method", + "start_trigger_args", "_needs_expansion", } ) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 817c3d2cc195e..2e6f5e21d3a38 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1539,9 +1539,8 @@ def schedule_tis( ): dummy_ti_ids.append((ti.task_id, ti.map_index)) elif ( - ti.task.start_trigger_cls is not None - and ti.task.start_trigger_kwargs is not None - and ti.task.next_method is not None + ti.task.start_trigger_args is not None + and ti.task.start_trigger_args.trigger_kwargs is not None and not ti.task.on_execute_callback and not ti.task.on_success_callback and not ti.task.outlets @@ -1549,10 +1548,7 @@ def schedule_tis( if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: ti.try_number += 1 ti.defer_task_from_start_trigger( - session=session, - trigger_cls=ti.task.start_trigger_cls, - trigger_kwargs=ti.task.start_trigger_kwargs, - next_method=ti.task.next_method, + session=session, start_trigger_args=ti.task.start_trigger_args ) else: schedulable_ti_ids.append((ti.task_id, ti.map_index)) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index ecbf6d1799f92..76a01f9f53b34 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -41,6 +41,7 @@ DEFAULT_WEIGHT_RULE, AbstractOperator, NotMapped, + StartTriggerArgs, ) from airflow.models.expandinput import ( DictOfListsExpandInput, @@ -236,9 +237,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: # For classic operators, this points to expand_input because kwargs # to BaseOperator.expand() contribute to operator arguments. expand_input_attr="expand_input", - start_trigger_cls=self.operator_class.start_trigger_cls, - start_trigger_kwargs=getattr(partial_kwargs, "start_trigger_kwargs", None), - next_method=self.operator_class.next_method, + start_trigger_args=self.operator_class.start_trigger_args, ) return op @@ -281,9 +280,7 @@ class MappedOperator(AbstractOperator): _task_module: str _task_type: str _operator_name: str - start_trigger_cls: str | None - start_trigger_kwargs: dict[str, Any] | None - next_method: str | None + start_trigger_args: StartTriggerArgs | None _needs_expansion: bool = True dag: DAG | None @@ -313,9 +310,7 @@ class MappedOperator(AbstractOperator): ( "parse_time_mapped_ti_count", "operator_class", - "start_trigger_cls", - "start_trigger_kwargs", - "next_method", + "start_trigger_args", ) ) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index e84b04dceb1a4..530493bedf44c 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -153,7 +153,7 @@ from sqlalchemy.sql.elements import BooleanClauseList from sqlalchemy.sql.expression import ColumnOperators - from airflow.models.abstractoperator import TaskStateChangeCallback + from airflow.models.abstractoperator import StartTriggerArgs, TaskStateChangeCallback from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun @@ -3011,9 +3011,6 @@ def defer_task_from_task_deferred(self, session: Session, exception: TaskDeferre """ from airflow.models.trigger import Trigger - if TYPE_CHECKING: - assert self.task - # First, make the trigger entry trigger_row = Trigger.from_object(exception.trigger) _defer_task( @@ -3027,26 +3024,30 @@ def defer_task_from_task_deferred(self, session: Session, exception: TaskDeferre @provide_session def defer_task_from_start_trigger( - self, session: Session, trigger_cls: str, trigger_kwargs: dict[str, Any], next_method: str + self, + session: Session, + start_trigger_args: StartTriggerArgs, ) -> None: """Mark the task as deferred and sets up the trigger that is needed to resume it when start_trigger arguments passed. :meta: private """ - from airflow.models.trigger import Trigger + if start_trigger_args.trigger_kwargs is None: + raise AirflowException("trigger_kwargs is required") - if TYPE_CHECKING: - assert self.task + from airflow.models.trigger import Trigger # First, make the trigger entry - trigger_row = Trigger(classpath=trigger_cls, kwargs=trigger_kwargs) + trigger_row = Trigger( + classpath=start_trigger_args.trigger_cls, kwargs=start_trigger_args.trigger_kwargs + ) _defer_task( ti=self, session=session, trigger_row=trigger_row, - trigger_kwargs=trigger_kwargs, - next_method=next_method, - timeout=None, + trigger_kwargs=start_trigger_args.trigger_kwargs, + next_method=start_trigger_args.next_method, + timeout=start_trigger_args.timeout, ) def _run_execute_callback(self, context: Context, task: BaseOperator) -> None: diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index da6f590e230d4..8e0ba07cf0bb8 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -40,7 +40,7 @@ from airflow.datasets import Dataset, DatasetAll, DatasetAny from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError, TaskDeferred from airflow.jobs.job import Job -from airflow.models.baseoperator import BaseOperator +from airflow.models.baseoperator import BaseOperator, StartTriggerArgs from airflow.models.connection import Connection from airflow.models.dag import DAG, DagModel, create_timetable from airflow.models.dagrun import DagRun @@ -1017,9 +1017,9 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) # Used to determine if an Operator is inherited from EmptyOperator serialize_op["_is_empty"] = op.inherits_from_empty_operator - serialize_op["start_trigger_cls"] = op.start_trigger_cls - serialize_op["start_trigger_kwargs"] = op.start_trigger_kwargs - serialize_op["next_method"] = op.next_method + serialize_op["start_trigger_args"] = ( + op.start_trigger_args.serialize() if op.start_trigger_args else None + ) if op.operator_extra_links: serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links( @@ -1203,9 +1203,10 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: # Used to determine if an Operator is inherited from EmptyOperator setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False))) - setattr(op, "start_trigger_cls", encoded_op.get("start_trigger_cls", None)) - setattr(op, "start_trigger_kwargs", encoded_op.get("start_trigger_kwargs", None)) - setattr(op, "next_method", encoded_op.get("next_method", None)) + start_trigger_args = None + if encoded_op.get("start_trigger_args", None): + start_trigger_args = StartTriggerArgs(**encoded_op.get("start_trigger_args", None)) + setattr(op, "start_trigger_args", start_trigger_args) @staticmethod def set_task_dag_references(task: Operator, dag: DAG) -> None: @@ -1268,9 +1269,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: end_date=None, disallow_kwargs_override=encoded_op["_disallow_kwargs_override"], expand_input_attr=encoded_op["_expand_input_attr"], - start_trigger_cls=encoded_op.get("start_trigger_cls", None), - start_trigger_kwargs=encoded_op.get("start_trigger_kwargs", None), - next_method=encoded_op.get("next_method", None), + start_trigger_args=encoded_op.get("start_trigger_args", None), ) else: op = SerializedBaseOperator(task_id=encoded_op["task_id"]) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 09900ffdfff91..9782767f80863 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -30,6 +30,7 @@ from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.decorators import setup, task, task_group, teardown from airflow.exceptions import AirflowException +from airflow.models.abstractoperator import StartTriggerArgs from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun, DagRunNote @@ -1988,17 +1989,20 @@ def test_schedule_tis_map_index(dag_maker, session): def test_schedule_tis_start_trigger(dag_maker, session): """ - Test that an operator with start_trigger_cls, start_trigger_kwargs and next_method set can be - directly deferred during scheduling. + Test that an operator with start_trigger_args set can be directly deferred during scheduling. """ class TestOperator(BaseOperator): - start_trigger_cls = "airflow.triggers.testing.SuccessTrigger" - next_method = "execute_complete" + start_trigger_args = StartTriggerArgs( + trigger_cls="airflow.triggers.testing.SuccessTrigger", + trigger_kwargs=None, + next_method="execute_complete", + timeout=None, + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.start_trigger_kwargs = {} + self.start_trigger_args.trigger_kwargs = {} def execute_complete(self): pass diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index ceb2a5566727d..7afa19c9d702c 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -47,6 +47,7 @@ from airflow.decorators.base import DecoratedOperator from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError from airflow.hooks.base import BaseHook +from airflow.models.abstractoperator import StartTriggerArgs from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection from airflow.models.dag import DAG @@ -196,8 +197,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i "_log_config_logger_name": "airflow.task.operators", "_needs_expansion": False, "weight_rule": "downstream", - "next_method": None, - "start_trigger": None, + "start_trigger_args": None, }, }, { @@ -226,9 +226,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i "_log_config_logger_name": "airflow.task.operators", "_needs_expansion": False, "weight_rule": "downstream", - "next_method": None, - "start_trigger_cls": None, - "start_trigger_kwargs": None, + "start_trigger_args": None, }, }, ], @@ -2167,27 +2165,33 @@ def execute(self, context: Context): SerializedDAG.to_dict(dag) @pytest.mark.db_test - def test_start_trigger_cls_kwargs_and_next_method_in_serialized_dag(self): + def test_start_trigger_args_in_serialized_dag(self): """ - Test that when we provide start_trigger_cls, start_trigger_kwargs and next_method, - the DAG can be correctly serialized. + Test that when we provide start_trigger_args, the DAG can be correctly serialized. """ class TestOperator(BaseOperator): - start_trigger_cls = "airflow.triggers.testing.SuccessTrigger" - next_method = "execute_complete" + start_trigger_args = StartTriggerArgs( + trigger_cls="airflow.triggers.testing.SuccessTrigger", + trigger_kwargs=None, + next_method="execute_complete", + timeout=None, + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.start_trigger_kwargs = {} + self.start_trigger_args.trigger_kwargs = {} def execute_complete(self): pass class Test2Operator(BaseOperator): - start_trigger_cls = "airflow.triggers.testing.SuccessTrigger" - start_trigger_kwargs = {} - next_method = "execute_complete" + start_trigger_args = StartTriggerArgs( + trigger_cls="airflow.triggers.testing.SuccessTrigger", + trigger_kwargs={}, + next_method="execute_complete", + timeout=None, + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -2204,9 +2208,12 @@ def execute_complete(self): serialized_obj = SerializedDAG.to_dict(dag) for task in serialized_obj["dag"]["tasks"]: - assert task["__var"]["start_trigger_cls"] == "airflow.triggers.testing.SuccessTrigger" - assert task["__var"]["start_trigger_kwargs"] == {} - assert task["__var"]["next_method"] == "execute_complete" + assert task["__var"]["start_trigger_args"] == { + "trigger_cls": "airflow.triggers.testing.SuccessTrigger", + "trigger_kwargs": {}, + "next_method": "execute_complete", + "timeout": None, + } def test_kubernetes_optional(): @@ -2258,9 +2265,7 @@ def test_operator_expand_serde(): "_needs_expansion": True, "_task_module": "airflow.operators.bash", "_task_type": "BashOperator", - "start_trigger_cls": None, - "start_trigger_kwargs": None, - "next_method": None, + "start_trigger_args": None, "downstream_task_ids": [], "expand_input": { "type": "dict-of-lists", @@ -2293,9 +2298,7 @@ def test_operator_expand_serde(): assert op.operator_class == { "_task_type": "BashOperator", "_needs_expansion": True, - "start_trigger_cls": None, - "start_trigger_kwargs": None, - "next_method": None, + "start_trigger_args": None, "downstream_task_ids": [], "task_id": "a", "template_ext": [".sh", ".bash"], @@ -2341,9 +2344,7 @@ def test_operator_expand_xcomarg_serde(): "ui_fgcolor": "#000", "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", - "next_method": None, - "start_trigger_cls": None, - "start_trigger_kwargs": None, + "start_trigger_args": None, } op = BaseSerialization.deserialize(serialized) @@ -2400,9 +2401,7 @@ def test_operator_expand_kwargs_literal_serde(strict): "ui_fgcolor": "#000", "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", - "next_method": None, - "start_trigger_cls": None, - "start_trigger_kwargs": None, + "start_trigger_args": None, } op = BaseSerialization.deserialize(serialized) @@ -2450,9 +2449,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "ui_fgcolor": "#000", "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", - "next_method": None, - "start_trigger_cls": None, - "start_trigger_kwargs": None, + "start_trigger_args": None, } op = BaseSerialization.deserialize(serialized) @@ -2570,9 +2567,7 @@ def x(arg1, arg2, arg3): "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, "_disallow_kwargs_override": False, "_expand_input_attr": "op_kwargs_expand_input", - "next_method": None, - "start_trigger_cls": None, - "start_trigger_kwargs": None, + "start_trigger_args": None, } deserialized = BaseSerialization.deserialize(serialized) @@ -2638,9 +2633,7 @@ def x(arg1, arg2, arg3): "_task_module": "airflow.decorators.python", "_task_type": "_PythonDecoratedOperator", "_operator_name": "@task", - "next_method": None, - "start_trigger_cls": None, - "start_trigger_kwargs": None, + "start_trigger_args": None, "downstream_task_ids": [], "partial_kwargs": { "is_setup": False, @@ -2792,9 +2785,7 @@ def operator_extra_links(self): "_is_empty": False, "_is_mapped": True, "_needs_expansion": True, - "next_method": None, - "start_trigger_cls": None, - "start_trigger_kwargs": None, + "start_trigger_args": None, } deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR]) assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()] From f88e60d2fc3249ef5689098f82219d7c894cfcaf Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 14 May 2024 19:20:27 +0800 Subject: [PATCH 04/19] docs(deferring): update docs for newly introduced StartTriggerArgs --- .../authoring-and-scheduling/deferring.rst | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst b/docs/apache-airflow/authoring-and-scheduling/deferring.rst index 084a08f0ac020..4830da6521211 100644 --- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst +++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst @@ -143,10 +143,12 @@ The ``self.defer`` call raises the ``TaskDeferred`` exception, so it can work an Triggering Deferral from Start ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you want to defer your task directly to the triggerer without going into the worker, you can add the class level attributes ``start_trigger`` and ``next_method`` to your deferrable operator. +If you want to defer your task directly to the triggerer without going into the worker, you can add the class level attributes ``start_trigger_args`` with the following 4 attributes to your deferrable operator. -* ``start_trigger``: An instance of a trigger you want to defer to. It will be serialized into the database. +* ``trigger_cls``: An importable path to your trigger class. +* ``trigger_kwargs``: Additional keyword arguments to pass to the method when it is called. * ``next_method``: The method name on your operator that you want Airflow to call when it resumes. +* ``timeout``: (Optional) A timedelta that specifies a timeout after which this deferral will fail, and fail the task instance. Defaults to ``None``, which means no timeout. This is particularly useful when deferring is the only thing the ``execute`` method does. Here's a basic refinement of the previous example. @@ -156,23 +158,27 @@ This is particularly useful when deferring is the only thing the ``execute`` met from datetime import timedelta from typing import Any + from airflow.models.abstractoperator import StartTriggerArgs from airflow.sensors.base import BaseSensorOperator - from airflow.triggers.temporal import TimeDeltaTrigger from airflow.utils.context import Context class WaitOneHourSensor(BaseSensorOperator): - start_trigger = TimeDeltaTrigger(timedelta(hours=1)) - next_method = "execute_complete" + start_trigger_args = StartTriggerArgs( + trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger", + trigger_kwargs={"moment": timedelta(hours=1)}, + next_method="execute_complete", + timeout=None, + ) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: # We have no more work to do here. Mark as complete. return -``start_trigger`` and ``next_method`` can also be set at the instance level for more flexible configuration. +``trigger_kwargs`` can also be modified at the instance level for more flexible configuration. .. warning:: - Dynamic task mapping is not supported when ``start_trigger`` and ``next_method`` are assigned in instance level. + Dynamic task mapping is not supported when ``trigger_kwargs`` is modified at instance level. .. code-block:: python @@ -184,11 +190,17 @@ This is particularly useful when deferring is the only thing the ``execute`` met from airflow.utils.context import Context - class WaitOneHourSensor(BaseSensorOperator): + class WaitTwoHourSensor(BaseSensorOperator): + start_trigger_args = StartTriggerArgs( + trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger", + trigger_kwargs={}, + next_method="execute_complete", + timeout=None, + ) + def __init__(self, *args: list[Any], **kwargs: dict[str, Any]) -> None: super().__init__(*args, **kwargs) - self.start_trigger = TimeDeltaTrigger(timedelta(hours=1)) - self.next_method = "execute_complete" + self.start_trigger_args.trigger_kwargs = trigger_kwargs = ({"moment": timedelta(hours=1)},) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: # We have no more work to do here. Mark as complete. From 14199cc22c6e66bbbbef23ccc69978e3fd4f605b Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 15 May 2024 11:17:19 +0800 Subject: [PATCH 05/19] feat(baseoperator): add start_from_trigger as the flag to decide whether to start task execution from triggerer --- airflow/decorators/base.py | 1 + airflow/models/abstractoperator.py | 2 ++ airflow/models/baseoperator.py | 2 ++ airflow/models/dagrun.py | 4 ++-- airflow/models/mappedoperator.py | 8 +++----- airflow/serialization/serialized_objects.py | 3 +++ .../authoring-and-scheduling/deferring.rst | 8 +++++--- tests/models/test_dagrun.py | 1 + tests/serialization/test_dag_serialization.py | 14 ++++++++++++++ 9 files changed, 33 insertions(+), 10 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index bd2df7494e893..74b44ffe23804 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -510,6 +510,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: # the operator arguments themselves, and should expand against it. expand_input_attr="op_kwargs_expand_input", start_trigger_args=self.operator_class.start_trigger_args, + start_from_trigger=self.operator_class.start_from_trigger, ) return XComArg(operator=operator) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 7aa9848f1c536..dbec9543854eb 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -142,7 +142,9 @@ class AbstractOperator(Templater, DAGNode): "node_id", # Duplicates task_id "task_group", # Doesn't have a useful repr, no point showing in UI "inherits_from_empty_operator", # impl detail + # Decide whether to start task execution from triggerer "start_trigger_args", + "start_from_trigger", # For compatibility with TG, for operators these are just the current task, no point showing "roots", "leaves", diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 1e835f868c521..c0cc17e854761 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -821,6 +821,7 @@ def say_hello_world(**context): __from_mapped = False start_trigger_args: StartTriggerArgs | None = None + start_from_trigger: bool = False def __init__( self, @@ -1681,6 +1682,7 @@ def get_serialized_fields(cls): "map_index_template", "start_trigger_args", "_needs_expansion", + "start_from_trigger", } ) DagContext.pop_context_managed_dag() diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 2e6f5e21d3a38..b5277aeb5d2c9 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1539,8 +1539,8 @@ def schedule_tis( ): dummy_ti_ids.append((ti.task_id, ti.map_index)) elif ( - ti.task.start_trigger_args is not None - and ti.task.start_trigger_args.trigger_kwargs is not None + ti.task.start_from_trigger is True + and ti.task.start_trigger_args is not None and not ti.task.on_execute_callback and not ti.task.on_success_callback and not ti.task.outlets diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 76a01f9f53b34..1ec66c6847fc1 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -238,6 +238,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: # to BaseOperator.expand() contribute to operator arguments. expand_input_attr="expand_input", start_trigger_args=self.operator_class.start_trigger_args, + start_from_trigger=self.operator_class.start_from_trigger, ) return op @@ -281,6 +282,7 @@ class MappedOperator(AbstractOperator): _task_type: str _operator_name: str start_trigger_args: StartTriggerArgs | None + start_from_trigger: bool _needs_expansion: bool = True dag: DAG | None @@ -307,11 +309,7 @@ class MappedOperator(AbstractOperator): supports_lineage: bool = False HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset( - ( - "parse_time_mapped_ti_count", - "operator_class", - "start_trigger_args", - ) + ("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger") ) def __hash__(self): diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 8e0ba07cf0bb8..eae53053ca8d6 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1020,6 +1020,7 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) serialize_op["start_trigger_args"] = ( op.start_trigger_args.serialize() if op.start_trigger_args else None ) + serialize_op["start_from_trigger"] = op.start_from_trigger if op.operator_extra_links: serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links( @@ -1207,6 +1208,7 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: if encoded_op.get("start_trigger_args", None): start_trigger_args = StartTriggerArgs(**encoded_op.get("start_trigger_args", None)) setattr(op, "start_trigger_args", start_trigger_args) + setattr(op, "start_from_trigger", bool(encoded_op.get("start_from_trigger", False))) @staticmethod def set_task_dag_references(task: Operator, dag: DAG) -> None: @@ -1270,6 +1272,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: disallow_kwargs_override=encoded_op["_disallow_kwargs_override"], expand_input_attr=encoded_op["_expand_input_attr"], start_trigger_args=encoded_op.get("start_trigger_args", None), + start_from_trigger=encoded_op.get("start_from_trigger", False), ) else: op = SerializedBaseOperator(task_id=encoded_op["task_id"]) diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst b/docs/apache-airflow/authoring-and-scheduling/deferring.rst index 4830da6521211..012d3dd6f043a 100644 --- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst +++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst @@ -143,7 +143,7 @@ The ``self.defer`` call raises the ``TaskDeferred`` exception, so it can work an Triggering Deferral from Start ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you want to defer your task directly to the triggerer without going into the worker, you can add the class level attributes ``start_trigger_args`` with the following 4 attributes to your deferrable operator. +If you want to defer your task directly to the triggerer without going into the worker, you can set class level attribute ``start_with_trigger`` to ``True`` add add class level attribute ``start_trigger_args`` with the following 4 attributes to your deferrable operator. * ``trigger_cls``: An importable path to your trigger class. * ``trigger_kwargs``: Additional keyword arguments to pass to the method when it is called. @@ -170,12 +170,13 @@ This is particularly useful when deferring is the only thing the ``execute`` met next_method="execute_complete", timeout=None, ) + start_from_trigger = True def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: # We have no more work to do here. Mark as complete. return -``trigger_kwargs`` can also be modified at the instance level for more flexible configuration. +``start_from_trigger`` and ``trigger_kwargs`` can also be modified at the instance level for more flexible configuration. .. warning:: Dynamic task mapping is not supported when ``trigger_kwargs`` is modified at instance level. @@ -200,7 +201,8 @@ This is particularly useful when deferring is the only thing the ``execute`` met def __init__(self, *args: list[Any], **kwargs: dict[str, Any]) -> None: super().__init__(*args, **kwargs) - self.start_trigger_args.trigger_kwargs = trigger_kwargs = ({"moment": timedelta(hours=1)},) + self.start_trigger_args.trigger_kwargs = {"moment": timedelta(hours=1)} + self.start_from_trigger = True def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: # We have no more work to do here. Mark as complete. diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 9782767f80863..fd71313390154 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -1999,6 +1999,7 @@ class TestOperator(BaseOperator): next_method="execute_complete", timeout=None, ) + start_from_trigger = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 7afa19c9d702c..44284e3740b39 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -198,6 +198,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i "_needs_expansion": False, "weight_rule": "downstream", "start_trigger_args": None, + "start_from_trigger": False, }, }, { @@ -227,6 +228,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i "_needs_expansion": False, "weight_rule": "downstream", "start_trigger_args": None, + "start_from_trigger": False, }, }, ], @@ -2177,10 +2179,12 @@ class TestOperator(BaseOperator): next_method="execute_complete", timeout=None, ) + start_from_trigger = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.start_trigger_args.trigger_kwargs = {} + self.start_from_trigger = True def execute_complete(self): pass @@ -2192,6 +2196,7 @@ class Test2Operator(BaseOperator): next_method="execute_complete", timeout=None, ) + start_from_trigger = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -2214,6 +2219,7 @@ def execute_complete(self): "next_method": "execute_complete", "timeout": None, } + assert task["__var"]["start_from_trigger"] is True def test_kubernetes_optional(): @@ -2266,6 +2272,7 @@ def test_operator_expand_serde(): "_task_module": "airflow.operators.bash", "_task_type": "BashOperator", "start_trigger_args": None, + "start_from_trigger": False, "downstream_task_ids": [], "expand_input": { "type": "dict-of-lists", @@ -2299,6 +2306,7 @@ def test_operator_expand_serde(): "_task_type": "BashOperator", "_needs_expansion": True, "start_trigger_args": None, + "start_from_trigger": False, "downstream_task_ids": [], "task_id": "a", "template_ext": [".sh", ".bash"], @@ -2345,6 +2353,7 @@ def test_operator_expand_xcomarg_serde(): "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", "start_trigger_args": None, + "start_from_trigger": False, } op = BaseSerialization.deserialize(serialized) @@ -2402,6 +2411,7 @@ def test_operator_expand_kwargs_literal_serde(strict): "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", "start_trigger_args": None, + "start_from_trigger": False, } op = BaseSerialization.deserialize(serialized) @@ -2450,6 +2460,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", "start_trigger_args": None, + "start_from_trigger": False, } op = BaseSerialization.deserialize(serialized) @@ -2568,6 +2579,7 @@ def x(arg1, arg2, arg3): "_disallow_kwargs_override": False, "_expand_input_attr": "op_kwargs_expand_input", "start_trigger_args": None, + "start_from_trigger": False, } deserialized = BaseSerialization.deserialize(serialized) @@ -2634,6 +2646,7 @@ def x(arg1, arg2, arg3): "_task_type": "_PythonDecoratedOperator", "_operator_name": "@task", "start_trigger_args": None, + "start_from_trigger": False, "downstream_task_ids": [], "partial_kwargs": { "is_setup": False, @@ -2786,6 +2799,7 @@ def operator_extra_links(self): "_is_mapped": True, "_needs_expansion": True, "start_trigger_args": None, + "start_from_trigger": False, } deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag[Encoding.VAR]) assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()] From df71b44ca42b80e1a792abd31b35fdb590415fa1 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 28 May 2024 15:49:41 +0800 Subject: [PATCH 06/19] fix(taskinstance): fix unexpected commit --- airflow/models/taskinstance.py | 9 +++++---- airflow/serialization/pydantic/taskinstance.py | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 530493bedf44c..8cf109592906f 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1619,8 +1619,6 @@ def _defer_task( ti.trigger_timeout = ti.start_date + execution_timeout if ti.test_mode: _add_log(event=ti.state, task_instance=ti, session=session) - session.merge(ti) - session.commit() return ti @@ -3004,7 +3002,7 @@ def _execute_task(self, context: Context, task_orig: Operator): return _execute_task(self, context, task_orig) @provide_session - def defer_task_from_task_deferred(self, session: Session, exception: TaskDeferred) -> None: + def defer_task_from_task_deferred(self, exception: TaskDeferred, session: Session = NEW_SESSION) -> None: """Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised. :meta: private @@ -3022,11 +3020,14 @@ def defer_task_from_task_deferred(self, session: Session, exception: TaskDeferre timeout=exception.timeout, ) + session.merge(self) + session.commit() + @provide_session def defer_task_from_start_trigger( self, - session: Session, start_trigger_args: StartTriggerArgs, + session: Session = NEW_SESSION, ) -> None: """Mark the task as deferred and sets up the trigger that is needed to resume it when start_trigger arguments passed. diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index 4a25989c1b126..e6b8594816989 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -27,6 +27,7 @@ from airflow.models.taskinstance import ( TaskInstance, TaskReturnCode, + _defer_task, _handle_reschedule, _run_raw_task, _set_ti_attrs, @@ -498,7 +499,19 @@ def _register_dataset_changes(self, *, events, session: Session | None = None) - def defer_task(self, exception: TaskDeferred, session: Session | None = None): """Defer task.""" - updated_ti = self.defer_task_from_task_deferred(exception=exception, session=session) + from airflow.models.trigger import Trigger + + trigger_row = Trigger.from_object(exception.trigger) + updated_ti = _defer_task( + ti=self, + session=session, + trigger_row=trigger_row, + trigger_kwargs=exception.kwargs, + next_method=exception.method_name, + timeout=exception.timeout, + ) + session.merge(self) + session.commit() _set_ti_attrs(self, updated_ti) def _handle_reschedule( From a6690689e9432e2ac043762a8ea00d6ba301b072 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 28 May 2024 17:10:18 +0800 Subject: [PATCH 07/19] fix(taskinstance): add _defer_task_from_task_deferred --- .../endpoints/rpc_api_endpoint.py | 2 + airflow/cli/commands/task_command.py | 2 +- airflow/models/taskinstance.py | 39 ++++++++++++------- .../serialization/pydantic/taskinstance.py | 18 ++------- tests/serialization/test_pydantic_models.py | 4 +- 5 files changed, 32 insertions(+), 33 deletions(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 1820d63194106..c8ff457ae181c 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -51,6 +51,7 @@ def _initialize_map() -> dict[str, Callable]: TaskInstance, _add_log, _defer_task, + _defer_task_from_task_deferred, _get_template_context, _handle_failure, _handle_reschedule, @@ -64,6 +65,7 @@ def _initialize_map() -> dict[str, Callable]: functions: list[Callable] = [ _default_action_log_internal, _defer_task, + _defer_task_from_task_deferred, _get_template_context, _get_ti_db_access, _update_rtif, diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index d2b0e90052743..92e6542000165 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -673,7 +673,7 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N else: ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True, raise_on_defer=True) except TaskDeferred as defer: - ti.defer_task_from_task_deferred(defer=defer, session=session) + ti.defer_task_from_task_deferred(exception=defer, session=session) log.info("[TASK TEST] running trigger in line") event = _run_inline_trigger(defer.trigger) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 8cf109592906f..9f7b9651cf0ae 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1573,6 +1573,29 @@ def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session: Ses return ti +@internal_api_call +@provide_session +def _defer_task_from_task_deferred( + ti: TaskInstance | TaskInstancePydantic, exception: TaskDeferred, session: Session = NEW_SESSION +) -> TaskInstancePydantic | TaskInstance: + from airflow.models.trigger import Trigger + + # First, make the trigger entry + trigger_row = Trigger.from_object(exception.trigger) + updated_ti = _defer_task( + ti=ti, + session=session, + trigger_row=trigger_row, + trigger_kwargs=exception.kwargs, + next_method=exception.method_name, + timeout=exception.timeout, + ) + + session.merge(updated_ti) + session.commit() + return updated_ti + + @internal_api_call @provide_session def _defer_task( @@ -3007,21 +3030,7 @@ def defer_task_from_task_deferred(self, exception: TaskDeferred, session: Sessio :meta: private """ - from airflow.models.trigger import Trigger - - # First, make the trigger entry - trigger_row = Trigger.from_object(exception.trigger) - _defer_task( - ti=self, - session=session, - trigger_row=trigger_row, - trigger_kwargs=exception.kwargs, - next_method=exception.method_name, - timeout=exception.timeout, - ) - - session.merge(self) - session.commit() + _defer_task_from_task_deferred(ti=self, session=session, exception=exception) @provide_session def defer_task_from_start_trigger( diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index e6b8594816989..5b7a8ca3d3052 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -27,7 +27,7 @@ from airflow.models.taskinstance import ( TaskInstance, TaskReturnCode, - _defer_task, + _defer_task_from_task_deferred, _handle_reschedule, _run_raw_task, _set_ti_attrs, @@ -497,21 +497,9 @@ def command_as_list( def _register_dataset_changes(self, *, events, session: Session | None = None) -> None: TaskInstance._register_dataset_changes(self=self, events=events, session=session) # type: ignore[arg-type] - def defer_task(self, exception: TaskDeferred, session: Session | None = None): + def defer_task_from_task_deferred(self, exception: TaskDeferred, session: Session | None = None): """Defer task.""" - from airflow.models.trigger import Trigger - - trigger_row = Trigger.from_object(exception.trigger) - updated_ti = _defer_task( - ti=self, - session=session, - trigger_row=trigger_row, - trigger_kwargs=exception.kwargs, - next_method=exception.method_name, - timeout=exception.timeout, - ) - session.merge(self) - session.commit() + updated_ti = _defer_task_from_task_deferred(ti=self, session=session, exception=exception) _set_ti_attrs(self, updated_ti) def _handle_reschedule( diff --git a/tests/serialization/test_pydantic_models.py b/tests/serialization/test_pydantic_models.py index 048faebf54d04..dae611e68bd24 100644 --- a/tests/serialization/test_pydantic_models.py +++ b/tests/serialization/test_pydantic_models.py @@ -78,8 +78,8 @@ def test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, d "_needs_expansion": True, "_task_type": "_PythonDecoratedOperator", "downstream_task_ids": [], - "next_method": None, - "start_trigger": None, + "start_from_trigger": False, + "start_trigger_args": None, "_operator_name": "@task", "ui_fgcolor": "#000", "ui_color": "#ffefeb", From 6731ab4538efea1f07fa022ae5b1f65327fa68ef Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 29 May 2024 18:21:12 +0800 Subject: [PATCH 08/19] docs(deferring): update version added --- docs/apache-airflow/authoring-and-scheduling/deferring.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst b/docs/apache-airflow/authoring-and-scheduling/deferring.rst index 012d3dd6f043a..28405783f07d3 100644 --- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst +++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst @@ -143,6 +143,8 @@ The ``self.defer`` call raises the ``TaskDeferred`` exception, so it can work an Triggering Deferral from Start ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + .. versionadded:: 2.10.0 + If you want to defer your task directly to the triggerer without going into the worker, you can set class level attribute ``start_with_trigger`` to ``True`` add add class level attribute ``start_trigger_args`` with the following 4 attributes to your deferrable operator. * ``trigger_cls``: An importable path to your trigger class. From e32ff040f808b8bd7b57024891387c10dbc9e9ff Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 30 May 2024 14:48:30 +0800 Subject: [PATCH 09/19] refactor: rename defer_task_from_* methods --- airflow/api_internal/endpoints/rpc_api_endpoint.py | 4 ++-- airflow/cli/commands/task_command.py | 2 +- airflow/models/dagrun.py | 4 +--- airflow/models/taskinstance.py | 10 +++++----- airflow/serialization/pydantic/taskinstance.py | 6 +++--- 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index c8ff457ae181c..8704cede18c91 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -51,7 +51,7 @@ def _initialize_map() -> dict[str, Callable]: TaskInstance, _add_log, _defer_task, - _defer_task_from_task_deferred, + _defer_task_from_exception, _get_template_context, _handle_failure, _handle_reschedule, @@ -65,7 +65,7 @@ def _initialize_map() -> dict[str, Callable]: functions: list[Callable] = [ _default_action_log_internal, _defer_task, - _defer_task_from_task_deferred, + _defer_task_from_exception, _get_template_context, _get_ti_db_access, _update_rtif, diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 92e6542000165..fd68e7124fabd 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -673,7 +673,7 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N else: ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True, raise_on_defer=True) except TaskDeferred as defer: - ti.defer_task_from_task_deferred(exception=defer, session=session) + ti.defer_task_from_exception(exception=defer, session=session) log.info("[TASK TEST] running trigger in line") event = _run_inline_trigger(defer.trigger) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index b5277aeb5d2c9..0f117c3c90857 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1547,9 +1547,7 @@ def schedule_tis( ): if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: ti.try_number += 1 - ti.defer_task_from_start_trigger( - session=session, start_trigger_args=ti.task.start_trigger_args - ) + ti.defer_task_from_scheduler(session=session, start_trigger_args=ti.task.start_trigger_args) else: schedulable_ti_ids.append((ti.task_id, ti.map_index)) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 9f7b9651cf0ae..09567e28e41a1 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -283,7 +283,7 @@ def _run_raw_task( # a trigger. if raise_on_defer: raise - ti.defer_task_from_task_deferred(exception=defer, session=session) + ti.defer_task_from_exception(exception=defer, session=session) ti.log.info( "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, execution_date=%s, start_date=%s", ti.dag_id, @@ -1575,7 +1575,7 @@ def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session: Ses @internal_api_call @provide_session -def _defer_task_from_task_deferred( +def _defer_task_from_exception( ti: TaskInstance | TaskInstancePydantic, exception: TaskDeferred, session: Session = NEW_SESSION ) -> TaskInstancePydantic | TaskInstance: from airflow.models.trigger import Trigger @@ -3025,15 +3025,15 @@ def _execute_task(self, context: Context, task_orig: Operator): return _execute_task(self, context, task_orig) @provide_session - def defer_task_from_task_deferred(self, exception: TaskDeferred, session: Session = NEW_SESSION) -> None: + def defer_task_from_exception(self, exception: TaskDeferred, session: Session = NEW_SESSION) -> None: """Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised. :meta: private """ - _defer_task_from_task_deferred(ti=self, session=session, exception=exception) + _defer_task_from_exception(ti=self, session=session, exception=exception) @provide_session - def defer_task_from_start_trigger( + def defer_task_from_scheduler( self, start_trigger_args: StartTriggerArgs, session: Session = NEW_SESSION, diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index 5b7a8ca3d3052..b80ca199a087c 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -27,7 +27,7 @@ from airflow.models.taskinstance import ( TaskInstance, TaskReturnCode, - _defer_task_from_task_deferred, + _defer_task_from_exception, _handle_reschedule, _run_raw_task, _set_ti_attrs, @@ -497,9 +497,9 @@ def command_as_list( def _register_dataset_changes(self, *, events, session: Session | None = None) -> None: TaskInstance._register_dataset_changes(self=self, events=events, session=session) # type: ignore[arg-type] - def defer_task_from_task_deferred(self, exception: TaskDeferred, session: Session | None = None): + def defer_task_from_exception(self, exception: TaskDeferred, session: Session | None = None): """Defer task.""" - updated_ti = _defer_task_from_task_deferred(ti=self, session=session, exception=exception) + updated_ti = _defer_task_from_exception(ti=self, session=session, exception=exception) _set_ti_attrs(self, updated_ti) def _handle_reschedule( From 80ae6ebfd4746f303a7a3ea7c80d920be7a546ef Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 30 May 2024 15:44:42 +0800 Subject: [PATCH 10/19] fix(dagrun): remove uncessary conditions on scheduling tasks --- airflow/models/dagrun.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 0f117c3c90857..ca21a51c80b18 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1538,13 +1538,7 @@ def schedule_tis( and not ti.task.outlets ): dummy_ti_ids.append((ti.task_id, ti.map_index)) - elif ( - ti.task.start_from_trigger is True - and ti.task.start_trigger_args is not None - and not ti.task.on_execute_callback - and not ti.task.on_success_callback - and not ti.task.outlets - ): + elif ti.task.start_from_trigger is True and ti.task.start_trigger_args is not None: if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: ti.try_number += 1 ti.defer_task_from_scheduler(session=session, start_trigger_args=ti.task.start_trigger_args) From 009bcd1f900ade96a2ea688e4845d231ec59a8df Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 30 May 2024 17:21:09 +0800 Subject: [PATCH 11/19] fix(taskinstance): remove unnecessay check on trigger_kwargs --- airflow/models/taskinstance.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 09567e28e41a1..d5e4afdcee734 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -3042,9 +3042,6 @@ def defer_task_from_scheduler( :meta: private """ - if start_trigger_args.trigger_kwargs is None: - raise AirflowException("trigger_kwargs is required") - from airflow.models.trigger import Trigger # First, make the trigger entry From 758b2e8ec5a30ac0cfe2729a956bf3b48bb28607 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 30 May 2024 17:22:18 +0800 Subject: [PATCH 12/19] fix(dagrun): set start_date before deferring task from scheduler --- airflow/models/dagrun.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index ca21a51c80b18..576b527763df2 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1539,6 +1539,7 @@ def schedule_tis( ): dummy_ti_ids.append((ti.task_id, ti.map_index)) elif ti.task.start_from_trigger is True and ti.task.start_trigger_args is not None: + ti.start_date = timezone.utcnow() if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: ti.try_number += 1 ti.defer_task_from_scheduler(session=session, start_trigger_args=ti.task.start_trigger_args) From ed94608f83ed107dc7bc770aef43042be1acea67 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 30 May 2024 17:32:54 +0800 Subject: [PATCH 13/19] refactor(triggers): move StartTriggerArgs to airflow.triggers.base --- airflow/models/abstractoperator.py | 20 ------------------- airflow/models/baseoperator.py | 3 +-- airflow/models/mappedoperator.py | 2 +- airflow/models/taskinstance.py | 3 ++- airflow/serialization/serialized_objects.py | 4 ++-- airflow/triggers/base.py | 20 +++++++++++++++++++ .../authoring-and-scheduling/deferring.rst | 2 +- tests/models/test_dagrun.py | 2 +- tests/serialization/test_dag_serialization.py | 2 +- 9 files changed, 29 insertions(+), 29 deletions(-) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index dbec9543854eb..1bb83a2dc0f89 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -20,8 +20,6 @@ import datetime import inspect from abc import abstractproperty -from dataclasses import dataclass -from datetime import timedelta from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence @@ -87,24 +85,6 @@ class NotMapped(Exception): """Raise if a task is neither mapped nor has any parent mapped groups.""" -@dataclass -class StartTriggerArgs: - """Arguments required for start task execution from triggerer.""" - - trigger_cls: str - next_method: str - trigger_kwargs: dict[str, Any] | None = None - timeout: timedelta | None = None - - def serialize(self): - return { - "trigger_cls": self.trigger_cls, - "trigger_kwargs": self.trigger_kwargs, - "next_method": self.next_method, - "timeout": self.timeout, - } - - class AbstractOperator(Templater, DAGNode): """Common implementation for operators, including unmapped and mapped. diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index c0cc17e854761..bbd629cfc156c 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -77,7 +77,6 @@ DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, DEFAULT_WEIGHT_RULE, AbstractOperator, - StartTriggerArgs, ) from airflow.models.base import _sentinel from airflow.models.mappedoperator import OperatorPartial, validate_mapping_kwargs @@ -117,7 +116,7 @@ from airflow.models.operator import Operator from airflow.models.xcom_arg import XComArg from airflow.ti_deps.deps.base_ti_dep import BaseTIDep - from airflow.triggers.base import BaseTrigger + from airflow.triggers.base import BaseTrigger, StartTriggerArgs from airflow.utils.task_group import TaskGroup from airflow.utils.types import ArgNotSet diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 1ec66c6847fc1..abbed3cfa934d 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -41,7 +41,6 @@ DEFAULT_WEIGHT_RULE, AbstractOperator, NotMapped, - StartTriggerArgs, ) from airflow.models.expandinput import ( DictOfListsExpandInput, @@ -82,6 +81,7 @@ from airflow.models.param import ParamsDict from airflow.models.xcom_arg import XComArg from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + from airflow.triggers.base import StartTriggerArgs from airflow.utils.context import Context from airflow.utils.operator_resources import Resources from airflow.utils.task_group import TaskGroup diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index d5e4afdcee734..e70b66c26fa0e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -153,7 +153,7 @@ from sqlalchemy.sql.elements import BooleanClauseList from sqlalchemy.sql.expression import ColumnOperators - from airflow.models.abstractoperator import StartTriggerArgs, TaskStateChangeCallback + from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun @@ -164,6 +164,7 @@ from airflow.serialization.pydantic.dataset import DatasetEventPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.timetables.base import DataInterval + from airflow.triggers.base import StartTriggerArgs from airflow.typing_compat import Literal, TypeGuard from airflow.utils.task_group import TaskGroup diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index eae53053ca8d6..eb9c15f43cdcd 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -40,7 +40,7 @@ from airflow.datasets import Dataset, DatasetAll, DatasetAny from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError, TaskDeferred from airflow.jobs.job import Job -from airflow.models.baseoperator import BaseOperator, StartTriggerArgs +from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection from airflow.models.dag import DAG, DagModel, create_timetable from airflow.models.dagrun import DagRun @@ -66,7 +66,7 @@ airflow_priority_weight_strategies, airflow_priority_weight_strategies_classes, ) -from airflow.triggers.base import BaseTrigger +from airflow.triggers.base import BaseTrigger, StartTriggerArgs from airflow.utils.code_utils import get_python_source from airflow.utils.context import Context, OutletEventAccessor, OutletEventAccessors from airflow.utils.docs import get_docs_url diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py index 0d239af0cafd4..ac5727a1baa93 100644 --- a/airflow/triggers/base.py +++ b/airflow/triggers/base.py @@ -17,11 +17,31 @@ from __future__ import annotations import abc +from dataclasses import dataclass +from datetime import timedelta from typing import Any, AsyncIterator from airflow.utils.log.logging_mixin import LoggingMixin +@dataclass +class StartTriggerArgs: + """Arguments required for start task execution from triggerer.""" + + trigger_cls: str + next_method: str + trigger_kwargs: dict[str, Any] | None = None + timeout: timedelta | None = None + + def serialize(self): + return { + "trigger_cls": self.trigger_cls, + "trigger_kwargs": self.trigger_kwargs, + "next_method": self.next_method, + "timeout": self.timeout, + } + + class BaseTrigger(abc.ABC, LoggingMixin): """ Base class for all triggers. diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst b/docs/apache-airflow/authoring-and-scheduling/deferring.rst index 28405783f07d3..9fee32e00b8f1 100644 --- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst +++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst @@ -160,7 +160,7 @@ This is particularly useful when deferring is the only thing the ``execute`` met from datetime import timedelta from typing import Any - from airflow.models.abstractoperator import StartTriggerArgs + from airflow.triggers.base import StartTriggerArgs from airflow.sensors.base import BaseSensorOperator from airflow.utils.context import Context diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index fd71313390154..93e0611243616 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -30,7 +30,6 @@ from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.decorators import setup, task, task_group, teardown from airflow.exceptions import AirflowException -from airflow.models.abstractoperator import StartTriggerArgs from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun, DagRunNote @@ -41,6 +40,7 @@ from airflow.operators.python import ShortCircuitOperator from airflow.serialization.serialized_objects import SerializedDAG from airflow.stats import Stats +from airflow.triggers.base import StartTriggerArgs from airflow.utils import timezone from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.trigger_rule import TriggerRule diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 44284e3740b39..16f6d3cb68fe6 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -47,7 +47,6 @@ from airflow.decorators.base import DecoratedOperator from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError from airflow.hooks.base import BaseHook -from airflow.models.abstractoperator import StartTriggerArgs from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection from airflow.models.dag import DAG @@ -73,6 +72,7 @@ from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.simple import NullTimetable, OnceTimetable +from airflow.triggers.base import StartTriggerArgs from airflow.utils import timezone from airflow.utils.operator_resources import Resources from airflow.utils.task_group import TaskGroup From dfa47ecb0317cf68c0fc6dcd62992e39f6e79fb5 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 31 May 2024 14:01:02 +0800 Subject: [PATCH 14/19] refactor: merge _defer_task* as _defer_task function --- .../endpoints/rpc_api_endpoint.py | 2 - airflow/models/taskinstance.py | 74 ++++++++----------- .../serialization/pydantic/taskinstance.py | 4 +- 3 files changed, 31 insertions(+), 49 deletions(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 8704cede18c91..1820d63194106 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -51,7 +51,6 @@ def _initialize_map() -> dict[str, Callable]: TaskInstance, _add_log, _defer_task, - _defer_task_from_exception, _get_template_context, _handle_failure, _handle_reschedule, @@ -65,7 +64,6 @@ def _initialize_map() -> dict[str, Callable]: functions: list[Callable] = [ _default_action_log_internal, _defer_task, - _defer_task_from_exception, _get_template_context, _get_ti_db_access, _update_rtif, diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index e70b66c26fa0e..9f853ca6915fe 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -118,7 +118,7 @@ context_merge, ) from airflow.utils.email import send_email -from airflow.utils.helpers import prune_dict, render_template_to_string +from airflow.utils.helpers import exactly_one, prune_dict, render_template_to_string from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.operator_helpers import ExecutionCallableRunner, context_to_airflow_vars @@ -159,7 +159,6 @@ from airflow.models.dagrun import DagRun from airflow.models.dataset import DatasetEvent from airflow.models.operator import Operator - from airflow.models.trigger import Trigger from airflow.serialization.pydantic.dag import DagModelPydantic from airflow.serialization.pydantic.dataset import DatasetEventPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic @@ -1574,40 +1573,34 @@ def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session: Ses return ti -@internal_api_call -@provide_session -def _defer_task_from_exception( - ti: TaskInstance | TaskInstancePydantic, exception: TaskDeferred, session: Session = NEW_SESSION -) -> TaskInstancePydantic | TaskInstance: - from airflow.models.trigger import Trigger - - # First, make the trigger entry - trigger_row = Trigger.from_object(exception.trigger) - updated_ti = _defer_task( - ti=ti, - session=session, - trigger_row=trigger_row, - trigger_kwargs=exception.kwargs, - next_method=exception.method_name, - timeout=exception.timeout, - ) - - session.merge(updated_ti) - session.commit() - return updated_ti - - @internal_api_call @provide_session def _defer_task( ti: TaskInstance | TaskInstancePydantic, - *, - trigger_row: Trigger, - trigger_kwargs: dict[str, Any] | None, - next_method: str, - timeout: timedelta | None = None, session: Session = NEW_SESSION, + *, + exception: TaskDeferred | None = None, + start_trigger_args: StartTriggerArgs | None = None, ) -> TaskInstancePydantic | TaskInstance: + from airflow.models.trigger import Trigger + + if not exactly_one(exception, start_trigger_args): + raise AirflowException( + "One and only one of the arumgent exception and start_trigger_args is required" + ) + elif exception is not None: + trigger_row = Trigger.from_object(exception.trigger) + trigger_kwargs = exception.kwargs + next_method = exception.method_name + timeout = exception.timeout + elif start_trigger_args is not None: + trigger_row = Trigger( + classpath=start_trigger_args.trigger_cls, kwargs=start_trigger_args.trigger_kwargs + ) + trigger_kwargs = start_trigger_args.trigger_kwargs + next_method = start_trigger_args.next_method + timeout = start_trigger_args.timeout + # First, make the trigger entry session.add(trigger_row) session.flush() @@ -1643,6 +1636,10 @@ def _defer_task( ti.trigger_timeout = ti.start_date + execution_timeout if ti.test_mode: _add_log(event=ti.state, task_instance=ti, session=session) + + if exception is not None: + session.merge(ti) + session.commit() return ti @@ -3031,7 +3028,7 @@ def defer_task_from_exception(self, exception: TaskDeferred, session: Session = :meta: private """ - _defer_task_from_exception(ti=self, session=session, exception=exception) + _defer_task(ti=self, session=session, exception=exception) @provide_session def defer_task_from_scheduler( @@ -3043,20 +3040,7 @@ def defer_task_from_scheduler( :meta: private """ - from airflow.models.trigger import Trigger - - # First, make the trigger entry - trigger_row = Trigger( - classpath=start_trigger_args.trigger_cls, kwargs=start_trigger_args.trigger_kwargs - ) - _defer_task( - ti=self, - session=session, - trigger_row=trigger_row, - trigger_kwargs=start_trigger_args.trigger_kwargs, - next_method=start_trigger_args.next_method, - timeout=start_trigger_args.timeout, - ) + _defer_task(ti=self, session=session, start_trigger_args=start_trigger_args) def _run_execute_callback(self, context: Context, task: BaseOperator) -> None: """Functions that need to be run before a Task is executed.""" diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index b80ca199a087c..e67d93d90a6e3 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -27,7 +27,7 @@ from airflow.models.taskinstance import ( TaskInstance, TaskReturnCode, - _defer_task_from_exception, + _defer_task, _handle_reschedule, _run_raw_task, _set_ti_attrs, @@ -499,7 +499,7 @@ def _register_dataset_changes(self, *, events, session: Session | None = None) - def defer_task_from_exception(self, exception: TaskDeferred, session: Session | None = None): """Defer task.""" - updated_ti = _defer_task_from_exception(ti=self, session=session, exception=exception) + updated_ti = _defer_task(ti=self, session=session, exception=exception) _set_ti_attrs(self, updated_ti) def _handle_reschedule( From 4568a59dd3c2e8c98086f3212026f7326af1d91b Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 31 May 2024 14:45:39 +0800 Subject: [PATCH 15/19] refactor(taskinstance): deduplicate defer_task logic --- airflow/cli/commands/task_command.py | 2 +- airflow/models/dagrun.py | 2 +- airflow/models/taskinstance.py | 39 +++++-------------- .../serialization/pydantic/taskinstance.py | 2 +- 4 files changed, 13 insertions(+), 32 deletions(-) diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index fd68e7124fabd..ac9b211c21798 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -673,7 +673,7 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N else: ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True, raise_on_defer=True) except TaskDeferred as defer: - ti.defer_task_from_exception(exception=defer, session=session) + ti.defer_task(exception=defer, session=session) log.info("[TASK TEST] running trigger in line") event = _run_inline_trigger(defer.trigger) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 576b527763df2..99111f36071c6 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1542,7 +1542,7 @@ def schedule_tis( ti.start_date = timezone.utcnow() if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: ti.try_number += 1 - ti.defer_task_from_scheduler(session=session, start_trigger_args=ti.task.start_trigger_args) + ti.defer_task(session=session) else: schedulable_ti_ids.append((ti.task_id, ti.map_index)) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 9f853ca6915fe..e73fc00afc2b1 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -118,7 +118,7 @@ context_merge, ) from airflow.utils.email import send_email -from airflow.utils.helpers import exactly_one, prune_dict, render_template_to_string +from airflow.utils.helpers import prune_dict, render_template_to_string from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.operator_helpers import ExecutionCallableRunner, context_to_airflow_vars @@ -163,7 +163,6 @@ from airflow.serialization.pydantic.dataset import DatasetEventPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.timetables.base import DataInterval - from airflow.triggers.base import StartTriggerArgs from airflow.typing_compat import Literal, TypeGuard from airflow.utils.task_group import TaskGroup @@ -283,7 +282,7 @@ def _run_raw_task( # a trigger. if raise_on_defer: raise - ti.defer_task_from_exception(exception=defer, session=session) + ti.defer_task(exception=defer, session=session) ti.log.info( "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, execution_date=%s, start_date=%s", ti.dag_id, @@ -1577,29 +1576,23 @@ def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session: Ses @provide_session def _defer_task( ti: TaskInstance | TaskInstancePydantic, - session: Session = NEW_SESSION, - *, exception: TaskDeferred | None = None, - start_trigger_args: StartTriggerArgs | None = None, + session: Session = NEW_SESSION, ) -> TaskInstancePydantic | TaskInstance: from airflow.models.trigger import Trigger - if not exactly_one(exception, start_trigger_args): - raise AirflowException( - "One and only one of the arumgent exception and start_trigger_args is required" - ) - elif exception is not None: + if exception is not None: trigger_row = Trigger.from_object(exception.trigger) trigger_kwargs = exception.kwargs next_method = exception.method_name timeout = exception.timeout - elif start_trigger_args is not None: + else: trigger_row = Trigger( - classpath=start_trigger_args.trigger_cls, kwargs=start_trigger_args.trigger_kwargs + classpath=ti.task.start_trigger_args.trigger_cls, kwargs=ti.task.start_trigger_args.trigger_kwargs ) - trigger_kwargs = start_trigger_args.trigger_kwargs - next_method = start_trigger_args.next_method - timeout = start_trigger_args.timeout + trigger_kwargs = ti.task.start_trigger_args.trigger_kwargs + next_method = ti.task.start_trigger_args.next_method + timeout = ti.task.start_trigger_args.timeout # First, make the trigger entry session.add(trigger_row) @@ -3023,25 +3016,13 @@ def _execute_task(self, context: Context, task_orig: Operator): return _execute_task(self, context, task_orig) @provide_session - def defer_task_from_exception(self, exception: TaskDeferred, session: Session = NEW_SESSION) -> None: + def defer_task(self, exception: TaskDeferred | None = None, session: Session = NEW_SESSION) -> None: """Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised. :meta: private """ _defer_task(ti=self, session=session, exception=exception) - @provide_session - def defer_task_from_scheduler( - self, - start_trigger_args: StartTriggerArgs, - session: Session = NEW_SESSION, - ) -> None: - """Mark the task as deferred and sets up the trigger that is needed to resume it when start_trigger arguments passed. - - :meta: private - """ - _defer_task(ti=self, session=session, start_trigger_args=start_trigger_args) - def _run_execute_callback(self, context: Context, task: BaseOperator) -> None: """Functions that need to be run before a Task is executed.""" if not (callbacks := task.on_execute_callback): diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index e67d93d90a6e3..ec303cb13bf4e 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -497,7 +497,7 @@ def command_as_list( def _register_dataset_changes(self, *, events, session: Session | None = None) -> None: TaskInstance._register_dataset_changes(self=self, events=events, session=session) # type: ignore[arg-type] - def defer_task_from_exception(self, exception: TaskDeferred, session: Session | None = None): + def defer_task(self, exception: TaskDeferred, session: Session | None = None): """Defer task.""" updated_ti = _defer_task(ti=self, session=session, exception=exception) _set_ti_attrs(self, updated_ti) From f1e4462cb70cde278fa7ae235339168dd8fda2de Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 31 May 2024 15:01:34 +0800 Subject: [PATCH 16/19] style: fix mypy warning --- airflow/models/taskinstance.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index e73fc00afc2b1..9ac03117b8b39 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1586,13 +1586,16 @@ def _defer_task( trigger_kwargs = exception.kwargs next_method = exception.method_name timeout = exception.timeout - else: + elif ti.task is not None and ti.task.start_trigger_args is not None: trigger_row = Trigger( - classpath=ti.task.start_trigger_args.trigger_cls, kwargs=ti.task.start_trigger_args.trigger_kwargs + classpath=ti.task.start_trigger_args.trigger_cls, + kwargs=ti.task.start_trigger_args.trigger_kwargs or {}, ) trigger_kwargs = ti.task.start_trigger_args.trigger_kwargs next_method = ti.task.start_trigger_args.next_method timeout = ti.task.start_trigger_args.timeout + else: + raise AirflowException("exception and ti.task.start_trigger_args cannot both be None") # First, make the trigger entry session.add(trigger_row) From 415993a898351fba0c6914b38f6815cbf6839d5d Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 11 Jun 2024 17:08:52 +0800 Subject: [PATCH 17/19] refactor: reorder parameter as suggested --- airflow/models/taskinstance.py | 2 +- airflow/serialization/pydantic/taskinstance.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 9ac03117b8b39..485b382838e9d 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -3024,7 +3024,7 @@ def defer_task(self, exception: TaskDeferred | None = None, session: Session = N :meta: private """ - _defer_task(ti=self, session=session, exception=exception) + _defer_task(ti=self, exception=exception, session=session) def _run_execute_callback(self, context: Context, task: BaseOperator) -> None: """Functions that need to be run before a Task is executed.""" diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index ec303cb13bf4e..e499a98691940 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -499,7 +499,7 @@ def _register_dataset_changes(self, *, events, session: Session | None = None) - def defer_task(self, exception: TaskDeferred, session: Session | None = None): """Defer task.""" - updated_ti = _defer_task(ti=self, session=session, exception=exception) + updated_ti = _defer_task(ti=self, exception=exception, session=session) _set_ti_attrs(self, updated_ti) def _handle_reschedule( From 91f8b30aa04ce8ae81195314870b8cb7c68c6726 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 11 Jun 2024 17:09:15 +0800 Subject: [PATCH 18/19] docs(deferring): reword description as suggested --- docs/apache-airflow/authoring-and-scheduling/deferring.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst b/docs/apache-airflow/authoring-and-scheduling/deferring.rst index 9fee32e00b8f1..a65e932a90ef6 100644 --- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst +++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst @@ -145,7 +145,7 @@ Triggering Deferral from Start .. versionadded:: 2.10.0 -If you want to defer your task directly to the triggerer without going into the worker, you can set class level attribute ``start_with_trigger`` to ``True`` add add class level attribute ``start_trigger_args`` with the following 4 attributes to your deferrable operator. +If you want to defer your task directly to the triggerer without going into the worker, you can set class level attribute ``start_with_trigger`` to ``True`` and add a class level attribute ``start_trigger_args`` with an ``StartTriggerArgs`` object with the following 4 attributes to your deferrable operator: * ``trigger_cls``: An importable path to your trigger class. * ``trigger_kwargs``: Additional keyword arguments to pass to the method when it is called. From e1805e50e620fa95a3a690677f92ad55dc5e13f4 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 11 Jun 2024 18:29:37 +0800 Subject: [PATCH 19/19] refactor: make argument "exception" in defer_task method required for explicitness --- airflow/models/dagrun.py | 2 +- airflow/models/taskinstance.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 99111f36071c6..4773a89d1dd16 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1542,7 +1542,7 @@ def schedule_tis( ti.start_date = timezone.utcnow() if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: ti.try_number += 1 - ti.defer_task(session=session) + ti.defer_task(exception=None, session=session) else: schedulable_ti_ids.append((ti.task_id, ti.map_index)) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 485b382838e9d..373ad108c29da 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -3019,7 +3019,7 @@ def _execute_task(self, context: Context, task_orig: Operator): return _execute_task(self, context, task_orig) @provide_session - def defer_task(self, exception: TaskDeferred | None = None, session: Session = NEW_SESSION) -> None: + def defer_task(self, exception: TaskDeferred | None, session: Session = NEW_SESSION) -> None: """Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised. :meta: private