diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index cc6d7e5e8cb85..a43c3c7078083 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -649,7 +649,7 @@ def post_clear_task_instances( if task_ids is not None: task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids] dag = dag.partial_subset( - task_ids_or_regex=task_id, + task_ids=task_id, include_downstream=downstream, include_upstream=upstream, ) diff --git a/airflow/api_fastapi/core_api/routes/ui/grid.py b/airflow/api_fastapi/core_api/routes/ui/grid.py index c8f4440ec6952..10705bb7c938f 100644 --- a/airflow/api_fastapi/core_api/routes/ui/grid.py +++ b/airflow/api_fastapi/core_api/routes/ui/grid.py @@ -134,7 +134,7 @@ def grid_data( if root: task_node_map_exclude = get_task_group_map( dag=dag.partial_subset( - task_ids_or_regex=root, + task_ids=root, include_upstream=include_upstream, include_downstream=include_downstream, ) diff --git a/airflow/api_fastapi/core_api/routes/ui/structure.py b/airflow/api_fastapi/core_api/routes/ui/structure.py index 774952149cd9d..f3d13d871e484 100644 --- a/airflow/api_fastapi/core_api/routes/ui/structure.py +++ b/airflow/api_fastapi/core_api/routes/ui/structure.py @@ -70,7 +70,7 @@ def structure_data( if root: dag = dag.partial_subset( - task_ids_or_regex=root, include_upstream=include_upstream, include_downstream=include_downstream + task_ids=root, include_upstream=include_upstream, include_downstream=include_downstream ) nodes = [task_group_to_dict(child) for child in dag.task_group.topological_sort()] diff --git a/airflow/cli/commands/remote_commands/task_command.py b/airflow/cli/commands/remote_commands/task_command.py index 687a3bdcef35d..f6da076fa3643 100644 --- a/airflow/cli/commands/remote_commands/task_command.py +++ b/airflow/cli/commands/remote_commands/task_command.py @@ -698,7 +698,7 @@ def task_clear(args) -> None: if args.task_regex: for idx, dag in enumerate(dags): dags[idx] = dag.partial_subset( - task_ids_or_regex=args.task_regex, + task_ids=args.task_regex, include_downstream=args.downstream, include_upstream=args.upstream, ) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 16cbf28d64706..5d67faf7f1636 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1134,7 +1134,7 @@ def _get_task_instances( if not external_dag: raise AirflowException(f"Could not find dag {tii.dag_id}") downstream = external_dag.partial_subset( - task_ids_or_regex=[tii.task_id], + task_ids=[tii.task_id], include_upstream=False, include_downstream=True, ) @@ -1248,7 +1248,7 @@ def set_task_instance_state( # Flush the session so that the tasks marked success are reflected in the db. session.flush() subdag = self.partial_subset( - task_ids_or_regex={task_id}, + task_ids={task_id}, include_downstream=True, include_upstream=False, ) @@ -1360,7 +1360,7 @@ def get_logical_date() -> datetime: # Flush the session so that the tasks marked success are reflected in the db. session.flush() task_subset = self.partial_subset( - task_ids_or_regex=task_ids, + task_ids=task_ids, include_downstream=True, include_upstream=False, ) diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index d716daa0f4697..9a8f93fb9afd9 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -28,7 +28,6 @@ from collections.abc import Collection, Iterable, MutableSet from datetime import datetime, timedelta from inspect import signature -from re import Pattern from typing import ( TYPE_CHECKING, Any, @@ -41,7 +40,6 @@ import attrs import jinja2 -import re2 from dateutil.relativedelta import relativedelta from airflow import settings @@ -741,7 +739,7 @@ def __deepcopy__(self, memo: dict[int, Any]): def partial_subset( self, - task_ids_or_regex: str | Pattern | Iterable[str], + task_ids: str | Iterable[str], include_downstream=False, include_upstream=True, include_direct_upstream=False, @@ -753,8 +751,7 @@ def partial_subset( based on a regex that should match one or many tasks, and includes upstream and downstream neighbours based on the flag passed. - :param task_ids_or_regex: Either a list of task_ids, or a regex to - match against task ids (as a string, or compiled regex pattern). + :param task_ids: Either a list of task_ids, or a string task_id :param include_downstream: Include all downstream tasks of matched tasks, in addition to matched tasks. :param include_upstream: Include all upstream tasks of matched tasks, @@ -769,10 +766,10 @@ def partial_subset( memo = {id(self.task_dict): None, id(self.task_group): None} dag = copy.deepcopy(self, memo) # type: ignore - if isinstance(task_ids_or_regex, (str, Pattern)): - matched_tasks = [t for t in self.tasks if re2.findall(task_ids_or_regex, t.task_id)] + if isinstance(task_ids, str): + matched_tasks = [t for t in self.tasks if task_ids in t.task_id] else: - matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex] + matched_tasks = [t for t in self.tasks if t.task_id in task_ids] also_include_ids: set[str] = set() for t in matched_tasks: diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 01befbc488f63..fd77b62e4911b 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -3046,7 +3046,7 @@ def cleared_downstream(task): upstream = False return set( task.dag.partial_subset( - task_ids_or_regex=[task.task_id], + task_ids=[task.task_id], include_downstream=not upstream, include_upstream=upstream, ).tasks @@ -3058,7 +3058,7 @@ def cleared_upstream(task): upstream = True return set( task.dag.partial_subset( - task_ids_or_regex=task.task_id, + task_ids=task.task_id, include_downstream=not upstream, include_upstream=upstream, ).tasks @@ -3069,7 +3069,7 @@ def cleared_neither(task): """Helper to return tasks that would be cleared if **upstream** selected.""" return set( task.dag.partial_subset( - task_ids_or_regex=[task.task_id], + task_ids=[task.task_id], include_downstream=False, include_upstream=False, ).tasks diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 1b347526ec4f8..9bce3646c00fc 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -1286,7 +1286,7 @@ def clear_tasks( """ Clear the task and its downstream tasks recursively for the dag in the given dagbag. """ - partial: DAG = dag.partial_subset(task_ids_or_regex=[task.task_id], include_downstream=True) + partial: DAG = dag.partial_subset(task_ids=[task.task_id], include_downstream=True) return partial.clear( start_date=start_date, end_date=end_date, @@ -1719,7 +1719,7 @@ def test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m session.flush() dag = dag.partial_subset( - task_ids_or_regex=["head"], + task_ids=["head"], include_downstream=True, include_upstream=False, ) diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index b6643348a4eb5..cfaeaf8288b61 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -455,7 +455,7 @@ def test_sub_dag_task_group(): group234 >> group6 group234 >> task7 - subdag = dag.partial_subset(task_ids_or_regex="task5", include_upstream=True, include_downstream=False) + subdag = dag.partial_subset(task_ids="task5", include_upstream=True, include_downstream=False) expected_node_id = { "id": None,