From fa4763bf5a911bc20c1993dc3886409717dd1902 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 26 Oct 2021 23:26:15 -0600 Subject: [PATCH 1/5] Consistent worker Client instance in `get_client` Fixes #4959 `get_client` was calling the private `Worker._get_client` method when it ran within a task. `_get_client` should really have been called `_make_client`, since it created a new client every time. The simplest correct thing to do instead would have been to use the `Worker.client` property, which caches this instance. In order to pass the `timeout` parameter through though, I changed `Worker.get_client` to actually match its docstring and always return the same instance. --- distributed/tests/test_worker.py | 23 +++++++ distributed/worker.py | 100 +++++++++++++++---------------- 2 files changed, 73 insertions(+), 50 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 7a3126e9d1f..ca87b2d24ab 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -44,6 +44,7 @@ TaskStateMetadataPlugin, _LockedCommPool, captured_logger, + cluster, dec, div, gen_cluster, @@ -965,6 +966,28 @@ def f(x): assert a._client is a_client +@gen_cluster(client=True, nthreads=[("127.0.0.1", 4)]) +async def test_get_client_threadsafe(c, s, a): + def f(x): + return get_client().id + + futures = c.map(f, range(100)) + ids = await c.gather(futures) + assert len(set(ids)) == 1 + + +def test_get_client_threadsafe_sync(): + def f(x): + return get_client().id + + with cluster(nworkers=1, worker_kwargs={"nthreads": 4}) as (scheduler, workers): + with Client(scheduler["address"]) as client: + futures = client.map(f, range(100)) + ids = client.gather(futures) + assert len(set(ids)) == 1 + assert set(ids) != {client.id} + + def test_get_client_sync(client): def f(x): cc = get_client() diff --git a/distributed/worker.py b/distributed/worker.py index ee17e115d4e..caf0103e189 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -449,7 +449,7 @@ def __init__( self.has_what = defaultdict(set) self.pending_data_per_worker = defaultdict(deque) self.nanny = nanny - self._lock = threading.Lock() + self._client_lock = threading.Lock() self.data_needed = [] @@ -3554,11 +3554,7 @@ def validate_state(self): @property def client(self) -> Client: - with self._lock: - if self._client: - return self._client - else: - return self._get_client() + return self._get_client() def _get_client(self, timeout=None) -> Client: """Get local client attached to this worker @@ -3569,56 +3565,60 @@ def _get_client(self, timeout=None) -> Client: -------- get_client """ + with self._client_lock: + if self._client: + return self._client - if timeout is None: - timeout = dask.config.get("distributed.comm.timeouts.connect") - - timeout = parse_timedelta(timeout, "s") + if timeout is None: + timeout = dask.config.get("distributed.comm.timeouts.connect") - try: - from .client import default_client + timeout = parse_timedelta(timeout, "s") - client = default_client() - except ValueError: # no clients found, need to make a new one - pass - else: - # must be lazy import otherwise cyclic import - from distributed.deploy.cluster import Cluster + try: + from .client import default_client - if ( - client.scheduler - and client.scheduler.address == self.scheduler.address - # The below conditions should only happen in case a second - # cluster is alive, e.g. if a submitted task spawned its onwn - # LocalCluster, see gh4565 - or ( - isinstance(client._start_arg, str) - and client._start_arg == self.scheduler.address - or isinstance(client._start_arg, Cluster) - and client._start_arg.scheduler_address == self.scheduler.address + client = default_client() + except ValueError: # no clients found, need to make a new one + pass + else: + # must be lazy import otherwise cyclic import + from distributed.deploy.cluster import Cluster + + if ( + client.scheduler + and client.scheduler.address == self.scheduler.address + # The below conditions should only happen in case a second + # cluster is alive, e.g. if a submitted task spawned its onwn + # LocalCluster, see gh4565 + or ( + isinstance(client._start_arg, str) + and client._start_arg == self.scheduler.address + or isinstance(client._start_arg, Cluster) + and client._start_arg.scheduler_address + == self.scheduler.address + ) + ): + self._client = client + + if not self._client: + from .client import Client + + asynchronous = self.loop is IOLoop.current() + self._client = Client( + self.scheduler, + loop=self.loop, + security=self.security, + set_as_default=True, + asynchronous=asynchronous, + direct_to_workers=True, + name="worker", + timeout=timeout, ) - ): - self._client = client - - if not self._client: - from .client import Client - - asynchronous = self.loop is IOLoop.current() - self._client = Client( - self.scheduler, - loop=self.loop, - security=self.security, - set_as_default=True, - asynchronous=asynchronous, - direct_to_workers=True, - name="worker", - timeout=timeout, - ) - Worker._initialized_clients.add(self._client) - if not asynchronous: - assert self._client.status == "running" + Worker._initialized_clients.add(self._client) + if not asynchronous: + assert self._client.status == "running" - return self._client + return self._client def get_current_task(self): """Get the key of the task we are currently running From 12b8278f0296657979adfae2c912002a81ecd310 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Mon, 1 Nov 2021 15:06:05 -0600 Subject: [PATCH 2/5] typo Co-authored-by: crusaderky --- distributed/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index caf0103e189..62798b5ca0e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3588,7 +3588,7 @@ def _get_client(self, timeout=None) -> Client: client.scheduler and client.scheduler.address == self.scheduler.address # The below conditions should only happen in case a second - # cluster is alive, e.g. if a submitted task spawned its onwn + # cluster is alive, e.g. if a submitted task spawned its own # LocalCluster, see gh4565 or ( isinstance(client._start_arg, str) From ecaf5ab26ace5dee8c0afd2b942d7c98149c82a9 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Mon, 1 Nov 2021 15:15:40 -0600 Subject: [PATCH 3/5] skip lock once client exists --- distributed/worker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/distributed/worker.py b/distributed/worker.py index 62798b5ca0e..233e26a64f1 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3565,6 +3565,10 @@ def _get_client(self, timeout=None) -> Client: -------- get_client """ + if self._client: + # Assuming `self._client` cannot become None again, this lets us skip the lock + # once a client already exists. + return self._client with self._client_lock: if self._client: return self._client From b7a8d0484e3dacb92d7184d2b521ae1891e12d5c Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Mon, 1 Nov 2021 15:37:34 -0600 Subject: [PATCH 4/5] Create futures within task to exercise client --- distributed/tests/test_worker.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index ca87b2d24ab..8da48884560 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -969,24 +969,36 @@ def f(x): @gen_cluster(client=True, nthreads=[("127.0.0.1", 4)]) async def test_get_client_threadsafe(c, s, a): def f(x): - return get_client().id + client = get_client() + # use client just to prove it's working + future = client.submit(inc, x) + return client.id, future futures = c.map(f, range(100)) - ids = await c.gather(futures) + ids, new_futures = zip(*await c.gather(futures)) assert len(set(ids)) == 1 + new_results = await c.gather(new_futures) + assert sorted(new_results) == list(range(1, 101)) + def test_get_client_threadsafe_sync(): def f(x): - return get_client().id + client = get_client() + # use client just to prove it's working + future = client.submit(inc, x) + return client.id, future with cluster(nworkers=1, worker_kwargs={"nthreads": 4}) as (scheduler, workers): with Client(scheduler["address"]) as client: futures = client.map(f, range(100)) - ids = client.gather(futures) + ids, new_futures = zip(*client.gather(futures)) assert len(set(ids)) == 1 assert set(ids) != {client.id} + new_results = client.gather(new_futures) + assert sorted(new_results) == list(range(1, 101)) + def test_get_client_sync(client): def f(x): From d41c82b2d6e30e54270b8e082aecf1bfba57fcfa Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Mon, 1 Nov 2021 15:41:23 -0600 Subject: [PATCH 5/5] Use a simpler client RPC Futures work fine, but there's all sorts of extra complexity with returning futures from futures. Don't want this to become a flaky test because of distributed reference-counting bugs with pickling futures, since that's not what this test is about. --- distributed/tests/test_worker.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 8da48884560..8303c18d478 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -971,34 +971,30 @@ async def test_get_client_threadsafe(c, s, a): def f(x): client = get_client() # use client just to prove it's working - future = client.submit(inc, x) - return client.id, future + info = client.scheduler_info() + assert info + return client.id futures = c.map(f, range(100)) - ids, new_futures = zip(*await c.gather(futures)) + ids = await c.gather(futures) assert len(set(ids)) == 1 - new_results = await c.gather(new_futures) - assert sorted(new_results) == list(range(1, 101)) - def test_get_client_threadsafe_sync(): def f(x): client = get_client() # use client just to prove it's working - future = client.submit(inc, x) - return client.id, future + info = client.scheduler_info() + assert info + return client.id with cluster(nworkers=1, worker_kwargs={"nthreads": 4}) as (scheduler, workers): with Client(scheduler["address"]) as client: futures = client.map(f, range(100)) - ids, new_futures = zip(*client.gather(futures)) + ids = client.gather(futures) assert len(set(ids)) == 1 assert set(ids) != {client.id} - new_results = client.gather(new_futures) - assert sorted(new_results) == list(range(1, 101)) - def test_get_client_sync(client): def f(x):