From 1b9ef4f23cce5e596c30cf8a72cb767af05abbae Mon Sep 17 00:00:00 2001 From: shaealh Date: Sun, 26 Apr 2026 19:24:19 -0700 Subject: [PATCH 1/3] Fix scheduler trigger deadlock for deferrable tasks --- .../src/airflow/jobs/scheduler_job_runner.py | 57 +++++++++++++------ airflow-core/src/airflow/models/trigger.py | 40 ++++++++++--- .../tests/unit/jobs/test_scheduler_job.py | 39 +++++++++++++ .../tests/unit/models/test_trigger.py | 33 +++++++++++ 4 files changed, 143 insertions(+), 26 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 67115ba757805..3d47c8ba7dbc9 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -127,6 +127,9 @@ TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT = "stuck in queued reschedule" """:meta private:""" +_TRIGGER_TIMEOUT_BATCH_SIZE = 1000 +"""Maximum number of task instances to lock per trigger-timeout batch.""" + def _eager_load_dag_run_for_validation() -> tuple[LoaderOption, LoaderOption]: """ @@ -2878,25 +2881,45 @@ def check_trigger_timeouts( self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION ) -> None: """Mark any "deferred" task as failed if the trigger or execution timeout has passed.""" - for attempt in run_with_db_retries(max_retries, logger=self.log): - with attempt: - result = session.execute( - update(TI) - .where( - TI.state == TaskInstanceState.DEFERRED, - TI.trigger_timeout < timezone.utcnow(), + while True: + task_instance_ids = [] + for attempt in run_with_db_retries(max_retries, logger=self.log): + with attempt: + now = timezone.utcnow() + candidates = ( + select(TI.id) + .where( + TI.state == TaskInstanceState.DEFERRED, + TI.trigger_timeout < now, + ) + .order_by(TI.id) + .limit(_TRIGGER_TIMEOUT_BATCH_SIZE) ) - .values( - state=TaskInstanceState.SCHEDULED, - next_method=TRIGGER_FAIL_REPR, - next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT}, - scheduled_dttm=timezone.utcnow(), - trigger_id=None, + task_instance_ids = list( + session.scalars( + with_row_locks(candidates, of=TI, session=session, skip_locked=True) + ).all() ) - ) - num_timed_out_tasks = getattr(result, "rowcount", 0) - if num_timed_out_tasks: - self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) + if task_instance_ids: + result = session.execute( + update(TI) + .where(TI.id.in_(task_instance_ids)) + .values( + state=TaskInstanceState.SCHEDULED, + next_method=TRIGGER_FAIL_REPR, + next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT}, + scheduled_dttm=now, + trigger_id=None, + ) + .execution_options(synchronize_session=False) + ) + num_timed_out_tasks = getattr(result, "rowcount", 0) + if num_timed_out_tasks: + self.log.info( + "Timed out %i deferred tasks without fired triggers", num_timed_out_tasks + ) + if len(task_instance_ids) < _TRIGGER_TIMEOUT_BATCH_SIZE: + break # [START find_and_purge_task_instances_without_heartbeats] def _find_and_purge_task_instances_without_heartbeats(self) -> None: diff --git a/airflow-core/src/airflow/models/trigger.py b/airflow-core/src/airflow/models/trigger.py index 9707995288566..b8ab81c06c5dd 100644 --- a/airflow-core/src/airflow/models/trigger.py +++ b/airflow-core/src/airflow/models/trigger.py @@ -57,6 +57,9 @@ log = logging.getLogger(__name__) +_TRIGGER_ID_CLEANUP_BATCH_SIZE = 1000 +"""Maximum number of task instances to lock per trigger-id cleanup batch.""" + class TriggerFailureReason(str, Enum): """ @@ -226,16 +229,35 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None: Triggers have a one-to-many relationship to task instances, so we need to clean those up first. Afterward we can drop the triggers not referenced by anyone. """ - # Update all task instances with trigger IDs that are not DEFERRED to remove them - for attempt in run_with_db_retries(): - with attempt: - session.execute( - update(TaskInstance) - .where( - TaskInstance.state != TaskInstanceState.DEFERRED, TaskInstance.trigger_id.is_not(None) + # Clear task-instance trigger references in primary-key order to avoid locking the same rows in + # a different order than scheduler timeout handling. + while True: + task_instance_ids = [] + for attempt in run_with_db_retries(): + with attempt: + candidates = ( + select(TaskInstance.id) + .where( + TaskInstance.state != TaskInstanceState.DEFERRED, + TaskInstance.trigger_id.is_not(None), + ) + .order_by(TaskInstance.id) + .limit(_TRIGGER_ID_CLEANUP_BATCH_SIZE) + ) + task_instance_ids = list( + session.scalars( + with_row_locks(candidates, of=TaskInstance, session=session, skip_locked=True) + ).all() ) - .values(trigger_id=None) - ) + if task_instance_ids: + session.execute( + update(TaskInstance) + .where(TaskInstance.id.in_(task_instance_ids)) + .values(trigger_id=None) + .execution_options(synchronize_session=False) + ) + if len(task_instance_ids) < _TRIGGER_ID_CLEANUP_BATCH_SIZE: + break # Get all triggers that have no task instances, assets, or callbacks depending on them and delete them ids = ( diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 061a87e3aa420..e664f2df7dba9 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -6831,6 +6831,45 @@ def test_timeout_triggers(self, dag_maker): assert ti1.next_method == "__fail__" assert ti2.state == State.DEFERRED + def test_timeout_triggers_processes_more_than_one_batch(self, dag_maker, monkeypatch): + """Timed-out deferred task instances are all updated when they span multiple batches.""" + import airflow.jobs.scheduler_job_runner as scheduler_job_runner_module + + monkeypatch.setattr(scheduler_job_runner_module, "_TRIGGER_TIMEOUT_BATCH_SIZE", 2) + + session = settings.Session() + with dag_maker( + dag_id="test_timeout_triggers_processes_more_than_one_batch", + start_date=DEFAULT_DATE, + schedule="@once", + max_active_runs=5, + session=session, + ): + EmptyOperator(task_id="dummy1") + + past = timezone.utcnow() - datetime.timedelta(seconds=60) + task_instances = [] + for index in range(5): + dag_run = dag_maker.create_dagrun( + run_id=f"test_batch_{index}", + logical_date=DEFAULT_DATE + datetime.timedelta(seconds=index), + ) + task_instance = dag_run.get_task_instance("dummy1", session) + task_instance.state = State.DEFERRED + task_instance.trigger_timeout = past + task_instances.append(task_instance) + session.flush() + + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job) + + self.job_runner.check_trigger_timeouts(session=session) + + for task_instance in task_instances: + session.refresh(task_instance) + assert task_instance.state == State.SCHEDULED + assert task_instance.next_method == "__fail__" + def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker, testing_dag_bundle, session): """ Tests that it will retry on DB error like deadlock when updating timeout triggers. diff --git a/airflow-core/tests/unit/models/test_trigger.py b/airflow-core/tests/unit/models/test_trigger.py index dfd0f2e99cd0a..b1e4f0b1ffc48 100644 --- a/airflow-core/tests/unit/models/test_trigger.py +++ b/airflow-core/tests/unit/models/test_trigger.py @@ -164,6 +164,39 @@ def test_clean_unused(session, dag_maker): assert {result.id for result in results} == {trigger1.id, trigger4.id, trigger5.id, trigger6.id} +def test_clean_unused_clears_trigger_ids_in_batches(session, dag_maker, monkeypatch): + """Non-deferred task instances have trigger references cleared when they span multiple batches.""" + import airflow.models.trigger as trigger_module + + monkeypatch.setattr(trigger_module, "_TRIGGER_ID_CLEANUP_BATCH_SIZE", 2) + + triggers = [ + Trigger(classpath=f"airflow.triggers.testing.SuccessTrigger{index}", kwargs={}) + for index in range(5) + ] + session.add_all(triggers) + session.flush() + + with dag_maker(session=session, dag_id="test_clean_unused_clears_trigger_ids_in_batches"): + for index in range(5): + EmptyOperator(task_id=f"fake{index}") + + dag_run = dag_maker.create_dagrun(logical_date=timezone.utcnow()) + task_instances = {task_instance.task_id: task_instance for task_instance in dag_run.task_instances} + for index, trigger in enumerate(triggers): + task_instance = task_instances[f"fake{index}"] + task_instance.state = State.SUCCESS + task_instance.trigger_id = trigger.id + session.flush() + + Trigger.clean_unused(session=session) + + for task_instance in task_instances.values(): + session.refresh(task_instance) + assert task_instance.trigger_id is None + assert session.scalar(select(func.count()).select_from(Trigger)) == 0 + + @patch.object(TriggererCallback, "handle_event") def test_submit_event(mock_callback_handle_event, session, create_task_instance): """ From d759b877cd78597ba1b2c82c1d4a0d189656dfd6 Mon Sep 17 00:00:00 2001 From: shaealh Date: Sun, 26 Apr 2026 20:22:57 -0700 Subject: [PATCH 2/3] Format trigger cleanup regression test --- airflow-core/tests/unit/models/test_trigger.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow-core/tests/unit/models/test_trigger.py b/airflow-core/tests/unit/models/test_trigger.py index b1e4f0b1ffc48..6846aa07914b6 100644 --- a/airflow-core/tests/unit/models/test_trigger.py +++ b/airflow-core/tests/unit/models/test_trigger.py @@ -171,8 +171,7 @@ def test_clean_unused_clears_trigger_ids_in_batches(session, dag_maker, monkeypa monkeypatch.setattr(trigger_module, "_TRIGGER_ID_CLEANUP_BATCH_SIZE", 2) triggers = [ - Trigger(classpath=f"airflow.triggers.testing.SuccessTrigger{index}", kwargs={}) - for index in range(5) + Trigger(classpath=f"airflow.triggers.testing.SuccessTrigger{index}", kwargs={}) for index in range(5) ] session.add_all(triggers) session.flush() From f72416c7e7597bec172cce4532f4791ba91eeb48 Mon Sep 17 00:00:00 2001 From: Vitaliy Isarev Date: Tue, 12 May 2026 13:47:35 +0300 Subject: [PATCH 3/3] fix api deadlock --- .../execution_api/routes/task_instances.py | 6 +-- .../versions/head/test_task_instances.py | 42 +++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 13d8245621d63..d465a1c71cfcb 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -36,7 +36,7 @@ from sqlalchemy import and_, func, or_, tuple_, update from sqlalchemy.engine import CursorResult from sqlalchemy.exc import NoResultFound, SQLAlchemyError -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload, lazyload from sqlalchemy.sql import select from structlog.contextvars import bind_contextvars @@ -425,7 +425,7 @@ def ti_update_state( "Error updating Task Instance state. Setting the task to failed.", payload=ti_patch_payload, ) - ti = session.get(TI, task_instance_id, with_for_update=True) + ti = session.get(TI, task_instance_id, options=[lazyload(TI.dag_run)], with_for_update=True) if session.bind is not None: query = TI.duration_expression_update(timezone.utcnow(), query, session.bind) query = query.values(state=(updated_state := TaskInstanceState.FAILED)) @@ -528,7 +528,7 @@ def _create_ti_state_update_query_and_update_state( dag_id: str, ) -> tuple[Update, TaskInstanceState]: if isinstance(ti_patch_payload, (TITerminalStatePayload, TIRetryStatePayload, TISuccessStatePayload)): - ti = session.get(TI, task_instance_id, with_for_update=True) + ti = session.get(TI, task_instance_id, options=[lazyload(TI.dag_run)], with_for_update=True) updated_state = TaskInstanceState(ti_patch_payload.state.value) if session.bind is not None: query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index d26bcf7bfd862..386155748c64e 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -1423,6 +1423,48 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta assert response.status_code == 500 assert response.json()["detail"] == "Database error occurred" + def test_ti_update_state_terminal_does_not_lock_dag_run(self, client, session, create_task_instance): + """ + Regression guard: session.get(TI, pk, with_for_update=True) must use + options=[lazyload(TI.dag_run)] to avoid inadvertently locking dag_run. + + TaskInstance.dag_run has lazy="joined", so without the lazyload override the + ORM emits FOR UPDATE on both task_instance and dag_run. The scheduler holds + a dag_run lock while bulk-updating task_instance rows in + _verify_integrity_if_dag_changed, producing a lock-order inversion deadlock. + """ + from sqlalchemy.orm import lazyload + + ti = create_task_instance( + task_id="test_ti_update_state_no_dag_run_lock", + state=State.RUNNING, + start_date=DEFAULT_START_DATE, + ) + session.commit() + + captured_for_update_calls: list[dict] = [] + real_get = Session.get + + def spy_get(self, entity, ident, **kwargs): + if kwargs.get("with_for_update"): + captured_for_update_calls.append({"entity": entity, "options": kwargs.get("options") or []}) + return real_get(self, entity, ident, **kwargs) + + with mock.patch.object(Session, "get", spy_get): + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={"state": State.SUCCESS, "end_date": DEFAULT_END_DATE.isoformat()}, + ) + assert response.status_code == 204 + + ti_for_update_calls = [c for c in captured_for_update_calls if c["entity"] is TaskInstance] + assert ti_for_update_calls, "Expected at least one session.get(TaskInstance, ..., with_for_update=True)" + for call in ti_for_update_calls: + assert any(isinstance(opt, lazyload) for opt in call["options"]), ( + "session.get(TaskInstance, ..., with_for_update=True) must pass " + "options=[lazyload(TI.dag_run)] to prevent inadvertent dag_run row lock" + ) + @pytest.mark.parametrize("queues_enabled", [False, True]) def test_ti_update_state_to_deferred( self, client, session, create_task_instance, time_machine, queues_enabled: bool