diff --git a/chart/files/pod-template-file.kubernetes-helm-yaml b/chart/files/pod-template-file.kubernetes-helm-yaml index a7c779d963576..76d7316ac150e 100644 --- a/chart/files/pod-template-file.kubernetes-helm-yaml +++ b/chart/files/pod-template-file.kubernetes-helm-yaml @@ -106,6 +106,16 @@ spec: env: - name: AIRFLOW__CORE__EXECUTOR value: {{ .Values.executor | quote }} + # Deliver pod-termination signals only to the task supervisor (dumb-init's + # direct child) instead of broadcasting them to the whole process group. + # On graceful pod shutdown the supervisor's warm-shutdown handler then lets + # the running task finish -- the same mechanism Celery workers use -- rather + # than the task subprocess being killed directly by dumb-init's group-wide + # SIGTERM (which could also reach the subprocess before it installs its own + # signal handler). Hard kills (heartbeat loss / overtime / the post-grace + # SIGKILL) are unaffected. + - name: DUMB_INIT_SETSID + value: "0" {{- if or .Values.workers.kubernetes.kerberosSidecar.enabled .Values.workers.kubernetes.kerberosInitContainer.enabled }} - name: KRB5_CONFIG value: {{ .Values.kerberos.configPath | quote }} diff --git a/chart/newsfragments/69034.significant.rst b/chart/newsfragments/69034.significant.rst new file mode 100644 index 0000000000000..800d4c8d68ca6 --- /dev/null +++ b/chart/newsfragments/69034.significant.rst @@ -0,0 +1,7 @@ +Default ``DUMB_INIT_SETSID`` changed to ``"0"`` for KubernetesExecutor task pods. + +Pod-termination signals (e.g. SIGTERM on graceful shutdown) are now delivered only to the +task supervisor (``dumb-init``'s direct child) instead of being broadcast to the whole +process group. This lets a running task finish via the supervisor's warm-shutdown handler -- +the same behaviour Celery worker pods already had -- rather than the task subprocess being +killed directly. Hard kills (heartbeat loss / overtime / the post-grace SIGKILL) are unaffected. diff --git a/chart/tests/helm_tests/airflow_aux/test_pod_template_file.py b/chart/tests/helm_tests/airflow_aux/test_pod_template_file.py index 95ea707f0a4d5..34d03fe7762d6 100644 --- a/chart/tests/helm_tests/airflow_aux/test_pod_template_file.py +++ b/chart/tests/helm_tests/airflow_aux/test_pod_template_file.py @@ -1111,6 +1111,17 @@ def test_should_add_extraEnvs(self): "valueFrom": {"configMapKeyRef": {"name": "my-config-map", "key": "my-key"}}, } in jmespath.search("spec.containers[0].env", docs[0]) + def test_should_set_dumb_init_setsid_for_warm_shutdown(self): + """Pod-termination signals must reach only the supervisor so a running task can warm-shut-down.""" + docs = render_chart( + show_only=["templates/pod-template-file.yaml"], + chart_dir=self.temp_chart_dir, + ) + + assert {"name": "DUMB_INIT_SETSID", "value": "0"} in jmespath.search( + "spec.containers[0].env", docs[0] + ) + def test_should_add_component_specific_labels(self): docs = render_chart( values={ diff --git a/docker-stack-docs/entrypoint.rst b/docker-stack-docs/entrypoint.rst index 9835b154083fd..e64077d6dddf7 100644 --- a/docker-stack-docs/entrypoint.rst +++ b/docker-stack-docs/entrypoint.rst @@ -300,6 +300,15 @@ The table below summarizes ``DUMB_INIT_SETSID`` possible values and their use ca | | If you are running it through ``["bash", "-c"]`` command, | | | you need to start the worker via ``exec airflow celery worker`` | | | as the last command executed. | +| | | +| | The same applies to KubernetesExecutor task pods. Here ``dumb-init`` | +| | runs as the init process and its direct child is the task | +| | *supervisor*, which supervises a single task subprocess. Setting the | +| | variable to 0 propagates a graceful SIGTERM only to the supervisor, | +| | which then performs a warm shutdown and waits for the running task | +| | to finish, instead of the signal being broadcast to the whole | +| | process group and killing the task subprocess directly. The Airflow | +| | Helm chart sets this on the KubernetesExecutor pod template for you. | +----------------+----------------------------------------------------------------------+ Additional quick test options diff --git a/task-sdk/src/airflow/sdk/execution_time/coordinator.py b/task-sdk/src/airflow/sdk/execution_time/coordinator.py index 55d1654a7feb4..1469b27888c51 100644 --- a/task-sdk/src/airflow/sdk/execution_time/coordinator.py +++ b/task-sdk/src/airflow/sdk/execution_time/coordinator.py @@ -40,6 +40,8 @@ import contextlib import functools +import os +import signal from typing import TYPE_CHECKING, Any import attrs @@ -50,7 +52,7 @@ from airflow.sdk.configuration import conf if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Generator, Mapping from os import PathLike from structlog.typing import FilteringBoundLogger @@ -114,6 +116,43 @@ class _CoordinatorSpec(pydantic.BaseModel): extra: dict[str, Any] | None = None +@contextlib.contextmanager +def _warm_shutdown_signals() -> Generator[None, None, None]: + """ + Install SIGTERM/SIGINT warm-shutdown handlers for the duration of task supervision. + + While supervising a task the supervisor must not be torn down by a + termination signal; instead it keeps running so the task can finish (or be + shut down gracefully) and its terminal state and logs are reported. The + handlers are installed around BOTH ``start()`` (which transitions the TI to + RUNNING) and ``wait()`` (which runs the task and then reports the terminal + state / uploads logs), so there is no window where Python's default SIGTERM + disposition could kill the supervisor and tear the just-started task down + with it. + + The previous dispositions are restored on exit so a long-lived supervisor + process (e.g. a reused Celery prefork worker) does not leak the handler into + later tasks or clobber the worker's own signal handling. + """ + + def _warm_shutdown(signum, frame): + log.info( + "Received signal; warm shutdown in progress, waiting for the running task to complete.", + signal=signal.Signals(signum).name, + pid=os.getpid(), + ) + + prev_sigterm = signal.getsignal(signal.SIGTERM) + prev_sigint = signal.getsignal(signal.SIGINT) + signal.signal(signal.SIGTERM, _warm_shutdown) + signal.signal(signal.SIGINT, _warm_shutdown) + try: + yield + finally: + signal.signal(signal.SIGTERM, prev_sigterm) + signal.signal(signal.SIGINT, prev_sigint) + + class _PythonCoordinator(BaseCoordinator): """ Coordinator implementation to execute Python tasks. @@ -140,17 +179,22 @@ def execute_task( # process handling. from airflow.sdk.execution_time.supervisor import ActivitySubprocess - process = ActivitySubprocess.start( - dag_rel_path=dag_rel_path, - what=what, - client=client, - logger=logger, - bundle_info=bundle_info, - subprocess_logs_to_stdout=subprocess_logs_to_stdout, - sentry_integration=sentry_integration, - ) - exit_code = process.wait() - return self.ExecutionResult(exit_code, process.final_state) + # Keep the warm-shutdown handlers installed across both start() (which + # transitions the TI to RUNNING) and wait() (which runs the task and + # reports its terminal state / uploads logs) so a SIGTERM at any point + # in this window can't kill the supervisor and tear the task down. + with _warm_shutdown_signals(): + process = ActivitySubprocess.start( + dag_rel_path=dag_rel_path, + what=what, + client=client, + logger=logger, + bundle_info=bundle_info, + subprocess_logs_to_stdout=subprocess_logs_to_stdout, + sentry_integration=sentry_integration, + ) + exit_code = process.wait() + return self.ExecutionResult(exit_code, process.final_state) @functools.cache diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 76dfd2009ac18..02124bb09e2ec 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1392,31 +1392,10 @@ def wait(self) -> int: if self._exit_code is not None: return self._exit_code - # Forward termination signals to the task subprocess so the operator's - # on_kill() hook runs on graceful shutdown (e.g. K8s pod SIGTERM). - # Without this the supervisor exits on SIGTERM without notifying the - # child, leaving spawned resources (pods, subprocesses, etc.) running. - prev_sigterm = signal.getsignal(signal.SIGTERM) - prev_sigint = signal.getsignal(signal.SIGINT) - - def _forward_signal(signum, frame): - log.info( - "Received signal, forwarding to task subprocess", - signal=signal.Signals(signum).name, - pid=self.pid, - ) - with suppress(ProcessLookupError): - os.kill(self.pid, signum) - - signal.signal(signal.SIGTERM, _forward_signal) - signal.signal(signal.SIGINT, _forward_signal) - try: self._monitor_subprocess() finally: self.selector.close() - signal.signal(signal.SIGTERM, prev_sigterm) - signal.signal(signal.SIGINT, prev_sigint) # self._monitor_subprocess() will set the exit code when the process has finished # If it hasn't, assume it's failed diff --git a/task-sdk/tests/task_sdk/dags/signal_forward_test.py b/task-sdk/tests/task_sdk/dags/signal_forward_test.py deleted file mode 100644 index ff85eda41a646..0000000000000 --- a/task-sdk/tests/task_sdk/dags/signal_forward_test.py +++ /dev/null @@ -1,44 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import os -import signal -import time - -from airflow.sdk.bases.operator import BaseOperator -from airflow.sdk.definitions.dag import dag - - -class SignalForwardOperator(BaseOperator): - """Send SIGTERM to the supervisor parent process to exercise signal forwarding.""" - - def execute(self, context): - print("EXECUTE_STARTED", flush=True) - os.kill(os.getppid(), signal.SIGTERM) - time.sleep(2) - - def on_kill(self) -> None: - print("ON_KILL_CALLED_VIA_SIGNAL_FORWARDING", flush=True) - - -@dag() -def signal_forward_test(): - SignalForwardOperator(task_id="signal_task") - - -signal_forward_test() diff --git a/task-sdk/tests/task_sdk/execution_time/test_coordinator.py b/task-sdk/tests/task_sdk/execution_time/test_coordinator.py index 40db5378afccf..62ee8fb4471bd 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_coordinator.py +++ b/task-sdk/tests/task_sdk/execution_time/test_coordinator.py @@ -19,6 +19,9 @@ from __future__ import annotations import json +import os +import signal +from unittest import mock import pytest @@ -28,6 +31,7 @@ BaseCoordinator, CoordinatorManager, _PythonCoordinator, + _warm_shutdown_signals, get_coordinator_manager, reset_coordinator_manager, ) @@ -229,3 +233,136 @@ def test_every_example_coordinator_constructs(self, sdk_config): for queue, key in queue_to_coordinator.items(): coordinator = manager.for_queue(queue) assert isinstance(coordinator, import_string(specs[key]["classpath"])) + + +class TestWarmShutdownSignals: + """Tests for the warm-shutdown signal handling that wraps task supervision.""" + + @pytest.fixture(autouse=True) + def _restore_disposition(self): + """Guarantee SIGTERM/SIGINT dispositions are restored even if a test leaks one.""" + original_term = signal.getsignal(signal.SIGTERM) + original_int = signal.getsignal(signal.SIGINT) + yield + signal.signal(signal.SIGTERM, original_term) + signal.signal(signal.SIGINT, original_int) + + def test_installs_handlers_inside_context(self): + """While the context is active a warm-shutdown handler is installed for both signals.""" + sentinel_term = signal.getsignal(signal.SIGTERM) + sentinel_int = signal.getsignal(signal.SIGINT) + + with _warm_shutdown_signals(): + inside_term = signal.getsignal(signal.SIGTERM) + inside_int = signal.getsignal(signal.SIGINT) + + assert callable(inside_term) + assert callable(inside_int) + # The installed handler is the warm-shutdown closure, not the previous disposition. + assert inside_term is not sentinel_term + assert inside_int is not sentinel_int + # Both signals share the same warm-shutdown closure. + assert inside_term is inside_int + + def test_restores_previous_dispositions_on_exit(self): + """The exact previous dispositions are restored when the context exits normally.""" + + def _prev_term(signum, frame): # pragma: no cover - never invoked + pass + + def _prev_int(signum, frame): # pragma: no cover - never invoked + pass + + signal.signal(signal.SIGTERM, _prev_term) + signal.signal(signal.SIGINT, _prev_int) + + with _warm_shutdown_signals(): + pass + + assert signal.getsignal(signal.SIGTERM) is _prev_term + assert signal.getsignal(signal.SIGINT) is _prev_int + + def test_restores_previous_dispositions_on_exception(self): + """Dispositions are restored even if the wrapped body raises.""" + + def _prev_term(signum, frame): # pragma: no cover - never invoked + pass + + signal.signal(signal.SIGTERM, _prev_term) + + with pytest.raises(RuntimeError, match="boom"), _warm_shutdown_signals(): + raise RuntimeError("boom") + + assert signal.getsignal(signal.SIGTERM) is _prev_term + + def test_sigterm_inside_context_does_not_kill(self): + """ + A SIGTERM delivered while supervising must be swallowed, not kill the process. + + This is the regression guard: with the default SIGTERM disposition (SIG_DFL) + in place as the *previous* handler, sending SIGTERM to ourselves would + terminate the process. The warm-shutdown handler installed by the context + manager must absorb it so the running task is allowed to finish. + """ + # Make the pre-context disposition the default so a missing warm-shutdown + # handler would actually kill this process (and fail the test by dying). + signal.signal(signal.SIGTERM, signal.SIG_DFL) + + reached_after_signal = False + with _warm_shutdown_signals(): + os.kill(os.getpid(), signal.SIGTERM) + # If the handler did not absorb the signal, we never get here. + reached_after_signal = True + + assert reached_after_signal + # And the default disposition is put back afterwards. + assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL + + +class TestPythonCoordinatorWarmShutdown: + """The Python coordinator must wrap start() and wait() in the warm-shutdown handlers.""" + + def test_execute_task_wraps_start_and_wait(self, monkeypatch): + """ + Handlers are installed for the whole start()+wait() window and restored after. + + Capturing the SIGTERM disposition at the moment ``start()`` and ``wait()`` + run proves the handler spans the RUNNING transition (start) and the + terminal-state report / log upload (wait), with no window left uncovered. + """ + original_term = signal.getsignal(signal.SIGTERM) + captured: dict[str, object] = {} + + class _FakeProcess: + final_state = "success" + + def wait(self_inner): + captured["wait"] = signal.getsignal(signal.SIGTERM) + return 0 + + def _fake_start(*args, **kwargs): + captured["start"] = signal.getsignal(signal.SIGTERM) + return _FakeProcess() + + import airflow.sdk.execution_time.supervisor as supervisor_mod + + monkeypatch.setattr(supervisor_mod.ActivitySubprocess, "start", staticmethod(_fake_start)) + + coordinator = _PythonCoordinator() + result = coordinator.execute_task( + what=mock.MagicMock(), + dag_rel_path="some_dag.py", + bundle_info=mock.MagicMock(), + client=mock.MagicMock(), + subprocess_logs_to_stdout=False, + ) + + assert result.exit_code == 0 + assert result.final_state == "success" + # During both start() and wait() a warm-shutdown handler was installed... + assert callable(captured["start"]) + assert callable(captured["wait"]) + assert captured["start"] is not original_term + assert captured["start"] is captured["wait"] + # ...and the original disposition is restored once execute_task returns. + assert signal.getsignal(signal.SIGTERM) is original_term diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 8dea3d0793f49..55fce707097d6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -27,9 +27,8 @@ import socket import subprocess import sys -import threading import time -from contextlib import nullcontext, suppress +from contextlib import nullcontext from dataclasses import dataclass, field from datetime import datetime, timezone as dt_timezone from operator import attrgetter @@ -263,52 +262,6 @@ def test_supervise( with expectation: supervise_task(**kw) - def test_on_kill_hook_called_when_supervisor_receives_sigterm( - self, - test_dags_dir, - captured_logs, - client_with_ti_start, - ): - """SIGTERM to the supervisor process is forwarded to the task subprocess.""" - ti = TaskInstance( - id=uuid7(), - task_id="signal_task", - dag_id="signal_forward_test", - run_id="r", - try_number=1, - dag_version_id=uuid7(), - queue="default", - ) - bundle_info = BundleInfo(name="my-bundle", version=None) - - supervisor_pid = os.getpid() - - def _kill_children(): - for child in psutil.Process(supervisor_pid).children(recursive=True): - with suppress(psutil.NoSuchProcess): - child.kill() - - watchdog = threading.Timer(20.0, _kill_children) - watchdog.daemon = True - watchdog.start() - - try: - with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): - supervise_task( - ti=ti, - dag_rel_path="signal_forward_test.py", - token="", - dry_run=True, - client=client_with_ti_start, - bundle_info=bundle_info, - ) - finally: - watchdog.cancel() - - stdout_events = [entry["event"] for entry in captured_logs if entry.get("logger") == "task.stdout"] - assert "EXECUTE_STARTED" in stdout_events - assert "ON_KILL_CALLED_VIA_SIGNAL_FORWARDING" in stdout_events - @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: