Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@
from airflow.api_fastapi.core_api.security import GetUserDep, ReadableTIFilterDep, requires_access_dag
from airflow.api_fastapi.core_api.services.public.task_instances import (
BulkTaskInstanceService,
_patch_task_instance_note,
_patch_task_instance_state,
_patch_ti_validate_request,
)
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.exceptions import TaskNotFound
from airflow.listeners.listener import get_listener_manager
from airflow.models import Base, DagRun
from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances
from airflow.models.taskinstancehistory import TaskInstanceHistory as TIH
Expand Down Expand Up @@ -848,46 +849,22 @@ def patch_task_instance(

for key, _ in data.items():
if key == "new_state":
tis = dag.set_task_instance_state(
_patch_task_instance_state(
task_id=task_id,
run_id=dag_run_id,
map_indexes=[map_index] if map_index is not None else None,
state=data["new_state"],
upstream=body.include_upstream,
downstream=body.include_downstream,
future=body.include_future,
past=body.include_past,
commit=True,
dag_run_id=dag_run_id,
dag=dag,
task_instance_body=body,
data=data,
session=session,
)
if not tis:
raise HTTPException(
status.HTTP_409_CONFLICT, f"Task id {task_id} is already in {data['new_state']} state"
)

for ti in tis:
try:
if data["new_state"] == TaskInstanceState.SUCCESS:
get_listener_manager().hook.on_task_instance_success(
previous_state=None, task_instance=ti
)
elif data["new_state"] == TaskInstanceState.FAILED:
get_listener_manager().hook.on_task_instance_failed(
previous_state=None,
task_instance=ti,
error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.",
)
except Exception:
log.exception("error calling listener")

elif key == "note":
for ti in tis:
if update_mask or body.note is not None:
if ti.task_instance_note is None:
ti.note = (body.note, user.get_id())
else:
ti.task_instance_note.content = body.note
ti.task_instance_note.user_id = user.get_id()
_patch_task_instance_note(
task_instance_body=body,
tis=tis,
user=user,
update_mask=update_mask,
)

return TaskInstanceCollectionResponse(
task_instances=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,64 @@ def _patch_ti_validate_request(
return dag, list(tis), body.model_dump(include=fields_to_update, by_alias=True)


def _patch_task_instance_state(
task_id: str,
dag_run_id: str,
dag: DAG,
task_instance_body: BulkTaskInstanceBody | PatchTaskInstanceBody,
data: dict,
session: Session,
) -> None:
map_index = getattr(task_instance_body, "map_index", None)
map_indexes = None if map_index is None else [map_index]

updated_tis = dag.set_task_instance_state(
task_id=task_id,
run_id=dag_run_id,
map_indexes=map_indexes,
state=data["new_state"],
upstream=task_instance_body.include_upstream,
downstream=task_instance_body.include_downstream,
future=task_instance_body.include_future,
past=task_instance_body.include_past,
commit=True,
session=session,
)
if not updated_tis:
raise HTTPException(
status.HTTP_409_CONFLICT,
f"Task id {task_id} is already in {data['new_state']} state",
)

for ti in updated_tis:
try:
if data["new_state"] == TaskInstanceState.SUCCESS:
get_listener_manager().hook.on_task_instance_success(previous_state=None, task_instance=ti)
elif data["new_state"] == TaskInstanceState.FAILED:
get_listener_manager().hook.on_task_instance_failed(
previous_state=None,
task_instance=ti,
error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.",
)
except Exception:
log.exception("error calling listener")


def _patch_task_instance_note(
task_instance_body: BulkTaskInstanceBody | PatchTaskInstanceBody,
tis: list[TI],
user: GetUserDep,
update_mask: list[str] | None = Query(None),
) -> None:
for ti in tis:
if update_mask or task_instance_body.note is not None:
if ti.task_instance_note is None:
ti.note = (task_instance_body.note, user.get_id())
else:
ti.task_instance_note.content = task_instance_body.note
ti.task_instance_note.user_id = user.get_id()


class BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]):
"""Service for handling bulk operations on task instances."""

Expand Down Expand Up @@ -134,55 +192,6 @@ def categorize_task_instances(
not_found_task_keys = {(task_id, map_index) for task_id, map_index in task_ids} - matched_task_keys
return task_instances_map, matched_task_keys, not_found_task_keys

def _patch_task_instance_state(
self,
dag: DAG,
task_instance_body: BulkTaskInstanceBody,
data: dict,
) -> None:
map_indexes = None if task_instance_body.map_index is None else [task_instance_body.map_index]

updated_tis = dag.set_task_instance_state(
task_id=task_instance_body.task_id,
run_id=self.dag_run_id,
map_indexes=map_indexes,
state=data["new_state"],
upstream=task_instance_body.include_upstream,
downstream=task_instance_body.include_downstream,
future=task_instance_body.include_future,
past=task_instance_body.include_past,
commit=True,
session=self.session,
)
if not updated_tis:
raise HTTPException(
status.HTTP_409_CONFLICT,
f"Task id {task_instance_body.task_id} is already in {data['new_state']} state",
)
for ti in updated_tis:
try:
if data["new_state"] == TaskInstanceState.SUCCESS:
get_listener_manager().hook.on_task_instance_success(
previous_state=None, task_instance=ti
)
elif data["new_state"] == TaskInstanceState.FAILED:
get_listener_manager().hook.on_task_instance_failed(
previous_state=None,
task_instance=ti,
error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.",
)
except Exception:
log.exception("error calling listener")

def _patch_task_instance_note(self, task_instance_body: BulkTaskInstanceBody, tis: list[TI]) -> None:
for ti in tis:
if task_instance_body.note is not None:
if ti.task_instance_note is None:
ti.note = (task_instance_body.note, self.user.get_id())
else:
ti.task_instance_note.content = task_instance_body.note
ti.task_instance_note.user_id = self.user.get_id()

def handle_bulk_create(
self, action: BulkCreateAction[BulkTaskInstanceBody], results: BulkActionResponse
) -> None:
Expand Down Expand Up @@ -232,13 +241,18 @@ def handle_bulk_update(

for key, _ in data.items():
if key == "new_state":
self._patch_task_instance_state(
_patch_task_instance_state(
task_id=task_instance_body.task_id,
dag_run_id=self.dag_run_id,
dag=dag,
task_instance_body=task_instance_body,
session=self.session,
data=data,
)
elif key == "note":
self._patch_task_instance_note(task_instance_body=task_instance_body, tis=tis)
_patch_task_instance_note(
task_instance_body=task_instance_body, tis=tis, user=self.user
)

results.success.append(task_instance_body.task_id)
except ValidationError as e:
Expand Down