From f34887355e1475dcf3fbea877edf736d36c0cf14 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 3 Nov 2025 18:08:58 +0800 Subject: [PATCH] Simplify typing in TriggerRuleDep Instead of doing checks on ti.task over and over again, we can just do it once and reuse the task variable everywhere. This should be safe since nothing leaks out of the function's local scope. --- .../airflow/ti_deps/deps/trigger_rule_dep.py | 62 ++++++++----------- 1 file changed, 26 insertions(+), 36 deletions(-) diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py index 670d038a70dd8..7bac5adb00adb 100644 --- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py @@ -35,10 +35,10 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnOperators - from airflow.models.baseoperator import BaseOperator from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup + from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.base_ti_dep import TIDepStatus @@ -134,6 +134,10 @@ def _evaluate_trigger_rule( from airflow.models.mappedoperator import is_mapped from airflow.models.taskinstance import TaskInstance + task = ti.task + if TYPE_CHECKING: + assert task + @functools.lru_cache def _get_expanded_ti_count() -> int: """ @@ -144,13 +148,10 @@ def _get_expanded_ti_count() -> int: """ from airflow.models.mappedoperator import get_mapped_ti_count - if TYPE_CHECKING: - assert ti.task - - return get_mapped_ti_count(ti.task, ti.run_id, session=session) + return get_mapped_ti_count(task, ti.run_id, session=session) def _iter_expansion_dependencies(task_group: SerializedMappedTaskGroup | None) -> Iterator[str]: - if (task := ti.task) is not None and is_mapped(task): + if is_mapped(task): for op in task.iter_mapped_dependencies(): yield op.task_id if task_group and task_group.iter_mapped_task_groups(): @@ -170,14 +171,13 @@ def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: task instance of the same task). """ if TYPE_CHECKING: - assert ti.task - assert ti.task.dag - assert ti.task.task_group + assert task.dag + assert task.task_group - if is_mapped(ti.task.task_group): - is_fast_triggered = ti.task.trigger_rule in (TR.ONE_SUCCESS, TR.ONE_FAILED, TR.ONE_DONE) + if is_mapped(task.task_group): + is_fast_triggered = task.trigger_rule in (TR.ONE_SUCCESS, TR.ONE_FAILED, TR.ONE_DONE) if is_fast_triggered and upstream_id not in set( - _iter_expansion_dependencies(task_group=ti.task.task_group) + _iter_expansion_dependencies(task_group=task.task_group) ): return None @@ -186,7 +186,7 @@ def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: except (NotFullyPopulated, NotMapped): return None return ti.get_relevant_upstream_map_indexes( - upstream=ti.task.dag.task_dict[upstream_id], + upstream=task.dag.task_dict[upstream_id], ti_count=expanded_ti_count, session=session, ) @@ -201,15 +201,12 @@ def _is_relevant_upstream(upstream: TaskInstance, relevant_ids: set[str] | KeysV 2. ti is in a mapped task group and upstream has a map index that ti does not depend on. """ - if TYPE_CHECKING: - assert ti.task - # Not actually an upstream task. if upstream.task_id not in relevant_ids: return False # The current task is not in a mapped task group. All tis from an # upstream task are relevant. - if ti.task.get_closest_mapped_task_group() is None: + if task.get_closest_mapped_task_group() is None: return True # The upstream ti is not expanded. The upstream may be mapped or # not, but the ti is relevant either way. @@ -231,10 +228,7 @@ def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnOperators] # it depends on all upstream task instances. from airflow.models.taskinstance import TaskInstance - if TYPE_CHECKING: - assert ti.task - - if ti.task.get_closest_mapped_task_group() is None: + if task.get_closest_mapped_task_group() is None: yield TaskInstance.task_id.in_(relevant_tasks.keys()) return # Otherwise we need to figure out which map indexes are depended on @@ -261,17 +255,15 @@ def _iter_upstream_conditions(relevant_tasks: dict) -> Iterator[ColumnOperators] yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes) def _evaluate_setup_constraint( - *, relevant_setups: Mapping[str, BaseOperator | MappedOperator] + *, relevant_setups: Mapping[str, SerializedBaseOperator | MappedOperator] ) -> Iterator[tuple[TIDepStatus, bool]]: """ Evaluate whether ``ti``'s trigger rule was met as part of the setup constraint. :param relevant_setups: Relevant setups for the current task instance. """ - if TYPE_CHECKING: - assert ti.task - - task = ti.task + if not relevant_setups: + return indirect_setups = {k: v for k, v in relevant_setups.items() if k not in task.upstream_task_ids} finished_upstream_tis = ( @@ -353,10 +345,6 @@ def _evaluate_setup_constraint( def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: """Evaluate whether ``ti``'s trigger rule in direct relatives was met.""" - if TYPE_CHECKING: - assert ti.task - - task = ti.task upstream_tasks = {t.task_id: t for t in task.upstream_list} trigger_rule = task.trigger_rule trigger_rule_str = getattr(trigger_rule, "value", trigger_rule) @@ -364,7 +352,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: finished_upstream_tis = ( finished_ti for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) - if _is_relevant_upstream(upstream=finished_ti, relevant_ids=ti.task.upstream_task_ids) + if _is_relevant_upstream(upstream=finished_ti, relevant_ids=task.upstream_task_ids) ) upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) @@ -629,12 +617,14 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: reason=f"No strategy to evaluate trigger rule '{trigger_rule_str}'." ) - if TYPE_CHECKING: - assert ti.task - - if not ti.task.is_teardown: + if not task.is_teardown: # a teardown cannot have any indirect setups - relevant_setups = {t.task_id: t for t in ti.task.get_upstreams_only_setups()} + relevant_setups: dict[str, MappedOperator | SerializedBaseOperator] = { + # TODO (GH-52141): This should return scheduler types, but + # currently we reuse logic in SDK DAGNode. + t.task_id: t # type: ignore[misc] + for t in task.get_upstreams_only_setups() + } if relevant_setups: for status, changed in _evaluate_setup_constraint(relevant_setups=relevant_setups): yield status