From e8116e706ce055ea8d9e4c757efcdd1f9a652c3d Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 21 Jun 2024 22:51:02 +0800 Subject: [PATCH 1/2] fix(trigger): add next_kwargs to StartTriggerArgs --- airflow/models/taskinstance.py | 6 +++--- airflow/triggers/base.py | 2 ++ tests/serialization/test_dag_serialization.py | 3 +++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 39ae2cd606072..c62752b5a5617 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1606,15 +1606,15 @@ def _defer_task( if exception is not None: trigger_row = Trigger.from_object(exception.trigger) - trigger_kwargs = exception.kwargs next_method = exception.method_name + next_kwargs = exception.kwargs timeout = exception.timeout elif ti.task is not None and ti.task.start_trigger_args is not None: trigger_row = Trigger( classpath=ti.task.start_trigger_args.trigger_cls, kwargs=ti.task.start_trigger_args.trigger_kwargs or {}, ) - trigger_kwargs = ti.task.start_trigger_args.trigger_kwargs + next_kwargs = ti.task.start_trigger_args.next_kwargs next_method = ti.task.start_trigger_args.next_method timeout = ti.task.start_trigger_args.timeout else: @@ -1635,7 +1635,7 @@ def _defer_task( ti.state = TaskInstanceState.DEFERRED ti.trigger_id = trigger_row.id ti.next_method = next_method - ti.next_kwargs = trigger_kwargs or {} + ti.next_kwargs = next_kwargs or {} # Calculate timeout too if it was passed if timeout is not None: diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py index ac5727a1baa93..5dacee3364c54 100644 --- a/airflow/triggers/base.py +++ b/airflow/triggers/base.py @@ -31,6 +31,7 @@ class StartTriggerArgs: trigger_cls: str next_method: str trigger_kwargs: dict[str, Any] | None = None + next_kwargs: dict[str, Any] | None = None timeout: timedelta | None = None def serialize(self): @@ -38,6 +39,7 @@ def serialize(self): "trigger_cls": self.trigger_cls, "trigger_kwargs": self.trigger_kwargs, "next_method": self.next_method, + "next_kwargs": self.next_kwargs, "timeout": self.timeout, } diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index c0e1d69c09770..a1e76a13a8213 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -2199,6 +2199,7 @@ class TestOperator(BaseOperator): trigger_cls="airflow.triggers.testing.SuccessTrigger", trigger_kwargs=None, next_method="execute_complete", + next_kwargs=None, timeout=None, ) start_from_trigger = False @@ -2216,6 +2217,7 @@ class Test2Operator(BaseOperator): trigger_cls="airflow.triggers.testing.SuccessTrigger", trigger_kwargs={}, next_method="execute_complete", + next_kwargs=None, timeout=None, ) start_from_trigger = True @@ -2239,6 +2241,7 @@ def execute_complete(self): "trigger_cls": "airflow.triggers.testing.SuccessTrigger", "trigger_kwargs": {}, "next_method": "execute_complete", + "next_kwargs": None, "timeout": None, } assert task["__var"]["start_from_trigger"] is True From 3a3202be37ba313bcb6605e38e59704cde8222fd Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 22 Jun 2024 11:18:40 +0800 Subject: [PATCH 2/2] docs(deferring): update docs for next_kwargs added in start_trigger_args --- docs/apache-airflow/authoring-and-scheduling/deferring.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst b/docs/apache-airflow/authoring-and-scheduling/deferring.rst index a65e932a90ef6..a9b26703bce63 100644 --- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst +++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst @@ -148,8 +148,9 @@ Triggering Deferral from Start If you want to defer your task directly to the triggerer without going into the worker, you can set class level attribute ``start_with_trigger`` to ``True`` and add a class level attribute ``start_trigger_args`` with an ``StartTriggerArgs`` object with the following 4 attributes to your deferrable operator: * ``trigger_cls``: An importable path to your trigger class. -* ``trigger_kwargs``: Additional keyword arguments to pass to the method when it is called. +* ``trigger_kwargs``: Keyword arguments to pass to the ``trigger_cls`` when it's initialized. * ``next_method``: The method name on your operator that you want Airflow to call when it resumes. +* ``next_kwargs``: Additional keyword arguments to pass to the ``next_method`` when it is called. * ``timeout``: (Optional) A timedelta that specifies a timeout after which this deferral will fail, and fail the task instance. Defaults to ``None``, which means no timeout. @@ -170,6 +171,7 @@ This is particularly useful when deferring is the only thing the ``execute`` met trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger", trigger_kwargs={"moment": timedelta(hours=1)}, next_method="execute_complete", + next_kwargs=None, timeout=None, ) start_from_trigger = True @@ -198,6 +200,7 @@ This is particularly useful when deferring is the only thing the ``execute`` met trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger", trigger_kwargs={}, next_method="execute_complete", + next_kwargs=None, timeout=None, )