From e2295c0abbeda654b39c1a70dfd95135d22acab2 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 3 May 2023 14:50:26 +0100 Subject: [PATCH 01/11] fix flaky test_single_executable_deprecated convert test_single_executable_deprecated into a sync test wait_for_log_line was intermittently blocking the handshake from occuring adding the extra asyncio.gather made this more likely --- distributed/cli/tests/test_dask_worker.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index ff2396f7031..2b342a1403b 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -723,11 +723,11 @@ def test_error_during_startup(monkeypatch, nanny, loop): assert worker.wait(10) == 1 -@gen_cluster(nthreads=[], client=True) -async def test_single_executable_deprecated(c, s): - with popen(["dask-worker", s.address], capture_output=True) as worker: - # ensure deprecation warning is emitted - wait_for_log_line(b"FutureWarning: dask-worker is deprecated", worker.stdout) +def test_single_executable_deprecated(): + assert ( + b"FutureWarning: dask-worker is deprecated" + in subprocess.run(["dask-worker"], capture_output=True).stderr + ) @pytest.mark.slow From 1dccdb02377ccb9dea2de8c9b1d5825a31c4e2c0 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 22 Mar 2023 17:41:10 +0100 Subject: [PATCH 02/11] Remove handshake from timeout --- distributed/comm/core.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 97f4b7d5b44..5645f289f3f 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -264,14 +264,12 @@ async def on_connection( ) -> None: local_info = {**comm.handshake_info(), **(handshake_overrides or {})} - timeout = dask.config.get("distributed.comm.timeouts.connect") - timeout = parse_timedelta(timeout, default="seconds") try: # Timeout is to ensure that we'll terminate connections eventually. # Connector side will employ smaller timeouts and we should only # reach this if the comm is dead anyhow. - await wait_for(comm.write(local_info), timeout=timeout) - handshake = await wait_for(comm.read(), timeout=timeout) + await comm.write(local_info) + handshake = await comm.read() # This would be better, but connections leak if worker is closed quickly # write, handshake = await asyncio.gather(comm.write(local_info), comm.read()) except Exception as e: @@ -370,8 +368,8 @@ def time_left(): try: # This would be better, but connections leak if worker is closed quickly # write, handshake = await asyncio.gather(comm.write(local_info), comm.read()) - handshake = await wait_for(comm.read(), time_left()) - await wait_for(comm.write(local_info), time_left()) + handshake = await comm.read() + await comm.write(local_info) except Exception as exc: with suppress(Exception): await comm.close() From c6af0aecc940788003ffcd03033b33653b183852 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 22 Mar 2023 17:46:25 +0100 Subject: [PATCH 03/11] simplify connection handshake --- distributed/comm/core.py | 27 +++------------------------ distributed/comm/tests/test_comms.py | 9 ++++----- 2 files changed, 7 insertions(+), 29 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 5645f289f3f..3e1b30fe308 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -7,7 +7,6 @@ import sys import weakref from abc import ABC, abstractmethod -from contextlib import suppress from typing import Any, ClassVar import dask @@ -264,18 +263,8 @@ async def on_connection( ) -> None: local_info = {**comm.handshake_info(), **(handshake_overrides or {})} - try: - # Timeout is to ensure that we'll terminate connections eventually. - # Connector side will employ smaller timeouts and we should only - # reach this if the comm is dead anyhow. - await comm.write(local_info) - handshake = await comm.read() - # This would be better, but connections leak if worker is closed quickly - # write, handshake = await asyncio.gather(comm.write(local_info), comm.read()) - except Exception as e: - with suppress(Exception): - await comm.close() - raise CommClosedError(f"Comm {comm!r} closed.") from e + await comm.write(local_info) + handshake = await comm.read() comm.remote_info = handshake comm.remote_info["address"] = comm.peer_address @@ -365,17 +354,7 @@ def time_left(): **comm.handshake_info(), **(handshake_overrides or {}), } - try: - # This would be better, but connections leak if worker is closed quickly - # write, handshake = await asyncio.gather(comm.write(local_info), comm.read()) - handshake = await comm.read() - await comm.write(local_info) - except Exception as exc: - with suppress(Exception): - await comm.close() - raise OSError( - f"Timed out during handshake while connecting to {addr} after {timeout} s" - ) from exc + _, handshake = await asyncio.gather(comm.write(local_info), comm.read()) comm.remote_info = handshake comm.remote_info["address"] = comm._peer_addr diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 93d8aa60789..f815d1ecb1e 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -962,6 +962,7 @@ class UnreliableBackend(tcp.TCPBackend): listener.stop() +@pytest.mark.slow @gen_test() async def test_handshake_slow_comm(tcp, monkeypatch): class SlowComm(tcp.TCP): @@ -996,11 +997,9 @@ def get_connector(self): import dask - with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): - with pytest.raises( - IOError, match="Timed out during handshake while connecting to" - ): - await connect(listener.contact_address) + # The connect itself is fast. Only the handshake is slow + with dask.config.set({"distributed.comm.timeouts.connect": "500ms"}): + await connect(listener.contact_address) finally: listener.stop() From 865bfb2b6b5abd5601c6ad6ab583b84e951c7011 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Wed, 22 Mar 2023 19:02:51 +0100 Subject: [PATCH 04/11] Update distributed/comm/core.py Co-authored-by: Thomas Grainger --- distributed/comm/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 3e1b30fe308..918d7de74f8 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -263,8 +263,7 @@ async def on_connection( ) -> None: local_info = {**comm.handshake_info(), **(handshake_overrides or {})} - await comm.write(local_info) - handshake = await comm.read() + _, handshake = await asyncio.gather(comm.write(local_info), comm.read()) comm.remote_info = handshake comm.remote_info["address"] = comm.peer_address From a445ec390a3e3e792794f4c5227060d14f60b979 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 3 May 2023 14:13:29 +0100 Subject: [PATCH 05/11] remove distributed/tests/test_client.py::test_client_active_bad_port --- distributed/tests/test_client.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index b37cea6e810..655d8d9583f 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5881,22 +5881,6 @@ async def test_client_timeout_2(): assert stop - start < 1 -@gen_test() -async def test_client_active_bad_port(): - import tornado.httpserver - import tornado.web - - application = tornado.web.Application([(r"/", tornado.web.RequestHandler)]) - http_server = tornado.httpserver.HTTPServer(application) - http_server.listen(8080) - with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): - c = Client("127.0.0.1:8080", asynchronous=True) - with pytest.raises((TimeoutError, IOError)): - async with c: - pass - http_server.stop() - - @pytest.mark.parametrize("direct", [True, False]) @gen_cluster(client=True, client_kwargs={"serializers": ["dask", "msgpack"]}) async def test_turn_off_pickle(c, s, a, b, direct): From d7214e252c5770b8c6b667e561c6e220587e9705 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 3 May 2023 15:58:18 +0100 Subject: [PATCH 06/11] remove asyncio.gather from handshake `await comm.write()` doesn't actually yield to the event loop (unless to_frames uses a thread) and so introducing `asyncio.gather` actually makes the handshake slower, which seems to be introducing more flaky test failures --- distributed/comm/core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 918d7de74f8..24d117463a7 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -263,7 +263,8 @@ async def on_connection( ) -> None: local_info = {**comm.handshake_info(), **(handshake_overrides or {})} - _, handshake = await asyncio.gather(comm.write(local_info), comm.read()) + await comm.write(local_info) + handshake = await comm.read() comm.remote_info = handshake comm.remote_info["address"] = comm.peer_address @@ -353,7 +354,8 @@ def time_left(): **comm.handshake_info(), **(handshake_overrides or {}), } - _, handshake = await asyncio.gather(comm.write(local_info), comm.read()) + await comm.write(local_info) + handshake = await comm.read() comm.remote_info = handshake comm.remote_info["address"] = comm._peer_addr From 08353e1c6e3e5006af4c80eb5cdd78bfae160bcb Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 15 Jun 2023 11:23:05 +0100 Subject: [PATCH 07/11] abort connections that have not completed handshakes on Listener.stop --- distributed/comm/asyncio_tcp.py | 6 ++++-- distributed/comm/core.py | 22 +++++++++++++++++++++- distributed/comm/inproc.py | 6 ++++-- distributed/comm/tcp.py | 6 ++++-- distributed/comm/ucx.py | 6 ++++-- distributed/comm/ws.py | 6 ++++-- 6 files changed, 41 insertions(+), 11 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 22270d41d3d..f2419d354b2 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -17,7 +17,7 @@ import dask from distributed.comm.addressing import parse_host_port, unparse_host_port -from distributed.comm.core import Comm, CommClosedError, Connector, Listener +from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector from distributed.comm.registry import Backend from distributed.comm.utils import ( ensure_concrete_host, @@ -594,7 +594,7 @@ def _get_extra_kwargs(self, address: str, **kwargs: Any) -> dict[str, Any]: return {"ssl": ctx} -class TCPListener(Listener): +class TCPListener(BaseListener): prefix = "tcp://" comm_class = TCP @@ -608,6 +608,7 @@ def __init__( default_port=0, **kwargs, ): + super().__init__() self.ip, self.port = parse_host_port(address, default_port) self.default_host = default_host self.comm_handler = comm_handler @@ -733,6 +734,7 @@ def stop(self) -> None: # Stop listening for server in self._servers: server.close() + super().stop() def get_host_port(self): """ diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 24d117463a7..faceb1aea4c 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -220,7 +220,7 @@ async def start(self): """ @abstractmethod - def stop(self): + def stop(self) -> None: """ Stop listening. This does not shutdown already established communications, but prevents accepting new ones. @@ -276,6 +276,26 @@ async def on_connection( ) +class BaseListener(Listener): + def __init__(self) -> None: + self.__comms: set[Comm] = set() + + async def on_connection( + self, comm: Comm, handshake_overrides: dict[str, Any] | None = None + ) -> None: + self.__comms.add(comm) + try: + return await super().on_connection(comm, handshake_overrides) + finally: + self.__comms.discard(comm) + + def stop(self) -> None: + comms, self.__comms = self.__comms, set() + for comm in comms: + comm.abort() + super().stop() + + class Connector(ABC): @abstractmethod async def connect(self, address, deserialize=True): diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 5991217539f..02d26cd965e 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -11,7 +11,7 @@ from tornado.concurrent import Future from tornado.ioloop import IOLoop -from distributed.comm.core import Comm, CommClosedError, Connector, Listener +from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector from distributed.comm.registry import Backend, backends from distributed.protocol import nested_deserialize from distributed.utils import get_ip @@ -257,10 +257,11 @@ def closed(self): return False -class InProcListener(Listener): +class InProcListener(BaseListener): prefix = "inproc" def __init__(self, address, comm_handler, deserialize=True): + super().__init__() self.manager = global_manager self.address = address or self.manager.new_address() self.comm_handler = comm_handler @@ -303,6 +304,7 @@ async def start(self): def stop(self): self.listen_q.put_nowait(None) self.manager.remove_listener(self.address) + super().stop() @property def listen_address(self): diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index bc6bc2ac54a..306006b9f99 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -24,11 +24,11 @@ from distributed.comm.addressing import parse_host_port, unparse_host_port from distributed.comm.core import ( + BaseListener, Comm, CommClosedError, Connector, FatalCommClosedError, - Listener, ) from distributed.comm.registry import Backend from distributed.comm.utils import ( @@ -539,7 +539,7 @@ def _get_connect_args(self, **connection_args): return tls_args -class BaseTCPListener(Listener, RequireEncryptionMixin): +class BaseTCPListener(BaseListener, RequireEncryptionMixin): def __init__( self, address, @@ -550,6 +550,7 @@ def __init__( default_port=0, **connection_args, ): + super().__init__() self._check_encryption(address, connection_args) self.ip, self.port = parse_host_port(address, default_port) self.default_host = default_host @@ -590,6 +591,7 @@ def stop(self): tcp_server, self.tcp_server = self.tcp_server, None if tcp_server is not None: tcp_server.stop() + super().stop() def _check_started(self): if self.tcp_server is None: diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index d58d1bab765..a740fb8c93e 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -20,7 +20,7 @@ from dask.utils import parse_bytes from distributed.comm.addressing import parse_host_port, unparse_host_port -from distributed.comm.core import Comm, CommClosedError, Connector, Listener +from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector from distributed.comm.registry import Backend, backends from distributed.comm.utils import ( ensure_concrete_host, @@ -479,7 +479,7 @@ async def connect( ) -class UCXListener(Listener): +class UCXListener(BaseListener): prefix = UCXConnector.prefix comm_class = UCXConnector.comm_class encrypted = UCXConnector.encrypted @@ -492,6 +492,7 @@ def __init__( allow_offload: bool = True, **connection_args: Any, ): + super().__init__() if not address.startswith("ucx"): address = "ucx://" + address self.ip, self._input_port = parse_host_port(address, default_port=0) @@ -532,6 +533,7 @@ async def serve_forever(client_ep): def stop(self): self.ucp_server = None + super().stop() def get_host_port(self): # TODO: TCP raises if this hasn't started yet. diff --git a/distributed/comm/ws.py b/distributed/comm/ws.py index 5534dd09ab0..7373fa6f6a3 100644 --- a/distributed/comm/ws.py +++ b/distributed/comm/ws.py @@ -25,11 +25,11 @@ from distributed.comm.addressing import parse_host_port, unparse_host_port from distributed.comm.core import ( + BaseListener, Comm, CommClosedError, Connector, FatalCommClosedError, - Listener, ) from distributed.comm.registry import backends from distributed.comm.tcp import ( @@ -332,7 +332,7 @@ def _read_extra(self): ) -class WSListener(Listener): +class WSListener(BaseListener): prefix = "ws://" def __init__( @@ -343,6 +343,7 @@ def __init__( allow_offload: bool = False, **connection_args: Any, ): + super().__init__() if not address.startswith(self.prefix): address = f"{self.prefix}{address}" @@ -402,6 +403,7 @@ async def start(self): def stop(self): self.server.stop() + super().stop() def get_host_port(self): """ From d069db436bd579f0ac6ab9739fb16adf53cf63ed Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 15 Jun 2023 15:20:53 +0100 Subject: [PATCH 08/11] specialcase aborting handshaking comms --- distributed/comm/asyncio_tcp.py | 1 - distributed/comm/core.py | 5 ++--- distributed/comm/inproc.py | 1 - distributed/comm/tcp.py | 1 - distributed/comm/ucx.py | 1 - distributed/comm/ws.py | 1 - distributed/core.py | 12 ++++++++++++ 7 files changed, 14 insertions(+), 8 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index f2419d354b2..497288ff4a7 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -734,7 +734,6 @@ def stop(self) -> None: # Stop listening for server in self._servers: server.close() - super().stop() def get_host_port(self): """ diff --git a/distributed/comm/core.py b/distributed/comm/core.py index faceb1aea4c..c9c4c880f26 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -220,7 +220,7 @@ async def start(self): """ @abstractmethod - def stop(self) -> None: + def stop(self): """ Stop listening. This does not shutdown already established communications, but prevents accepting new ones. @@ -289,11 +289,10 @@ async def on_connection( finally: self.__comms.discard(comm) - def stop(self) -> None: + def abort_handshaking_comms(self) -> None: comms, self.__comms = self.__comms, set() for comm in comms: comm.abort() - super().stop() class Connector(ABC): diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 02d26cd965e..24bcd4c705b 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -304,7 +304,6 @@ async def start(self): def stop(self): self.listen_q.put_nowait(None) self.manager.remove_listener(self.address) - super().stop() @property def listen_address(self): diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 306006b9f99..6c521650941 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -591,7 +591,6 @@ def stop(self): tcp_server, self.tcp_server = self.tcp_server, None if tcp_server is not None: tcp_server.stop() - super().stop() def _check_started(self): if self.tcp_server is None: diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index a740fb8c93e..674d2f7e49d 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -533,7 +533,6 @@ async def serve_forever(client_ep): def stop(self): self.ucp_server = None - super().stop() def get_host_port(self): # TODO: TCP raises if this hasn't started yet. diff --git a/distributed/comm/ws.py b/distributed/comm/ws.py index 7373fa6f6a3..38aa840907e 100644 --- a/distributed/comm/ws.py +++ b/distributed/comm/ws.py @@ -403,7 +403,6 @@ async def start(self): def stop(self): self.server.stop() - super().stop() def get_host_port(self): """ diff --git a/distributed/core.py b/distributed/core.py index 5a026f69380..08a3a461064 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -673,6 +673,12 @@ def stop(self): future = listener.stop() if inspect.isawaitable(future): _stops.add(future) + try: + abort_handshaking_comms = listener.abort_handshaking_comms + except AttributeError: + pass + else: + abort_handshaking_comms() if _stops: @@ -1037,6 +1043,12 @@ async def close(self, timeout=None): PendingDeprecationWarning, ) _stops.add(future) + try: + abort_handshaking_comms = listener.abort_handshaking_comms + except AttributeError: + pass + else: + abort_handshaking_comms() if _stops: await asyncio.gather(*_stops) From 3e6f0f6a9c167a74bafa8148f56db2a8992cd9f6 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 6 Jul 2023 11:48:20 +0100 Subject: [PATCH 09/11] wait for the dask_worker to stop in test_dump_cluster_unresponsive_remote_worker --- distributed/tests/test_utils_test.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 0d3f88180bf..30896b4f423 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -598,16 +598,9 @@ async def test_dump_cluster_state_unresponsive_local_worker(s, a, b, tmp_path): @pytest.mark.slow -@gen_cluster( - client=True, - Worker=Nanny, - config={"distributed.comm.timeouts.connect": "600ms"}, -) +@gen_cluster(client=True, Worker=Nanny) async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmp_path): - clog_fut = asyncio.create_task( - c.run(lambda dask_scheduler: dask_scheduler.stop(), workers=[a.worker_address]) - ) - await asyncio.sleep(0.2) + await c.run(lambda dask_worker: dask_worker.stop(), workers=[a.worker_address]) await dump_cluster_state(s, [a, b], str(tmp_path), "dump") with open(f"{tmp_path}/dump.yaml") as fh: @@ -619,8 +612,6 @@ async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmp_path): "OSError('Timed out trying to connect to" ) - clog_fut.cancel() - # Note: WINDOWS constant doesn't work with `mypy --platform win32` if sys.platform == "win32": From a1723de0f14a5e1df4abfa9256654911524229d0 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 6 Jul 2023 12:10:01 +0100 Subject: [PATCH 10/11] stop worker as soon as possible --- distributed/worker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index 7d89998c4d6..fc3ac36205b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1538,6 +1538,8 @@ async def close( # type: ignore for pc in self.periodic_callbacks.values(): pc.stop() + self.stop() + # Cancel async instructions await BaseWorker.close(self, timeout=timeout) @@ -1640,7 +1642,6 @@ def _close(executor, wait): executor=executor, wait=executor_wait ) # Just run it directly - self.stop() await self.rpc.close() self.status = Status.closed From bd2ba8fd0e6142ee99a8055566285a688f81492a Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 6 Jul 2023 13:32:36 +0100 Subject: [PATCH 11/11] fix test_restart_nanny_timeout_exceeded changing when the worker calls self.stop() seems to cause this test to fail unless the SlowKillNanny.kill() is allowed to proceed when called during test teardown --- distributed/tests/test_scheduler.py | 49 ++++++++++++++++------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 497cc1dd35f..38d481fc369 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1063,30 +1063,35 @@ async def kill(self, *, timeout, reason=None): @pytest.mark.slow @gen_cluster(client=True, Worker=SlowKillNanny, nthreads=[("", 1)] * 2) async def test_restart_nanny_timeout_exceeded(c, s, a, b): - f = c.submit(div, 1, 0) - fr = c.submit(inc, 1, resources={"FOO": 1}) - await wait(f) - assert s.erred_tasks - assert s.computations - assert s.unrunnable - assert s.tasks - - with pytest.raises( - TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s" - ): - await c.restart(timeout="1s") - assert a.kill_called.is_set() - assert b.kill_called.is_set() + try: + f = c.submit(div, 1, 0) + fr = c.submit(inc, 1, resources={"FOO": 1}) + await wait(f) + assert s.erred_tasks + assert s.computations + assert s.unrunnable + assert s.tasks - assert not s.workers - assert not s.erred_tasks - assert not s.computations - assert not s.unrunnable - assert not s.tasks + with pytest.raises( + TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s" + ): + await c.restart(timeout="1s") + assert a.kill_called.is_set() + assert b.kill_called.is_set() + + assert not s.workers + assert not s.erred_tasks + assert not s.computations + assert not s.unrunnable + assert not s.tasks + + assert not c.futures + assert f.status == "cancelled" + assert fr.status == "cancelled" + finally: + a.kill_proceed.set() + b.kill_proceed.set() - assert not c.futures - assert f.status == "cancelled" - assert fr.status == "cancelled" @gen_cluster(client=True, nthreads=[("", 1)] * 2)