diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 9070024c430..e67fa06d23a 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -1,5 +1,6 @@ import asyncio import atexit +import functools import logging import gc import os @@ -13,7 +14,7 @@ from dask.system import CPU_COUNT from distributed import Nanny, Worker from distributed.security import Security -from distributed.cli.utils import check_python_3, install_signal_handlers +from distributed.cli.utils import check_python_3 from distributed.comm import get_address_host_port from distributed.preloading import validate_preload_argv from distributed.proctitle import ( @@ -386,26 +387,32 @@ def del_pid_file(): for i in range(nprocs) ] - async def close_all(): - # Unregister all workers from scheduler - if nanny: - await asyncio.gather(*[n.close(timeout=2) for n in nannies]) - signal_fired = False - def on_signal(signum): + async def _on_signal(signum): nonlocal signal_fired - signal_fired = True - if signum != signal.SIGINT: + from distributed.utils import log_errors + + with log_errors(): logger.info("Exiting on signal %d", signum) - asyncio.ensure_future(close_all()) + signal_fired = True + if signum == signal.SIGINT: + logger.info("Gracefully closing worker because of SIGINT call") + await asyncio.gather(*[n.close_gracefully() for n in nannies]) + logger.info("Closing workers") + await asyncio.gather(*[n.close() for n in nannies]) + + def on_signal(sig): + asyncio.ensure_future(_on_signal(sig)) async def run(): await asyncio.gather(*nannies) await asyncio.gather(*[n.finished() for n in nannies]) - install_signal_handlers(loop, cleanup=on_signal) - + for sig in [signal.SIGINT, signal.SIGTERM]: + asyncio.get_event_loop().add_signal_handler( + sig, functools.partial(on_signal, sig) + ) try: loop.run_sync(run) except TimeoutError: diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index c509772d113..9d488e2c4a7 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -1,4 +1,6 @@ import asyncio +import signal + import pytest from click.testing import CliRunner @@ -10,10 +12,11 @@ from time import sleep import distributed.cli.dask_worker -from distributed import Client, Scheduler +from distributed import Client, Scheduler, Worker, wait +from distributed.compatibility import WINDOWS from distributed.metrics import time from distributed.utils import sync, tmpfile -from distributed.utils_test import popen, terminate_process, wait_for_port +from distributed.utils_test import popen, terminate_process, wait_for_port, slowinc from distributed.utils_test import loop, cleanup # noqa: F401 @@ -47,6 +50,36 @@ def test_nanny_worker_ports(loop): ) +@pytest.mark.skipif(WINDOWS, reason="Not supported on Windows") +@pytest.mark.asyncio +async def test_sigint(cleanup): + async with Scheduler(port=0) as s: + with popen(["dask-worker", s.address, "--name", "alice"]) as worker: + async with Client(s.address, asynchronous=True) as c: + async with Worker(s.address) as w: + await c.wait_for_workers(2) + a, b = s.workers.values() + scattered = await asyncio.gather( + c.scatter(list(range(0, 10)), workers=[a.address]), + c.scatter(list(range(10, 20)), workers=[b.address]), + ) + scattered = scattered[0] + scattered[1] + assert a.has_what and b.has_what + + submitted = c.map(slowinc, range(10), delay=0.05) + await asyncio.sleep(0.10) + + worker.send_signal(signal.SIGINT) + while len(s.workers) > 1: + await asyncio.sleep(0.01) + + await asyncio.sleep(0.5) + + await wait(submitted) + assert all(future.status == "finished" for future in scattered) + assert all(future.status == "finished" for future in submitted) + + def test_memory_limit(loop): with popen(["dask-scheduler", "--no-dashboard"]) as sched: with popen( diff --git a/distributed/nanny.py b/distributed/nanny.py index dc2e8a3ea48..946b7b83b6d 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -170,7 +170,7 @@ def __init__( # cannot call it 'close' on the rpc side for naming conflict "get_logs": self.get_logs, "terminate": self.close, - "close_gracefully": self.close_gracefully, + "close_gracefully": self.close_gracefully_signal, "run": self.run, } @@ -423,7 +423,7 @@ def _close(self, *args, **kwargs): warnings.warn("Worker._close has moved to Worker.close", stacklevel=2) return self.close(*args, **kwargs) - def close_gracefully(self, comm=None): + def close_gracefully_signal(self, comm=None): """ A signal that we shouldn't try to restart workers if they go away @@ -431,6 +431,13 @@ def close_gracefully(self, comm=None): """ self.status = "closing-gracefully" + async def close_gracefully(self): + try: + await self.rpc(self.worker_address).close_gracefully() + except CommClosedError: # worker will have closed connection + pass + await self.close() + async def close(self, comm=None, timeout=5, report=None): """ Close the worker process, stop all comms. diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index cacd98477e0..be399f38ad7 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -511,3 +511,14 @@ async def test_config(cleanup): async with Client(s.address, asynchronous=True) as client: config = await client.run(dask.config.get, "foo") assert config[n.worker_address] == "bar" + + +@gen_cluster(client=True, Worker=Nanny) +async def test_close_gracefully(c, s, a, b): + futures = await c.scatter(list(range(10))) + assert all(ws.has_what for ws in s.workers.values()) + + await a.close_gracefully() + assert a.status == "closed" + assert len(s.workers) == 1 + assert all(f.status == "finished" for f in futures) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index aef7bde8eee..680b3835bb4 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -982,7 +982,7 @@ def terminate_process(proc): if sys.platform.startswith("win"): proc.send_signal(signal.CTRL_BREAK_EVENT) else: - proc.send_signal(signal.SIGINT) + proc.send_signal(signal.SIGTERM) try: if sys.version_info[0] == 3: proc.wait(10) diff --git a/distributed/worker.py b/distributed/worker.py index 90059002e9a..897af8c19a2 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2,6 +2,7 @@ import bisect from collections import defaultdict, deque, namedtuple from collections.abc import MutableMapping +import concurrent.futures from datetime import timedelta import heapq from inspect import isawaitable @@ -619,6 +620,7 @@ def __init__( "actor_execute": self.actor_execute, "actor_attribute": self.actor_attribute, "plugin-add": self.plugin_add, + "close_gracefully": self.close_gracefully, } stream_handlers = { @@ -1103,27 +1105,29 @@ async def close( self.rpc.close() self.status = "closed" - await ServerNode.close(self) + with ignoring(concurrent.futures.CancelledError): + await ServerNode.close(self) setproctitle("dask-worker [closed]") return "OK" - async def close_gracefully(self): + async def close_gracefully(self, comm=None): """ Gracefully shut down a worker This first informs the scheduler that we're shutting down, and asks it to move our data elsewhere. Afterwards, we close as normal """ - if self.status.startswith("closing"): - await self.finished() + with log_errors(): + if self.status.startswith("closing"): + await self.finished() - if self.status == "closed": - return + if self.status == "closed": + return - logger.info("Closing worker gracefully: %s", self.address) - self.status = "closing-gracefully" - await self.scheduler.retire_workers(workers=[self.address], remove=False) - await self.close(safe=True, nanny=not self.lifetime_restart) + logger.info("Closing worker gracefully: %s", self.address) + self.status = "closing-gracefully" + await self.scheduler.retire_workers(workers=[self.address], remove=False) + await self.close(safe=True, nanny=not self.lifetime_restart) async def terminate(self, comm, report=True, **kwargs): await self.close(report=report, **kwargs)