diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 50ad43dfce8..0b069ca7845 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -128,10 +128,14 @@ async def test_async_ctx(s, a, b): @pytest.mark.slow def test_worker_dies(): + config = { + "distributed.scheduler.locks.lease-timeout": "0.5s", + "distributed.scheduler.locks.lease-validation-interval": "0.2s", + } with cluster( - config={ - "distributed.scheduler.locks.lease-timeout": "0.1s", - } + config=config, + # Since we kill off the worker, the disconnect will never be successful. + disconnect_timeout=1, ) as (scheduler, workers): with Client(scheduler["address"]) as client: sem = Semaphore(name="x", max_leases=1) @@ -147,12 +151,16 @@ def f(x, sem, kill_address): os.kill(os.getpid(), 15) return x - futures = client.map( - f, range(10), sem=sem, kill_address=workers[0]["address"] + futs = client.map( + f, + range(10), + sem=sem, + kill_address=workers[0]["address"], + workers=[workers[0]["address"]], + allow_other_workers=True, ) - results = client.gather(futures) - - assert sorted(results) == list(range(10)) + result = client.gather(futs) + assert result == list(range(10)) @gen_cluster(client=True) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index f6625e74180..d28b64ce035 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -19,6 +19,7 @@ import tempfile import threading import uuid +import warnings import weakref from collections import defaultdict from collections.abc import Callable @@ -27,6 +28,7 @@ from itertools import count from time import sleep from typing import Any, Literal +from unittest.mock import patch from distributed.compatibility import MACOS from distributed.scheduler import Scheduler @@ -80,6 +82,9 @@ except ImportError: pass +# https://docs.github.com/en/actions/learn-github-actions/environment-variables#default-environment-variables +ON_CI = bool(os.environ.get("CI", False)) + logger = logging.getLogger(__name__) @@ -226,6 +231,11 @@ def get_ip(): def reset_config(): + """Reset the global Dask config.""" + warnings.warn( + "The function reset_config will be removed soon. Instead, consider using unitest.mock.patch.dict", + FutureWarning, + ) dask.config.config.clear() dask.config.config.update(copy.deepcopy(original_config)) @@ -645,133 +655,136 @@ def cluster( enable_proctitle_on_children() with clean(timeout=active_rpc_timeout, threads=False) as loop: - if nanny: - _run_worker = run_nanny - else: - _run_worker = run_worker + with dask.config.set(config): + if nanny: + _run_worker = run_nanny + else: + _run_worker = run_worker - # The scheduler queue will receive the scheduler's address - scheduler_q = mp_context.Queue() + # The scheduler queue will receive the scheduler's address + scheduler_q = mp_context.Queue() - # Launch scheduler - scheduler = mp_context.Process( - name="Dask cluster test: Scheduler", - target=run_scheduler, - args=(scheduler_q, nworkers + 1, config), - kwargs=scheduler_kwargs, - ) - ws.add(scheduler) - scheduler.daemon = True - scheduler.start() - - # Launch workers - workers = [] - for i in range(nworkers): - q = mp_context.Queue() - fn = "_test_worker-%s" % uuid.uuid4() - kwargs = merge( - { - "nthreads": 1, - "local_directory": fn, - "memory_limit": system.MEMORY_LIMIT, - }, - worker_kwargs, - ) - proc = mp_context.Process( - name="Dask cluster test: Worker", - target=_run_worker, - args=(q, scheduler_q, config), - kwargs=kwargs, + # Launch scheduler + scheduler = mp_context.Process( + name="Dask cluster test: Scheduler", + target=run_scheduler, + args=(scheduler_q, nworkers + 1, config), + kwargs=scheduler_kwargs, ) - ws.add(proc) - workers.append({"proc": proc, "queue": q, "dir": fn}) - - for worker in workers: - worker["proc"].start() - saddr_or_exception = scheduler_q.get() - if isinstance(saddr_or_exception, Exception): - raise saddr_or_exception - saddr = saddr_or_exception - - for worker in workers: - addr_or_exception = worker["queue"].get() - if isinstance(addr_or_exception, Exception): - raise addr_or_exception - worker["address"] = addr_or_exception + ws.add(scheduler) + scheduler.daemon = True + scheduler.start() - start = time() - try: + # Launch workers + workers = [] + for i in range(nworkers): + q = mp_context.Queue() + fn = "_test_worker-%s" % uuid.uuid4() + kwargs = merge( + { + "nthreads": 1, + "local_directory": fn, + "memory_limit": system.MEMORY_LIMIT, + }, + worker_kwargs, + ) + proc = mp_context.Process( + name="Dask cluster test: Worker", + target=_run_worker, + args=(q, scheduler_q, config), + kwargs=kwargs, + ) + ws.add(proc) + workers.append({"proc": proc, "queue": q, "dir": fn}) + + for worker in workers: + worker["proc"].start() + saddr_or_exception = scheduler_q.get() + if isinstance(saddr_or_exception, Exception): + raise saddr_or_exception + saddr = saddr_or_exception + + for worker in workers: + addr_or_exception = worker["queue"].get() + if isinstance(addr_or_exception, Exception): + raise addr_or_exception + worker["address"] = addr_or_exception + + start = time() try: - security = scheduler_kwargs["security"] - rpc_kwargs = {"connection_args": security.get_connection_args("client")} - except KeyError: - rpc_kwargs = {} - - with rpc(saddr, **rpc_kwargs) as s: - while True: - nthreads = loop.run_sync(s.ncores) - if len(nthreads) == nworkers: - break - if time() - start > 5: - raise Exception("Timeout on cluster creation") - - # avoid sending processes down to function - yield {"address": saddr}, [ - {"address": w["address"], "proc": weakref.ref(w["proc"])} - for w in workers - ] - finally: - logger.debug("Closing out test cluster") + try: + security = scheduler_kwargs["security"] + rpc_kwargs = { + "connection_args": security.get_connection_args("client") + } + except KeyError: + rpc_kwargs = {} + + with rpc(saddr, **rpc_kwargs) as s: + while True: + nthreads = loop.run_sync(s.ncores) + if len(nthreads) == nworkers: + break + if time() - start > 5: + raise Exception("Timeout on cluster creation") + + # avoid sending processes down to function + yield {"address": saddr}, [ + {"address": w["address"], "proc": weakref.ref(w["proc"])} + for w in workers + ] + finally: + logger.debug("Closing out test cluster") - loop.run_sync( - lambda: disconnect_all( - [w["address"] for w in workers], - timeout=disconnect_timeout, - rpc_kwargs=rpc_kwargs, + loop.run_sync( + lambda: disconnect_all( + [w["address"] for w in workers], + timeout=disconnect_timeout, + rpc_kwargs=rpc_kwargs, + ) ) - ) - loop.run_sync( - lambda: disconnect( - saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs + loop.run_sync( + lambda: disconnect( + saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs + ) ) - ) - scheduler.terminate() - scheduler_q.close() - scheduler_q._reader.close() - scheduler_q._writer.close() + scheduler.terminate() + scheduler_q.close() + scheduler_q._reader.close() + scheduler_q._writer.close() - for w in workers: - w["proc"].terminate() - w["queue"].close() - w["queue"]._reader.close() - w["queue"]._writer.close() + for w in workers: + w["proc"].terminate() + w["queue"].close() + w["queue"]._reader.close() + w["queue"]._writer.close() - scheduler.join(2) - del scheduler - for proc in [w["proc"] for w in workers]: - proc.join(timeout=30) + scheduler.join(2) + del scheduler + for proc in [w["proc"] for w in workers]: + proc.join(timeout=30) - with suppress(UnboundLocalError): - del worker, w, proc - del workers[:] + with suppress(UnboundLocalError): + del worker, w, proc + del workers[:] - for fn in glob("_test_worker-*"): - with suppress(OSError): - shutil.rmtree(fn) + for fn in glob("_test_worker-*"): + with suppress(OSError): + shutil.rmtree(fn) - try: - client = default_client() - except ValueError: - pass - else: - client.close() + try: + client = default_client() + except ValueError: + pass + else: + client.close() - start = time() - while any(proc.is_alive() for proc in ws): - text = str(list(ws)) - sleep(0.2) - assert time() < start + 5, ("Workers still around after five seconds", text) + start = time() + while any(proc.is_alive() for proc in ws): + text = str(list(ws)) + sleep(0.2) + assert time() < start + 5, ("Workers still around after five seconds", text) async def disconnect(addr, timeout=3, rpc_kwargs=None): @@ -787,7 +800,8 @@ async def do_disconnect(): # the timeout await w.terminate(reply=False) - await asyncio.wait_for(do_disconnect(), timeout=timeout) + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(do_disconnect(), timeout=timeout) async def disconnect_all(addresses, timeout=3, rpc_kwargs=None): @@ -1733,15 +1747,16 @@ def clean(threads=not WINDOWS, instances=True, timeout=1, processes=True): with check_process_leak(check=processes): with check_instances() if instances else nullcontext(): with check_active_rpc(loop, timeout): - reset_config() - - dask.config.set({"distributed.comm.timeouts.connect": "5s"}) - # Restore default logging levels - # XXX use pytest hooks/fixtures instead? - for name, level in logging_levels.items(): - logging.getLogger(name).setLevel(level) + with patch.dict(original_config, clear=True): + with dask.config.set( + {"distributed.comm.timeouts.connect": "300s"} + ): + # Restore default logging levels + # XXX use pytest hooks/fixtures instead? + for name, level in logging_levels.items(): + logging.getLogger(name).setLevel(level) - yield loop + yield loop @pytest.fixture