From 9caa747cdf0e1931f255c44efe40ba92fb0d746c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 26 Jun 2026 16:48:06 -0500 Subject: [PATCH] Return all workers from scheduler_info() by default `Client.scheduler_info()` defaulted to `n_workers=5`, silently truncating the "workers" dict to the first five workers. This was misleading and, more importantly, broke `restart_workers`: both the nanny-required validation and the worker name -> address resolution read this truncated set, so on clusters with more than five workers the nanny check was silently skipped and name-based restarts of later workers misfired. Changes: - `scheduler_info()` now defaults to `n_workers=-1` (all workers). This only affects explicit, on-demand calls, not the periodic per-client poll, so the scheduler-scalability fix from #9045 / #9043 is not reintroduced (the periodic cache populator stays capped at 5). - `_restart_workers` now fetches a fresh, full identity via `scheduler.identity(n_workers=-1)` and performs both the nanny check and the name resolution there, correct for any cluster size. This also collapses the previous two identity RPCs on the sync path into one. - Document the sync/async asymmetry: async clients read the periodically cached value, so `n_workers` is only honored for synchronous clients. Adds a regression test that restarts a non-nanny worker beyond the historical five-worker cap; it raises the expected error only with the fix. xref https://github.com/dask/distributed/issues/9065 Co-Authored-By: Claude Opus 4.8 --- distributed/client.py | 36 ++++++++++++++++++++------------ distributed/tests/test_client.py | 26 ++++++++++++++++++++--- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 31bc53ae37..242cbdd76e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3916,8 +3916,19 @@ async def _restart_workers( timeout = self._timeout * 4 timeout = parse_timedelta(cast("str|int|float", timeout), "s") - info = self.scheduler_info() - name_to_addr = {meta["name"]: addr for addr, meta in info["workers"].items()} + # Fetch info for all workers (not the truncated cache from + # ``scheduler_info``) so that name resolution and the nanny check below + # are correct on clusters with many workers. + workers_info = (await self.scheduler.identity(n_workers=-1))["workers"] + + for addr, meta in workers_info.items(): + if (addr in workers or meta["name"] in workers) and meta["nanny"] is None: + raise ValueError( + f"Restarting workers requires a nanny to be used. Worker " + f"{addr} has type {meta['type']}." + ) + + name_to_addr = {meta["name"]: addr for addr, meta in workers_info.items()} worker_addrs = [name_to_addr.get(w, w) for w in workers] out: dict[ @@ -3989,14 +4000,6 @@ def restart_workers( -------- Client.restart """ - info = self.scheduler_info() - - for worker, meta in info["workers"].items(): - if (worker in workers or meta["name"] in workers) and meta["nanny"] is None: - raise ValueError( - f"Restarting workers requires a nanny to be used. Worker " - f"{worker} has type {info['workers'][worker]['type']}." - ) return self.sync( self._restart_workers, workers=workers, @@ -4475,14 +4478,21 @@ async def _profile( else: return state - def scheduler_info(self, n_workers: int = 5, **kwargs: Any) -> SchedulerInfo: + def scheduler_info(self, n_workers: int = -1, **kwargs: Any) -> SchedulerInfo: """Basic information about the workers in the cluster Parameters ---------- n_workers: int - The number of workers for which to fetch information. To fetch all, - use -1. + The number of workers for which to fetch information. Defaults to + ``-1``, which fetches all workers. Pass a positive integer to limit + the number of workers returned. + + Note: this argument is only honored for synchronous clients. For + asynchronous clients this method returns the most recently cached + value (refreshed periodically) without a fresh fetch; use + ``await client.scheduler.identity(n_workers=-1)`` to fetch all + workers on demand. **kwargs : dict Optional keyword arguments for the remote function diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index d2123d43ad..d0dffeedbc 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2112,10 +2112,12 @@ def test_repr(loop, func): *workers, ): with Client(s["address"], loop=loop) as c: - # NOTE: Intentionally testing when we have more workers than the default - # in `client.scheduler_info()` (xref https://github.com/dask/distributed/issues/9065) + # NOTE: Intentionally testing the repr with more workers than the repr + # itself displays (the HTML view caps the worker table at 5). + # ``scheduler_info()`` returns all workers by default + # (xref https://github.com/dask/distributed/issues/9065). info = c.scheduler_info() - assert len(info["workers"]) < nworkers + assert len(info["workers"]) == nworkers text = func(c) assert c.scheduler.address in text @@ -3948,12 +3950,15 @@ async def test_idempotence(s, a, b): def test_scheduler_info(c): + # By default all workers are returned info = c.scheduler_info() assert isinstance(info, dict) assert len(info["workers"]) == 2 assert isinstance(info["started"], float) + # A positive ``n_workers`` truncates the worker list info = c.scheduler_info(n_workers=1) assert len(info["workers"]) == 1 + # ``-1`` is equivalent to the default and returns all workers info = c.scheduler_info(n_workers=-1) assert len(info["workers"]) == 2 @@ -4768,6 +4773,21 @@ async def test_restart_workers_no_nanny_raises(c, s, a, b): assert a.address in msg +@gen_cluster(client=True, nthreads=[("", 1)] * 7) +async def test_restart_workers_no_nanny_raises_many_workers(c, s, *workers): + # Regression test: the nanny check must consider all workers, not just the + # truncated set that ``scheduler_info`` used to return by default + # (xref https://github.com/dask/distributed/issues/9065). Target a worker + # beyond the historical 5-worker truncation cap. + assert len(s.workers) == 7 + target = list(s.workers)[-1] + with pytest.raises(ValueError) as excinfo: + await c.restart_workers(workers=[target]) + msg = str(excinfo.value).lower() + assert "restarting workers requires a nanny" in msg + assert target in msg + + @pytest.mark.slow @pytest.mark.parametrize("raise_for_error", (True, False)) @gen_cluster(client=True, nthreads=[("", 1)], Worker=BlockedKillNanny)