diff --git a/distributed/client.py b/distributed/client.py index 31bc53ae37..242cbdd76e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3916,8 +3916,19 @@ async def _restart_workers( timeout = self._timeout * 4 timeout = parse_timedelta(cast("str|int|float", timeout), "s") - info = self.scheduler_info() - name_to_addr = {meta["name"]: addr for addr, meta in info["workers"].items()} + # Fetch info for all workers (not the truncated cache from + # ``scheduler_info``) so that name resolution and the nanny check below + # are correct on clusters with many workers. + workers_info = (await self.scheduler.identity(n_workers=-1))["workers"] + + for addr, meta in workers_info.items(): + if (addr in workers or meta["name"] in workers) and meta["nanny"] is None: + raise ValueError( + f"Restarting workers requires a nanny to be used. Worker " + f"{addr} has type {meta['type']}." + ) + + name_to_addr = {meta["name"]: addr for addr, meta in workers_info.items()} worker_addrs = [name_to_addr.get(w, w) for w in workers] out: dict[ @@ -3989,14 +4000,6 @@ def restart_workers( -------- Client.restart """ - info = self.scheduler_info() - - for worker, meta in info["workers"].items(): - if (worker in workers or meta["name"] in workers) and meta["nanny"] is None: - raise ValueError( - f"Restarting workers requires a nanny to be used. Worker " - f"{worker} has type {info['workers'][worker]['type']}." - ) return self.sync( self._restart_workers, workers=workers, @@ -4475,14 +4478,21 @@ async def _profile( else: return state - def scheduler_info(self, n_workers: int = 5, **kwargs: Any) -> SchedulerInfo: + def scheduler_info(self, n_workers: int = -1, **kwargs: Any) -> SchedulerInfo: """Basic information about the workers in the cluster Parameters ---------- n_workers: int - The number of workers for which to fetch information. To fetch all, - use -1. + The number of workers for which to fetch information. Defaults to + ``-1``, which fetches all workers. Pass a positive integer to limit + the number of workers returned. + + Note: this argument is only honored for synchronous clients. For + asynchronous clients this method returns the most recently cached + value (refreshed periodically) without a fresh fetch; use + ``await client.scheduler.identity(n_workers=-1)`` to fetch all + workers on demand. **kwargs : dict Optional keyword arguments for the remote function diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index d2123d43ad..d0dffeedbc 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2112,10 +2112,12 @@ def test_repr(loop, func): *workers, ): with Client(s["address"], loop=loop) as c: - # NOTE: Intentionally testing when we have more workers than the default - # in `client.scheduler_info()` (xref https://github.com/dask/distributed/issues/9065) + # NOTE: Intentionally testing the repr with more workers than the repr + # itself displays (the HTML view caps the worker table at 5). + # ``scheduler_info()`` returns all workers by default + # (xref https://github.com/dask/distributed/issues/9065). info = c.scheduler_info() - assert len(info["workers"]) < nworkers + assert len(info["workers"]) == nworkers text = func(c) assert c.scheduler.address in text @@ -3948,12 +3950,15 @@ async def test_idempotence(s, a, b): def test_scheduler_info(c): + # By default all workers are returned info = c.scheduler_info() assert isinstance(info, dict) assert len(info["workers"]) == 2 assert isinstance(info["started"], float) + # A positive ``n_workers`` truncates the worker list info = c.scheduler_info(n_workers=1) assert len(info["workers"]) == 1 + # ``-1`` is equivalent to the default and returns all workers info = c.scheduler_info(n_workers=-1) assert len(info["workers"]) == 2 @@ -4768,6 +4773,21 @@ async def test_restart_workers_no_nanny_raises(c, s, a, b): assert a.address in msg +@gen_cluster(client=True, nthreads=[("", 1)] * 7) +async def test_restart_workers_no_nanny_raises_many_workers(c, s, *workers): + # Regression test: the nanny check must consider all workers, not just the + # truncated set that ``scheduler_info`` used to return by default + # (xref https://github.com/dask/distributed/issues/9065). Target a worker + # beyond the historical 5-worker truncation cap. + assert len(s.workers) == 7 + target = list(s.workers)[-1] + with pytest.raises(ValueError) as excinfo: + await c.restart_workers(workers=[target]) + msg = str(excinfo.value).lower() + assert "restarting workers requires a nanny" in msg + assert target in msg + + @pytest.mark.slow @pytest.mark.parametrize("raise_for_error", (True, False)) @gen_cluster(client=True, nthreads=[("", 1)], Worker=BlockedKillNanny)