Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def post_clear_task_instances(
if task_ids is not None:
task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids]
dag = dag.partial_subset(
task_ids_or_regex=task_id,
task_ids=task_id,
include_downstream=downstream,
include_upstream=upstream,
)
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/routes/ui/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def grid_data(
if root:
task_node_map_exclude = get_task_group_map(
dag=dag.partial_subset(
task_ids_or_regex=root,
task_ids=root,
include_upstream=include_upstream,
include_downstream=include_downstream,
)
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/routes/ui/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def structure_data(

if root:
dag = dag.partial_subset(
task_ids_or_regex=root, include_upstream=include_upstream, include_downstream=include_downstream
task_ids=root, include_upstream=include_upstream, include_downstream=include_downstream
)

nodes = [task_group_to_dict(child) for child in dag.task_group.topological_sort()]
Expand Down
2 changes: 1 addition & 1 deletion airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def task_clear(args) -> None:
if args.task_regex:
for idx, dag in enumerate(dags):
dags[idx] = dag.partial_subset(
task_ids_or_regex=args.task_regex,
task_ids=args.task_regex,
include_downstream=args.downstream,
include_upstream=args.upstream,
)
Expand Down
6 changes: 3 additions & 3 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ def _get_task_instances(
if not external_dag:
raise AirflowException(f"Could not find dag {tii.dag_id}")
downstream = external_dag.partial_subset(
task_ids_or_regex=[tii.task_id],
task_ids=[tii.task_id],
include_upstream=False,
include_downstream=True,
)
Expand Down Expand Up @@ -1248,7 +1248,7 @@ def set_task_instance_state(
# Flush the session so that the tasks marked success are reflected in the db.
session.flush()
subdag = self.partial_subset(
task_ids_or_regex={task_id},
task_ids={task_id},
include_downstream=True,
include_upstream=False,
)
Expand Down Expand Up @@ -1360,7 +1360,7 @@ def get_logical_date() -> datetime:
# Flush the session so that the tasks marked success are reflected in the db.
session.flush()
task_subset = self.partial_subset(
task_ids_or_regex=task_ids,
task_ids=task_ids,
include_downstream=True,
include_upstream=False,
)
Expand Down
13 changes: 5 additions & 8 deletions task-sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from collections.abc import Collection, Iterable, MutableSet
from datetime import datetime, timedelta
from inspect import signature
from re import Pattern
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -41,7 +40,6 @@

import attrs
import jinja2
import re2
from dateutil.relativedelta import relativedelta

from airflow import settings
Expand Down Expand Up @@ -741,7 +739,7 @@ def __deepcopy__(self, memo: dict[int, Any]):

def partial_subset(
self,
task_ids_or_regex: str | Pattern | Iterable[str],
task_ids: str | Iterable[str],
include_downstream=False,
include_upstream=True,
include_direct_upstream=False,
Expand All @@ -753,8 +751,7 @@ def partial_subset(
based on a regex that should match one or many tasks, and includes
upstream and downstream neighbours based on the flag passed.

:param task_ids_or_regex: Either a list of task_ids, or a regex to
match against task ids (as a string, or compiled regex pattern).
:param task_ids: Either a list of task_ids, or a string task_id
:param include_downstream: Include all downstream tasks of matched
tasks, in addition to matched tasks.
:param include_upstream: Include all upstream tasks of matched tasks,
Expand All @@ -769,10 +766,10 @@ def partial_subset(
memo = {id(self.task_dict): None, id(self.task_group): None}
dag = copy.deepcopy(self, memo) # type: ignore

if isinstance(task_ids_or_regex, (str, Pattern)):
matched_tasks = [t for t in self.tasks if re2.findall(task_ids_or_regex, t.task_id)]
if isinstance(task_ids, str):
matched_tasks = [t for t in self.tasks if task_ids in t.task_id]
else:
matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex]
matched_tasks = [t for t in self.tasks if t.task_id in task_ids]

also_include_ids: set[str] = set()
for t in matched_tasks:
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3046,7 +3046,7 @@ def cleared_downstream(task):
upstream = False
return set(
task.dag.partial_subset(
task_ids_or_regex=[task.task_id],
task_ids=[task.task_id],
include_downstream=not upstream,
include_upstream=upstream,
).tasks
Expand All @@ -3058,7 +3058,7 @@ def cleared_upstream(task):
upstream = True
return set(
task.dag.partial_subset(
task_ids_or_regex=task.task_id,
task_ids=task.task_id,
include_downstream=not upstream,
include_upstream=upstream,
).tasks
Expand All @@ -3069,7 +3069,7 @@ def cleared_neither(task):
"""Helper to return tasks that would be cleared if **upstream** selected."""
return set(
task.dag.partial_subset(
task_ids_or_regex=[task.task_id],
task_ids=[task.task_id],
include_downstream=False,
include_upstream=False,
).tasks
Expand Down
4 changes: 2 additions & 2 deletions tests/sensors/test_external_task_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ def clear_tasks(
"""
Clear the task and its downstream tasks recursively for the dag in the given dagbag.
"""
partial: DAG = dag.partial_subset(task_ids_or_regex=[task.task_id], include_downstream=True)
partial: DAG = dag.partial_subset(task_ids=[task.task_id], include_downstream=True)
return partial.clear(
start_date=start_date,
end_date=end_date,
Expand Down Expand Up @@ -1719,7 +1719,7 @@ def test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m
session.flush()

dag = dag.partial_subset(
task_ids_or_regex=["head"],
task_ids=["head"],
include_downstream=True,
include_upstream=False,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def test_sub_dag_task_group():
group234 >> group6
group234 >> task7

subdag = dag.partial_subset(task_ids_or_regex="task5", include_upstream=True, include_downstream=False)
subdag = dag.partial_subset(task_ids="task5", include_upstream=True, include_downstream=False)

expected_node_id = {
"id": None,
Expand Down