diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 7a3126e9d1f..8303c18d478 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -44,6 +44,7 @@ TaskStateMetadataPlugin, _LockedCommPool, captured_logger, + cluster, dec, div, gen_cluster, @@ -965,6 +966,36 @@ def f(x): assert a._client is a_client +@gen_cluster(client=True, nthreads=[("127.0.0.1", 4)]) +async def test_get_client_threadsafe(c, s, a): + def f(x): + client = get_client() + # use client just to prove it's working + info = client.scheduler_info() + assert info + return client.id + + futures = c.map(f, range(100)) + ids = await c.gather(futures) + assert len(set(ids)) == 1 + + +def test_get_client_threadsafe_sync(): + def f(x): + client = get_client() + # use client just to prove it's working + info = client.scheduler_info() + assert info + return client.id + + with cluster(nworkers=1, worker_kwargs={"nthreads": 4}) as (scheduler, workers): + with Client(scheduler["address"]) as client: + futures = client.map(f, range(100)) + ids = client.gather(futures) + assert len(set(ids)) == 1 + assert set(ids) != {client.id} + + def test_get_client_sync(client): def f(x): cc = get_client() diff --git a/distributed/worker.py b/distributed/worker.py index ee17e115d4e..233e26a64f1 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -449,7 +449,7 @@ def __init__( self.has_what = defaultdict(set) self.pending_data_per_worker = defaultdict(deque) self.nanny = nanny - self._lock = threading.Lock() + self._client_lock = threading.Lock() self.data_needed = [] @@ -3554,11 +3554,7 @@ def validate_state(self): @property def client(self) -> Client: - with self._lock: - if self._client: - return self._client - else: - return self._get_client() + return self._get_client() def _get_client(self, timeout=None) -> Client: """Get local client attached to this worker @@ -3569,56 +3565,64 @@ def _get_client(self, timeout=None) -> Client: -------- get_client """ + if self._client: + # Assuming `self._client` cannot become None again, this lets us skip the lock + # once a client already exists. + return self._client + with self._client_lock: + if self._client: + return self._client - if timeout is None: - timeout = dask.config.get("distributed.comm.timeouts.connect") - - timeout = parse_timedelta(timeout, "s") + if timeout is None: + timeout = dask.config.get("distributed.comm.timeouts.connect") - try: - from .client import default_client + timeout = parse_timedelta(timeout, "s") - client = default_client() - except ValueError: # no clients found, need to make a new one - pass - else: - # must be lazy import otherwise cyclic import - from distributed.deploy.cluster import Cluster + try: + from .client import default_client - if ( - client.scheduler - and client.scheduler.address == self.scheduler.address - # The below conditions should only happen in case a second - # cluster is alive, e.g. if a submitted task spawned its onwn - # LocalCluster, see gh4565 - or ( - isinstance(client._start_arg, str) - and client._start_arg == self.scheduler.address - or isinstance(client._start_arg, Cluster) - and client._start_arg.scheduler_address == self.scheduler.address + client = default_client() + except ValueError: # no clients found, need to make a new one + pass + else: + # must be lazy import otherwise cyclic import + from distributed.deploy.cluster import Cluster + + if ( + client.scheduler + and client.scheduler.address == self.scheduler.address + # The below conditions should only happen in case a second + # cluster is alive, e.g. if a submitted task spawned its own + # LocalCluster, see gh4565 + or ( + isinstance(client._start_arg, str) + and client._start_arg == self.scheduler.address + or isinstance(client._start_arg, Cluster) + and client._start_arg.scheduler_address + == self.scheduler.address + ) + ): + self._client = client + + if not self._client: + from .client import Client + + asynchronous = self.loop is IOLoop.current() + self._client = Client( + self.scheduler, + loop=self.loop, + security=self.security, + set_as_default=True, + asynchronous=asynchronous, + direct_to_workers=True, + name="worker", + timeout=timeout, ) - ): - self._client = client - - if not self._client: - from .client import Client - - asynchronous = self.loop is IOLoop.current() - self._client = Client( - self.scheduler, - loop=self.loop, - security=self.security, - set_as_default=True, - asynchronous=asynchronous, - direct_to_workers=True, - name="worker", - timeout=timeout, - ) - Worker._initialized_clients.add(self._client) - if not asynchronous: - assert self._client.status == "running" + Worker._initialized_clients.add(self._client) + if not asynchronous: + assert self._client.status == "running" - return self._client + return self._client def get_current_task(self): """Get the key of the task we are currently running