From df806caf4470cb6104726b043aff0366bb69715b Mon Sep 17 00:00:00 2001 From: mingshi Date: Mon, 27 Dec 2021 02:22:18 -0800 Subject: [PATCH] Add basic sensor decorator --- airflow/decorators/__init__.py | 9 +- airflow/decorators/sensor.py | 89 ++++++++++++ docs/apache-airflow/tutorial_taskflow_api.rst | 45 ++++++ tests/decorators/test_sensor.py | 131 ++++++++++++++++++ 4 files changed, 271 insertions(+), 3 deletions(-) create mode 100644 airflow/decorators/sensor.py create mode 100644 tests/decorators/test_sensor.py diff --git a/airflow/decorators/__init__.py b/airflow/decorators/__init__.py index 47a20d47826ef..d83e341e61ba3 100644 --- a/airflow/decorators/__init__.py +++ b/airflow/decorators/__init__.py @@ -19,6 +19,7 @@ from airflow.decorators.python import PythonDecoratorMixin, python_task # noqa from airflow.decorators.python_virtualenv import PythonVirtualenvDecoratorMixin +from airflow.decorators.sensor import sensor from airflow.decorators.task_group import task_group # noqa from airflow.models.dag import dag # noqa from airflow.providers_manager import ProvidersManager @@ -29,9 +30,11 @@ def __getattr__(self, name): if name.startswith("__"): raise AttributeError(f'{type(self).__name__} has no attribute {name!r}') decorators = ProvidersManager().taskflow_decorators - if name not in decorators: - raise AttributeError(f"task decorator {name!r} not found") - return decorators[name] + if name in decorators: + return decorators[name] + if name == "sensor": + return sensor + raise AttributeError(f"task decorator {name!r} not found") # [START mixin_for_autocomplete] diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py new file mode 100644 index 0000000000000..fa4404684126c --- /dev/null +++ b/airflow/decorators/sensor.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from inspect import signature +from typing import Any, Callable, Collection, Dict, Iterable, Mapping, Optional, Tuple + +from airflow.decorators.base import get_unique_task_id, task_decorator_factory +from airflow.models.taskinstance import Context +from airflow.sensors.base import BaseSensorOperator + + +class DecoratedSensorOperator(BaseSensorOperator): + """ + Wraps a Python callable and captures args/kwargs when called for execution. + + :param python_callable: A reference to an object that is callable + :type python_callable: python callable + :param task_id: task Id + :type task_id: str + :param op_args: a list of positional arguments that will get unpacked when + calling your callable (templated) + :type op_args: list + :param op_kwargs: a dictionary of keyword arguments that will get unpacked + in your function (templated) + :type op_kwargs: dict + :param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments + that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the + PythonOperator). This gives a user the option to upstream kwargs as needed. + :type kwargs_to_upstream: dict + """ + + template_fields: Iterable[str] = ('op_args', 'op_kwargs') + template_fields_renderers: Dict[str, str] = {"op_args": "py", "op_kwargs": "py"} + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects (e.g protobuf). + shallow_copy_attrs: Tuple[str, ...] = ('python_callable',) + + def __init__( + self, + *, + python_callable: Callable, + task_id: str, + op_args: Collection[Any], + op_kwargs: Mapping[str, Any], + **kwargs, + ) -> None: + kwargs.pop('multiple_outputs') + kwargs['task_id'] = get_unique_task_id(task_id, kwargs.get('dag'), kwargs.get('task_group')) + self.python_callable = python_callable + # Check that arguments can be binded + signature(python_callable).bind(*op_args, **op_kwargs) + self.op_args = op_args + self.op_kwargs = op_kwargs + super().__init__(**kwargs) + + def poke(self, context: Context) -> bool: + return self.python_callable(*self.op_args, **self.op_kwargs) + + +def sensor(python_callable: Optional[Callable] = None, **kwargs): + """ + Wraps a function into an Airflow operator. + + Accepts kwargs for operator kwarg. Can be reused in a single DAG. + + :param python_callable: Function to decorate + :type python_callable: Optional[Callable] + """ + return task_decorator_factory( + python_callable=python_callable, + multiple_outputs=False, + decorated_operator_class=DecoratedSensorOperator, + **kwargs, + ) diff --git a/docs/apache-airflow/tutorial_taskflow_api.rst b/docs/apache-airflow/tutorial_taskflow_api.rst index 3d2c4f1e68e07..6d9a69582553f 100644 --- a/docs/apache-airflow/tutorial_taskflow_api.rst +++ b/docs/apache-airflow/tutorial_taskflow_api.rst @@ -208,6 +208,51 @@ Python version to run your function. These two options should allow for far greater flexibility for users who wish to keep their workflows more simple and Pythonic. +Using the TaskFlow API for Sensor operators +------------------------------------------- +You can apply the @task.sensor decorator to convert a regular Python function to an instance of the BaseSensorOperator +class. The Python function implements the poke logic and returns a Boolean value just as the poke() method in the +BaseSensorOperator does. + +.. code-block:: python + + # Using a sensor operator to wait for the upstream data to be ready. + @task.sensor(poke_interval=60, timeout=3600, mode="reschedule") + def wait_for_upstream() -> bool: + upstream_data_available = ... # custom logic to check the upstream data + return upstream_data_available + + + @task + def custom_operator() -> None: + # do something + do_some_thing() + + + wait = wait_for_upstream() + op = custom_operator() + wait >> op + + +.. code-block:: python + + # Using a sensor operator to wait for the upstream Spark job to be done. + @task + def start_spark_job() -> str: + # start a Spark job and return the job Id + job_id = submit_spark_job() + return job_id + + + @task.sensor(poke_interval=60, timeout=3600, mode="reschedule") + def wait_for_job(job_id: str) -> bool: + # check if the upstream Spark job is done + return check_spark_job_done(job_id) + + + wait_for_job(start_spark_job()) + + Multiple outputs inference -------------------------- Tasks can also infer multiple outputs by using dict Python typing. diff --git a/tests/decorators/test_sensor.py b/tests/decorators/test_sensor.py new file mode 100644 index 0000000000000..6a57b65437154 --- /dev/null +++ b/tests/decorators/test_sensor.py @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from airflow.decorators import task +from airflow.exceptions import AirflowException, AirflowSensorTimeout +from airflow.utils.state import State + + +class TestSensorDecorator: + def test_sensor_fails_on_none_python_callable(self, dag_maker): + not_callable = {} + with pytest.raises(AirflowException): + task.sensor(not_callable) + + def test_basic_sensor_success(self, dag_maker): + @task.sensor() + def sensor_f(): + return True + + @task + def dummy_f(): + pass + + with dag_maker(): + sf = sensor_f() + df = dummy_f() + sf >> df + + dr = dag_maker.create_dagrun() + sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True) + tis = dr.get_task_instances() + assert len(tis) == 2 + for ti in tis: + if ti.task_id == "sensor_f": + assert ti.state == State.SUCCESS + if ti.task_id == "dummy_f": + assert ti.state == State.NONE + + def test_basic_sensor_failure(self, dag_maker): + @task.sensor(timeout=0) + def sensor_f(): + return False + + @task + def dummy_f(): + pass + + with dag_maker(): + sf = sensor_f() + df = dummy_f() + sf >> df + + dr = dag_maker.create_dagrun() + with pytest.raises(AirflowSensorTimeout): + sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True) + + tis = dr.get_task_instances() + assert len(tis) == 2 + for ti in tis: + if ti.task_id == "sensor_f": + assert ti.state == State.FAILED + if ti.task_id == "dummy_f": + assert ti.state == State.NONE + + def test_basic_sensor_soft_fail(self, dag_maker): + @task.sensor(timeout=0, soft_fail=True) + def sensor_f(): + return False + + @task + def dummy_f(): + pass + + with dag_maker(): + sf = sensor_f() + df = dummy_f() + sf >> df + + dr = dag_maker.create_dagrun() + sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True) + tis = dr.get_task_instances() + assert len(tis) == 2 + for ti in tis: + if ti.task_id == "sensor_f": + assert ti.state == State.SKIPPED + if ti.task_id == "dummy_f": + assert ti.state == State.NONE + + def test_basic_sensor_get_upstream_output(self, dag_maker): + ret_val = 100 + + @task + def upstream_f() -> int: + return ret_val + + @task.sensor() + def sensor_f(n: int): + assert n == ret_val + return True + + with dag_maker(): + uf = upstream_f() + sf = sensor_f(uf) + + dr = dag_maker.create_dagrun() + uf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True) + sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) + tis = dr.get_task_instances() + assert len(tis) == 2 + for ti in tis: + if ti.task_id == "sensor_f": + assert ti.state == State.SUCCESS + if ti.task_id == "dummy_f": + assert ti.state == State.SUCCESS