Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions distributed/tests/test_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,14 @@ async def test_async_ctx(s, a, b):

@pytest.mark.slow
def test_worker_dies():
config = {
"distributed.scheduler.locks.lease-timeout": "0.5s",
"distributed.scheduler.locks.lease-validation-interval": "0.2s",
}
with cluster(
config={
"distributed.scheduler.locks.lease-timeout": "0.1s",
}
config=config,
# Since we kill off the worker, the disconnect will never be successful.
disconnect_timeout=1,
) as (scheduler, workers):
with Client(scheduler["address"]) as client:
sem = Semaphore(name="x", max_leases=1)
Expand All @@ -147,12 +151,16 @@ def f(x, sem, kill_address):
os.kill(os.getpid(), 15)
return x

futures = client.map(
f, range(10), sem=sem, kill_address=workers[0]["address"]
futs = client.map(
f,
range(10),
sem=sem,
kill_address=workers[0]["address"],
workers=[workers[0]["address"]],
allow_other_workers=True,
)
results = client.gather(futures)

assert sorted(results) == list(range(10))
result = client.gather(futs)
assert result == list(range(10))


@gen_cluster(client=True)
Expand Down
257 changes: 136 additions & 121 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tempfile
import threading
import uuid
import warnings
import weakref
from collections import defaultdict
from collections.abc import Callable
Expand All @@ -27,6 +28,7 @@
from itertools import count
from time import sleep
from typing import Any, Literal
from unittest.mock import patch

from distributed.compatibility import MACOS
from distributed.scheduler import Scheduler
Expand Down Expand Up @@ -80,6 +82,9 @@
except ImportError:
pass

# https://docs.github.com/en/actions/learn-github-actions/environment-variables#default-environment-variables
ON_CI = bool(os.environ.get("CI", False))


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -226,6 +231,11 @@ def get_ip():


def reset_config():
"""Reset the global Dask config."""
warnings.warn(
"The function reset_config will be removed soon. Instead, consider using unitest.mock.patch.dict",
FutureWarning,
)
dask.config.config.clear()
dask.config.config.update(copy.deepcopy(original_config))

Expand Down Expand Up @@ -645,133 +655,136 @@ def cluster(
enable_proctitle_on_children()

with clean(timeout=active_rpc_timeout, threads=False) as loop:
if nanny:
_run_worker = run_nanny
else:
_run_worker = run_worker
with dask.config.set(config):
if nanny:
_run_worker = run_nanny
else:
_run_worker = run_worker

# The scheduler queue will receive the scheduler's address
scheduler_q = mp_context.Queue()
# The scheduler queue will receive the scheduler's address
scheduler_q = mp_context.Queue()

# Launch scheduler
scheduler = mp_context.Process(
name="Dask cluster test: Scheduler",
target=run_scheduler,
args=(scheduler_q, nworkers + 1, config),
kwargs=scheduler_kwargs,
)
ws.add(scheduler)
scheduler.daemon = True
scheduler.start()

# Launch workers
workers = []
for i in range(nworkers):
q = mp_context.Queue()
fn = "_test_worker-%s" % uuid.uuid4()
kwargs = merge(
{
"nthreads": 1,
"local_directory": fn,
"memory_limit": system.MEMORY_LIMIT,
},
worker_kwargs,
)
proc = mp_context.Process(
name="Dask cluster test: Worker",
target=_run_worker,
args=(q, scheduler_q, config),
kwargs=kwargs,
# Launch scheduler
scheduler = mp_context.Process(
name="Dask cluster test: Scheduler",
target=run_scheduler,
args=(scheduler_q, nworkers + 1, config),
kwargs=scheduler_kwargs,
)
ws.add(proc)
workers.append({"proc": proc, "queue": q, "dir": fn})

for worker in workers:
worker["proc"].start()
saddr_or_exception = scheduler_q.get()
if isinstance(saddr_or_exception, Exception):
raise saddr_or_exception
saddr = saddr_or_exception

for worker in workers:
addr_or_exception = worker["queue"].get()
if isinstance(addr_or_exception, Exception):
raise addr_or_exception
worker["address"] = addr_or_exception
ws.add(scheduler)
scheduler.daemon = True
scheduler.start()

start = time()
try:
# Launch workers
workers = []
for i in range(nworkers):
q = mp_context.Queue()
fn = "_test_worker-%s" % uuid.uuid4()
kwargs = merge(
{
"nthreads": 1,
"local_directory": fn,
"memory_limit": system.MEMORY_LIMIT,
},
worker_kwargs,
)
proc = mp_context.Process(
name="Dask cluster test: Worker",
target=_run_worker,
args=(q, scheduler_q, config),
kwargs=kwargs,
)
ws.add(proc)
workers.append({"proc": proc, "queue": q, "dir": fn})

for worker in workers:
worker["proc"].start()
saddr_or_exception = scheduler_q.get()
if isinstance(saddr_or_exception, Exception):
raise saddr_or_exception
saddr = saddr_or_exception

for worker in workers:
addr_or_exception = worker["queue"].get()
if isinstance(addr_or_exception, Exception):
raise addr_or_exception
worker["address"] = addr_or_exception

start = time()
try:
security = scheduler_kwargs["security"]
rpc_kwargs = {"connection_args": security.get_connection_args("client")}
except KeyError:
rpc_kwargs = {}

with rpc(saddr, **rpc_kwargs) as s:
while True:
nthreads = loop.run_sync(s.ncores)
if len(nthreads) == nworkers:
break
if time() - start > 5:
raise Exception("Timeout on cluster creation")

# avoid sending processes down to function
yield {"address": saddr}, [
{"address": w["address"], "proc": weakref.ref(w["proc"])}
for w in workers
]
finally:
logger.debug("Closing out test cluster")
try:
security = scheduler_kwargs["security"]
rpc_kwargs = {
"connection_args": security.get_connection_args("client")
}
except KeyError:
rpc_kwargs = {}

with rpc(saddr, **rpc_kwargs) as s:
while True:
nthreads = loop.run_sync(s.ncores)
if len(nthreads) == nworkers:
break
if time() - start > 5:
raise Exception("Timeout on cluster creation")

# avoid sending processes down to function
yield {"address": saddr}, [
{"address": w["address"], "proc": weakref.ref(w["proc"])}
for w in workers
]
finally:
logger.debug("Closing out test cluster")

loop.run_sync(
lambda: disconnect_all(
[w["address"] for w in workers],
timeout=disconnect_timeout,
rpc_kwargs=rpc_kwargs,
loop.run_sync(
lambda: disconnect_all(
[w["address"] for w in workers],
timeout=disconnect_timeout,
rpc_kwargs=rpc_kwargs,
)
)
)
loop.run_sync(
lambda: disconnect(
saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs
loop.run_sync(
lambda: disconnect(
saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs
)
)
)

scheduler.terminate()
scheduler_q.close()
scheduler_q._reader.close()
scheduler_q._writer.close()
scheduler.terminate()
scheduler_q.close()
scheduler_q._reader.close()
scheduler_q._writer.close()

for w in workers:
w["proc"].terminate()
w["queue"].close()
w["queue"]._reader.close()
w["queue"]._writer.close()
for w in workers:
w["proc"].terminate()
w["queue"].close()
w["queue"]._reader.close()
w["queue"]._writer.close()

scheduler.join(2)
del scheduler
for proc in [w["proc"] for w in workers]:
proc.join(timeout=30)
scheduler.join(2)
del scheduler
for proc in [w["proc"] for w in workers]:
proc.join(timeout=30)

with suppress(UnboundLocalError):
del worker, w, proc
del workers[:]
with suppress(UnboundLocalError):
del worker, w, proc
del workers[:]

for fn in glob("_test_worker-*"):
with suppress(OSError):
shutil.rmtree(fn)
for fn in glob("_test_worker-*"):
with suppress(OSError):
shutil.rmtree(fn)

try:
client = default_client()
except ValueError:
pass
else:
client.close()
try:
client = default_client()
except ValueError:
pass
else:
client.close()

start = time()
while any(proc.is_alive() for proc in ws):
text = str(list(ws))
sleep(0.2)
assert time() < start + 5, ("Workers still around after five seconds", text)
start = time()
while any(proc.is_alive() for proc in ws):
text = str(list(ws))
sleep(0.2)
assert time() < start + 5, ("Workers still around after five seconds", text)


async def disconnect(addr, timeout=3, rpc_kwargs=None):
Expand All @@ -787,7 +800,8 @@ async def do_disconnect():
# the timeout
await w.terminate(reply=False)

await asyncio.wait_for(do_disconnect(), timeout=timeout)
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(do_disconnect(), timeout=timeout)


async def disconnect_all(addresses, timeout=3, rpc_kwargs=None):
Expand Down Expand Up @@ -1733,15 +1747,16 @@ def clean(threads=not WINDOWS, instances=True, timeout=1, processes=True):
with check_process_leak(check=processes):
with check_instances() if instances else nullcontext():
with check_active_rpc(loop, timeout):
reset_config()

dask.config.set({"distributed.comm.timeouts.connect": "5s"})
# Restore default logging levels
# XXX use pytest hooks/fixtures instead?
for name, level in logging_levels.items():
logging.getLogger(name).setLevel(level)
with patch.dict(original_config, clear=True):
with dask.config.set(
{"distributed.comm.timeouts.connect": "300s"}
):
# Restore default logging levels
# XXX use pytest hooks/fixtures instead?
for name, level in logging_levels.items():
logging.getLogger(name).setLevel(level)

yield loop
yield loop


@pytest.fixture
Expand Down