diff --git a/distributed/nanny.py b/distributed/nanny.py index 29e71f14c8e..9e7be25c3ce 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -1,8 +1,11 @@ from __future__ import annotations import asyncio +import contextlib import errno +import functools import logging +import multiprocessing import os import shutil import tempfile @@ -11,14 +14,12 @@ import warnings import weakref from collections.abc import Collection -from contextlib import suppress from inspect import isawaitable from queue import Empty from time import sleep as sync_sleep -from typing import TYPE_CHECKING, ClassVar, Literal +from typing import TYPE_CHECKING, Callable, ClassVar, Literal from toolz import merge -from tornado import gen from tornado.ioloop import IOLoop import dask @@ -45,7 +46,6 @@ from distributed.protocol import pickle from distributed.security import Security from distributed.utils import ( - TimeoutError, get_ip, get_mp_context, json_load_robust, @@ -303,14 +303,15 @@ async def _unregister(self, timeout=10): if worker_address is None: return - allowed_errors = (TimeoutError, CommClosedError, EnvironmentError, RPCClosed) - with suppress(allowed_errors): + try: await asyncio.wait_for( self.scheduler.unregister( address=self.worker_address, stimulus_id=f"nanny-close-{time()}" ), timeout, ) + except (asyncio.TimeoutError, CommClosedError, OSError, RPCClosed): + pass @property def worker_address(self): @@ -425,7 +426,7 @@ async def instantiate(self) -> Status: result = await asyncio.wait_for( self.process.start(), self.death_timeout ) - except TimeoutError: + except asyncio.TimeoutError: logger.error( "Timed out connecting Nanny '%s' to scheduler '%s'", self, @@ -496,7 +497,7 @@ async def _(): try: await asyncio.wait_for(_(), timeout) - except TimeoutError: + except asyncio.TimeoutError: logger.error( f"Restart timed out after {timeout}s; returning before finished" ) @@ -679,18 +680,18 @@ async def start(self) -> Status: uid = uuid.uuid4().hex self.process = AsyncProcess( - target=self._run, - name="Dask Worker process (from Nanny)", - kwargs=dict( - worker_kwargs=self.worker_kwargs, + target=functools.partial( + self._run, silence_logs=self.silence_logs, init_result_q=self.init_result_q, child_stop_q=self.child_stop_q, uid=uid, - Worker=self.Worker, + worker_factory=functools.partial(self.Worker, **self.worker_kwargs), env=self.env, config=self.config, ), + name="Dask Worker process (from Nanny)", + kwargs=dict(), ) self.process.daemon = dask.config.get("distributed.worker.daemon", default=True) self.process.set_exit_callback(self._on_exit) @@ -860,86 +861,66 @@ async def _wait_until_connected(self, uid): @classmethod def _run( cls, - worker_kwargs, - silence_logs, - init_result_q, - child_stop_q, - uid, - env, - config, - Worker, - ): # pragma: no cover - try: - os.environ.update(env) - dask.config.refresh() - dask.config.set(config) - - from dask.multiprocessing import default_initializer - - default_initializer() - - if silence_logs: - logger.setLevel(silence_logs) - - IOLoop.clear_instance() - loop = IOLoop() - loop.make_current() - worker = Worker(**worker_kwargs) - - async def do_stop( - timeout=5, executor_wait=True, reason="workerprocess-stop" - ): - try: - await worker.close( - nanny=False, - executor_wait=executor_wait, - timeout=timeout, - reason=reason, - ) - finally: - loop.stop() - - def watch_stop_q(): - """ - Wait for an incoming stop message and then stop the - worker cleanly. - """ - try: - msg = child_stop_q.get() - except (TypeError, OSError, EOFError): - logger.error("Worker process died unexpectedly") - msg = {"op": "stop"} - finally: - child_stop_q.close() - assert msg["op"] == "stop", msg - del msg["op"] - loop.add_callback(do_stop, **msg) - - thread = threading.Thread( - target=watch_stop_q, name="Nanny stop queue watch" + silence_logs: bool, + init_result_q: multiprocessing.Queue, + child_stop_q: multiprocessing.Queue, + uid: str, + env: dict, + config: dict, + worker_factory: Callable[[], Worker], + ) -> None: # pragma: no cover + async def do_stop( + *, + worker: Worker, + timeout: float = 5, + executor_wait: bool = True, + reason: str = "workerprocess-stop", + ) -> None: + await worker.close( + nanny=False, + executor_wait=executor_wait, + timeout=timeout, + reason=reason, ) - thread.daemon = True - thread.start() - async def run(): - """ - Try to start worker and inform parent of outcome. - """ - try: - await worker - except Exception as e: - logger.exception("Failed to start worker") - init_result_q.put({"uid": uid, "exception": e}) - init_result_q.close() - # If we hit an exception here we need to wait for a least - # one interval for the outside to pick up this message. - # Otherwise we arrive in a race condition where the process - # cleanup wipes the queue before the exception can be - # properly handled. See also - # WorkerProcess._wait_until_connected (the 2 is for good - # measure) - sync_sleep(cls._init_msg_interval * 2) - else: + def watch_stop_q(loop: IOLoop, worker: Worker) -> None: + """ + Wait for an incoming stop message and then stop the + worker cleanly. + """ + try: + msg = child_stop_q.get() + except (TypeError, OSError, EOFError): + logger.error("Worker process died unexpectedly") + msg = {"op": "stop"} + finally: + child_stop_q.close() + assert msg["op"] == "stop", msg + del msg["op"] + loop.add_callback(do_stop, worker=worker, **msg) + + async def run() -> None: + """ + Try to start worker and inform parent of outcome. + """ + failure_type: str | None = "initialize" + try: + worker = worker_factory() + failure_type = "start" + thread = threading.Thread( + target=functools.partial( + watch_stop_q, + worker=worker, + loop=IOLoop.current(), + ), + name="Nanny stop queue watch", + daemon=True, + ) + thread.start() + stack.callback(thread.join, timeout=2) + async with worker: + failure_type = None + try: assert worker.address except ValueError: @@ -955,34 +936,49 @@ async def run(): init_result_q.close() await worker.finished() logger.info("Worker closed") + except Exception as e: + if failure_type is None: + raise - except Exception as e: - logger.exception("Failed to initialize Worker") - init_result_q.put({"uid": uid, "exception": e}) - init_result_q.close() - # If we hit an exception here we need to wait for a least one - # interval for the outside to pick up this message. Otherwise we - # arrive in a race condition where the process cleanup wipes the - # queue before the exception can be properly handled. See also - # WorkerProcess._wait_until_connected (the 2 is for good measure) - sync_sleep(cls._init_msg_interval * 2) - else: - try: - loop.run_sync(run) - except (TimeoutError, gen.TimeoutError): - # Loop was stopped before wait_until_closed() returned, ignore - pass - except KeyboardInterrupt: - # At this point the loop is not running thus we have to run - # do_stop() explicitly. - loop.run_sync(do_stop) - finally: - with suppress(ValueError): + logger.exception(f"Failed to {failure_type} worker") + init_result_q.put({"uid": uid, "exception": e}) + init_result_q.close() + # If we hit an exception here we need to wait for a least + # one interval for the outside to pick up this message. + # Otherwise we arrive in a race condition where the process + # cleanup wipes the queue before the exception can be + # properly handled. See also + # WorkerProcess._wait_until_connected (the 3 is for good + # measure) + sync_sleep(cls._init_msg_interval * 3) + + with contextlib.ExitStack() as stack: + + @stack.callback + def close_stop_q() -> None: + try: child_stop_q.put({"op": "stop"}) # usually redundant - with suppress(ValueError): + except ValueError: + pass + + try: child_stop_q.close() # usually redundant + except ValueError: + pass child_stop_q.join_thread() - thread.join(timeout=2) + + os.environ.update(env) + dask.config.refresh() + dask.config.set(config) + + from dask.multiprocessing import default_initializer + + default_initializer() + + if silence_logs: + logger.setLevel(silence_logs) + + asyncio.run(run()) def _get_env_variables(config_key: str) -> dict[str, str]: diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c0d5b716220..26748dedbc4 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -106,7 +106,6 @@ map_varying, nodebug, popen, - pristine_loop, randominc, save_sys_modules, slowadd, @@ -2207,8 +2206,26 @@ async def test_multi_client(s, a, b): await asyncio.sleep(0.01) +@contextmanager +def _pristine_loop(): + IOLoop.clear_instance() + IOLoop.clear_current() + loop = IOLoop() + loop.make_current() + assert IOLoop.current() is loop + try: + yield loop + finally: + try: + loop.close(all_fds=True) + except (KeyError, ValueError): + pass + IOLoop.clear_instance() + IOLoop.clear_current() + + def long_running_client_connection(address): - with pristine_loop(): + with _pristine_loop(): c = Client(address) x = c.submit(lambda x: x + 1, 10) x.result() @@ -5602,7 +5619,7 @@ async def close(): async with client: pass - with pristine_loop() as loop: + with _pristine_loop() as loop: with pytest.warns( DeprecationWarning, match=r"Constructing LoopRunner\(loop=loop\) without a running loop is deprecated", diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 3bc471de19e..bd2e54fb04d 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -160,24 +160,6 @@ async def run(): return -@contextmanager -def pristine_loop(): - IOLoop.clear_instance() - IOLoop.clear_current() - loop = IOLoop() - loop.make_current() - assert IOLoop.current() is loop - try: - yield loop - finally: - try: - loop.close(all_fds=True) - except (KeyError, ValueError): - pass - IOLoop.clear_instance() - IOLoop.clear_current() - - original_config = copy.deepcopy(dask.config.config)