From 656ff58fd7b3ad52be59c32accad08847d175a25 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 4 Mar 2022 13:27:15 +0100 Subject: [PATCH 1/4] Reset config properly if multiple default clients are used --- distributed/actor.py | 2 +- distributed/client.py | 145 ++++++++++++++---------- distributed/tests/test_client.py | 102 +++++++++++++++-- distributed/tests/test_worker_client.py | 51 ++++++--- distributed/utils_test.py | 14 ++- 5 files changed, 220 insertions(+), 94 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index 064c1432725..bb323bc4fdd 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -108,7 +108,7 @@ def __init__(self, cls, address, key, worker=None): self._worker = None try: self._client = get_client() - self._future = Future(key, inform=self._worker is None) + self._future = Future(key, self._client, inform=self._worker is None) # ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable. except ValueError: self._client = None diff --git a/distributed/client.py b/distributed/client.py index c3952ec078b..1102e58a128 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -50,6 +50,8 @@ from dask.delayed import single_key except ImportError: single_key = first +from typing import Generator + from tornado import gen from tornado.ioloop import IOLoop @@ -110,11 +112,6 @@ logger = logging.getLogger(__name__) -_global_clients: weakref.WeakValueDictionary[ - int, Client -] = weakref.WeakValueDictionary() -_global_client_index = [0] - _current_client = ContextVar("_current_client", default=None) DEFAULT_EXTENSIONS = { @@ -122,30 +119,65 @@ } -def _get_global_client() -> Client | None: - L = sorted(list(_global_clients), reverse=True) - for k in L: - c = _global_clients[k] - if c.status != "closed": - return c - else: - del _global_clients[k] - return None +class _GlobalClientManager: + def __init__(self): + self._lock = threading.RLock() + self._global_clients: weakref.WeakValueDictionary[ + int, Client + ] = weakref.WeakValueDictionary() + self._global_client_index = 0 + self._set = None + + def _get_global_client(self) -> Client | None: + with self._lock: + L = sorted(self._global_clients, reverse=True) + for k in L: + return self._global_clients[k] + return None -def _set_global_client(c: Client | None) -> None: - if c is not None: - _global_clients[_global_client_index[0]] = c - _global_client_index[0] += 1 + def _set_global_client(self, c: Client) -> None: + with self._lock: + if not self._set: + self._set = dask.config.set( + scheduler="dask.distributed", shuffle="tasks" + ) + self._global_clients[self._global_client_index] = c + self._global_client_index += 1 + def _del_global_client(self, c: Client) -> None: + with self._lock: + for k in list(self._global_clients): + try: + if self._global_clients[k] is c: + del self._global_clients[k] + except KeyError: # pragma: no cover + pass -def _del_global_client(c: Client) -> None: - for k in list(_global_clients): - try: - if _global_clients[k] is c: - del _global_clients[k] - except KeyError: # pragma: no cover - pass + if not self._global_clients and self._set: + self._set.__exit__(None, None, None) + self._set = None + + def clear(self): + with self._lock: + self._global_clients.clear() + self._global_client_index = 0 + + def close(self, timeout=3): + while c := self._get_global_client(): + if c.asynchronous: + c.loop.add_callback(c.close, timeout=timeout) + else: + c.close(timeout=timeout) + + +_global_client_manager = _GlobalClientManager() +_set_global_client = _global_client_manager._set_global_client +_get_global_client = _global_client_manager._get_global_client +_del_global_client = _global_client_manager._del_global_client + + +atexit.register(_global_client_manager.close) class Future(WrappedKey): @@ -189,11 +221,17 @@ class Future(WrappedKey): _cb_executor = None _cb_executor_pid = None - def __init__(self, key, client=None, inform=True, state=None): + def __init__( + self, + key: str | tuple, + client: Client, + inform: bool = True, + state: FutureState | None = None, + ): self.key = key self._cleared = False tkey = stringify(key) - self.client = client or Client.current() + self.client = client self.client._inc_ref(tkey) self._generation = self.client.generation @@ -944,10 +982,6 @@ def __init__( self._start_arg = address self._set_as_default = set_as_default - if set_as_default: - self._set_config = dask.config.set( - scheduler="dask.distributed", shuffle="tasks" - ) self._event_handlers = {} self._stream_handlers = { @@ -1176,8 +1210,6 @@ def start(self, **kwargs): return self._loop_runner.start() - if self._set_as_default: - _set_global_client(self) if self.asynchronous: self._started = asyncio.ensure_future(self._start(**kwargs)) @@ -1215,6 +1247,8 @@ def _send_to_scheduler(self, msg): async def _start(self, timeout=no_default, **kwargs): self.status = "connecting" + if self._set_as_default: + _set_global_client(self) await self.rpc.start() @@ -1599,6 +1633,7 @@ async def _close(self, fast=False): return self.status = "closing" + _del_global_client(self) for preload in self.preloads: await preload.teardown() @@ -1608,14 +1643,7 @@ async def _close(self, fast=False): pc.stop() with log_errors(): - _del_global_client(self) self._scheduler_identity = {} - with suppress(AttributeError): - # clear the dask.config set keys - with self._set_config: - pass - if self.get == dask.config.get("get", None): - del dask.config.config["get"] if ( self.scheduler_comm @@ -1654,9 +1682,6 @@ async def _close(self, fast=False): self.status = "closed" - if _get_global_client() is self: - _set_global_client(None) - if ( handle_report_task is not None and handle_report_task is not current_task @@ -5182,7 +5207,7 @@ def AsCompleted(*args, **kwargs): raise Exception("This has moved to as_completed") -def default_client(c=None): +def default_client(c: Client | None = None) -> Client: """Return a client if one has started Parameters @@ -5207,7 +5232,7 @@ def default_client(c=None): ) -def ensure_default_client(client): +def ensure_default_client(client: Client) -> None: """Ensures the client passed as argument is set as the default Parameters @@ -5215,7 +5240,6 @@ def ensure_default_client(client): client : Client The client """ - dask.config.set(scheduler="dask.distributed") _set_global_client(client) @@ -5518,7 +5542,7 @@ def __exit__(self, exc_type, exc_value, traceback): @contextmanager -def temp_default_client(c): +def temp_default_client(c: Client) -> Generator[None, None, None]: """Set the default client for the duration of the context .. note:: @@ -5543,19 +5567,14 @@ def temp_default_client(c): _set_global_client(old_exec) -def _close_global_client(): - """ - Force close of global client. This cleans up when a client - wasn't close explicitly, e.g. interactive sessions. - """ - c = _get_global_client() - if c is not None: - c._should_close_loop = False - with suppress(TimeoutError, RuntimeError): - if c.asynchronous: - c.loop.add_callback(c.close, timeout=3) - else: - c.close(timeout=3) - - -atexit.register(_close_global_client) +def __getattr__(name): + if name == "ensure_default_get": + warnings.warn( + "`ensure_default_get` is deprecated and will be removed in a future release. " + "Please use `distributed.client.ensure_default_client` instead.", + category=FutureWarning, + stacklevel=2, + ) + return ensure_default_client + else: + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 3ee6c779f38..98ed30a0eb1 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -120,6 +120,37 @@ pytestmark = pytest.mark.ci1 +@contextmanager +def _ensure_no_global_client_state_left(): + assert _get_global_client() is None + assert dask.base.get_scheduler() is None + assert dask.config.get("shuffle", None) is None + assert dask.config.get("scheduler", None) is None + + yield + assert _get_global_client() is None + try: + assert dask.config.get("shuffle", None) is None + assert dask.config.get("scheduler", None) is None + except AssertionError: + dask.config.config.pop("shuffle", None) + dask.config.config.pop("scheduler", None) + raise + assert dask.base.get_scheduler() is None + + +@pytest.fixture(autouse=True) +def ensure_no_global_clients_are_leaking(): + with _ensure_no_global_client_state_left(): + yield + + +def test_modify_global_state_raises(): + with pytest.raises(AssertionError): + with _ensure_no_global_client_state_left(): + dask.config.config["scheduler"] = "test-value" + + @gen_cluster(client=True) async def test_submit(c, s, a, b): x = c.submit(inc, 10, key="x") @@ -3365,6 +3396,68 @@ def test_default_get(loop_in_thread): assert dask.base.get_scheduler() == pre_get +@gen_cluster(nthreads=[]) +async def test_default_global_client_multi_clients(s): + pre_get = dask.base.get_scheduler() + assert _get_global_client() is None + pytest.raises(KeyError, dask.config.get, "shuffle") + + async with Client(s.address, asynchronous=True) as c: + assert dask.base.get_scheduler() == c.get + assert dask.config.get("shuffle") == "tasks" + assert dask.config.get("scheduler") == "dask.distributed" + + async with Client(s.address, asynchronous=True) as d: + assert dask.base.get_scheduler() == d.get + + assert dask.base.get_scheduler() == c.get + + assert dask.base.get_scheduler() is None + + assert _get_global_client() is None + c = await Client(s.address, asynchronous=True) + + assert dask.base.get_scheduler() == c.get + assert dask.config.get("shuffle") == "tasks" + assert dask.config.get("scheduler") == "dask.distributed" + + d = await Client(s.address, asynchronous=True) + + assert dask.base.get_scheduler() == d.get + assert dask.config.get("shuffle") == "tasks" + assert dask.config.get("scheduler") == "dask.distributed" + + await c.close() + + assert dask.base.get_scheduler() == d.get + assert dask.config.get("shuffle") == "tasks" + assert dask.config.get("scheduler") == "dask.distributed" + await d.close() + + assert dask.config.get("shuffle", None) is None + assert dask.config.get("scheduler", None) is None + assert dask.base.get_scheduler() is None + + c = await Client(s.address, asynchronous=True) + + assert dask.base.get_scheduler() == c.get + assert dask.config.get("shuffle") == "tasks" + assert dask.config.get("scheduler") == "dask.distributed" + + non_default = await Client(s.address, asynchronous=True, set_as_default=False) + + assert dask.base.get_scheduler() == c.get + assert dask.config.get("shuffle") == "tasks" + assert dask.config.get("scheduler") == "dask.distributed" + + await c.close() + + assert dask.config.get("shuffle", None) is None + assert dask.config.get("scheduler", None) is None + assert dask.base.get_scheduler() is None + await d.close() + + @gen_cluster(client=True) async def test_ensure_default_client(c, s, a, b): assert c is default_client() @@ -5592,15 +5685,6 @@ async def test_client_with_name(s, a, b): assert "foo" in text -@gen_cluster(client=True) -async def test_future_defaults_to_default_client(c, s, a, b): - x = c.submit(inc, 1) - await wait(x) - - future = Future(x.key) - assert future.client is c - - @gen_cluster(client=True) async def test_future_auto_inform(c, s, a, b): x = c.submit(inc, 1) diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index bf52d7cd8a3..0750fad2e70 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -297,24 +297,45 @@ def f(): assert result -@pytest.mark.xfail( - reason="Flaky due to https://github.com/dask/distributed/issues/5915" +@pytest.mark.parametrize("pass_addr_get_client", [True, False]) +@pytest.mark.parametrize( + "scheduler_addr_fact", + [ + lambda s: s.address, + pytest.param( + lambda s: "localhost:" + s.address.split(":")[-1], + marks=pytest.mark.xfail( + reason="Flaky due to https://github.com/dask/distributed/issues/5915" + ), + ), + ], ) @gen_cluster() -async def test_submit_different_names(s, a, b): +async def test_submit_different_names( + s, a, b, scheduler_addr_fact, pass_addr_get_client +): # https://github.com/dask/distributed/issues/2058 - da = pytest.importorskip("dask.array") - c = await Client( - "localhost:" + s.address.split(":")[-1], loop=s.loop, asynchronous=True - ) - try: - X = c.persist(da.random.uniform(size=(100, 10), chunks=50)) - await wait(X) - - fut = await c.submit(lambda x: x.sum().compute(), X) - assert fut > 0 - finally: - await c.close() + s_addr = scheduler_addr_fact(s) + async with Client( + s_addr, + loop=s.loop, + asynchronous=True, + name="test-localhost", + ) as c: + assert c.asynchronous is True + + def in_threadpool(x): + c_inner = get_client(s_addr if pass_addr_get_client else None) + assert c_inner.asynchronous is False + assert c_inner.status == "running" + assert len(Client._instances) == 1 + + assert c.asynchronous is True + # Ensure that we're scheduling on both workers, regardless of scheduling + # heuristics + futsA = c.map(in_threadpool, range(5), workers=[a.address]) + futsB = c.map(in_threadpool, range(5, 10), workers=[b.address]) + await c.gather(futsA + futsB) @gen_cluster(client=True) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 9213916f563..a6fe385b206 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -42,7 +42,7 @@ from distributed import Event, Scheduler, system from distributed import versions as version_module from distributed.batched import BatchedSend -from distributed.client import Client, _global_clients, default_client +from distributed.client import Client, _global_client_manager, default_client from distributed.comm import Comm from distributed.comm.tcp import TCP from distributed.compatibility import MACOS, WINDOWS @@ -1055,7 +1055,9 @@ async def async_fn(): def get_unclosed(): return [c for c in Comm._instances if not c.closed()] + [ - c for c in _global_clients.values() if c.status != "closed" + c + for c in _global_client_manager._global_clients.values() + if c.status != "closed" ] try: @@ -1072,7 +1074,7 @@ def get_unclosed(): raise RuntimeError("Unclosed Comms", get_unclosed()) finally: Comm._instances.clear() - _global_clients.clear() + _global_client_manager.clear() for w in workers: if getattr(w, "data", None): @@ -1716,17 +1718,17 @@ def check_instances(): SchedulerTaskState._instances.clear() WorkerTaskState._instances.clear() Nanny._instances.clear() - _global_clients.clear() + _global_client_manager.clear() Comm._instances.clear() yield start = time() - while set(_global_clients): + while _global_client_manager._get_global_client(): sleep(0.1) assert time() < start + 10 - _global_clients.clear() + _global_client_manager.clear() for w in Worker._instances: with suppress(RuntimeError): # closed IOLoop From 103db32871ccd952dac0b3062a78657570c69958 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 1 Apr 2022 13:17:42 +0200 Subject: [PATCH 2/4] ensure_default_client updates state on client --- distributed/client.py | 4 +++- distributed/tests/test_client.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 1102e58a128..83b20d2dd8e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1633,7 +1633,8 @@ async def _close(self, fast=False): return self.status = "closing" - _del_global_client(self) + if self._set_as_default: + _del_global_client(self) for preload in self.preloads: await preload.teardown() @@ -5240,6 +5241,7 @@ def ensure_default_client(client: Client) -> None: client : Client The client """ + client._set_as_default = True _set_global_client(client) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 98ed30a0eb1..64a465ff128 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3462,7 +3462,9 @@ async def test_default_global_client_multi_clients(s): async def test_ensure_default_client(c, s, a, b): assert c is default_client() - async with Client(s.address, set_as_default=False, asynchronous=True) as c2: + async with Client( + s.address, set_as_default=False, asynchronous=True, name="c2" + ) as c2: assert c is default_client() assert c2 is not default_client() ensure_default_client(c2) From 0f4683441ba72dfc8627a11efd3cbb62055e63ef Mon Sep 17 00:00:00 2001 From: fjetter Date: Mon, 12 Dec 2022 16:28:59 +0100 Subject: [PATCH 3/4] Fix linting --- distributed/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index 83b20d2dd8e..bcd61301638 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -119,7 +119,6 @@ } - class _GlobalClientManager: def __init__(self): self._lock = threading.RLock() From caf75f518c02ebd783739d58fa470468dc923911 Mon Sep 17 00:00:00 2001 From: fjetter Date: Mon, 12 Dec 2022 17:39:20 +0100 Subject: [PATCH 4/4] Remove getattr --- distributed/client.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index bcd61301638..d72b85e17da 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5566,16 +5566,3 @@ def temp_default_client(c: Client) -> Generator[None, None, None]: yield finally: _set_global_client(old_exec) - - -def __getattr__(name): - if name == "ensure_default_get": - warnings.warn( - "`ensure_default_get` is deprecated and will be removed in a future release. " - "Please use `distributed.client.ensure_default_client` instead.", - category=FutureWarning, - stacklevel=2, - ) - return ensure_default_client - else: - raise AttributeError(f"module {__name__} has no attribute {name}")