diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py index 79a0f977916da..915054cdf792e 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -65,11 +65,12 @@ from airflow.api_fastapi.core_api.security import GetUserDep, ReadableTIFilterDep, requires_access_dag from airflow.api_fastapi.core_api.services.public.task_instances import ( BulkTaskInstanceService, + _patch_task_instance_note, + _patch_task_instance_state, _patch_ti_validate_request, ) from airflow.api_fastapi.logging.decorators import action_logging from airflow.exceptions import TaskNotFound -from airflow.listeners.listener import get_listener_manager from airflow.models import Base, DagRun from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances from airflow.models.taskinstancehistory import TaskInstanceHistory as TIH @@ -848,46 +849,22 @@ def patch_task_instance( for key, _ in data.items(): if key == "new_state": - tis = dag.set_task_instance_state( + _patch_task_instance_state( task_id=task_id, - run_id=dag_run_id, - map_indexes=[map_index] if map_index is not None else None, - state=data["new_state"], - upstream=body.include_upstream, - downstream=body.include_downstream, - future=body.include_future, - past=body.include_past, - commit=True, + dag_run_id=dag_run_id, + dag=dag, + task_instance_body=body, + data=data, session=session, ) - if not tis: - raise HTTPException( - status.HTTP_409_CONFLICT, f"Task id {task_id} is already in {data['new_state']} state" - ) - - for ti in tis: - try: - if data["new_state"] == TaskInstanceState.SUCCESS: - get_listener_manager().hook.on_task_instance_success( - previous_state=None, task_instance=ti - ) - elif data["new_state"] == TaskInstanceState.FAILED: - get_listener_manager().hook.on_task_instance_failed( - previous_state=None, - task_instance=ti, - error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.", - ) - except Exception: - log.exception("error calling listener") elif key == "note": - for ti in tis: - if update_mask or body.note is not None: - if ti.task_instance_note is None: - ti.note = (body.note, user.get_id()) - else: - ti.task_instance_note.content = body.note - ti.task_instance_note.user_id = user.get_id() + _patch_task_instance_note( + task_instance_body=body, + tis=tis, + user=user, + update_mask=update_mask, + ) return TaskInstanceCollectionResponse( task_instances=[ diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py index a99e802278f4e..4a25e07340bf8 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py @@ -94,6 +94,64 @@ def _patch_ti_validate_request( return dag, list(tis), body.model_dump(include=fields_to_update, by_alias=True) +def _patch_task_instance_state( + task_id: str, + dag_run_id: str, + dag: DAG, + task_instance_body: BulkTaskInstanceBody | PatchTaskInstanceBody, + data: dict, + session: Session, +) -> None: + map_index = getattr(task_instance_body, "map_index", None) + map_indexes = None if map_index is None else [map_index] + + updated_tis = dag.set_task_instance_state( + task_id=task_id, + run_id=dag_run_id, + map_indexes=map_indexes, + state=data["new_state"], + upstream=task_instance_body.include_upstream, + downstream=task_instance_body.include_downstream, + future=task_instance_body.include_future, + past=task_instance_body.include_past, + commit=True, + session=session, + ) + if not updated_tis: + raise HTTPException( + status.HTTP_409_CONFLICT, + f"Task id {task_id} is already in {data['new_state']} state", + ) + + for ti in updated_tis: + try: + if data["new_state"] == TaskInstanceState.SUCCESS: + get_listener_manager().hook.on_task_instance_success(previous_state=None, task_instance=ti) + elif data["new_state"] == TaskInstanceState.FAILED: + get_listener_manager().hook.on_task_instance_failed( + previous_state=None, + task_instance=ti, + error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.", + ) + except Exception: + log.exception("error calling listener") + + +def _patch_task_instance_note( + task_instance_body: BulkTaskInstanceBody | PatchTaskInstanceBody, + tis: list[TI], + user: GetUserDep, + update_mask: list[str] | None = Query(None), +) -> None: + for ti in tis: + if update_mask or task_instance_body.note is not None: + if ti.task_instance_note is None: + ti.note = (task_instance_body.note, user.get_id()) + else: + ti.task_instance_note.content = task_instance_body.note + ti.task_instance_note.user_id = user.get_id() + + class BulkTaskInstanceService(BulkService[BulkTaskInstanceBody]): """Service for handling bulk operations on task instances.""" @@ -134,55 +192,6 @@ def categorize_task_instances( not_found_task_keys = {(task_id, map_index) for task_id, map_index in task_ids} - matched_task_keys return task_instances_map, matched_task_keys, not_found_task_keys - def _patch_task_instance_state( - self, - dag: DAG, - task_instance_body: BulkTaskInstanceBody, - data: dict, - ) -> None: - map_indexes = None if task_instance_body.map_index is None else [task_instance_body.map_index] - - updated_tis = dag.set_task_instance_state( - task_id=task_instance_body.task_id, - run_id=self.dag_run_id, - map_indexes=map_indexes, - state=data["new_state"], - upstream=task_instance_body.include_upstream, - downstream=task_instance_body.include_downstream, - future=task_instance_body.include_future, - past=task_instance_body.include_past, - commit=True, - session=self.session, - ) - if not updated_tis: - raise HTTPException( - status.HTTP_409_CONFLICT, - f"Task id {task_instance_body.task_id} is already in {data['new_state']} state", - ) - for ti in updated_tis: - try: - if data["new_state"] == TaskInstanceState.SUCCESS: - get_listener_manager().hook.on_task_instance_success( - previous_state=None, task_instance=ti - ) - elif data["new_state"] == TaskInstanceState.FAILED: - get_listener_manager().hook.on_task_instance_failed( - previous_state=None, - task_instance=ti, - error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.", - ) - except Exception: - log.exception("error calling listener") - - def _patch_task_instance_note(self, task_instance_body: BulkTaskInstanceBody, tis: list[TI]) -> None: - for ti in tis: - if task_instance_body.note is not None: - if ti.task_instance_note is None: - ti.note = (task_instance_body.note, self.user.get_id()) - else: - ti.task_instance_note.content = task_instance_body.note - ti.task_instance_note.user_id = self.user.get_id() - def handle_bulk_create( self, action: BulkCreateAction[BulkTaskInstanceBody], results: BulkActionResponse ) -> None: @@ -232,13 +241,18 @@ def handle_bulk_update( for key, _ in data.items(): if key == "new_state": - self._patch_task_instance_state( + _patch_task_instance_state( + task_id=task_instance_body.task_id, + dag_run_id=self.dag_run_id, dag=dag, task_instance_body=task_instance_body, + session=self.session, data=data, ) elif key == "note": - self._patch_task_instance_note(task_instance_body=task_instance_body, tis=tis) + _patch_task_instance_note( + task_instance_body=task_instance_body, tis=tis, user=self.user + ) results.success.append(task_instance_body.task_id) except ValidationError as e: