Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,10 +829,18 @@ def tasks(self, val):
def task_ids(self) -> List[str]:
return list(self.task_dict.keys())

@property
def task_group_dict(self) -> Dict[str, "TaskGroup"]:
return {k: v for k, v in self.task_group.get_task_group_dict().items() if k is not None}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_task_group_dict() is a recursive function that can be costly. I think we should keep it a method instead of making it a property (which tends to suggest to users that it's cheap to access).


@property
def task_group(self) -> "TaskGroup":
return self._task_group

@property
def task_groups(self) -> List["TaskGroup"]:
return list(self.task_group_dict.values())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for task_groups


@property
def filepath(self) -> str:
""":meta private:"""
Expand Down Expand Up @@ -1883,8 +1891,11 @@ def filter_task_group(group, parent_group):

return dag

def has_task_group(self, group_id: str) -> bool:
return group_id in self.task_group_dict

def has_task(self, task_id: str):
return task_id in (t.task_id for t in self.tasks)
return task_id in self.task_dict

def get_task(self, task_id: str, include_subdags: bool = False) -> BaseOperator:
if task_id in self.task_dict:
Expand Down
98 changes: 73 additions & 25 deletions airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import datetime
import os
from typing import Any, Callable, FrozenSet, Iterable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, FrozenSet, Iterable, Optional, Union

from sqlalchemy import func

Expand All @@ -30,6 +30,11 @@
from airflow.utils.session import provide_session
from airflow.utils.state import State

if TYPE_CHECKING:
from sqlalchemy.orm import Query

from airflow.utils.task_group import TaskGroup


class ExternalTaskSensorLink(BaseOperatorLink):
"""
Expand All @@ -46,20 +51,27 @@ def get_link(self, operator, dttm):

class ExternalTaskSensor(BaseSensorOperator):
"""
Waits for a different DAG or a task in a different DAG to complete for a
Waits for a different DAG, a task group, or a task in a different DAG to complete for a
specific execution_date

:param external_dag_id: The dag_id that contains the task you want to
wait for
If both `external_task_group_id` and `external_task_id` are ``None`` (default), the sensor
waits for the DAG.

Values for `external_task_group_id` and `external_task_id` can't be set at the same time.

:param external_dag_id: The dag_id that contains the task you want to wait for
:type external_dag_id: str
:param external_task_id: The task_id that contains the task you want to
wait for. If ``None`` (default value) the sensor waits for the DAG
wait for.
:type external_task_id: str or None
:param external_task_ids: The list of task_ids that you want to wait for.
If ``None`` (default value) the sensor waits for the DAG. Either
external_task_id or external_task_ids can be passed to
ExternalTaskSensor, but not both.
:type external_task_ids: Iterable of task_ids or None, default is None
:param external_task_group_id: The task group_id that contains the tasks you want to
wait for.
:type external_task_group_id: str or None
:param allowed_states: Iterable of allowed states, default is ``['success']``
:type allowed_states: Iterable
:param failed_states: Iterable of failed or dis-allowed states, default is ``None``
Expand Down Expand Up @@ -97,6 +109,7 @@ def __init__(
external_dag_id: str,
external_task_id: Optional[str] = None,
external_task_ids: Optional[Iterable[str]] = None,
external_task_group_id: Optional[str] = None,
allowed_states: Optional[Iterable[str]] = None,
failed_states: Optional[Iterable[str]] = None,
execution_delta: Optional[datetime.timedelta] = None,
Expand Down Expand Up @@ -125,6 +138,13 @@ def __init__(

if external_task_id is not None:
external_task_ids = [external_task_id]


if external_task_group_id and external_task_ids:
raise ValueError(
"Values for `external_task_group_id` and `external_task_id` or `external_task_ids` "
"can't be set at the same time"
)

if external_task_ids:
if not total_states <= set(State.task_states):
Expand All @@ -149,21 +169,24 @@ def __init__(
self.execution_delta = execution_delta
self.execution_date_fn = execution_date_fn
self.external_dag_id = external_dag_id
self.external_task_group_id = external_task_group_id
self.external_task_id = external_task_id
self.external_task_ids = external_task_ids
self.check_existence = check_existence

@xinbinhuang xinbinhuang Mar 6, 2021

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.check_existence = check_existence is False by default, which maybe make sense for external_dag or external_task. But external_task_group has to check and get an existing dag in order to get the list of task_ids.

https://github.com/apache/airflow/blob/fce49402461ee4e7a5f6ffd18cee3121f3496a39/airflow/sensors/external_task.py#L174-L180

I wonder if we can change the default to True or even have check_existence enabled required? This can give more useful errors if the external task/dag does not exist as well as having a consistent behavior as external_task_group. Also, what would be use case to have a Sensor waiting for an object that doesn't exist until it times out?

self._has_checked_existence = False

@provide_session
def poke(self, context, session=None):
def _get_dttm_filter(self, context):
if self.execution_delta:
dttm = context['execution_date'] - self.execution_delta
elif self.execution_date_fn:
dttm = self._handle_execution_date_fn(context=context)
else:
dttm = context['execution_date']
return dttm if isinstance(dttm, list) else [dttm]

dttm_filter = dttm if isinstance(dttm, list) else [dttm]
@provide_session
def poke(self, context, session=None):
dttm_filter = self._get_dttm_filter(context)
serialized_dttm_filter = ','.join(dt.isoformat() for dt in dttm_filter)

self.log.info(
Expand All @@ -189,13 +212,18 @@ def poke(self, context, session=None):
f'Some of the external tasks {self.external_task_ids} '
f'in DAG {self.external_dag_id} failed.'
)
elif self.external_task_group_id:
raise AirflowException(
f"f'The external task group {self.external_task_group_id} "
f"in DAG {self.external_dag_id} failed.'"
)
else:
raise AirflowException(f'The external DAG {self.external_dag_id} failed.')

return count_allowed == len(dttm_filter)

def _check_for_existence(self, session) -> None:
dag_to_wait = session.query(DagModel).filter(DagModel.dag_id == self.external_dag_id).first()
dag_to_wait = DagModel.get_current(self.external_dag_id, session)

if not dag_to_wait:
raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.')
Expand Down Expand Up @@ -227,30 +255,50 @@ def get_count(self, dttm_filter, session, states) -> int:
"""
TI = TaskInstance
DR = DagRun

if self.external_task_ids:
count = (
session.query(func.count()) # .count() is inefficient
.filter(
TI.dag_id == self.external_dag_id,
TI.task_id.in_(self.external_task_ids),
TI.state.in_(states),
TI.execution_date.in_(dttm_filter),
)
self._count_query(TI, session, states, dttm_filter)
.filter(TI.task_id.in_(self.external_task_ids))
.scalar()
)
count = count / len(self.external_task_ids)
else:
) / len(self.external_task_ids)
count /= len(self.external_task_ids)
elif self.external_task_group_id:
external_task_group_task_ids = self.get_external_task_group_task_ids(session)
count = (
session.query(func.count())
.filter(
DR.dag_id == self.external_dag_id,
DR.state.in_(states),
DR.execution_date.in_(dttm_filter),
)
self._count_query(TI, session, states, dttm_filter)
.filter(TI.task_id.in_(external_task_group_task_ids))
.scalar()
)
count /= len(external_task_group_task_ids)
else:
count = self._count_query(DR, session, states, dttm_filter).scalar()

return count

def _count_query(self, model, session, states, dttm_filter) -> "Query":
query = session.query(func.count()).filter( # .count() is inefficient
model.dag_id == self.external_dag_id,
model.state.in_(states), # pylint: disable=no-member
model.execution_date.in_(dttm_filter),
)

return query

def get_external_task_group_task_ids(self, session):
"""Return task ids for the external TaskGroup"""
refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session)
task_group: Optional["TaskGroup"] = refreshed_dag_info.task_group_dict.get(
self.external_task_group_id
)
if not task_group:
raise AirflowException(
f"The external task group {self.external_task_group_id} in "
f"DAG {self.external_dag_id} does not exist."
)
task_ids = [task.task_id for task in task_group]
return task_ids

Comment on lines 288 to 301

@xinbinhuang xinbinhuang Mar 30, 2021

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main piece that you retrieve a list of tasks for a TaskGroup. I believe that read_dags_from_db=True is safe to use here because serialized dag is enabled by default in 2.0.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing task execution code is creating DagBag on its own instead of reading serialized dags from db. For example this line is creating a DagBag. I think we should do the same here. It's important for tasks to get the latest view of the dag during execution.

https://github.com/apache/airflow/blob/f1edc220d3f9cb050016d23246a682276bd09eee/airflow/sensors/external_task.py#L213

def _handle_execution_date_fn(self, context) -> Any:
"""
This function is to handle backwards compatibility with how this operator was
Expand Down
5 changes: 4 additions & 1 deletion airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def is_root(self) -> bool:
"""Returns True if this TaskGroup is the root TaskGroup. Otherwise False"""
return not self.group_id

def __iter__(self):
def __iter__(self) -> "BaseOperator":
for child in self.children.values():
if isinstance(child, TaskGroup):
yield from child
Expand Down Expand Up @@ -343,6 +343,9 @@ def get_child_by_label(self, label: str) -> Union["BaseOperator", "TaskGroup"]:
"""Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)"""
return self.children[self.child_id(label)]

def __repr__(self):
return f"<{self.__class__.__name__}: {self.group_id}>"


class TaskGroupContext:
"""TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
Expand Down
Loading