diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c8c849eea79..15fa0d96928 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5064,6 +5064,30 @@ def f(x): del distributed.tmp_client +@pytest.mark.parametrize("asynchronous", [True, False]) +@gen_cluster(client=False) +async def test_get_async_clients(c, s, a, b, asynchronous): + async def create_client(scheduler): + client = await get_client( + address=scheduler.address, timeout=2, asynchronous=asynchronous + ) + return client + + futures = [asyncio.create_task(create_client(s)) for i in range(5)] + + clients = await asyncio.gather(*futures) + for i, client in enumerate(clients): + assert client.scheduler.address == s.address + assert client._asynchronous == asynchronous + if i < 4: + assert clients[i] == clients[i + 1] + + if asynchronous: + await clients[0].close() + else: + clients[0].close() + + def test_get_client_no_cluster(): # Clean up any global workers added by other tests. This test requires that # there are no global workers. diff --git a/distributed/worker.py b/distributed/worker.py index 221e4fc6c29..1e38fa3b33e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2555,7 +2555,7 @@ def get_worker() -> Worker: raise ValueError("No workers found") -def get_client(address=None, timeout=None, resolve_address=True) -> Client: +def get_client(address=None, timeout=None, resolve_address=True, asynchronous=False) -> Client: """Get a client while within a task. This client connects to the same scheduler to which the worker is connected @@ -2615,10 +2615,12 @@ def get_client(address=None, timeout=None, resolve_address=True) -> Client: client = Client.current() # TODO: assumes the same scheduler except ValueError: client = None - if client and (not address or client.scheduler.address == address): + if client and (not address or not client.scheduler or client.scheduler.address == address): + # Client may not have yet been awaited on if get_client() was + # previously called with asynchronous=True return client elif address: - return Client(address, timeout=timeout) + return Client(address, timeout=timeout, asynchronous=asynchronous) else: raise ValueError("No global client found and no address provided")