From cee08313bc77b58ebceaa421067a02a96be8a6ab Mon Sep 17 00:00:00 2001 From: Casey Jao Date: Thu, 14 Jul 2022 10:25:17 -0400 Subject: [PATCH 1/3] Support `asynchronous` option in `get_client()` --- distributed/worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index 221e4fc6c29..cef6d83e904 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2555,7 +2555,7 @@ def get_worker() -> Worker: raise ValueError("No workers found") -def get_client(address=None, timeout=None, resolve_address=True) -> Client: +def get_client(address=None, timeout=None, resolve_address=True, asynchronous=False) -> Client: """Get a client while within a task. This client connects to the same scheduler to which the worker is connected @@ -2618,7 +2618,7 @@ def get_client(address=None, timeout=None, resolve_address=True) -> Client: if client and (not address or client.scheduler.address == address): return client elif address: - return Client(address, timeout=timeout) + return Client(address, timeout=timeout, asynchronous=asynchronous) else: raise ValueError("No global client found and no address provided") From dad90d61e8bcc3a3bb33041b55ba81f00391acfa Mon Sep 17 00:00:00 2001 From: Casey Jao Date: Thu, 14 Jul 2022 12:07:56 -0400 Subject: [PATCH 2/3] In get_client(), don't assume that async Client has been awaited on --- distributed/worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index cef6d83e904..1e38fa3b33e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2615,7 +2615,9 @@ def get_client(address=None, timeout=None, resolve_address=True, asynchronous=Fa client = Client.current() # TODO: assumes the same scheduler except ValueError: client = None - if client and (not address or client.scheduler.address == address): + if client and (not address or not client.scheduler or client.scheduler.address == address): + # Client may not have yet been awaited on if get_client() was + # previously called with asynchronous=True return client elif address: return Client(address, timeout=timeout, asynchronous=asynchronous) From 38e903b392292a5c1b90b36c5287b3f60b591dc5 Mon Sep 17 00:00:00 2001 From: Casey Jao Date: Thu, 14 Jul 2022 13:32:28 -0400 Subject: [PATCH 3/3] Add test for async get_client --- distributed/tests/test_client.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c8c849eea79..15fa0d96928 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5064,6 +5064,30 @@ def f(x): del distributed.tmp_client +@pytest.mark.parametrize("asynchronous", [True, False]) +@gen_cluster(client=False) +async def test_get_async_clients(c, s, a, b, asynchronous): + async def create_client(scheduler): + client = await get_client( + address=scheduler.address, timeout=2, asynchronous=asynchronous + ) + return client + + futures = [asyncio.create_task(create_client(s)) for i in range(5)] + + clients = await asyncio.gather(*futures) + for i, client in enumerate(clients): + assert client.scheduler.address == s.address + assert client._asynchronous == asynchronous + if i < 4: + assert clients[i] == clients[i + 1] + + if asynchronous: + await clients[0].close() + else: + clients[0].close() + + def test_get_client_no_cluster(): # Clean up any global workers added by other tests. This test requires that # there are no global workers.