-
Notifications
You must be signed in to change notification settings - Fork 17.3k
Allow ExternalTaskSensor to wait for taskgroup #14640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
aace321
6f74efb
2bfe78e
81c50e0
c3495aa
c31e8c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -829,10 +829,18 @@ def tasks(self, val): | |
| def task_ids(self) -> List[str]: | ||
| return list(self.task_dict.keys()) | ||
|
|
||
| @property | ||
| def task_group_dict(self) -> Dict[str, "TaskGroup"]: | ||
| return {k: v for k, v in self.task_group.get_task_group_dict().items() if k is not None} | ||
|
|
||
| @property | ||
| def task_group(self) -> "TaskGroup": | ||
| return self._task_group | ||
|
|
||
| @property | ||
| def task_groups(self) -> List["TaskGroup"]: | ||
| return list(self.task_group_dict.values()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here for |
||
|
|
||
| @property | ||
| def filepath(self) -> str: | ||
| """:meta private:""" | ||
|
|
@@ -1883,8 +1891,11 @@ def filter_task_group(group, parent_group): | |
|
|
||
| return dag | ||
|
|
||
| def has_task_group(self, group_id: str) -> bool: | ||
| return group_id in self.task_group_dict | ||
|
|
||
| def has_task(self, task_id: str): | ||
| return task_id in (t.task_id for t in self.tasks) | ||
| return task_id in self.task_dict | ||
|
|
||
| def get_task(self, task_id: str, include_subdags: bool = False) -> BaseOperator: | ||
| if task_id in self.task_dict: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |
|
|
||
| import datetime | ||
| import os | ||
| from typing import Any, Callable, FrozenSet, Iterable, Optional, Union | ||
| from typing import TYPE_CHECKING, Any, Callable, FrozenSet, Iterable, Optional, Union | ||
|
|
||
| from sqlalchemy import func | ||
|
|
||
|
|
@@ -30,6 +30,11 @@ | |
| from airflow.utils.session import provide_session | ||
| from airflow.utils.state import State | ||
|
|
||
| if TYPE_CHECKING: | ||
| from sqlalchemy.orm import Query | ||
|
|
||
| from airflow.utils.task_group import TaskGroup | ||
|
|
||
|
|
||
| class ExternalTaskSensorLink(BaseOperatorLink): | ||
| """ | ||
|
|
@@ -46,20 +51,27 @@ def get_link(self, operator, dttm): | |
|
|
||
| class ExternalTaskSensor(BaseSensorOperator): | ||
| """ | ||
| Waits for a different DAG or a task in a different DAG to complete for a | ||
| Waits for a different DAG, a task group, or a task in a different DAG to complete for a | ||
| specific execution_date | ||
|
|
||
| :param external_dag_id: The dag_id that contains the task you want to | ||
| wait for | ||
| If both `external_task_group_id` and `external_task_id` are ``None`` (default), the sensor | ||
| waits for the DAG. | ||
|
|
||
| Values for `external_task_group_id` and `external_task_id` can't be set at the same time. | ||
|
|
||
| :param external_dag_id: The dag_id that contains the task you want to wait for | ||
| :type external_dag_id: str | ||
| :param external_task_id: The task_id that contains the task you want to | ||
| wait for. If ``None`` (default value) the sensor waits for the DAG | ||
| wait for. | ||
| :type external_task_id: str or None | ||
| :param external_task_ids: The list of task_ids that you want to wait for. | ||
| If ``None`` (default value) the sensor waits for the DAG. Either | ||
| external_task_id or external_task_ids can be passed to | ||
| ExternalTaskSensor, but not both. | ||
| :type external_task_ids: Iterable of task_ids or None, default is None | ||
| :param external_task_group_id: The task group_id that contains the tasks you want to | ||
| wait for. | ||
| :type external_task_group_id: str or None | ||
| :param allowed_states: Iterable of allowed states, default is ``['success']`` | ||
| :type allowed_states: Iterable | ||
| :param failed_states: Iterable of failed or dis-allowed states, default is ``None`` | ||
|
|
@@ -97,6 +109,7 @@ def __init__( | |
| external_dag_id: str, | ||
| external_task_id: Optional[str] = None, | ||
| external_task_ids: Optional[Iterable[str]] = None, | ||
| external_task_group_id: Optional[str] = None, | ||
| allowed_states: Optional[Iterable[str]] = None, | ||
| failed_states: Optional[Iterable[str]] = None, | ||
| execution_delta: Optional[datetime.timedelta] = None, | ||
|
|
@@ -125,6 +138,13 @@ def __init__( | |
|
|
||
| if external_task_id is not None: | ||
| external_task_ids = [external_task_id] | ||
|
|
||
|
|
||
| if external_task_group_id and external_task_ids: | ||
| raise ValueError( | ||
| "Values for `external_task_group_id` and `external_task_id` or `external_task_ids` " | ||
| "can't be set at the same time" | ||
| ) | ||
|
|
||
| if external_task_ids: | ||
| if not total_states <= set(State.task_states): | ||
|
|
@@ -149,21 +169,24 @@ def __init__( | |
| self.execution_delta = execution_delta | ||
| self.execution_date_fn = execution_date_fn | ||
| self.external_dag_id = external_dag_id | ||
| self.external_task_group_id = external_task_group_id | ||
| self.external_task_id = external_task_id | ||
| self.external_task_ids = external_task_ids | ||
| self.check_existence = check_existence | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I wonder if we can change the default to True or even have |
||
| self._has_checked_existence = False | ||
|
|
||
| @provide_session | ||
| def poke(self, context, session=None): | ||
| def _get_dttm_filter(self, context): | ||
| if self.execution_delta: | ||
| dttm = context['execution_date'] - self.execution_delta | ||
| elif self.execution_date_fn: | ||
| dttm = self._handle_execution_date_fn(context=context) | ||
| else: | ||
| dttm = context['execution_date'] | ||
| return dttm if isinstance(dttm, list) else [dttm] | ||
|
|
||
| dttm_filter = dttm if isinstance(dttm, list) else [dttm] | ||
| @provide_session | ||
| def poke(self, context, session=None): | ||
| dttm_filter = self._get_dttm_filter(context) | ||
| serialized_dttm_filter = ','.join(dt.isoformat() for dt in dttm_filter) | ||
|
|
||
| self.log.info( | ||
|
|
@@ -189,13 +212,18 @@ def poke(self, context, session=None): | |
| f'Some of the external tasks {self.external_task_ids} ' | ||
| f'in DAG {self.external_dag_id} failed.' | ||
| ) | ||
| elif self.external_task_group_id: | ||
| raise AirflowException( | ||
| f"f'The external task group {self.external_task_group_id} " | ||
| f"in DAG {self.external_dag_id} failed.'" | ||
| ) | ||
| else: | ||
| raise AirflowException(f'The external DAG {self.external_dag_id} failed.') | ||
|
|
||
| return count_allowed == len(dttm_filter) | ||
|
|
||
| def _check_for_existence(self, session) -> None: | ||
| dag_to_wait = session.query(DagModel).filter(DagModel.dag_id == self.external_dag_id).first() | ||
| dag_to_wait = DagModel.get_current(self.external_dag_id, session) | ||
|
|
||
| if not dag_to_wait: | ||
| raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.') | ||
|
|
@@ -227,30 +255,50 @@ def get_count(self, dttm_filter, session, states) -> int: | |
| """ | ||
| TI = TaskInstance | ||
| DR = DagRun | ||
|
|
||
| if self.external_task_ids: | ||
| count = ( | ||
| session.query(func.count()) # .count() is inefficient | ||
| .filter( | ||
| TI.dag_id == self.external_dag_id, | ||
| TI.task_id.in_(self.external_task_ids), | ||
| TI.state.in_(states), | ||
| TI.execution_date.in_(dttm_filter), | ||
| ) | ||
| self._count_query(TI, session, states, dttm_filter) | ||
| .filter(TI.task_id.in_(self.external_task_ids)) | ||
| .scalar() | ||
| ) | ||
| count = count / len(self.external_task_ids) | ||
| else: | ||
| ) / len(self.external_task_ids) | ||
| count /= len(self.external_task_ids) | ||
| elif self.external_task_group_id: | ||
| external_task_group_task_ids = self.get_external_task_group_task_ids(session) | ||
| count = ( | ||
| session.query(func.count()) | ||
| .filter( | ||
| DR.dag_id == self.external_dag_id, | ||
| DR.state.in_(states), | ||
| DR.execution_date.in_(dttm_filter), | ||
| ) | ||
| self._count_query(TI, session, states, dttm_filter) | ||
| .filter(TI.task_id.in_(external_task_group_task_ids)) | ||
| .scalar() | ||
| ) | ||
| count /= len(external_task_group_task_ids) | ||
| else: | ||
| count = self._count_query(DR, session, states, dttm_filter).scalar() | ||
|
|
||
| return count | ||
|
|
||
| def _count_query(self, model, session, states, dttm_filter) -> "Query": | ||
| query = session.query(func.count()).filter( # .count() is inefficient | ||
| model.dag_id == self.external_dag_id, | ||
| model.state.in_(states), # pylint: disable=no-member | ||
| model.execution_date.in_(dttm_filter), | ||
| ) | ||
|
|
||
| return query | ||
|
|
||
| def get_external_task_group_task_ids(self, session): | ||
| """Return task ids for the external TaskGroup""" | ||
| refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session) | ||
| task_group: Optional["TaskGroup"] = refreshed_dag_info.task_group_dict.get( | ||
| self.external_task_group_id | ||
| ) | ||
| if not task_group: | ||
| raise AirflowException( | ||
| f"The external task group {self.external_task_group_id} in " | ||
| f"DAG {self.external_dag_id} does not exist." | ||
| ) | ||
| task_ids = [task.task_id for task in task_group] | ||
| return task_ids | ||
|
|
||
|
Comment on lines
288
to
301
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main piece that you retrieve a list of tasks for a TaskGroup. I believe that
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The existing task execution code is creating DagBag on its own instead of reading serialized dags from db. For example this line is creating a DagBag. I think we should do the same here. It's important for tasks to get the latest view of the dag during execution. |
||
| def _handle_execution_date_fn(self, context) -> Any: | ||
| """ | ||
| This function is to handle backwards compatibility with how this operator was | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_task_group_dict()is a recursive function that can be costly. I think we should keep it a method instead of making it a property (which tends to suggest to users that it's cheap to access).