diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 21f3d4eaf3a88..b2b3b2f8c0e27 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -829,10 +829,18 @@ def tasks(self, val): def task_ids(self) -> List[str]: return list(self.task_dict.keys()) + @property + def task_group_dict(self) -> Dict[str, "TaskGroup"]: + return {k: v for k, v in self.task_group.get_task_group_dict().items() if k is not None} + @property def task_group(self) -> "TaskGroup": return self._task_group + @property + def task_groups(self) -> List["TaskGroup"]: + return list(self.task_group_dict.values()) + @property def filepath(self) -> str: """:meta private:""" @@ -1883,8 +1891,11 @@ def filter_task_group(group, parent_group): return dag + def has_task_group(self, group_id: str) -> bool: + return group_id in self.task_group_dict + def has_task(self, task_id: str): - return task_id in (t.task_id for t in self.tasks) + return task_id in self.task_dict def get_task(self, task_id: str, include_subdags: bool = False) -> BaseOperator: if task_id in self.task_dict: diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index c4510015138e0..c9b312f8eaf90 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -18,7 +18,7 @@ import datetime import os -from typing import Any, Callable, FrozenSet, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, FrozenSet, Iterable, Optional, Union from sqlalchemy import func @@ -30,6 +30,11 @@ from airflow.utils.session import provide_session from airflow.utils.state import State +if TYPE_CHECKING: + from sqlalchemy.orm import Query + + from airflow.utils.task_group import TaskGroup + class ExternalTaskSensorLink(BaseOperatorLink): """ @@ -46,20 +51,27 @@ def get_link(self, operator, dttm): class ExternalTaskSensor(BaseSensorOperator): """ - Waits for a different DAG or a task in a different DAG to complete for a + Waits for a different DAG, a task group, or a task in a different DAG to complete for a specific execution_date - :param external_dag_id: The dag_id that contains the task you want to - wait for + If both `external_task_group_id` and `external_task_id` are ``None`` (default), the sensor + waits for the DAG. + + Values for `external_task_group_id` and `external_task_id` can't be set at the same time. + + :param external_dag_id: The dag_id that contains the task you want to wait for :type external_dag_id: str :param external_task_id: The task_id that contains the task you want to - wait for. If ``None`` (default value) the sensor waits for the DAG + wait for. :type external_task_id: str or None :param external_task_ids: The list of task_ids that you want to wait for. If ``None`` (default value) the sensor waits for the DAG. Either external_task_id or external_task_ids can be passed to ExternalTaskSensor, but not both. :type external_task_ids: Iterable of task_ids or None, default is None + :param external_task_group_id: The task group_id that contains the tasks you want to + wait for. + :type external_task_group_id: str or None :param allowed_states: Iterable of allowed states, default is ``['success']`` :type allowed_states: Iterable :param failed_states: Iterable of failed or dis-allowed states, default is ``None`` @@ -97,6 +109,7 @@ def __init__( external_dag_id: str, external_task_id: Optional[str] = None, external_task_ids: Optional[Iterable[str]] = None, + external_task_group_id: Optional[str] = None, allowed_states: Optional[Iterable[str]] = None, failed_states: Optional[Iterable[str]] = None, execution_delta: Optional[datetime.timedelta] = None, @@ -125,6 +138,13 @@ def __init__( if external_task_id is not None: external_task_ids = [external_task_id] + + + if external_task_group_id and external_task_ids: + raise ValueError( + "Values for `external_task_group_id` and `external_task_id` or `external_task_ids` " + "can't be set at the same time" + ) if external_task_ids: if not total_states <= set(State.task_states): @@ -149,21 +169,24 @@ def __init__( self.execution_delta = execution_delta self.execution_date_fn = execution_date_fn self.external_dag_id = external_dag_id + self.external_task_group_id = external_task_group_id self.external_task_id = external_task_id self.external_task_ids = external_task_ids self.check_existence = check_existence self._has_checked_existence = False - @provide_session - def poke(self, context, session=None): + def _get_dttm_filter(self, context): if self.execution_delta: dttm = context['execution_date'] - self.execution_delta elif self.execution_date_fn: dttm = self._handle_execution_date_fn(context=context) else: dttm = context['execution_date'] + return dttm if isinstance(dttm, list) else [dttm] - dttm_filter = dttm if isinstance(dttm, list) else [dttm] + @provide_session + def poke(self, context, session=None): + dttm_filter = self._get_dttm_filter(context) serialized_dttm_filter = ','.join(dt.isoformat() for dt in dttm_filter) self.log.info( @@ -189,13 +212,18 @@ def poke(self, context, session=None): f'Some of the external tasks {self.external_task_ids} ' f'in DAG {self.external_dag_id} failed.' ) + elif self.external_task_group_id: + raise AirflowException( + f"f'The external task group {self.external_task_group_id} " + f"in DAG {self.external_dag_id} failed.'" + ) else: raise AirflowException(f'The external DAG {self.external_dag_id} failed.') return count_allowed == len(dttm_filter) def _check_for_existence(self, session) -> None: - dag_to_wait = session.query(DagModel).filter(DagModel.dag_id == self.external_dag_id).first() + dag_to_wait = DagModel.get_current(self.external_dag_id, session) if not dag_to_wait: raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.') @@ -227,30 +255,50 @@ def get_count(self, dttm_filter, session, states) -> int: """ TI = TaskInstance DR = DagRun + if self.external_task_ids: count = ( - session.query(func.count()) # .count() is inefficient - .filter( - TI.dag_id == self.external_dag_id, - TI.task_id.in_(self.external_task_ids), - TI.state.in_(states), - TI.execution_date.in_(dttm_filter), - ) + self._count_query(TI, session, states, dttm_filter) + .filter(TI.task_id.in_(self.external_task_ids)) .scalar() - ) - count = count / len(self.external_task_ids) - else: + ) / len(self.external_task_ids) + count /= len(self.external_task_ids) + elif self.external_task_group_id: + external_task_group_task_ids = self.get_external_task_group_task_ids(session) count = ( - session.query(func.count()) - .filter( - DR.dag_id == self.external_dag_id, - DR.state.in_(states), - DR.execution_date.in_(dttm_filter), - ) + self._count_query(TI, session, states, dttm_filter) + .filter(TI.task_id.in_(external_task_group_task_ids)) .scalar() ) + count /= len(external_task_group_task_ids) + else: + count = self._count_query(DR, session, states, dttm_filter).scalar() + return count + def _count_query(self, model, session, states, dttm_filter) -> "Query": + query = session.query(func.count()).filter( # .count() is inefficient + model.dag_id == self.external_dag_id, + model.state.in_(states), # pylint: disable=no-member + model.execution_date.in_(dttm_filter), + ) + + return query + + def get_external_task_group_task_ids(self, session): + """Return task ids for the external TaskGroup""" + refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session) + task_group: Optional["TaskGroup"] = refreshed_dag_info.task_group_dict.get( + self.external_task_group_id + ) + if not task_group: + raise AirflowException( + f"The external task group {self.external_task_group_id} in " + f"DAG {self.external_dag_id} does not exist." + ) + task_ids = [task.task_id for task in task_group] + return task_ids + def _handle_execution_date_fn(self, context) -> Any: """ This function is to handle backwards compatibility with how this operator was diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index e2e85a76c8c3f..f33bc8fbd6971 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -160,7 +160,7 @@ def is_root(self) -> bool: """Returns True if this TaskGroup is the root TaskGroup. Otherwise False""" return not self.group_id - def __iter__(self): + def __iter__(self) -> "BaseOperator": for child in self.children.values(): if isinstance(child, TaskGroup): yield from child @@ -343,6 +343,9 @@ def get_child_by_label(self, label: str) -> Union["BaseOperator", "TaskGroup"]: """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)""" return self.children[self.child_id(label)] + def __repr__(self): + return f"<{self.__class__.__name__}: {self.group_id}>" + class TaskGroupContext: """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task.py similarity index 87% rename from tests/sensors/test_external_task_sensor.py rename to tests/sensors/test_external_task.py index 5b50fc22bd4ca..568bd62cbbc46 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task.py @@ -15,8 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import logging -import unittest from datetime import time, timedelta import pytest @@ -25,12 +23,14 @@ from airflow.exceptions import AirflowException, AirflowSensorTimeout from airflow.models import DagBag, TaskInstance from airflow.models.dag import DAG +from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.bash import BashOperator from airflow.operators.dummy import DummyOperator from airflow.sensors.external_task import ExternalTaskMarker, ExternalTaskSensor from airflow.sensors.time_sensor import TimeSensor from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.state import State +from airflow.utils.task_group import TaskGroup from airflow.utils.timezone import datetime from tests.test_utils.db import clear_db_runs @@ -38,26 +38,35 @@ TEST_DAG_ID = 'unit_test_dag' TEST_TASK_ID = 'time_sensor_check' TEST_TASK_ID_ALTERNATE = 'time_sensor_check_alternate' +TEST_TASK_GROUP_ID = 'dummy_task_group' DEV_NULL = '/dev/null' -@pytest.fixture(autouse=True) -def clean_db(): - clear_db_runs() - - -class TestExternalTaskSensor(unittest.TestCase): - def setUp(self): - self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True) +class TestExternalTaskSensor: + def setup_method(self): self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=self.args) + SerializedDagModel.write_dag(self.dag) - def test_time_sensor(self, task_id=TEST_TASK_ID): - op = TimeSensor(task_id=task_id, target_time=time(0), dag=self.dag) + def run_time_sensor(self): + op = TimeSensor(task_id=TEST_TASK_ID, target_time=time(0), dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + def run_task_group(self, target_states=None): + target_states = [State.SUCCESS] * 3 if target_states is None else target_states + + with self.dag as dag: + with TaskGroup(TEST_TASK_GROUP_ID) as task_group: + _ = [DummyOperator(task_id=f"task{i}") for i in range(len(target_states))] + SerializedDagModel.write_dag(dag) + + for idx, task in enumerate(task_group): + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti.run(ignore_ti_state=True, mark_success=True) + ti.set_state(target_states[idx]) + def test_external_task_sensor(self): - self.test_time_sensor() + self.run_time_sensor() op = ExternalTaskSensor( task_id='test_external_task_sensor_check', external_dag_id=TEST_DAG_ID, @@ -99,7 +108,7 @@ def test_external_task_sensor_wrong_failed_states(self): ) def test_external_task_sensor_failed_states(self): - self.test_time_sensor() + self.run_time_sensor() op = ExternalTaskSensor( task_id='test_external_task_sensor_check', external_dag_id=TEST_DAG_ID, @@ -110,7 +119,7 @@ def test_external_task_sensor_failed_states(self): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_failed_states_as_success(self): - self.test_time_sensor() + self.run_time_sensor() op = ExternalTaskSensor( task_id='test_external_task_sensor_check', external_dag_id=TEST_DAG_ID, @@ -157,8 +166,92 @@ def test_external_task_sensor_failed_states_as_success_mulitple_task_ids(self): "unit_test_dag failed." ) + def test_raise_with_external_task_id_and_external_task_group_id(self): + with pytest.raises( + AirflowException, + match=r"`external_task_group_id` and `external_task_id` can't be set at the same time", + ): + ExternalTaskSensor( + task_id='test_external_task_sensor_check', + external_dag_id=TEST_DAG_ID, + external_task_id=TEST_TASK_ID, + external_task_group_id=TEST_TASK_GROUP_ID, + dag=self.dag, + ) + + def test_external_task_group_not_exists(self): + with pytest.raises(AirflowException, match=r"The external task group .* does not exist"): + op = ExternalTaskSensor( + task_id='test_external_task_sensor_check', + external_dag_id=TEST_DAG_ID, + external_task_group_id='fake-task-group', + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + def test_external_task_group_sensor_success(self): + self.run_task_group() + op = ExternalTaskSensor( + task_id='test_external_task_sensor_check', + external_dag_id=TEST_DAG_ID, + external_task_group_id=TEST_TASK_GROUP_ID, + failed_states=[State.FAILED], + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + @pytest.mark.parametrize( + "ti_states", + [ + (State.SUCCESS, State.FAILED, State.SUCCESS), + (State.FAILED, State.SKIPPED, State.FAILED), + ], + ) + def test_external_task_group_sensor_failed_states(self, ti_states): + self.run_task_group(ti_states) + op = ExternalTaskSensor( + task_id='test_external_task_sensor_check', + external_dag_id=TEST_DAG_ID, + external_task_group_id=TEST_TASK_GROUP_ID, + failed_states=[State.FAILED], + dag=self.dag, + ) + with pytest.raises(AirflowException, match=r"The external task group .* failed."): + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + def test_external_task_group_sensor_multiple_execution_dates(self): + dag_external_id = TEST_DAG_ID + '_external' + dag_external = DAG(dag_external_id, default_args=self.args, schedule_interval=timedelta(seconds=1)) + with dag_external: + with TaskGroup(TEST_TASK_GROUP_ID) as task_group: + _ = [DummyOperator(task_id=f"task{i}") for i in range(3)] + + SerializedDagModel.write_dag(dag_external) + + for task in task_group: + task.run( + start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(seconds=2), ignore_ti_state=True + ) + + dag_id = TEST_DAG_ID + dag = DAG(dag_id, default_args=self.args, schedule_interval=timedelta(minutes=1)) + task_group_sensor = ExternalTaskSensor( + task_id='task_group_external', + external_dag_id=dag_external_id, + external_task_group_id=TEST_TASK_GROUP_ID, + execution_date_fn=lambda dt: [dt + timedelta(seconds=i) for i in range(3)], + retries=0, + timeout=1, + poke_interval=1, + dag=dag, + ) + + task_group_sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + def test_external_dag_sensor(self): other_dag = DAG('other_dag', default_args=self.args, end_date=DEFAULT_DATE, schedule_interval='@once') + + clear_db_runs() other_dag.create_dagrun( run_id='test', start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE, state=State.SUCCESS ) @@ -260,7 +353,7 @@ def test_external_task_sensor_fn_multiple_execution_dates(self): task_with_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_delta(self): - self.test_time_sensor() + self.run_time_sensor() op = ExternalTaskSensor( task_id='test_external_task_sensor_check_delta', external_dag_id=TEST_DAG_ID, @@ -272,7 +365,7 @@ def test_external_task_sensor_delta(self): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_fn(self): - self.test_time_sensor() + self.run_time_sensor() # check that the execution_fn works op1 = ExternalTaskSensor( task_id='test_external_task_sensor_check_delta_1', @@ -299,7 +392,7 @@ def test_external_task_sensor_fn(self): def test_external_task_sensor_fn_multiple_args(self): """Check this task sensor passes multiple args with full context. If no failure, means clean run.""" - self.test_time_sensor() + self.run_time_sensor() def my_func(dt, context): assert context['execution_date'] == dt @@ -317,7 +410,7 @@ def my_func(dt, context): def test_external_task_sensor_fn_kwargs(self): """Check this task sensor passes multiple args with full context. If no failure, means clean run.""" - self.test_time_sensor() + self.run_time_sensor() def my_func(dt, ds_nodash, tomorrow_ds_nodash): assert ds_nodash == dt.strftime("%Y%m%d") @@ -335,7 +428,7 @@ def my_func(dt, ds_nodash, tomorrow_ds_nodash): op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_external_task_sensor_error_delta_and_fn(self): - self.test_time_sensor() + self.run_time_sensor() # Test that providing execution_delta and a function raises an error with pytest.raises(ValueError): ExternalTaskSensor( @@ -417,7 +510,7 @@ def test_external_task_sensor_waits_for_dag_check_existence(self): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) -class TestExternalTaskMarker(unittest.TestCase): +class TestExternalTaskMarker: def test_serialized_fields(self): assert {"recursion_depth"}.issubset(ExternalTaskMarker.get_serialized_fields())