From 6fa046bc42399cda1cb63922b604c29c3532bc45 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 9 Apr 2024 17:48:34 +0200 Subject: [PATCH 1/5] ensure workers do not kill on restart --- distributed/comm/tcp.py | 4 +- distributed/deploy/spec.py | 49 +++- distributed/deploy/tests/test_local.py | 15 ++ distributed/nanny.py | 279 +++++++++++------------ distributed/process.py | 9 +- distributed/scheduler.py | 162 ++++++------- distributed/shuffle/_scheduler_plugin.py | 4 +- distributed/tests/test_client.py | 30 +-- distributed/tests/test_failed_workers.py | 18 +- distributed/tests/test_nanny.py | 78 +++++-- distributed/tests/test_scheduler.py | 85 +------ distributed/utils_test.py | 16 +- distributed/worker.py | 5 +- 13 files changed, 381 insertions(+), 373 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index c741aba83e9..b395ff52378 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -665,7 +665,9 @@ async def _handle_stream(self, stream, address): try: await self.on_connection(comm) except CommClosedError: - logger.info("Connection from %s closed before handshake completed", address) + logger.debug( + "Connection from %s closed before handshake completed", address + ) return await self.comm_handler(comm) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 8189cfdd7dc..86d4cb2728d 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -10,6 +10,7 @@ from collections.abc import Awaitable, Generator from contextlib import suppress from inspect import isawaitable +from time import time from typing import TYPE_CHECKING, Any, ClassVar, TypeVar from tornado import gen @@ -389,28 +390,64 @@ async def _correct_state_internal(self) -> None: # proper teardown. await asyncio.gather(*worker_futs) - def _update_worker_status(self, op, msg): + def _update_worker_status(self, op, worker_addr): if op == "remove": - name = self.scheduler_info["workers"][msg]["name"] + worker_info = self.scheduler_info["workers"][worker_addr].copy() + name = worker_info["name"] + + from distributed import Nanny, Worker def f(): + # FIXME: SpecCluster is tracking workers by `name`` which are + # not necessarily unique. + # Clusters with Nannies (default) are susceptible to falsely + # removing the Nannies on restart due to this logic since the + # restart emits a op==remove signal on the worker address but + # the SpecCluster only tracks the names, i.e. after + # `lost-worker-timeout` the Nanny is still around and this logic + # could trigger a false close. The below code should handle this + # but it would be cleaner if the cluster tracked by address + # instead of name just like the scheduler does if ( name in self.workers - and msg not in self.scheduler_info["workers"] + and worker_addr not in self.scheduler_info["workers"] and not any( d["name"] == name for d in self.scheduler_info["workers"].values() ) ): - self._futures.add(asyncio.ensure_future(self.workers[name].close())) - del self.workers[name] + w = self.workers[name] + + async def remove_worker(): + await w.close(reason=f"lost-worker-timeout-{time()}") + self.workers.pop(name, None) + + if ( + worker_info["type"] == "Worker" + and (isinstance(w, Nanny) and w.worker_address == worker_addr) + or (isinstance(w, Worker) and w.address == worker_addr) + ): + self._futures.add( + asyncio.create_task( + remove_worker(), + name="remove-worker-lost-worker-timeout", + ) + ) + elif worker_info["type"] == "Nanny": + # This should never happen + logger.critical( + "Unespected signal encountered. WorkerStatusPlugin " + "emitted a op==remove signal for a Nanny which " + "should not happen. This might cause a lingering " + "Nanny process." + ) delay = parse_timedelta( dask.config.get("distributed.deploy.lost-worker-timeout") ) asyncio.get_running_loop().call_later(delay, f) - super()._update_worker_status(op, msg) + super()._update_worker_status(op, worker_addr) def __await__(self: Self) -> Generator[Any, Any, Self]: async def _() -> Self: diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 1594bc4db5f..d6acf5780f8 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -11,6 +11,7 @@ import pytest from tornado.httpclient import AsyncHTTPClient +import dask from dask.system import CPU_COUNT from distributed import Client, LocalCluster, Nanny, Worker, get_client @@ -1285,3 +1286,17 @@ def test_localcluster_get_client(loop): with Client(cluster) as client2: assert client1 != client2 assert client2 == cluster.get_client() + + +@pytest.mark.slow() +def test_localcluster_restart(loop): + with ( + dask.config.set({"distributed.deploy.lost-worker-timeout": "0.5s"}), + LocalCluster(asynchronous=False, dashboard_address=":0", loop=loop) as cluster, + cluster.get_client() as client, + ): + nworkers = len(client.run(lambda: None)) + for _ in range(10): + assert len(client.run(lambda: None)) == nworkers + client.restart() + assert len(client.run(lambda: None)) == nworkers diff --git a/distributed/nanny.py b/distributed/nanny.py index 99644e9292d..1e980767e18 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -163,7 +163,6 @@ def __init__( # type: ignore[no-untyped-def] ) self.__exit_stack = stack = contextlib.ExitStack() - self.process = None self._setup_logging(logger) self.loop = self.io_loop = IOLoop.current() @@ -186,7 +185,6 @@ def __init__( # type: ignore[no-untyped-def] preload_nanny_argv = dask.config.get("distributed.nanny.preload-argv") handlers = { - "instantiate": self.instantiate, "kill": self.kill, "restart": self.restart, "get_logs": self.get_logs, @@ -282,7 +280,11 @@ def __init__( # type: ignore[no-untyped-def] self._start_host = host self._interface = interface self._protocol = protocol - + self._lifecycle_lock = asyncio.Lock() + self._proc_changed = asyncio.Event() + self.process = None + self._reconciler_task = None + self._reconciler_queue: asyncio.Queue[dict] = asyncio.Queue() self._listen_address = listen_address Nanny._instances.add(self) @@ -294,30 +296,25 @@ def __init__( # type: ignore[no-untyped-def] def __repr__(self): return "" % (self.worker_address, self.nthreads) - async def _unregister(self, timeout=10): - if self.process is None: - return - worker_address = self.process.worker_address + async def _unregister(self, proc): + worker_address = proc.worker_address if worker_address is None: return try: - await wait_for( - self.scheduler.unregister( - address=self.worker_address, stimulus_id=f"nanny-close-{time()}" - ), - timeout, + await self.scheduler.unregister( + address=worker_address, stimulus_id=f"nanny-close-{time()}" ) except (asyncio.TimeoutError, CommClosedError, OSError, RPCClosed): pass @property def worker_address(self): - return None if self.process is None else self.process.worker_address + return None if not self.process.is_alive() else self.process.worker_address @property def worker_dir(self): - return None if self.process is None else self.process.worker_dir + return None if not self.process.is_alive() else self.process.worker_dir async def start_unsafe(self): """Start nanny, start local process, start watching""" @@ -362,14 +359,15 @@ async def start_unsafe(self): await comm.write({"op": "register_nanny", "address": self.address}) msg = await comm.read() try: - for name, plugin in msg["nanny-plugins"].items(): - await self.plugin_add(plugin=plugin, name=name) - logger.info(" Start Nanny at: %r", self.address) - response = await self.instantiate() + self.process = self.instantiate_worker_process() + async with self.successful_instantiation(): + self._reconciler_task = asyncio.create_task( + self._reconciler(), name="WorkerProcess reconciler" + ) - if response != Status.running: - raise RuntimeError("Nanny failed to start worker process") + for name, plugin in msg["nanny-plugins"].items(): + await self.plugin_add(plugin=plugin, name=name) except Exception: try: await comm.write({"status": "error"}) @@ -396,68 +394,62 @@ async def kill(self, timeout: float = 2, reason: str = "nanny-kill") -> None: Blocks until both the process is down and the scheduler is properly informed """ - if self.process is None: + + if self.process is None or not self.process.is_alive(): return deadline = time() + timeout await self.process.kill(reason=reason, timeout=0.8 * (deadline - time())) - async def instantiate(self) -> Status: + def instantiate_worker_process(self) -> WorkerProcess: + worker_kwargs = dict( + scheduler_ip=self.scheduler_addr, + nthreads=self.nthreads, + local_directory=self._original_local_dir, + services=self.services, + nanny=self.address, + name=self.name, + memory_limit=self.memory_manager.memory_limit, + resources=self.resources, + validate=self.validate, + silence_logs=self.silence_logs, + death_timeout=self.death_timeout, + preload=self.preload, + preload_argv=self.preload_argv, + security=self.security, + contact_address=self.contact_address, + ) + worker_kwargs.update(self.worker_kwargs) + + return WorkerProcess( + worker_kwargs=worker_kwargs, + silence_logs=self.silence_logs, + on_exit=self._on_worker_exit_sync, + worker=self.Worker, + env=self.env, + pre_spawn_env=self.pre_spawn_env, + config=self.config, + ) + + async def start_worker_process(self) -> dict[str, Status]: """Start a local worker process Blocks until the process is up and the scheduler is properly informed """ - if self.process is None: - worker_kwargs = dict( - scheduler_ip=self.scheduler_addr, - nthreads=self.nthreads, - local_directory=self._original_local_dir, - services=self.services, - nanny=self.address, - name=self.name, - memory_limit=self.memory_manager.memory_limit, - resources=self.resources, - validate=self.validate, - silence_logs=self.silence_logs, - death_timeout=self.death_timeout, - preload=self.preload, - preload_argv=self.preload_argv, - security=self.security, - contact_address=self.contact_address, - ) - worker_kwargs.update(self.worker_kwargs) - self.process = WorkerProcess( - worker_kwargs=worker_kwargs, - silence_logs=self.silence_logs, - on_exit=self._on_worker_exit_sync, - worker=self.Worker, - env=self.env, - pre_spawn_env=self.pre_spawn_env, - config=self.config, + assert self.process is not None + try: + result = await wait_for(self.process.start(), self.death_timeout) + except asyncio.TimeoutError: + logger.error( + "Timed out connecting Nanny '%s' to scheduler '%s'", + self, + self.scheduler_addr, ) - - if self.death_timeout: - try: - result = await wait_for(self.process.start(), self.death_timeout) - except asyncio.TimeoutError: - logger.error( - "Timed out connecting Nanny '%s' to scheduler '%s'", - self, - self.scheduler_addr, - ) - await self.close( - timeout=self.death_timeout, reason="nanny-instantiate-timeout" - ) - raise - - else: - try: - result = await self.process.start() - except Exception: - logger.error("Failed to start process", exc_info=True) - await self.close(reason="nanny-instantiate-failed") - raise - return result + raise + except Exception: + logger.error("Failed to start process", exc_info=True) + raise + return {"status": result} @log_errors async def plugin_add( @@ -508,28 +500,33 @@ async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage: return {"status": "OK"} + @contextlib.asynccontextmanager + async def successful_instantiation(self): + async with self._lifecycle_lock: + while not self._reconciler_queue.empty(): + self._reconciler_queue.get_nowait() + self._proc_changed.set() + try: + yield + finally: + response = await self._reconciler_queue.get() + if response["status"] != Status.running: + raise response["exception"] + async def restart( self, timeout: float = 30, reason: str = "nanny-restart" - ) -> Literal["OK", "timed out"] | ErrorMessage: - async def _(): - if self.process is not None: - await self.kill(reason=reason) - await self.instantiate() - + ) -> ErrorMessage | OKMessage: try: - await wait_for(_(), timeout) - except asyncio.TimeoutError: - logger.error( - f"Restart timed out after {timeout}s; returning before finished" - ) - return "timed out" + async with self.successful_instantiation(): + await self.kill(reason=reason, timeout=timeout) except Exception as e: + logger.exception("Failed to kill worker", exc_info=True) return error_message(e) else: - return "OK" + return {"status": "OK"} def is_alive(self): - return self.process is not None and self.process.is_alive() + return self.process.is_alive() def run(self, comm, *args, **kwargs): return run(self, comm, *args, **kwargs) @@ -542,43 +539,44 @@ def _on_worker_exit_sync(self, exitcode): ): # Async task group has already been closed, so the nanny is already clos(ed|ing). pass + async def _reconciler(self): + while self.status in (Status.init, Status.running): + if not self.process.is_alive(): + if self.status == Status.running: + logger.warning("Restarting Worker on Nanny %s", self.address) + try: + result = await self.start_worker_process() + except Exception as e: + logger.exception("Failed to start worker", exc_info=True) + self.status = Status.failed + result = {"status": Status.failed, "exception": e} + self._reconciler_queue.put_nowait(result) + self._proc_changed.clear() + await self._proc_changed.wait() + await self.close(reason="nanny-close-gracefully") + @log_errors async def _on_worker_exit(self, exitcode): - if self.status not in ( - Status.init, - Status.closing, - Status.closed, - Status.closing_gracefully, - Status.failed, - ): - try: - await self._unregister() - except OSError: - logger.exception("Failed to unregister") - if not self.reconnect: - await self.close(reason="nanny-unregister-failed") - return - + # NOTE: This coroutine method should only be scheduled as an isolated + # task but not called directly since it is catching an + # asyncio.CancelledError. + old_proc = self.process + self.process = self.instantiate_worker_process() try: - if self.status not in ( - Status.closing, - Status.closed, - Status.closing_gracefully, - Status.failed, - ): - logger.warning("Restarting worker") - await self.instantiate() - elif self.status == Status.closing_gracefully: - await self.close(reason="nanny-close-gracefully") - - except Exception: - logger.error( - "Failed to restart worker after its process exited", exc_info=True - ) + if self.status in (Status.starting, Status.running): + await self._unregister(old_proc) + except OSError: + self.status = Status.failed + logger.exception("Failed to unregister") + except asyncio.CancelledError: + # Can happen during teardown. + pass + finally: + self._proc_changed.set() @property def pid(self): - return self.process and self.process.pid + return self.process.pid def _close(self, *args, **kwargs): warnings.warn("Worker._close has moved to Worker.close", stacklevel=2) @@ -622,12 +620,12 @@ async def close( # type:ignore[override] await asyncio.gather(*(td for td in teardowns if isawaitable(td))) self.stop() - if self.process is not None: + if self.process is not None and self.process.is_alive(): await self.kill(timeout=timeout, reason=reason) - self.process = None await self.rpc.close() self.status = Status.closed + await super().close() self.__exit_stack.__exit__(None, None, None) return "OK" @@ -744,6 +742,7 @@ async def start(self) -> Status: self.process.daemon = dask.config.get("distributed.worker.daemon", default=True) self.process.set_exit_callback(self._on_exit) self.running = asyncio.Event() + self.process_up = asyncio.Event() self.stopped = asyncio.Event() self.status = Status.starting @@ -758,6 +757,8 @@ async def start(self) -> Status: # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent await self.process.terminate() self.status = Status.failed + finally: + self.process_up.set() try: msg = await self._wait_until_connected(uid) except Exception: @@ -828,28 +829,28 @@ async def kill( ) -> None: """ Ensure the worker process is stopped, waiting at most - ``timeout * 0.8`` seconds before killing it abruptly. + ``timeout`` seconds before killing it abruptly. When `kill` returns, the worker process has been joined. If the worker process does not terminate within ``timeout`` seconds, even after being killed, `asyncio.TimeoutError` is raised. """ - deadline = time() + timeout + # If the process is not properly up it will not watch the closing queue + # and we may end up leaking this process + # Therefore wait for it to be properly started before killing it + if self.status == Status.starting: + await self.process_up.wait() if self.status == Status.stopped: return if self.status == Status.stopping: await self.stopped.wait() return - # If the process is not properly up it will not watch the closing queue - # and we may end up leaking this process - # Therefore wait for it to be properly started before killing it - if self.status == Status.starting: - await self.running.wait() assert self.status in ( Status.running, + Status.starting, Status.failed, # process failed to start, but hasn't been joined yet Status.closing_gracefully, ), self.status @@ -857,11 +858,10 @@ async def kill( logger.info("Nanny asking worker to close. Reason: %s", reason) process = self.process - wait_timeout = timeout * 0.8 self.child_stop_q.put( { "op": "stop", - "timeout": wait_timeout, + "timeout": timeout, "executor_wait": executor_wait, "reason": reason, } @@ -869,21 +869,13 @@ async def kill( self.child_stop_q.close() assert process is not None try: - try: - await process.join(wait_timeout) - return - except asyncio.TimeoutError: - pass - + await process.join(timeout) + except asyncio.TimeoutError: logger.warning( - f"Worker process still alive after {wait_timeout} seconds, killing" + f"Worker process still alive after {timeout} seconds, killing" ) await process.kill() - await process.join(max(0, deadline - time())) - except ValueError as e: - if "invalid operation on closed AsyncProcess" in str(e): - return - raise + await process.join() async def _wait_until_connected(self, uid): while True: @@ -965,7 +957,12 @@ async def run() -> None: ) thread.start() stack.callback(thread.join, timeout=2) - async with worker: + try: + await worker + except Exception: + await worker.close(nanny=False) + raise + else: failure_type = None try: diff --git a/distributed/process.py b/distributed/process.py index 150948fd5c7..8ae494306a0 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -325,9 +325,12 @@ async def join(self, timeout=None): assert self._state.pid is not None, "can only join a started process" if self._state.exitcode is not None: return - # Shield otherwise the timeout cancels the future and our - # on_exit callback will try to set a result on a canceled future - await wait_for(asyncio.shield(self._exit_future), timeout) + if timeout: + # Shield otherwise the timeout cancels the future and our + # on_exit callback will try to set a result on a canceled future + await wait_for(asyncio.shield(self._exit_future), timeout) + else: + await self._exit_future def close(self): """ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0251e0c310d..33540b7f92f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import contextlib import dataclasses import heapq import inspect @@ -91,7 +90,6 @@ Status, clean_exception, error_message, - rpc, send_recv, ) from distributed.diagnostics.memory_sampler import MemorySamplerExtension @@ -126,7 +124,6 @@ log_errors, offload, recursive_to_dict, - wait_for, ) from distributed.utils_comm import ( gather_from_workers, @@ -5222,7 +5219,7 @@ async def remove_worker( if not dh_addresses: del self.host_info[host] - self.rpc.remove(address) + self.rpc.remove(address, reason=f"Worker {address} removed by {stimulus_id}.") del self.stream_comms[address] del self.aliases[ws.name] self.idle.pop(ws.address, None) @@ -6250,7 +6247,7 @@ async def restart_workers( wait_for_workers: bool = True, on_error: Literal["raise", "return"] = "raise", stimulus_id: str, - ) -> dict[str, Literal["OK", "removed", "timed out"]]: + ) -> dict[str, Literal["OK", "removed", "error"]]: """Restart selected workers. Optionally wait for workers to return. Workers without nannies are shut down, hoping an external deployment system @@ -6258,9 +6255,7 @@ async def restart_workers( does not automatically restart workers, ``restart`` will just shut down all workers, then time out! - After ``restart``, all connected workers are new, regardless of whether - ``TimeoutError`` was raised. Any workers that failed to shut down in time are - removed, and may or may not shut down on their own in the future. + After ``restart``, all connected workers are new. Parameters ---------- @@ -6269,19 +6264,19 @@ async def restart_workers( timeout: How long to wait for workers to shut down and come back, if ``wait_for_workers`` is True, otherwise just how long to wait for workers to shut down. - Raises ``asyncio.TimeoutError`` if this is exceeded. + Raises ``TimeoutError`` if this is exceeded. wait_for_workers: Whether to wait for all workers to reconnect, or just for them to shut down (default True). Use ``restart(wait_for_workers=False)`` combined with :meth:`Client.wait_for_workers` for granular control over how many workers to wait for. on_error: - If 'raise' (the default), raise if any nanny times out while restarting the - worker. If 'return', return error messages. + If 'raise' (the default), raise if any errors occur during restart. + If 'return', return error messages. Returns ------- - {worker address: "OK", "no nanny", or "timed out" or error message} + {worker address: "OK", "removed", or "error" or error message} See also -------- @@ -6297,6 +6292,8 @@ async def restart_workers( workers = list(set(workers).intersection(self.workers)) logger.info(f"Restarting {len(workers)} workers: {workers} ({stimulus_id=}") + worker_set = set(workers) + nanny_workers = { addr: self.workers[addr].nanny for addr in workers @@ -6305,6 +6302,7 @@ async def restart_workers( # Close non-Nanny workers. We have no way to restart them, so we just let them # go, and assume a deployment system is going to restart them for us. no_nanny_workers = [addr for addr in workers if addr not in nanny_workers] + n_workers -= len(no_nanny_workers) if no_nanny_workers: logger.warning( f"Workers {no_nanny_workers} do not use a nanny and will be terminated " @@ -6316,64 +6314,49 @@ async def restart_workers( for addr in no_nanny_workers ) ) - out: dict[str, Literal["OK", "removed", "timed out"]] + out: dict[str, Literal["OK", "removed", "error"]] out = {addr: "removed" for addr in no_nanny_workers} deadline = Deadline.after(timeout) logger.debug("Send kill signal to nannies: %s", nanny_workers) - async with contextlib.AsyncExitStack() as stack: - nannies = await asyncio.gather( - *( - stack.enter_async_context( - rpc(nanny_address, connection_args=self.connection_args) - ) - for nanny_address in nanny_workers.values() - ) + resps = await asyncio.gather( + *( + self.rpc(nanny_address).kill(reason=stimulus_id, timeout=timeout * 0.8) + for nanny_address in nanny_workers.values() + ), + return_exceptions=True, + ) + # NOTE: the `WorkerState` entries for these workers will be removed + # naturally when they disconnect from the scheduler. + + # Remove any workers that failed to shut down, so we can guarantee + # that after `restart`, there are no old workers around. + bad_nannies = set() + for addr, resp in zip(nanny_workers, resps): + if resp is None: + out[addr] = "OK" + elif on_error == "return": + bad_nannies.add(addr) + out[addr] = "error" + else: # pragma: nocover + raise RuntimeError("Exception during worker restart") from resp + + if bad_nannies: + logger.error( + f"Workers {list(bad_nannies)} did not shut down within {timeout}s; " + "force closing" ) - resps = await asyncio.gather( + await asyncio.gather( *( - wait_for( - # FIXME does not raise if the process fails to shut down, - # see https://github.com/dask/distributed/pull/6427/files#r894917424 - # NOTE: Nanny will automatically restart worker process when it's killed - nanny.kill(reason=stimulus_id, timeout=timeout), - timeout, - ) - for nanny in nannies - ), - return_exceptions=True, - ) - # NOTE: the `WorkerState` entries for these workers will be removed - # naturally when they disconnect from the scheduler. - - # Remove any workers that failed to shut down, so we can guarantee - # that after `restart`, there are no old workers around. - bad_nannies = set() - for addr, resp in zip(nanny_workers, resps): - if resp is None: - out[addr] = "OK" - elif isinstance(resp, (OSError, TimeoutError)): - bad_nannies.add(addr) - out[addr] = "timed out" - else: # pragma: nocover - raise resp - - if bad_nannies: - logger.error( - f"Workers {list(bad_nannies)} did not shut down within {timeout}s; " - "force closing" + self.remove_worker(addr, stimulus_id=stimulus_id) + for addr in bad_nannies ) - await asyncio.gather( - *( - self.remove_worker(addr, stimulus_id=stimulus_id) - for addr in bad_nannies - ) + ) + if on_error == "raise": + raise TimeoutError( + f"{len(bad_nannies)}/{len(nanny_workers)} nanny worker(s) did not " + f"shut down within {timeout}s: {bad_nannies}" ) - if on_error == "raise": - raise TimeoutError( - f"{len(bad_nannies)}/{len(nannies)} nanny worker(s) did not " - f"shut down within {timeout}s: {bad_nannies}" - ) if client: self.log_event(client, {"action": "restart-workers", "workers": workers}) @@ -6388,21 +6371,36 @@ async def restart_workers( ) return out - # NOTE: if new (unrelated) workers join while we're waiting, we may return - # before our shut-down workers have come back up. That's fine; workers are - # interchangeable. - while not deadline.expired and len(self.workers) < n_workers: + def _connected_workers(): + return { + ws.address + for ws in self.workers.values() + if ws.status in (Status.running, Status.paused) + } + + while ( + not deadline.expired + and len(_connected_workers()) < n_workers + and not worker_set.intersection(set(self.workers)) + ): await asyncio.sleep(0.2) - if len(self.workers) >= n_workers: + if ( + not no_nanny_workers + and len(_connected_workers()) >= n_workers + and not worker_set.intersection(set(self.workers)) + ): logger.info(f"Workers restart finished ({stimulus_id=}") return out - msg = ( - f"Waited for {len(workers)} worker(s) to reconnect after restarting but, " - f"after {timeout}s, {n_workers - len(self.workers)} have not returned. " - "Consider a longer timeout, or `wait_for_workers=False`." - ) + remaining_workers = n_workers + len(no_nanny_workers) - len(self.workers) + msg = "" + if remaining_workers: + msg += ( + f"Waited for {len(workers)} worker(s) to reconnect after restarting but, " + f"after {timeout}s, {remaining_workers} have not returned. " + "Consider a longer timeout, or `wait_for_workers=False`." + ) if no_nanny_workers: msg += ( f" The {len(no_nanny_workers)} worker(s) not using Nannies were just shut " @@ -6411,7 +6409,7 @@ async def restart_workers( "processes, then those workers will never come back, and `Client.restart` " "will always time out. Do not use `Client.restart` in that case." ) - + assert msg if on_error == "raise": raise TimeoutError(msg) logger.error(f"{msg} ({stimulus_id=})") @@ -6419,7 +6417,7 @@ async def restart_workers( new_nannies = {ws.nanny for ws in self.workers.values() if ws.nanny} for worker_addr, nanny_addr in nanny_workers.items(): if nanny_addr not in new_nannies: - out[worker_addr] = "timed out" + out[worker_addr] = "error" return out @@ -7611,9 +7609,11 @@ def get_processing( def get_who_has(self, keys: Iterable[Key] | None = None) -> dict[Key, list[str]]: if keys is not None: return { - key: [ws.address for ws in self.tasks[key].who_has or ()] - if key in self.tasks - else [] + key: ( + [ws.address for ws in self.tasks[key].who_has or ()] + if key in self.tasks + else [] + ) for key in keys } else: @@ -7628,9 +7628,11 @@ def get_has_what( if workers is not None: workers = map(self.coerce_address, workers) return { - w: [ts.key for ts in self.workers[w].has_what] - if w in self.workers - else [] + w: ( + [ts.key for ts in self.workers[w].has_what] + if w in self.workers + else [] + ) for w in workers } else: diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 09d97fffc9a..05dc70ef43b 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -155,7 +155,7 @@ def get_or_create( self.active_shuffles[spec.id] = state self._shuffles[spec.id].add(state) state.participating_workers.add(worker) - logger.warning( + logger.debug( "Shuffle %s initialized by task %r executed on worker %s", spec.id, key, @@ -426,7 +426,7 @@ def _fail_on_workers(self, shuffle: SchedulerShuffleState, message: str) -> None def _clean_on_scheduler(self, id: ShuffleId, stimulus_id: str) -> None: shuffle = self.active_shuffles.pop(id) - logger.warning("Shuffle %s deactivated due to stimulus '%s'", id, stimulus_id) + logger.debug("Shuffle %s deactivated due to stimulus '%s'", id, stimulus_id) if not shuffle._archived_by: shuffle._archived_by = stimulus_id self._archived_by_stimulus[stimulus_id].add(shuffle) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 0d25f0bfe4a..8c83a8509c0 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -89,7 +89,6 @@ BlockedGatherDep, BlockedGetData, BlockedInstantiateNanny, - BlockedKillNanny, TaskStateMetadataPlugin, _UnhashableCallable, async_poll_for, @@ -3657,7 +3656,8 @@ async def test_Client_clears_references_after_restart(c, s, a, b): with pytest.raises(TimeoutError): await c.restart(timeout=1) - assert x.key not in c.refcount + while x.key in c.refcount: + await asyncio.sleep(0.01) assert not c.futures key = x.key @@ -5037,26 +5037,6 @@ async def test_restart_workers_no_nanny_raises(c, s, a, b): assert a.address in msg -@pytest.mark.slow -@pytest.mark.parametrize("raise_for_error", (True, False)) -@gen_cluster(client=True, nthreads=[("", 1)], Worker=BlockedKillNanny) -async def test_restart_workers_kill_timeout(c, s, a, raise_for_error): - # FIXME a timeout _too_ tight causes the scheduler to hang as wait_for cancels - # the nanny.kill RPC too soon. - kwargs = dict(workers=[a.worker_address], timeout=2) - - if raise_for_error: # default is to raise - with pytest.raises(TimeoutError) as excinfo: - await c.restart_workers(**kwargs) - msg = str(excinfo.value) - assert "1/1 nanny worker(s) did not shut down within 2s" in msg - assert a.worker_address in msg - else: - results = await c.restart_workers(raise_for_error=raise_for_error, **kwargs) - assert results == {a.worker_address: "timed out"} - a.wait_kill.set() - - @pytest.mark.slow @pytest.mark.parametrize("raise_for_error", (True, False)) @gen_cluster(client=True, nthreads=[]) @@ -5080,17 +5060,17 @@ async def test_restart_workers_restart_timeout(c, s, raise_for_error): "after 3s, 1 have not returned" ) in msg else: + worker_addr = a.worker_address results = await c.restart_workers(raise_for_error=raise_for_error, **kwargs) - assert results == {a.worker_address: "timed out"} + assert results == {worker_addr: "error"} -@pytest.mark.slow @gen_cluster(client=True, Worker=Nanny) async def test_restart_workers_exception(c, s, a, b): async def fail_instantiate(*_args, **_kwargs): raise ValueError("broken") - a.instantiate = fail_instantiate + a.start_worker_process = fail_instantiate with captured_logger("distributed.nanny") as log, pytest.raises(TimeoutError): await c.restart_workers(workers=[a.worker_address], timeout=3) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 7c7c667e4b7..2eca028f638 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -215,9 +215,10 @@ def test_worker_doesnt_await_task_completion(loop): @gen_cluster(Worker=Nanny, timeout=60) async def test_multiple_clients_restart(s, a, b): - async with Client(s.address, asynchronous=True) as c1, Client( - s.address, asynchronous=True - ) as c2: + async with ( + Client(s.address, asynchronous=True) as c1, + Client(s.address, asynchronous=True) as c2, + ): x = c1.submit(inc, 1) y = c2.submit(inc, 2) xx = await x @@ -288,10 +289,7 @@ async def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b): async def test_broken_worker_during_computation(c, s, a, b): s.allowed_failures = 100 async with Nanny(s.address, nthreads=2) as n: - start = time() - while len(s.workers) < 3: - await asyncio.sleep(0.01) - assert time() < start + 5 + await c.wait_for_workers(3) N = 256 expected_result = N * (N + 1) // 2 @@ -308,10 +306,12 @@ async def test_broken_worker_during_computation(c, s, a, b): await asyncio.sleep(random.random() / 20) with suppress(CommClosedError): # comm will be closed abrupty await c.run(os._exit, 1, workers=[n.worker_address]) + assert not n.process.is_alive() + while not n.process.is_alive(): + await asyncio.sleep(0.01) + await c.wait_for_workers(3) await asyncio.sleep(random.random() / 20) - while len(s.workers) < 3: - await asyncio.sleep(0.01) with suppress( CommClosedError, EnvironmentError diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 32e86e98e45..6dddebf6be0 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -22,7 +22,7 @@ from distributed import Nanny, Scheduler, Worker, profile, rpc, wait, worker from distributed.compatibility import LINUX, WINDOWS -from distributed.core import CommClosedError, Status, error_message +from distributed.core import CommClosedError, ConnectionPool, Status, error_message from distributed.diagnostics import SchedulerPlugin from distributed.diagnostics.plugin import NannyPlugin, WorkerPlugin from distributed.metrics import time @@ -215,13 +215,7 @@ async def test_scheduler_file(): @gen_cluster(client=True, Worker=Nanny, nthreads=[("127.0.0.1", 2)]) async def test_nanny_timeout(c, s, a): x = await c.scatter(123) - with captured_logger( - logging.getLogger("distributed.nanny"), level=logging.ERROR - ) as logger: - await a.restart(timeout=0.1) - - out = logger.getvalue() - assert "timed out" in out.lower() + await a.restart(timeout=0.1) start = time() while x.status != "cancelled": @@ -460,7 +454,7 @@ async def test_nanny_closes_cleanly(s): async with Nanny(s.address) as n: assert n.process.pid proc = n.process.process - assert not n.process + assert not n.process.is_alive() assert not proc.is_alive() assert proc.exitcode == 0 @@ -576,10 +570,10 @@ async def test_worker_start_exception(s): pass assert nanny.status == Status.failed # ^ NOTE: `Nanny.close` sets it to `closed`, then `Server.start._close_on_failure` sets it to `failed` - assert nanny.process is None + assert not nanny.process.is_alive() assert "Restarting worker" not in logs.getvalue() # Avoid excessive spewing. (It's also printed once extra within the subprocess, which is okay.) - assert logs.getvalue().count("ValueError: broken") == 1, logs.getvalue() + assert logs.getvalue().count("ValueError: broken") == 2, logs.getvalue() @gen_cluster(nthreads=[]) @@ -693,7 +687,7 @@ async def test_close_joins(s): await close_t assert nanny.status == Status.closed - assert not nanny.process + assert not nanny.process.is_alive() assert p.status == Status.stopped assert not p.process @@ -712,7 +706,7 @@ async def test_scheduler_crash_doesnt_restart(s, a): await a.finished() assert a.status == Status.closed - assert a.process is None + assert not a.process.is_alive() @pytest.mark.slow @@ -875,7 +869,7 @@ def __init__(self, *args, **kwargs): self.in_instantiate = asyncio.Event() self.wait_instantiate = asyncio.Event() - async def instantiate(self): + async def start_worker_process(self): self.in_instantiate.set() await self.wait_instantiate.wait() raise RuntimeError("Nope") @@ -905,10 +899,10 @@ def __init__(self, *args, in_instantiate, wait_instantiate, **kwargs): self.in_instantiate = in_instantiate self.wait_instantiate = wait_instantiate - async def instantiate(self): + async def start_worker_process(self): self.in_instantiate.set() self.wait_instantiate.wait() - return await super().instantiate() + return await super().start_worker_process() def run_nanny(scheduler_addr, in_instantiate, wait_instantiate): @@ -945,3 +939,55 @@ async def test_nanny_plugin_register_nanny_killed(c, s, restart): finally: proc.kill() assert await register == {} + + +@pytest.mark.parametrize("api", ["restart", "kill"]) +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_restart_stress(c, s, a, api): + async def keep_killing(): + pool = await ConnectionPool() + try: + rpc = pool(a.address) + for _ in range(2): + try: + import uuid + + meth = getattr(rpc, api) + uid = uuid.uuid4().hex + print(api, uid) + await meth(reason="scheduler-restart") + print(api, uid, "done") + + except OSError: + break + + await asyncio.sleep(0.1) + finally: + print("closing pool") + await pool.close() + + kill_tasks = [asyncio.create_task(keep_killing()) for _ in range(2)] + await asyncio.gather(*kill_tasks) + assert a.status == Status.running + + +@pytest.mark.parametrize( + "api", + [ + "kill", + "restart", + ], +) +@gen_cluster(nthreads=[]) +async def test_worker_start_exception_after_restart(s, api): + async with Nanny(s.address, death_timeout="2s") as nanny: + # Stop the listener on the scheduler, i.e. do not allow any new incoming + # connections. The restarting workers will fail while trying to attempt + # connection + s.stop() + if api == "kill": + await nanny.kill() + else: + await nanny.restart() + await nanny.finished() + assert nanny.status in (Status.failed, Status.closed) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 68ccfbfccbd..f0aefc61d66 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -35,7 +35,6 @@ Future, Lock, Nanny, - SchedulerPlugin, Worker, fire_and_forget, wait, @@ -53,7 +52,6 @@ NO_AMM, BlockedGatherDep, BlockedGetData, - BlockedKillNanny, BrokenComm, NoSchedulerDelayWorker, assert_story, @@ -1133,39 +1131,6 @@ async def test_restart_waits_for_new_workers(c, s, *workers): assert set(s.workers.values()).isdisjoint(original_workers.values()) -@pytest.mark.slow -@gen_cluster(client=True, Worker=BlockedKillNanny, nthreads=[("", 1)] * 2) -async def test_restart_nanny_timeout_exceeded(c, s, a, b): - try: - f = c.submit(div, 1, 0) - fr = c.submit(inc, 1, resources={"FOO": 1}) - await wait(f) - assert s.erred_tasks - assert s.computations - assert s.unrunnable - assert s.tasks - - with pytest.raises( - TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s" - ): - await c.restart(timeout="1s") - assert a.in_kill.is_set() - assert b.in_kill.is_set() - - assert not s.workers - assert not s.erred_tasks - assert not s.computations - assert not s.unrunnable - assert not s.tasks - - assert not c.futures - assert f.status == "cancelled" - assert fr.status == "cancelled" - finally: - a.wait_kill.set() - b.wait_kill.set() - - @gen_cluster(client=True, nthreads=[("", 1)] * 2) async def test_restart_not_all_workers_return(c, s, a, b): with pytest.raises(TimeoutError, match="Waited for 2 worker"): @@ -1176,43 +1141,6 @@ async def test_restart_not_all_workers_return(c, s, a, b): assert b.status in (Status.closed, Status.closing) -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_restart_worker_rejoins_after_timeout_expired(c, s, a): - """ - We don't want to see an error message like: - - ``Waited for 1 worker(s) to reconnect after restarting, but after 0s, 0 have not returned.`` - - If a worker rejoins after our last poll for new workers, but before we raise the error, - we shouldn't raise the error. - """ - # We'll use a 0s timeout on the restart, so it always expires. - # And we'll use a plugin to block the restart process, and spin up a new worker - # in the middle of it. - - class Plugin(SchedulerPlugin): - removed = asyncio.Event() - proceed = asyncio.Event() - - async def remove_worker(self, *args, **kwargs): - self.removed.set() - await self.proceed.wait() - - s.add_plugin(Plugin()) - - task = asyncio.create_task(c.restart(timeout=0)) - await Plugin.removed.wait() - assert not s.workers - - async with Worker(s.address, nthreads=1): - assert len(s.workers) == 1 - Plugin.proceed.set() - - # New worker has joined, but the timeout has expired (since it was 0). - # Still, we should not time out. - await task - - @gen_cluster(client=True, nthreads=[("", 1)] * 2) async def test_restart_no_wait_for_workers(c, s, a, b): await c.restart(timeout="1s", wait_for_workers=False) @@ -1232,7 +1160,6 @@ async def test_restart_some_nannies_some_not(c, s, a, b): async with Worker(s.address, nthreads=1) as w: await c.wait_for_workers(3) - # FIXME how to make this not always take 20s if the nannies do restart quickly? with pytest.raises(TimeoutError, match=r"The 1 worker\(s\) not using Nannies"): await c.restart(timeout="20s") @@ -1243,6 +1170,18 @@ async def test_restart_some_nannies_some_not(c, s, a, b): assert w.address not in s.workers +class BlockedKillNanny(Nanny): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.in_kill = asyncio.Event() + self.wait_kill = asyncio.Event() + + async def kill(self, **kwargs): + self.in_kill.set() + await self.wait_kill.wait() + return await super().kill(**kwargs) + + @pytest.mark.slow @gen_cluster( client=True, diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 63a7488efea..6c7df5df36e 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2402,22 +2402,10 @@ def __init__(self, *args, **kwargs): self.in_instantiate = asyncio.Event() self.wait_instantiate = asyncio.Event() - async def instantiate(self): + async def start_worker_process(self): self.in_instantiate.set() await self.wait_instantiate.wait() - return await super().instantiate() - - -class BlockedKillNanny(Nanny): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.in_kill = asyncio.Event() - self.wait_kill = asyncio.Event() - - async def kill(self, **kwargs): - self.in_kill.set() - await self.wait_kill.wait() - return await super().kill(**kwargs) + return await super().start_worker_process() async def wait_for_state( diff --git a/distributed/worker.py b/distributed/worker.py index cd7f60efea1..f88eadfd1f2 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1225,8 +1225,8 @@ async def _register_with_scheduler(self) -> None: raise ValueError(f"Unexpected response from register: {response!r}") self.batched_stream.start(comm) - self.status = Status.running + self.status = Status.running await asyncio.gather( *( self.plugin_add(name=name, plugin=plugin) @@ -1554,7 +1554,6 @@ async def close( # type: ignore # This also informs the scheduler about the status update self.status = Status.closing setproctitle("dask worker [closing]") - if nanny and self.nanny: with self.rpc(self.nanny) as r: await r.close_gracefully(reason=reason) @@ -2630,7 +2629,7 @@ def get_current_task(self) -> Key: return self.active_threads[threading.get_ident()] def _handle_remove_worker(self, worker: str, stimulus_id: str) -> None: - self.rpc.remove(worker) + self.rpc.remove(worker, reason=stimulus_id) self.handle_stimulus(RemoveWorkerEvent(worker=worker, stimulus_id=stimulus_id)) def validate_state(self) -> None: From 5c2279af264b16240f9605c2f042f5f466bd9da9 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 12 Apr 2024 11:11:14 +0200 Subject: [PATCH 2/5] deal with signal handling --- distributed/cli/dask_worker.py | 6 ++- distributed/nanny.py | 65 +++++++++++++----------- distributed/tests/test_failed_workers.py | 7 ++- distributed/tests/test_nanny.py | 1 + distributed/worker.py | 2 +- 5 files changed, 48 insertions(+), 33 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 8ee87d0fee2..ef149ba6fbc 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -417,7 +417,11 @@ async def run(): async def wait_for_nannies_to_finish(): """Wait for all nannies to initialize and finish""" - await asyncio.gather(*nannies) + try: + await asyncio.gather(*nannies) + except Exception: + if not signal_fired: + raise await asyncio.gather(*(n.finished() for n in nannies)) async def wait_for_signals_and_close(): diff --git a/distributed/nanny.py b/distributed/nanny.py index 1e980767e18..b7094f3a3e9 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -366,6 +366,7 @@ async def start_unsafe(self): self._reconciler(), name="WorkerProcess reconciler" ) + assert self.worker_address, self.worker_address for name, plugin in msg["nanny-plugins"].items(): await self.plugin_add(plugin=plugin, name=name) except Exception: @@ -382,8 +383,6 @@ async def start_unsafe(self): finally: await comm.close() - assert self.worker_address - self.start_periodic_callbacks() return self @@ -511,7 +510,10 @@ async def successful_instantiation(self): finally: response = await self._reconciler_queue.get() if response["status"] != Status.running: - raise response["exception"] + if "exception" in response: + raise response["exception"] + else: + raise RuntimeError("Worker failed to start") async def restart( self, timeout: float = 30, reason: str = "nanny-restart" @@ -566,7 +568,6 @@ async def _on_worker_exit(self, exitcode): if self.status in (Status.starting, Status.running): await self._unregister(old_proc) except OSError: - self.status = Status.failed logger.exception("Failed to unregister") except asyncio.CancelledError: # Can happen during teardown. @@ -599,9 +600,12 @@ async def close( # type:ignore[override] """ Close the worker process, stop all comms. """ + if self.status == Status.starting: + await self + assert self.status in (Status.running, Status.failed) if self.status == Status.closing: await self.finished() - assert self.status == Status.closed + assert self.status in (Status.closed, Status.failed), self.status if self.status == Status.closed: return "OK" @@ -699,6 +703,9 @@ def __init__( except ValueError: pass + self.process_up = asyncio.Event() + self.running = asyncio.Event() + self.stopped = asyncio.Event() # Initialized when worker is ready self.worker_dir = None self.worker_address = None @@ -741,38 +748,36 @@ async def start(self) -> Status: ) self.process.daemon = dask.config.get("distributed.worker.daemon", default=True) self.process.set_exit_callback(self._on_exit) - self.running = asyncio.Event() - self.process_up = asyncio.Event() - self.stopped = asyncio.Event() self.status = Status.starting # Set selected environment variables before spawning the subprocess. # See note in Nanny docstring. os.environ.update(self.pre_spawn_env) - try: - await self.process.start() - except OSError: - logger.exception("Nanny failed to start process", exc_info=True) - # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent - await self.process.terminate() - self.status = Status.failed + try: + await self.process.start() + except OSError: + logger.exception("Nanny failed to start process", exc_info=True) + # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent + await self.process.terminate() + self.status = Status.failed + finally: + self.process_up.set() + try: + msg = await self._wait_until_connected(uid) + except Exception: + # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent + await self.process.terminate() + self.status = Status.failed + raise + if not msg: + return self.status + self.worker_address = msg["address"] + self.worker_dir = msg["dir"] + assert self.worker_address + self.status = Status.running finally: - self.process_up.set() - try: - msg = await self._wait_until_connected(uid) - except Exception: - # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent - await self.process.terminate() - self.status = Status.failed - raise - if not msg: - return self.status - self.worker_address = msg["address"] - self.worker_dir = msg["dir"] - assert self.worker_address - self.status = Status.running - self.running.set() + self.running.set() return self.status diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 2eca028f638..5dd23be057d 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -303,11 +303,16 @@ async def test_broken_worker_during_computation(c, s, a, b): key=["add-%d-%d" % (i, j) for j in range(len(L) // 2)], ) + old_worker_address = n.worker_address await asyncio.sleep(random.random() / 20) with suppress(CommClosedError): # comm will be closed abrupty await c.run(os._exit, 1, workers=[n.worker_address]) assert not n.process.is_alive() - while not n.process.is_alive(): + while ( + not n.process.is_alive() + or n.worker_address == old_worker_address + or n.worker_address is None + ): await asyncio.sleep(0.01) await c.wait_for_workers(3) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 6dddebf6be0..0dba0dab2f0 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -971,6 +971,7 @@ async def keep_killing(): assert a.status == Status.running +@pytest.mark.slow @pytest.mark.parametrize( "api", [ diff --git a/distributed/worker.py b/distributed/worker.py index f88eadfd1f2..57be631de36 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1693,7 +1693,7 @@ async def close_gracefully( async def wait_until_closed(self): warnings.warn("wait_until_closed has moved to finished()") await self.finished() - assert self.status == Status.closed + assert self.status == Status.closed, self.status ################ # Worker Peers # From 70d1684707e6565cec644ef36c2de20effc3e4f4 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 12 Apr 2024 12:37:16 +0200 Subject: [PATCH 3/5] fix --- distributed/nanny.py | 11 +++-------- distributed/tests/test_failed_workers.py | 2 +- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index b7094f3a3e9..ba356e99375 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -439,11 +439,6 @@ async def start_worker_process(self) -> dict[str, Status]: try: result = await wait_for(self.process.start(), self.death_timeout) except asyncio.TimeoutError: - logger.error( - "Timed out connecting Nanny '%s' to scheduler '%s'", - self, - self.scheduler_addr, - ) raise except Exception: logger.error("Failed to start process", exc_info=True) @@ -573,6 +568,8 @@ async def _on_worker_exit(self, exitcode): # Can happen during teardown. pass finally: + # NOTE: This causes the new worker to only be started once the old + # one is unregistered self._proc_changed.set() @property @@ -600,12 +597,10 @@ async def close( # type:ignore[override] """ Close the worker process, stop all comms. """ - if self.status == Status.starting: - await self - assert self.status in (Status.running, Status.failed) if self.status == Status.closing: await self.finished() assert self.status in (Status.closed, Status.failed), self.status + return "OK" if self.status == Status.closed: return "OK" diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 5dd23be057d..af938f44fcf 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -307,7 +307,7 @@ async def test_broken_worker_during_computation(c, s, a, b): await asyncio.sleep(random.random() / 20) with suppress(CommClosedError): # comm will be closed abrupty await c.run(os._exit, 1, workers=[n.worker_address]) - assert not n.process.is_alive() + while ( not n.process.is_alive() or n.worker_address == old_worker_address From e2339d0c1a01f59e445e154778ae8423ff05065f Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 12 Jul 2024 18:02:52 +0200 Subject: [PATCH 4/5] fix test_worker_ttl_restarts_worker --- distributed/scheduler.py | 8 +++- distributed/tests/test_failed_workers.py | 48 +++++++++++++++--------- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6a5540db910..8fde6c23ba5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8473,12 +8473,18 @@ async def check_worker_ttl(self) -> None: f"Worker failed to heartbeat for {last_seen:.0f}s; " f"{'attempting restart' if ws.nanny else 'removing'}: {ws}" ) - if to_restart: await self.restart_workers( to_restart, wait_for_workers=False, stimulus_id=stimulus_id, + # At this point we gave up on the worker, no reason to add an + # artificial timeout. Kill it right away. + # FIXME: Setting this to zero somehow causes things to hang. + timeout=0.1, + ) + self.log_event( + "all", {"action": "worker-ttl-restart", "workers": to_restart.copy()} ) def check_idle(self) -> float | None: diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 04ceeac42c5..43de7621b05 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -17,7 +17,6 @@ from distributed import Client, KilledWorker, Nanny, get_worker, profile, wait from distributed.comm import CommClosedError from distributed.compatibility import MACOS -from distributed.core import Status from distributed.metrics import time from distributed.utils import CancelledError, sync from distributed.utils_test import ( @@ -477,48 +476,61 @@ async def test_worker_time_to_live(c, s, a, b): @pytest.mark.slow -@pytest.mark.parametrize("block_evloop", [False, True]) +@pytest.mark.parametrize("block_on", [None, "event_loop", "threadpool"]) @gen_cluster( client=True, Worker=Nanny, nthreads=[("", 1)], - scheduler_kwargs={"worker_ttl": "500ms", "allowed_failures": 0}, + scheduler_kwargs={"worker_ttl": "100ms", "allowed_failures": 0}, ) -async def test_worker_ttl_restarts_worker(c, s, a, block_evloop): +async def test_worker_ttl_restarts_worker(c, s, a, block_on, monkeypatch): """If the event loop of a worker becomes completely unresponsive, the scheduler will restart it through the nanny. """ - ws = s.workers[a.worker_address] + ev = asyncio.Event() - async def f(): + def wait_for_restart(event): + _, msg = event + if msg.get("action") == "worker-ttl-restart": + ev.set() + + c.subscribe_topic("all", wait_for_restart) + + def f(): w = get_worker() w.periodic_callbacks["heartbeat"].stop() - if block_evloop: - sleep(9999) # Block event loop indefinitely - else: - await asyncio.sleep(9999) + if block_on is None: + return + elif block_on == "event_loop": + + async def _(): + sleep(9999) # Block event loop indefinitely + + w.loop.add_callback(_) + elif block_on == "threadpool": + sleep(9999) fut = c.submit(f, key="x") - while not s.workers or ( - (new_ws := next(iter(s.workers.values()))) is ws - or new_ws.status != Status.running - ): - await asyncio.sleep(0.01) + # TTL is set to at least 10 heartbeats + import distributed.scheduler + + monkeypatch.setattr(distributed.scheduler, "heartbeat_interval", lambda n: 0.001) + + await ev.wait() - if block_evloop: + if block_on: # The nanny killed the worker with SIGKILL. # The restart has increased the suspicious count. with pytest.raises(KilledWorker): await fut - assert s.tasks["x"].state == "erred" assert s.tasks["x"].suspicious == 1 else: # The nanny sent to the WorkerProcess a {op: stop} through IPC, which in turn # successfully invoked Worker.close(nanny=False). # This behaviour makes sense as the worker-ttl timeout was most likely caused # by a failure in networking, rather than a hung process. - assert s.tasks["x"].state == "processing" + await fut assert s.tasks["x"].suspicious == 0 From 014ed0753ac48db667ad9fb154163b46a37a5608 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 12 Jul 2024 18:36:00 +0200 Subject: [PATCH 5/5] ignore leaked threads --- distributed/tests/test_failed_workers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 43de7621b05..3c2e08ee2b5 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -482,6 +482,7 @@ async def test_worker_time_to_live(c, s, a, b): Worker=Nanny, nthreads=[("", 1)], scheduler_kwargs={"worker_ttl": "100ms", "allowed_failures": 0}, + clean_kwargs={"threads": False}, ) async def test_worker_ttl_restarts_worker(c, s, a, block_on, monkeypatch): """If the event loop of a worker becomes completely unresponsive, the scheduler will