diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index df464164f5d..ef2d1e79c63 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -21,7 +21,7 @@ import zipfile from collections import deque from collections.abc import Generator -from contextlib import contextmanager, nullcontext +from contextlib import ExitStack, contextmanager, nullcontext from functools import partial from operator import add from threading import Semaphore @@ -71,7 +71,7 @@ from distributed.cluster_dump import load_cluster_dump from distributed.comm import CommClosedError from distributed.compatibility import LINUX, WINDOWS -from distributed.core import Server, Status +from distributed.core import Status from distributed.metrics import time from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler from distributed.sizeof import sizeof @@ -3592,65 +3592,53 @@ async def test_scatter_raises_if_no_workers(c, s): async def test_reconnect(): port = open_port() - async def hard_stop(s): - for pc in s.periodic_callbacks.values(): - pc.stop() + stack = ExitStack() + proc = popen(["dask-scheduler", "--no-dashboard", f"--port={port}"]) + stack.enter_context(proc) + async with Client(f"127.0.0.1:{port}", asynchronous=True) as c, Worker( + f"127.0.0.1:{port}" + ) as w: + await c.wait_for_workers(1, timeout=10) + x = c.submit(inc, 1) + assert (await x) == 2 + stack.close() - s.stop_services() - for comm in list(s.stream_comms.values()): - comm.abort() - for comm in list(s.client_comms.values()): - comm.abort() + start = time() + while c.status != "connecting": + assert time() < start + 10 + await asyncio.sleep(0.01) - await s.rpc.close() - s.stop() - await Server.close(s) + assert x.status == "cancelled" + with pytest.raises(CancelledError): + await x - async with Scheduler(port=port) as s: - async with Client(f"127.0.0.1:{port}", asynchronous=True) as c: - async with Worker(f"127.0.0.1:{port}") as w: - await c.wait_for_workers(1, timeout=10) - x = c.submit(inc, 1) - assert (await x) == 2 - await hard_stop(s) + with popen(["dask-scheduler", "--no-dashboard", f"--port={port}"]): + start = time() + while c.status != "running": + await asyncio.sleep(0.1) + assert time() < start + 10 + await w.finished() + async with Worker(f"127.0.0.1:{port}"): start = time() - while c.status != "connecting": + while len(await c.nthreads()) != 1: + await asyncio.sleep(0.05) assert time() < start + 10 - await asyncio.sleep(0.01) - - assert x.status == "cancelled" - with pytest.raises(CancelledError): - await x - async with Scheduler(port=port) as s2: - start = time() - while c.status != "running": - await asyncio.sleep(0.1) - assert time() < start + 10 - - await w.finished() - async with Worker(f"127.0.0.1:{port}"): - start = time() - while len(await c.nthreads()) != 1: - await asyncio.sleep(0.05) - assert time() < start + 10 - - x = c.submit(inc, 1) - assert (await x) == 2 - await hard_stop(s2) + x = c.submit(inc, 1) + assert (await x) == 2 - start = time() - while True: - assert time() < start + 10 - try: - await x - assert False - except CommClosedError: - continue - except CancelledError: - break - await c._close(fast=True) + start = time() + while True: + assert time() < start + 10 + try: + await x + assert False + except CommClosedError: + continue + except CancelledError: + break + await c._close(fast=True) class UnhandledExceptions(Exception):