diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 781e98ae2f67f..12d64dcf3f7cb 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1391,10 +1391,31 @@ def wait(self) -> int: if self._exit_code is not None: return self._exit_code + # Forward termination signals to the task subprocess so the operator's + # on_kill() hook runs on graceful shutdown (e.g. K8s pod SIGTERM). + # Without this the supervisor exits on SIGTERM without notifying the + # child, leaving spawned resources (pods, subprocesses, etc.) running. + prev_sigterm = signal.getsignal(signal.SIGTERM) + prev_sigint = signal.getsignal(signal.SIGINT) + + def _forward_signal(signum, frame): + log.info( + "Received signal, forwarding to task subprocess", + signal=signal.Signals(signum).name, + pid=self.pid, + ) + with suppress(ProcessLookupError): + os.kill(self.pid, signum) + + signal.signal(signal.SIGTERM, _forward_signal) + signal.signal(signal.SIGINT, _forward_signal) + try: self._monitor_subprocess() finally: self.selector.close() + signal.signal(signal.SIGTERM, prev_sigterm) + signal.signal(signal.SIGINT, prev_sigint) # self._monitor_subprocess() will set the exit code when the process has finished # If it hasn't, assume it's failed diff --git a/task-sdk/tests/task_sdk/dags/signal_forward_test.py b/task-sdk/tests/task_sdk/dags/signal_forward_test.py new file mode 100644 index 0000000000000..ff85eda41a646 --- /dev/null +++ b/task-sdk/tests/task_sdk/dags/signal_forward_test.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import signal +import time + +from airflow.sdk.bases.operator import BaseOperator +from airflow.sdk.definitions.dag import dag + + +class SignalForwardOperator(BaseOperator): + """Send SIGTERM to the supervisor parent process to exercise signal forwarding.""" + + def execute(self, context): + print("EXECUTE_STARTED", flush=True) + os.kill(os.getppid(), signal.SIGTERM) + time.sleep(2) + + def on_kill(self) -> None: + print("ON_KILL_CALLED_VIA_SIGNAL_FORWARDING", flush=True) + + +@dag() +def signal_forward_test(): + SignalForwardOperator(task_id="signal_task") + + +signal_forward_test() diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 798500d381c45..a9ce41bd29a62 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -27,8 +27,9 @@ import socket import subprocess import sys +import threading import time -from contextlib import nullcontext +from contextlib import nullcontext, suppress from dataclasses import dataclass, field from datetime import datetime, timezone as dt_timezone from operator import attrgetter @@ -262,6 +263,54 @@ def test_supervise( with expectation: supervise_task(**kw) + def test_on_kill_hook_called_when_supervisor_receives_sigterm( + self, + test_dags_dir, + captured_logs, + client_with_ti_start, + ): + """SIGTERM to the supervisor process is forwarded to the task subprocess.""" + ti = TaskInstanceDTO( + id=uuid7(), + task_id="signal_task", + dag_id="signal_forward_test", + run_id="r", + try_number=1, + dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, + ) + bundle_info = BundleInfo(name="my-bundle", version=None) + + supervisor_pid = os.getpid() + + def _kill_children(): + for child in psutil.Process(supervisor_pid).children(recursive=True): + with suppress(psutil.NoSuchProcess): + child.kill() + + watchdog = threading.Timer(20.0, _kill_children) + watchdog.daemon = True + watchdog.start() + + try: + with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): + supervise_task( + ti=ti, + dag_rel_path="signal_forward_test.py", + token="", + dry_run=True, + client=client_with_ti_start, + bundle_info=bundle_info, + ) + finally: + watchdog.cancel() + + stdout_events = [entry["event"] for entry in captured_logs if entry.get("logger") == "task.stdout"] + assert "EXECUTE_STARTED" in stdout_events + assert "ON_KILL_CALLED_VIA_SIGNAL_FORWARDING" in stdout_events + @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: