From 34ccad10555db360e8d60541b39709a54d82ba9b Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Sun, 16 Nov 2025 20:56:30 +0900 Subject: [PATCH 1/8] fix local executor issue caused by cow --- .../src/airflow/executors/local_executor.py | 23 +++- airflow-core/src/airflow/utils/gc_utils.py | 44 ++++++++ .../unit/executors/test_local_executor.py | 103 ++++++++++++++---- .../test_local_executor_check_workers.py | 4 +- airflow-core/tests/unit/utils/test_gc_util.py | 72 ++++++++++++ 5 files changed, 218 insertions(+), 28 deletions(-) create mode 100644 airflow-core/src/airflow/utils/gc_utils.py create mode 100644 airflow-core/tests/unit/utils/test_gc_util.py diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 1df402213d39a..1b6272f4426a2 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -36,6 +36,7 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor +from airflow.utils.gc_utils import with_gc_freeze from airflow.utils.state import TaskInstanceState # add logger to parameter of setproctitle to support logging @@ -142,6 +143,7 @@ class LocalExecutor(BaseExecutor): """ is_local: bool = True + is_mp_using_fork: bool = multiprocessing.get_start_method() == "fork" serve_logs: bool = True @@ -163,6 +165,11 @@ def start(self) -> None: # (it looks like an int to python) self._unread_messages = multiprocessing.Value(ctypes.c_uint) + if self.is_mp_using_fork: + # This creates the maximum number of worker processes (parallelism) at once + # to minimize gc freeze/unfreeze cycles when using fork in multiprocessing + self._spawn_workers_with_gc_freeze(self.parallelism) + def _check_workers(self): # Reap any dead workers to_remove = set() @@ -186,9 +193,14 @@ def _check_workers(self): # via `sync()` a few times before the spawned process actually starts picking up messages. Try not to # create too much if num_outstanding and len(self.workers) < self.parallelism: - # This only creates one worker, which is fine as we call this directly after putting a message on - # activity_queue in execute_async - self._spawn_worker() + if self.is_mp_using_fork: + # This creates the maximum number of worker processes at once + # to minimize gc freeze/unfreeze cycles when using fork in multiprocessing + self._spawn_workers_with_gc_freeze(self.parallelism - len(self.workers)) + else: + # This only creates one worker, which is fine as we call this directly after putting a message on + # activity_queue in execute_async when using spawn in multiprocessing + self._spawn_worker() def _spawn_worker(self): p = multiprocessing.Process( @@ -205,6 +217,11 @@ def _spawn_worker(self): assert p.pid # Since we've called start self.workers[p.pid] = p + @with_gc_freeze + def _spawn_workers_with_gc_freeze(self, spawn_number): + for _ in range(spawn_number): + self._spawn_worker() + def sync(self) -> None: """Sync will get called periodically by the heartbeat method.""" self._read_results() diff --git a/airflow-core/src/airflow/utils/gc_utils.py b/airflow-core/src/airflow/utils/gc_utils.py new file mode 100644 index 0000000000000..70acdf29644c0 --- /dev/null +++ b/airflow-core/src/airflow/utils/gc_utils.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import gc +from functools import wraps + + +def with_gc_freeze(func): + """ + Freeze the GC before executing the function and unfreeze it after execution. + + This is done to prevent memory increase due to COW (Copy-on-Write) by moving all + existing objects to the permanent generation before forking the process. After the + function executes, unfreeze is called to ensure there is no impact on gc operations + in the original running process. + + Ref: https://docs.python.org/3/library/gc.html#gc.freeze + """ + + @wraps(func) + def wrapper(*args, **kwargs): + gc.freeze() + try: + return func(*args, **kwargs) + finally: + gc.unfreeze() + + return wrapper diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index b0261baf04e13..18e6c679e25db 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import gc import multiprocessing import os from unittest import mock @@ -55,6 +56,18 @@ def test_is_local_default_value(self): def test_serve_logs_default_value(self): assert LocalExecutor.serve_logs + @skip_spawn_mp_start + @mock.patch.object(gc, "unfreeze") + @mock.patch.object(gc, "freeze") + def test_executor_worker_spawned(self, mock_freeze, mock_unfreeze): + executor = LocalExecutor(parallelism=5) + executor.start() + + mock_freeze.assert_called_once() + mock_unfreeze.assert_called_once() + + assert len(executor.workers) == 5 + @skip_spawn_mp_start @mock.patch("airflow.sdk.execution_time.supervisor.supervise") def test_execution(self, mock_supervise): @@ -90,23 +103,11 @@ def fake_supervise(ti, **kwargs): assert executor.result_queue.empty() - with spy_on(executor._spawn_worker) as spawn_worker: - for ti in success_tis: - executor.queue_workload( - workloads.ExecuteTask( - token="", - ti=ti, - dag_rel_path="some/path", - log_path=None, - bundle_info=dict(name="hi", version="hi"), - ), - session=mock.MagicMock(spec=Session), - ) - + for ti in success_tis: executor.queue_workload( workloads.ExecuteTask( token="", - ti=fail_ti, + ti=ti, dag_rel_path="some/path", log_path=None, bundle_info=dict(name="hi", version="hi"), @@ -114,14 +115,21 @@ def fake_supervise(ti, **kwargs): session=mock.MagicMock(spec=Session), ) - # Process queued workloads to trigger worker spawning - executor._process_workloads(list(executor.queued_tasks.values())) + executor.queue_workload( + workloads.ExecuteTask( + token="", + ti=fail_ti, + dag_rel_path="some/path", + log_path=None, + bundle_info=dict(name="hi", version="hi"), + ), + session=mock.MagicMock(spec=Session), + ) - executor.end() + # Process queued workloads to trigger worker spawning + executor._process_workloads(list(executor.queued_tasks.values())) - expected = 2 - # Depending on how quickly the tasks run, we might not need to create all the workers we could - assert 1 <= len(spawn_worker.calls) <= expected + executor.end() # By that time Queues are already shutdown so we cannot check if they are empty assert len(executor.running) == 0 @@ -158,9 +166,6 @@ def test_clean_stop_on_signal(self): executor = LocalExecutor(parallelism=2) executor.start() - # We want to ensure we start a worker process, as we now only create them on demand - executor._spawn_worker() - try: os.kill(os.getpid(), signal.SIGINT) except KeyboardInterrupt: @@ -168,6 +173,58 @@ def test_clean_stop_on_signal(self): finally: executor.end() + @skip_spawn_mp_start + def test_worker_process_revive(self): + import signal + + executor = LocalExecutor(parallelism=2) + executor.start() + + worker_pid = list(executor.workers.keys()) + for killed_pid in worker_pid: + os.kill(killed_pid, signal.SIGTERM) + + # wait until worker is terminated + for killed_pid in worker_pid: + executor.workers[killed_pid].join(timeout=3) + + success_tis = [ + workloads.TaskInstance( + id=uuid7(), + dag_version_id=uuid7(), + task_id=f"success_{i}", + dag_id="mydag", + run_id="run1", + try_number=1, + state="queued", + pool_slots=1, + queue="default", + priority_weight=1, + map_index=-1, + start_date=timezone.utcnow(), + ) + for i in range(self.TEST_SUCCESS_COMMANDS) + ] + + for ti in success_tis: + executor.queue_workload( + workloads.ExecuteTask( + token="", + ti=ti, + dag_rel_path="some/path", + log_path=None, + bundle_info=dict(name="hi", version="hi"), + ), + session=mock.MagicMock(spec=Session), + ) + + with spy_on(executor._spawn_worker) as spawn_worker: + executor._process_workloads(list(executor.queued_tasks.values())) + if executor.is_mp_using_fork: + assert len(spawn_worker.calls) == 2 + else: + assert len(spawn_worker.calls) == 1 + @pytest.mark.parametrize( ("conf_values", "expected_server"), [ diff --git a/airflow-core/tests/unit/executors/test_local_executor_check_workers.py b/airflow-core/tests/unit/executors/test_local_executor_check_workers.py index b0adfe5b9e53a..557ff4bbcbf20 100644 --- a/airflow-core/tests/unit/executors/test_local_executor_check_workers.py +++ b/airflow-core/tests/unit/executors/test_local_executor_check_workers.py @@ -101,7 +101,7 @@ def test_spawn_worker_when_needed(setup_executor): executor.activity_queue.empty.return_value = False executor.workers = {} executor._check_workers() - executor._spawn_worker.assert_called_once() + executor._spawn_worker.assert_called() def test_no_spawn_if_parallelism_reached(setup_executor): @@ -133,4 +133,4 @@ def test_spawn_worker_when_we_have_parallelism_left(setup_executor): executor.activity_queue.empty.return_value = False executor._spawn_worker.reset_mock() executor._check_workers() - executor._spawn_worker.assert_called_once() + executor._spawn_worker.assert_called() diff --git a/airflow-core/tests/unit/utils/test_gc_util.py b/airflow-core/tests/unit/utils/test_gc_util.py new file mode 100644 index 0000000000000..67a35bcda9031 --- /dev/null +++ b/airflow-core/tests/unit/utils/test_gc_util.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import gc +from unittest import mock + +import pytest + +from airflow.utils.gc_utils import with_gc_freeze + + +class TestWithGcFreeze: + @mock.patch.object(gc, "unfreeze") + @mock.patch.object(gc, "freeze") + def test_gc_freeze_and_unfreeze_called(self, freeze_mock, unfreeze_mock): + """Test that gc.freeze and gc.unfreeze are called""" + + @with_gc_freeze + def dummy_function(): + return "success" + + result = dummy_function() + + assert result == "success" + freeze_mock.assert_called_once() + unfreeze_mock.assert_called_once() + + @mock.patch.object(gc, "unfreeze") + @mock.patch.object(gc, "freeze") + def test_unfreeze_called_even_on_exception(self, freeze_mock, unfreeze_mock): + """Test that gc.unfreeze is called even when an exception occurs""" + + @with_gc_freeze + def failing_function(): + raise ValueError("test error") + + with pytest.raises(ValueError, match="test error"): + failing_function() + + freeze_mock.assert_called_once() + unfreeze_mock.assert_called_once() + + @mock.patch.object(gc, "unfreeze") + @mock.patch.object(gc, "freeze") + def test_function_arguments_passed_correctly(self, freeze_mock, unfreeze_mock): + """Test that function arguments are passed correctly""" + + @with_gc_freeze + def function_with_args(a, b, c=None): + return a + b + (c or 0) + + result = function_with_args(1, 2, c=3) + + assert result == 6 + freeze_mock.assert_called_once() + unfreeze_mock.assert_called_once() From 838167f44b193aded7ff5cf7e71064cd37df9e72 Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Sun, 16 Nov 2025 23:33:46 +0900 Subject: [PATCH 2/8] fix test --- .../tests/unit/executors/test_local_executor.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 18e6c679e25db..b2bfb1c24379e 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -175,18 +175,16 @@ def test_clean_stop_on_signal(self): @skip_spawn_mp_start def test_worker_process_revive(self): - import signal - executor = LocalExecutor(parallelism=2) executor.start() worker_pid = list(executor.workers.keys()) for killed_pid in worker_pid: - os.kill(killed_pid, signal.SIGTERM) + # killing the worker process + proc = mock.MagicMock() + proc.is_alive.return_value = True - # wait until worker is terminated - for killed_pid in worker_pid: - executor.workers[killed_pid].join(timeout=3) + executor.workers[killed_pid] = proc success_tis = [ workloads.TaskInstance( From e246aa4abba07694868610a2f1bfa69792e6785e Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Mon, 17 Nov 2025 10:39:05 +0900 Subject: [PATCH 3/8] fix test --- airflow-core/tests/unit/executors/test_local_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index b2bfb1c24379e..f3562efcdb1f1 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -182,7 +182,7 @@ def test_worker_process_revive(self): for killed_pid in worker_pid: # killing the worker process proc = mock.MagicMock() - proc.is_alive.return_value = True + proc.is_alive.return_value = False executor.workers[killed_pid] = proc From d32bda4203ced3e52412535d5e8cc69b8d260614 Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Fri, 21 Nov 2025 15:15:53 +0900 Subject: [PATCH 4/8] remove gc utils --- .../src/airflow/executors/local_executor.py | 22 ++++-- airflow-core/src/airflow/utils/gc_utils.py | 44 ------------ airflow-core/tests/unit/utils/test_gc_util.py | 72 ------------------- 3 files changed, 18 insertions(+), 120 deletions(-) delete mode 100644 airflow-core/src/airflow/utils/gc_utils.py delete mode 100644 airflow-core/tests/unit/utils/test_gc_util.py diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 1b6272f4426a2..b624e6f1bdad0 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -36,7 +36,6 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor -from airflow.utils.gc_utils import with_gc_freeze from airflow.utils.state import TaskInstanceState # add logger to parameter of setproctitle to support logging @@ -217,10 +216,25 @@ def _spawn_worker(self): assert p.pid # Since we've called start self.workers[p.pid] = p - @with_gc_freeze def _spawn_workers_with_gc_freeze(self, spawn_number): - for _ in range(spawn_number): - self._spawn_worker() + """ + Freeze the GC before forking worker process and unfreeze it after forking. + + This is done to prevent memory increase due to COW (Copy-on-Write) by moving all + existing objects to the permanent generation before forking the process. After forking, + unfreeze is called to ensure there is no impact on gc operations + in the original running process. + + Ref: https://docs.python.org/3/library/gc.html#gc.freeze + """ + import gc + + gc.freeze() + try: + for _ in range(spawn_number): + self._spawn_worker() + finally: + gc.unfreeze() def sync(self) -> None: """Sync will get called periodically by the heartbeat method.""" diff --git a/airflow-core/src/airflow/utils/gc_utils.py b/airflow-core/src/airflow/utils/gc_utils.py deleted file mode 100644 index 70acdf29644c0..0000000000000 --- a/airflow-core/src/airflow/utils/gc_utils.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import gc -from functools import wraps - - -def with_gc_freeze(func): - """ - Freeze the GC before executing the function and unfreeze it after execution. - - This is done to prevent memory increase due to COW (Copy-on-Write) by moving all - existing objects to the permanent generation before forking the process. After the - function executes, unfreeze is called to ensure there is no impact on gc operations - in the original running process. - - Ref: https://docs.python.org/3/library/gc.html#gc.freeze - """ - - @wraps(func) - def wrapper(*args, **kwargs): - gc.freeze() - try: - return func(*args, **kwargs) - finally: - gc.unfreeze() - - return wrapper diff --git a/airflow-core/tests/unit/utils/test_gc_util.py b/airflow-core/tests/unit/utils/test_gc_util.py deleted file mode 100644 index 67a35bcda9031..0000000000000 --- a/airflow-core/tests/unit/utils/test_gc_util.py +++ /dev/null @@ -1,72 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import gc -from unittest import mock - -import pytest - -from airflow.utils.gc_utils import with_gc_freeze - - -class TestWithGcFreeze: - @mock.patch.object(gc, "unfreeze") - @mock.patch.object(gc, "freeze") - def test_gc_freeze_and_unfreeze_called(self, freeze_mock, unfreeze_mock): - """Test that gc.freeze and gc.unfreeze are called""" - - @with_gc_freeze - def dummy_function(): - return "success" - - result = dummy_function() - - assert result == "success" - freeze_mock.assert_called_once() - unfreeze_mock.assert_called_once() - - @mock.patch.object(gc, "unfreeze") - @mock.patch.object(gc, "freeze") - def test_unfreeze_called_even_on_exception(self, freeze_mock, unfreeze_mock): - """Test that gc.unfreeze is called even when an exception occurs""" - - @with_gc_freeze - def failing_function(): - raise ValueError("test error") - - with pytest.raises(ValueError, match="test error"): - failing_function() - - freeze_mock.assert_called_once() - unfreeze_mock.assert_called_once() - - @mock.patch.object(gc, "unfreeze") - @mock.patch.object(gc, "freeze") - def test_function_arguments_passed_correctly(self, freeze_mock, unfreeze_mock): - """Test that function arguments are passed correctly""" - - @with_gc_freeze - def function_with_args(a, b, c=None): - return a + b + (c or 0) - - result = function_with_args(1, 2, c=3) - - assert result == 6 - freeze_mock.assert_called_once() - unfreeze_mock.assert_called_once() From 88b1153f021cd1d8c3f44fb63e78da89ed29de17 Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Mon, 24 Nov 2025 16:59:57 +0900 Subject: [PATCH 5/8] fix test to prevent timeout --- .../unit/executors/test_local_executor.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index f3562efcdb1f1..8007b6faf419e 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -45,6 +45,11 @@ class TestLocalExecutor: + """ + When the executor is started, end() must be called before the test finishes. + Otherwise, subprocesses will remain running, preventing the test from terminating and causing a timeout. + """ + TEST_SUCCESS_COMMANDS = 5 def test_sentry_integration(self): @@ -68,6 +73,8 @@ def test_executor_worker_spawned(self, mock_freeze, mock_unfreeze): assert len(executor.workers) == 5 + executor.end() + @skip_spawn_mp_start @mock.patch("airflow.sdk.execution_time.supervisor.supervise") def test_execution(self, mock_supervise): @@ -178,12 +185,16 @@ def test_worker_process_revive(self): executor = LocalExecutor(parallelism=2) executor.start() - worker_pid = list(executor.workers.keys()) - for killed_pid in worker_pid: - # killing the worker process + # Mock the process to make it appear dead. + # However, the processes that lost their references must be included in end() before termination. + # Otherwise, the test will not finish and a timeout will occur. + dead_process = {} + + for killed_pid, killed_proc in executor.workers.items(): proc = mock.MagicMock() proc.is_alive.return_value = False + dead_process[killed_pid] = killed_proc executor.workers[killed_pid] = proc success_tis = [ @@ -218,11 +229,17 @@ def test_worker_process_revive(self): with spy_on(executor._spawn_worker) as spawn_worker: executor._process_workloads(list(executor.queued_tasks.values())) + if executor.is_mp_using_fork: assert len(spawn_worker.calls) == 2 else: assert len(spawn_worker.calls) == 1 + for killed_pid, killed_proc in dead_process.items(): + executor.workers[killed_pid] = killed_proc + + executor.end() + @pytest.mark.parametrize( ("conf_values", "expected_server"), [ From 0866c9f04c6f9175c8e6107b020ecea6df6a5abc Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Tue, 25 Nov 2025 23:46:20 +0900 Subject: [PATCH 6/8] fix tests --- .../unit/executors/test_local_executor.py | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 8007b6faf419e..dfd3e2e3dc055 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -106,15 +106,28 @@ def fake_supervise(ti, **kwargs): mock_supervise.side_effect = fake_supervise executor = LocalExecutor(parallelism=2) - executor.start() - assert executor.result_queue.empty() + with spy_on(executor._spawn_worker) as spawn_worker: + executor.start() + + assert executor.result_queue.empty() + + for ti in success_tis: + executor.queue_workload( + workloads.ExecuteTask( + token="", + ti=ti, + dag_rel_path="some/path", + log_path=None, + bundle_info=dict(name="hi", version="hi"), + ), + session=mock.MagicMock(spec=Session), + ) - for ti in success_tis: executor.queue_workload( workloads.ExecuteTask( token="", - ti=ti, + ti=fail_ti, dag_rel_path="some/path", log_path=None, bundle_info=dict(name="hi", version="hi"), @@ -122,21 +135,14 @@ def fake_supervise(ti, **kwargs): session=mock.MagicMock(spec=Session), ) - executor.queue_workload( - workloads.ExecuteTask( - token="", - ti=fail_ti, - dag_rel_path="some/path", - log_path=None, - bundle_info=dict(name="hi", version="hi"), - ), - session=mock.MagicMock(spec=Session), - ) + # Process queued workloads to trigger worker spawning + executor._process_workloads(list(executor.queued_tasks.values())) - # Process queued workloads to trigger worker spawning - executor._process_workloads(list(executor.queued_tasks.values())) + executor.end() - executor.end() + expected = 2 + # Depending on how quickly the tasks run, we might not need to create all the workers we could + assert 1 <= len(spawn_worker.calls) <= expected # By that time Queues are already shutdown so we cannot check if they are empty assert len(executor.running) == 0 @@ -181,7 +187,7 @@ def test_clean_stop_on_signal(self): executor.end() @skip_spawn_mp_start - def test_worker_process_revive(self): + def test_executor_replace_dead_workers(self): executor = LocalExecutor(parallelism=2) executor.start() @@ -230,10 +236,7 @@ def test_worker_process_revive(self): with spy_on(executor._spawn_worker) as spawn_worker: executor._process_workloads(list(executor.queued_tasks.values())) - if executor.is_mp_using_fork: - assert len(spawn_worker.calls) == 2 - else: - assert len(spawn_worker.calls) == 1 + assert 1 <= len(spawn_worker.calls) <= 2 for killed_pid, killed_proc in dead_process.items(): executor.workers[killed_pid] = killed_proc From 8dcd87069549549a88a199f400d206ad4096e8ca Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Tue, 25 Nov 2025 23:54:21 +0900 Subject: [PATCH 7/8] fix tests --- airflow-core/tests/unit/executors/test_local_executor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index dfd3e2e3dc055..d9c9454dff33a 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -179,6 +179,9 @@ def test_clean_stop_on_signal(self): executor = LocalExecutor(parallelism=2) executor.start() + # We want to ensure we start a worker process, as we now only create them on demand + executor._spawn_worker() + try: os.kill(os.getpid(), signal.SIGINT) except KeyboardInterrupt: From 1e249e890d90115b19c1059a5f371d7d19009457 Mon Sep 17 00:00:00 2001 From: wjddn279 Date: Wed, 26 Nov 2025 00:43:53 +0900 Subject: [PATCH 8/8] fix tests --- .../unit/executors/test_local_executor.py | 57 ------------------- 1 file changed, 57 deletions(-) diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index d9c9454dff33a..5f0420cf98014 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -189,63 +189,6 @@ def test_clean_stop_on_signal(self): finally: executor.end() - @skip_spawn_mp_start - def test_executor_replace_dead_workers(self): - executor = LocalExecutor(parallelism=2) - executor.start() - - # Mock the process to make it appear dead. - # However, the processes that lost their references must be included in end() before termination. - # Otherwise, the test will not finish and a timeout will occur. - dead_process = {} - - for killed_pid, killed_proc in executor.workers.items(): - proc = mock.MagicMock() - proc.is_alive.return_value = False - - dead_process[killed_pid] = killed_proc - executor.workers[killed_pid] = proc - - success_tis = [ - workloads.TaskInstance( - id=uuid7(), - dag_version_id=uuid7(), - task_id=f"success_{i}", - dag_id="mydag", - run_id="run1", - try_number=1, - state="queued", - pool_slots=1, - queue="default", - priority_weight=1, - map_index=-1, - start_date=timezone.utcnow(), - ) - for i in range(self.TEST_SUCCESS_COMMANDS) - ] - - for ti in success_tis: - executor.queue_workload( - workloads.ExecuteTask( - token="", - ti=ti, - dag_rel_path="some/path", - log_path=None, - bundle_info=dict(name="hi", version="hi"), - ), - session=mock.MagicMock(spec=Session), - ) - - with spy_on(executor._spawn_worker) as spawn_worker: - executor._process_workloads(list(executor.queued_tasks.values())) - - assert 1 <= len(spawn_worker.calls) <= 2 - - for killed_pid, killed_proc in dead_process.items(): - executor.workers[killed_pid] = killed_proc - - executor.end() - @pytest.mark.parametrize( ("conf_values", "expected_server"), [