From a33d007abe25d75f33d37008046bb49b09639ec8 Mon Sep 17 00:00:00 2001 From: "bach.ab" Date: Fri, 26 Jun 2026 11:33:56 +0200 Subject: [PATCH] Warm-shutdown supervisor on SIGTERM instead of killing the running task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #61627 installed a SIGTERM/SIGINT handler in `ActivitySubprocess.wait()` that forwards the signal to the task subprocess via `os.kill(self.pid, ...)`, so the operator's `on_kill()` runs when the supervisor is signalled. On KubernetesExecutor this kills the task on every routine pod shutdown. `dumb-init` runs as PID 1 with DUMB_INIT_SETSID=1 (the image default), so it broadcasts SIGTERM to the whole process group — the task subprocess already receives it directly. After #61627 the supervisor forwards it a second time, so the task gets SIGTERM twice; either delivery makes the child handler call `on_kill()` and abort the task the moment the pod starts terminating. That is wrong for a graceful pod shutdown (spot interruption, scale-down, rolling update). For operators that release external resources in `on_kill()` (Databricks tearing down a job cluster, ADF cancelling a pipeline, etc.) this destroys in-flight work we want to let finish. This change makes the task warm-shut-down instead, in two parts: 1. supervisor: replace the forward-and-kill handler with a *warm-shutdown* handler that does NOT signal the task. It logs the received signal and keeps the supervisor in its monitoring loop so the running task continues to completion. The handler is installed in `_PythonCoordinator.execute_task` around both `start()` (the RUNNING transition) and `wait()` (run + terminal-state report + log upload), so no window is left where the default SIGTERM disposition could tear the supervisor down. Previous dispositions are restored on exit so a reused Celery prefork worker doesn't leak the handler. 2. KubernetesExecutor pod template: set DUMB_INIT_SETSID=0, so dumb-init delivers pod-termination signals only to its direct child (the supervisor), not to the whole process group. Without this the task subprocess would still receive the group-broadcast SIGTERM directly and call `on_kill()`, defeating the warm shutdown. This mirrors what Celery worker pods already do. `on_kill()` is still reached through the legitimate kill paths (heartbeat-failure / server-terminated kills, overtime termination, or an explicit success/failure transition) — all of which already drive `self.kill(...)` and are unchanged. --- .../pod-template-file.kubernetes-helm-yaml | 10 ++ chart/newsfragments/69034.significant.rst | 7 + .../airflow_aux/test_pod_template_file.py | 11 ++ docker-stack-docs/entrypoint.rst | 9 ++ .../airflow/sdk/execution_time/coordinator.py | 68 +++++++-- .../airflow/sdk/execution_time/supervisor.py | 21 --- .../task_sdk/dags/signal_forward_test.py | 44 ------ .../execution_time/test_coordinator.py | 137 ++++++++++++++++++ .../execution_time/test_supervisor.py | 49 +------ 9 files changed, 231 insertions(+), 125 deletions(-) create mode 100644 chart/newsfragments/69034.significant.rst delete mode 100644 task-sdk/tests/task_sdk/dags/signal_forward_test.py diff --git a/chart/files/pod-template-file.kubernetes-helm-yaml b/chart/files/pod-template-file.kubernetes-helm-yaml index a7c779d963576..76d7316ac150e 100644 --- a/chart/files/pod-template-file.kubernetes-helm-yaml +++ b/chart/files/pod-template-file.kubernetes-helm-yaml @@ -106,6 +106,16 @@ spec: env: - name: AIRFLOW__CORE__EXECUTOR value: {{ .Values.executor | quote }} + # Deliver pod-termination signals only to the task supervisor (dumb-init's + # direct child) instead of broadcasting them to the whole process group. + # On graceful pod shutdown the supervisor's warm-shutdown handler then lets + # the running task finish -- the same mechanism Celery workers use -- rather + # than the task subprocess being killed directly by dumb-init's group-wide + # SIGTERM (which could also reach the subprocess before it installs its own + # signal handler). Hard kills (heartbeat loss / overtime / the post-grace + # SIGKILL) are unaffected. + - name: DUMB_INIT_SETSID + value: "0" {{- if or .Values.workers.kubernetes.kerberosSidecar.enabled .Values.workers.kubernetes.kerberosInitContainer.enabled }} - name: KRB5_CONFIG value: {{ .Values.kerberos.configPath | quote }} diff --git a/chart/newsfragments/69034.significant.rst b/chart/newsfragments/69034.significant.rst new file mode 100644 index 0000000000000..800d4c8d68ca6 --- /dev/null +++ b/chart/newsfragments/69034.significant.rst @@ -0,0 +1,7 @@ +Default ``DUMB_INIT_SETSID`` changed to ``"0"`` for KubernetesExecutor task pods. + +Pod-termination signals (e.g. SIGTERM on graceful shutdown) are now delivered only to the +task supervisor (``dumb-init``'s direct child) instead of being broadcast to the whole +process group. This lets a running task finish via the supervisor's warm-shutdown handler -- +the same behaviour Celery worker pods already had -- rather than the task subprocess being +killed directly. Hard kills (heartbeat loss / overtime / the post-grace SIGKILL) are unaffected. diff --git a/chart/tests/helm_tests/airflow_aux/test_pod_template_file.py b/chart/tests/helm_tests/airflow_aux/test_pod_template_file.py index 95ea707f0a4d5..34d03fe7762d6 100644 --- a/chart/tests/helm_tests/airflow_aux/test_pod_template_file.py +++ b/chart/tests/helm_tests/airflow_aux/test_pod_template_file.py @@ -1111,6 +1111,17 @@ def test_should_add_extraEnvs(self): "valueFrom": {"configMapKeyRef": {"name": "my-config-map", "key": "my-key"}}, } in jmespath.search("spec.containers[0].env", docs[0]) + def test_should_set_dumb_init_setsid_for_warm_shutdown(self): + """Pod-termination signals must reach only the supervisor so a running task can warm-shut-down.""" + docs = render_chart( + show_only=["templates/pod-template-file.yaml"], + chart_dir=self.temp_chart_dir, + ) + + assert {"name": "DUMB_INIT_SETSID", "value": "0"} in jmespath.search( + "spec.containers[0].env", docs[0] + ) + def test_should_add_component_specific_labels(self): docs = render_chart( values={ diff --git a/docker-stack-docs/entrypoint.rst b/docker-stack-docs/entrypoint.rst index 9835b154083fd..e64077d6dddf7 100644 --- a/docker-stack-docs/entrypoint.rst +++ b/docker-stack-docs/entrypoint.rst @@ -300,6 +300,15 @@ The table below summarizes ``DUMB_INIT_SETSID`` possible values and their use ca | | If you are running it through ``["bash", "-c"]`` command, | | | you need to start the worker via ``exec airflow celery worker`` | | | as the last command executed. | +| | | +| | The same applies to KubernetesExecutor task pods. Here ``dumb-init`` | +| | runs as the init process and its direct child is the task | +| | *supervisor*, which supervises a single task subprocess. Setting the | +| | variable to 0 propagates a graceful SIGTERM only to the supervisor, | +| | which then performs a warm shutdown and waits for the running task | +| | to finish, instead of the signal being broadcast to the whole | +| | process group and killing the task subprocess directly. The Airflow | +| | Helm chart sets this on the KubernetesExecutor pod template for you. | +----------------+----------------------------------------------------------------------+ Additional quick test options diff --git a/task-sdk/src/airflow/sdk/execution_time/coordinator.py b/task-sdk/src/airflow/sdk/execution_time/coordinator.py index 55d1654a7feb4..1469b27888c51 100644 --- a/task-sdk/src/airflow/sdk/execution_time/coordinator.py +++ b/task-sdk/src/airflow/sdk/execution_time/coordinator.py @@ -40,6 +40,8 @@ import contextlib import functools +import os +import signal from typing import TYPE_CHECKING, Any import attrs @@ -50,7 +52,7 @@ from airflow.sdk.configuration import conf if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Generator, Mapping from os import PathLike from structlog.typing import FilteringBoundLogger @@ -114,6 +116,43 @@ class _CoordinatorSpec(pydantic.BaseModel): extra: dict[str, Any] | None = None +@contextlib.contextmanager +def _warm_shutdown_signals() -> Generator[None, None, None]: + """ + Install SIGTERM/SIGINT warm-shutdown handlers for the duration of task supervision. + + While supervising a task the supervisor must not be torn down by a + termination signal; instead it keeps running so the task can finish (or be + shut down gracefully) and its terminal state and logs are reported. The + handlers are installed around BOTH ``start()`` (which transitions the TI to + RUNNING) and ``wait()`` (which runs the task and then reports the terminal + state / uploads logs), so there is no window where Python's default SIGTERM + disposition could kill the supervisor and tear the just-started task down + with it. + + The previous dispositions are restored on exit so a long-lived supervisor + process (e.g. a reused Celery prefork worker) does not leak the handler into + later tasks or clobber the worker's own signal handling. + """ + + def _warm_shutdown(signum, frame): + log.info( + "Received signal; warm shutdown in progress, waiting for the running task to complete.", + signal=signal.Signals(signum).name, + pid=os.getpid(), + ) + + prev_sigterm = signal.getsignal(signal.SIGTERM) + prev_sigint = signal.getsignal(signal.SIGINT) + signal.signal(signal.SIGTERM, _warm_shutdown) + signal.signal(signal.SIGINT, _warm_shutdown) + try: + yield + finally: + signal.signal(signal.SIGTERM, prev_sigterm) + signal.signal(signal.SIGINT, prev_sigint) + + class _PythonCoordinator(BaseCoordinator): """ Coordinator implementation to execute Python tasks. @@ -140,17 +179,22 @@ def execute_task( # process handling. from airflow.sdk.execution_time.supervisor import ActivitySubprocess - process = ActivitySubprocess.start( - dag_rel_path=dag_rel_path, - what=what, - client=client, - logger=logger, - bundle_info=bundle_info, - subprocess_logs_to_stdout=subprocess_logs_to_stdout, - sentry_integration=sentry_integration, - ) - exit_code = process.wait() - return self.ExecutionResult(exit_code, process.final_state) + # Keep the warm-shutdown handlers installed across both start() (which + # transitions the TI to RUNNING) and wait() (which runs the task and + # reports its terminal state / uploads logs) so a SIGTERM at any point + # in this window can't kill the supervisor and tear the task down. + with _warm_shutdown_signals(): + process = ActivitySubprocess.start( + dag_rel_path=dag_rel_path, + what=what, + client=client, + logger=logger, + bundle_info=bundle_info, + subprocess_logs_to_stdout=subprocess_logs_to_stdout, + sentry_integration=sentry_integration, + ) + exit_code = process.wait() + return self.ExecutionResult(exit_code, process.final_state) @functools.cache diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 76dfd2009ac18..02124bb09e2ec 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1392,31 +1392,10 @@ def wait(self) -> int: if self._exit_code is not None: return self._exit_code - # Forward termination signals to the task subprocess so the operator's - # on_kill() hook runs on graceful shutdown (e.g. K8s pod SIGTERM). - # Without this the supervisor exits on SIGTERM without notifying the - # child, leaving spawned resources (pods, subprocesses, etc.) running. - prev_sigterm = signal.getsignal(signal.SIGTERM) - prev_sigint = signal.getsignal(signal.SIGINT) - - def _forward_signal(signum, frame): - log.info( - "Received signal, forwarding to task subprocess", - signal=signal.Signals(signum).name, - pid=self.pid, - ) - with suppress(ProcessLookupError): - os.kill(self.pid, signum) - - signal.signal(signal.SIGTERM, _forward_signal) - signal.signal(signal.SIGINT, _forward_signal) - try: self._monitor_subprocess() finally: self.selector.close() - signal.signal(signal.SIGTERM, prev_sigterm) - signal.signal(signal.SIGINT, prev_sigint) # self._monitor_subprocess() will set the exit code when the process has finished # If it hasn't, assume it's failed diff --git a/task-sdk/tests/task_sdk/dags/signal_forward_test.py b/task-sdk/tests/task_sdk/dags/signal_forward_test.py deleted file mode 100644 index ff85eda41a646..0000000000000 --- a/task-sdk/tests/task_sdk/dags/signal_forward_test.py +++ /dev/null @@ -1,44 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import os -import signal -import time - -from airflow.sdk.bases.operator import BaseOperator -from airflow.sdk.definitions.dag import dag - - -class SignalForwardOperator(BaseOperator): - """Send SIGTERM to the supervisor parent process to exercise signal forwarding.""" - - def execute(self, context): - print("EXECUTE_STARTED", flush=True) - os.kill(os.getppid(), signal.SIGTERM) - time.sleep(2) - - def on_kill(self) -> None: - print("ON_KILL_CALLED_VIA_SIGNAL_FORWARDING", flush=True) - - -@dag() -def signal_forward_test(): - SignalForwardOperator(task_id="signal_task") - - -signal_forward_test() diff --git a/task-sdk/tests/task_sdk/execution_time/test_coordinator.py b/task-sdk/tests/task_sdk/execution_time/test_coordinator.py index 40db5378afccf..62ee8fb4471bd 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_coordinator.py +++ b/task-sdk/tests/task_sdk/execution_time/test_coordinator.py @@ -19,6 +19,9 @@ from __future__ import annotations import json +import os +import signal +from unittest import mock import pytest @@ -28,6 +31,7 @@ BaseCoordinator, CoordinatorManager, _PythonCoordinator, + _warm_shutdown_signals, get_coordinator_manager, reset_coordinator_manager, ) @@ -229,3 +233,136 @@ def test_every_example_coordinator_constructs(self, sdk_config): for queue, key in queue_to_coordinator.items(): coordinator = manager.for_queue(queue) assert isinstance(coordinator, import_string(specs[key]["classpath"])) + + +class TestWarmShutdownSignals: + """Tests for the warm-shutdown signal handling that wraps task supervision.""" + + @pytest.fixture(autouse=True) + def _restore_disposition(self): + """Guarantee SIGTERM/SIGINT dispositions are restored even if a test leaks one.""" + original_term = signal.getsignal(signal.SIGTERM) + original_int = signal.getsignal(signal.SIGINT) + yield + signal.signal(signal.SIGTERM, original_term) + signal.signal(signal.SIGINT, original_int) + + def test_installs_handlers_inside_context(self): + """While the context is active a warm-shutdown handler is installed for both signals.""" + sentinel_term = signal.getsignal(signal.SIGTERM) + sentinel_int = signal.getsignal(signal.SIGINT) + + with _warm_shutdown_signals(): + inside_term = signal.getsignal(signal.SIGTERM) + inside_int = signal.getsignal(signal.SIGINT) + + assert callable(inside_term) + assert callable(inside_int) + # The installed handler is the warm-shutdown closure, not the previous disposition. + assert inside_term is not sentinel_term + assert inside_int is not sentinel_int + # Both signals share the same warm-shutdown closure. + assert inside_term is inside_int + + def test_restores_previous_dispositions_on_exit(self): + """The exact previous dispositions are restored when the context exits normally.""" + + def _prev_term(signum, frame): # pragma: no cover - never invoked + pass + + def _prev_int(signum, frame): # pragma: no cover - never invoked + pass + + signal.signal(signal.SIGTERM, _prev_term) + signal.signal(signal.SIGINT, _prev_int) + + with _warm_shutdown_signals(): + pass + + assert signal.getsignal(signal.SIGTERM) is _prev_term + assert signal.getsignal(signal.SIGINT) is _prev_int + + def test_restores_previous_dispositions_on_exception(self): + """Dispositions are restored even if the wrapped body raises.""" + + def _prev_term(signum, frame): # pragma: no cover - never invoked + pass + + signal.signal(signal.SIGTERM, _prev_term) + + with pytest.raises(RuntimeError, match="boom"), _warm_shutdown_signals(): + raise RuntimeError("boom") + + assert signal.getsignal(signal.SIGTERM) is _prev_term + + def test_sigterm_inside_context_does_not_kill(self): + """ + A SIGTERM delivered while supervising must be swallowed, not kill the process. + + This is the regression guard: with the default SIGTERM disposition (SIG_DFL) + in place as the *previous* handler, sending SIGTERM to ourselves would + terminate the process. The warm-shutdown handler installed by the context + manager must absorb it so the running task is allowed to finish. + """ + # Make the pre-context disposition the default so a missing warm-shutdown + # handler would actually kill this process (and fail the test by dying). + signal.signal(signal.SIGTERM, signal.SIG_DFL) + + reached_after_signal = False + with _warm_shutdown_signals(): + os.kill(os.getpid(), signal.SIGTERM) + # If the handler did not absorb the signal, we never get here. + reached_after_signal = True + + assert reached_after_signal + # And the default disposition is put back afterwards. + assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL + + +class TestPythonCoordinatorWarmShutdown: + """The Python coordinator must wrap start() and wait() in the warm-shutdown handlers.""" + + def test_execute_task_wraps_start_and_wait(self, monkeypatch): + """ + Handlers are installed for the whole start()+wait() window and restored after. + + Capturing the SIGTERM disposition at the moment ``start()`` and ``wait()`` + run proves the handler spans the RUNNING transition (start) and the + terminal-state report / log upload (wait), with no window left uncovered. + """ + original_term = signal.getsignal(signal.SIGTERM) + captured: dict[str, object] = {} + + class _FakeProcess: + final_state = "success" + + def wait(self_inner): + captured["wait"] = signal.getsignal(signal.SIGTERM) + return 0 + + def _fake_start(*args, **kwargs): + captured["start"] = signal.getsignal(signal.SIGTERM) + return _FakeProcess() + + import airflow.sdk.execution_time.supervisor as supervisor_mod + + monkeypatch.setattr(supervisor_mod.ActivitySubprocess, "start", staticmethod(_fake_start)) + + coordinator = _PythonCoordinator() + result = coordinator.execute_task( + what=mock.MagicMock(), + dag_rel_path="some_dag.py", + bundle_info=mock.MagicMock(), + client=mock.MagicMock(), + subprocess_logs_to_stdout=False, + ) + + assert result.exit_code == 0 + assert result.final_state == "success" + # During both start() and wait() a warm-shutdown handler was installed... + assert callable(captured["start"]) + assert callable(captured["wait"]) + assert captured["start"] is not original_term + assert captured["start"] is captured["wait"] + # ...and the original disposition is restored once execute_task returns. + assert signal.getsignal(signal.SIGTERM) is original_term diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 8dea3d0793f49..55fce707097d6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -27,9 +27,8 @@ import socket import subprocess import sys -import threading import time -from contextlib import nullcontext, suppress +from contextlib import nullcontext from dataclasses import dataclass, field from datetime import datetime, timezone as dt_timezone from operator import attrgetter @@ -263,52 +262,6 @@ def test_supervise( with expectation: supervise_task(**kw) - def test_on_kill_hook_called_when_supervisor_receives_sigterm( - self, - test_dags_dir, - captured_logs, - client_with_ti_start, - ): - """SIGTERM to the supervisor process is forwarded to the task subprocess.""" - ti = TaskInstance( - id=uuid7(), - task_id="signal_task", - dag_id="signal_forward_test", - run_id="r", - try_number=1, - dag_version_id=uuid7(), - queue="default", - ) - bundle_info = BundleInfo(name="my-bundle", version=None) - - supervisor_pid = os.getpid() - - def _kill_children(): - for child in psutil.Process(supervisor_pid).children(recursive=True): - with suppress(psutil.NoSuchProcess): - child.kill() - - watchdog = threading.Timer(20.0, _kill_children) - watchdog.daemon = True - watchdog.start() - - try: - with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): - supervise_task( - ti=ti, - dag_rel_path="signal_forward_test.py", - token="", - dry_run=True, - client=client_with_ti_start, - bundle_info=bundle_info, - ) - finally: - watchdog.cancel() - - stdout_events = [entry["event"] for entry in captured_logs if entry.get("logger") == "task.stdout"] - assert "EXECUTE_STARTED" in stdout_events - assert "ON_KILL_CALLED_VIA_SIGNAL_FORWARDING" in stdout_events - @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: