Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3916,8 +3916,19 @@ async def _restart_workers(
timeout = self._timeout * 4
timeout = parse_timedelta(cast("str|int|float", timeout), "s")

info = self.scheduler_info()
name_to_addr = {meta["name"]: addr for addr, meta in info["workers"].items()}
# Fetch info for all workers (not the truncated cache from
# ``scheduler_info``) so that name resolution and the nanny check below
# are correct on clusters with many workers.
workers_info = (await self.scheduler.identity(n_workers=-1))["workers"]

for addr, meta in workers_info.items():
if (addr in workers or meta["name"] in workers) and meta["nanny"] is None:
raise ValueError(
f"Restarting workers requires a nanny to be used. Worker "
f"{addr} has type {meta['type']}."
)

name_to_addr = {meta["name"]: addr for addr, meta in workers_info.items()}
worker_addrs = [name_to_addr.get(w, w) for w in workers]

out: dict[
Expand Down Expand Up @@ -3989,14 +4000,6 @@ def restart_workers(
--------
Client.restart
"""
info = self.scheduler_info()

for worker, meta in info["workers"].items():
if (worker in workers or meta["name"] in workers) and meta["nanny"] is None:
raise ValueError(
f"Restarting workers requires a nanny to be used. Worker "
f"{worker} has type {info['workers'][worker]['type']}."
)
return self.sync(
self._restart_workers,
workers=workers,
Expand Down Expand Up @@ -4475,14 +4478,21 @@ async def _profile(
else:
return state

def scheduler_info(self, n_workers: int = 5, **kwargs: Any) -> SchedulerInfo:
def scheduler_info(self, n_workers: int = -1, **kwargs: Any) -> SchedulerInfo:
"""Basic information about the workers in the cluster

Parameters
----------
n_workers: int
The number of workers for which to fetch information. To fetch all,
use -1.
The number of workers for which to fetch information. Defaults to
``-1``, which fetches all workers. Pass a positive integer to limit
the number of workers returned.

Note: this argument is only honored for synchronous clients. For
asynchronous clients this method returns the most recently cached
value (refreshed periodically) without a fresh fetch; use
``await client.scheduler.identity(n_workers=-1)`` to fetch all
workers on demand.
**kwargs : dict
Optional keyword arguments for the remote function

Expand Down
26 changes: 23 additions & 3 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,10 +2112,12 @@ def test_repr(loop, func):
*workers,
):
with Client(s["address"], loop=loop) as c:
# NOTE: Intentionally testing when we have more workers than the default
# in `client.scheduler_info()` (xref https://github.com/dask/distributed/issues/9065)
# NOTE: Intentionally testing the repr with more workers than the repr
# itself displays (the HTML view caps the worker table at 5).
# ``scheduler_info()`` returns all workers by default
# (xref https://github.com/dask/distributed/issues/9065).
info = c.scheduler_info()
assert len(info["workers"]) < nworkers
assert len(info["workers"]) == nworkers

text = func(c)
assert c.scheduler.address in text
Expand Down Expand Up @@ -3948,12 +3950,15 @@ async def test_idempotence(s, a, b):


def test_scheduler_info(c):
# By default all workers are returned
info = c.scheduler_info()
assert isinstance(info, dict)
assert len(info["workers"]) == 2
assert isinstance(info["started"], float)
# A positive ``n_workers`` truncates the worker list
info = c.scheduler_info(n_workers=1)
assert len(info["workers"]) == 1
# ``-1`` is equivalent to the default and returns all workers
info = c.scheduler_info(n_workers=-1)
assert len(info["workers"]) == 2

Expand Down Expand Up @@ -4768,6 +4773,21 @@ async def test_restart_workers_no_nanny_raises(c, s, a, b):
assert a.address in msg


@gen_cluster(client=True, nthreads=[("", 1)] * 7)
async def test_restart_workers_no_nanny_raises_many_workers(c, s, *workers):
# Regression test: the nanny check must consider all workers, not just the
# truncated set that ``scheduler_info`` used to return by default
# (xref https://github.com/dask/distributed/issues/9065). Target a worker
# beyond the historical 5-worker truncation cap.
assert len(s.workers) == 7
target = list(s.workers)[-1]
with pytest.raises(ValueError) as excinfo:
await c.restart_workers(workers=[target])
msg = str(excinfo.value).lower()
assert "restarting workers requires a nanny" in msg
assert target in msg


@pytest.mark.slow
@pytest.mark.parametrize("raise_for_error", (True, False))
@gen_cluster(client=True, nthreads=[("", 1)], Worker=BlockedKillNanny)
Expand Down
Loading