diff --git a/distributed/comm/core.py b/distributed/comm/core.py index fa3b8cb52a5..c9c4c880f26 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -7,7 +7,6 @@ import sys import weakref from abc import ABC, abstractmethod -from contextlib import suppress from typing import Any, ClassVar import dask @@ -264,20 +263,8 @@ async def on_connection( ) -> None: local_info = {**comm.handshake_info(), **(handshake_overrides or {})} - timeout = dask.config.get("distributed.comm.timeouts.connect") - timeout = parse_timedelta(timeout, default="seconds") - try: - # Timeout is to ensure that we'll terminate connections eventually. - # Connector side will employ smaller timeouts and we should only - # reach this if the comm is dead anyhow. - await wait_for(comm.write(local_info), timeout=timeout) - handshake = await wait_for(comm.read(), timeout=timeout) - # This would be better, but connections leak if worker is closed quickly - # write, handshake = await asyncio.gather(comm.write(local_info), comm.read()) - except Exception as e: - with suppress(Exception): - await comm.close() - raise CommClosedError(f"Comm {comm!r} closed.") from e + await comm.write(local_info) + handshake = await comm.read() comm.remote_info = handshake comm.remote_info["address"] = comm.peer_address @@ -386,17 +373,8 @@ def time_left(): **comm.handshake_info(), **(handshake_overrides or {}), } - try: - # This would be better, but connections leak if worker is closed quickly - # write, handshake = await asyncio.gather(comm.write(local_info), comm.read()) - handshake = await wait_for(comm.read(), time_left()) - await wait_for(comm.write(local_info), time_left()) - except Exception as exc: - with suppress(Exception): - await comm.close() - raise OSError( - f"Timed out during handshake while connecting to {addr} after {timeout} s" - ) from exc + await comm.write(local_info) + handshake = await comm.read() comm.remote_info = handshake comm.remote_info["address"] = comm._peer_addr diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 070953eeb86..c8a2264f37c 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -961,6 +961,7 @@ class UnreliableBackend(tcp.TCPBackend): listener.stop() +@pytest.mark.slow @gen_test() async def test_handshake_slow_comm(tcp, monkeypatch): class SlowComm(tcp.TCP): @@ -995,11 +996,9 @@ def get_connector(self): import dask - with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): - with pytest.raises( - IOError, match="Timed out during handshake while connecting to" - ): - await connect(listener.contact_address) + # The connect itself is fast. Only the handshake is slow + with dask.config.set({"distributed.comm.timeouts.connect": "500ms"}): + await connect(listener.contact_address) finally: listener.stop() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 2f9c252b0c3..6fb24546c0b 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5864,22 +5864,6 @@ async def test_client_timeout_2(): assert stop - start < 1 -@gen_test() -async def test_client_active_bad_port(): - import tornado.httpserver - import tornado.web - - application = tornado.web.Application([(r"/", tornado.web.RequestHandler)]) - http_server = tornado.httpserver.HTTPServer(application) - http_server.listen(8080) - with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}): - c = Client("127.0.0.1:8080", asynchronous=True) - with pytest.raises((TimeoutError, IOError)): - async with c: - pass - http_server.stop() - - @pytest.mark.parametrize("direct", [True, False]) @gen_cluster(client=True, client_kwargs={"serializers": ["dask", "msgpack"]}) async def test_turn_off_pickle(c, s, a, b, direct): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 497cc1dd35f..38d481fc369 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1063,30 +1063,35 @@ async def kill(self, *, timeout, reason=None): @pytest.mark.slow @gen_cluster(client=True, Worker=SlowKillNanny, nthreads=[("", 1)] * 2) async def test_restart_nanny_timeout_exceeded(c, s, a, b): - f = c.submit(div, 1, 0) - fr = c.submit(inc, 1, resources={"FOO": 1}) - await wait(f) - assert s.erred_tasks - assert s.computations - assert s.unrunnable - assert s.tasks - - with pytest.raises( - TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s" - ): - await c.restart(timeout="1s") - assert a.kill_called.is_set() - assert b.kill_called.is_set() + try: + f = c.submit(div, 1, 0) + fr = c.submit(inc, 1, resources={"FOO": 1}) + await wait(f) + assert s.erred_tasks + assert s.computations + assert s.unrunnable + assert s.tasks - assert not s.workers - assert not s.erred_tasks - assert not s.computations - assert not s.unrunnable - assert not s.tasks + with pytest.raises( + TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s" + ): + await c.restart(timeout="1s") + assert a.kill_called.is_set() + assert b.kill_called.is_set() + + assert not s.workers + assert not s.erred_tasks + assert not s.computations + assert not s.unrunnable + assert not s.tasks + + assert not c.futures + assert f.status == "cancelled" + assert fr.status == "cancelled" + finally: + a.kill_proceed.set() + b.kill_proceed.set() - assert not c.futures - assert f.status == "cancelled" - assert fr.status == "cancelled" @gen_cluster(client=True, nthreads=[("", 1)] * 2) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 0d3f88180bf..30896b4f423 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -598,16 +598,9 @@ async def test_dump_cluster_state_unresponsive_local_worker(s, a, b, tmp_path): @pytest.mark.slow -@gen_cluster( - client=True, - Worker=Nanny, - config={"distributed.comm.timeouts.connect": "600ms"}, -) +@gen_cluster(client=True, Worker=Nanny) async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmp_path): - clog_fut = asyncio.create_task( - c.run(lambda dask_scheduler: dask_scheduler.stop(), workers=[a.worker_address]) - ) - await asyncio.sleep(0.2) + await c.run(lambda dask_worker: dask_worker.stop(), workers=[a.worker_address]) await dump_cluster_state(s, [a, b], str(tmp_path), "dump") with open(f"{tmp_path}/dump.yaml") as fh: @@ -619,8 +612,6 @@ async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmp_path): "OSError('Timed out trying to connect to" ) - clog_fut.cancel() - # Note: WINDOWS constant doesn't work with `mypy --platform win32` if sys.platform == "win32": diff --git a/distributed/worker.py b/distributed/worker.py index 7d89998c4d6..fc3ac36205b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1538,6 +1538,8 @@ async def close( # type: ignore for pc in self.periodic_callbacks.values(): pc.stop() + self.stop() + # Cancel async instructions await BaseWorker.close(self, timeout=timeout) @@ -1640,7 +1642,6 @@ def _close(executor, wait): executor=executor, wait=executor_wait ) # Just run it directly - self.stop() await self.rpc.close() self.status = Status.closed