Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TaskStateMetadataPlugin,
_LockedCommPool,
captured_logger,
cluster,
dec,
div,
gen_cluster,
Expand Down Expand Up @@ -965,6 +966,36 @@ def f(x):
assert a._client is a_client


@gen_cluster(client=True, nthreads=[("127.0.0.1", 4)])
async def test_get_client_threadsafe(c, s, a):
def f(x):
client = get_client()
# use client just to prove it's working
info = client.scheduler_info()
assert info
return client.id

futures = c.map(f, range(100))
ids = await c.gather(futures)
assert len(set(ids)) == 1


def test_get_client_threadsafe_sync():
def f(x):
client = get_client()
# use client just to prove it's working
info = client.scheduler_info()
assert info
return client.id

with cluster(nworkers=1, worker_kwargs={"nthreads": 4}) as (scheduler, workers):
with Client(scheduler["address"]) as client:
futures = client.map(f, range(100))
ids = client.gather(futures)
assert len(set(ids)) == 1
assert set(ids) != {client.id}


Comment thread
gjoseph92 marked this conversation as resolved.
def test_get_client_sync(client):
def f(x):
cc = get_client()
Expand Down
104 changes: 54 additions & 50 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def __init__(
self.has_what = defaultdict(set)
self.pending_data_per_worker = defaultdict(deque)
self.nanny = nanny
self._lock = threading.Lock()
self._client_lock = threading.Lock()

self.data_needed = []

Expand Down Expand Up @@ -3554,11 +3554,7 @@ def validate_state(self):

@property
def client(self) -> Client:
with self._lock:
if self._client:
return self._client
else:
return self._get_client()
return self._get_client()

def _get_client(self, timeout=None) -> Client:
"""Get local client attached to this worker
Expand All @@ -3569,56 +3565,64 @@ def _get_client(self, timeout=None) -> Client:
--------
get_client
"""
if self._client:
# Assuming `self._client` cannot become None again, this lets us skip the lock
# once a client already exists.
return self._client
with self._client_lock:
if self._client:
return self._client
Comment thread
gjoseph92 marked this conversation as resolved.

if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")

timeout = parse_timedelta(timeout, "s")
if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")

try:
from .client import default_client
timeout = parse_timedelta(timeout, "s")

client = default_client()
except ValueError: # no clients found, need to make a new one
pass
else:
# must be lazy import otherwise cyclic import
from distributed.deploy.cluster import Cluster
try:
from .client import default_client

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be cleaner to move all imports to the top of the function

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this is just existing code. But I can refactor it.


if (
client.scheduler
and client.scheduler.address == self.scheduler.address
# The below conditions should only happen in case a second
# cluster is alive, e.g. if a submitted task spawned its onwn
# LocalCluster, see gh4565
or (
isinstance(client._start_arg, str)
and client._start_arg == self.scheduler.address
or isinstance(client._start_arg, Cluster)
and client._start_arg.scheduler_address == self.scheduler.address
client = default_client()
except ValueError: # no clients found, need to make a new one
pass
else:
# must be lazy import otherwise cyclic import
from distributed.deploy.cluster import Cluster

if (
client.scheduler
and client.scheduler.address == self.scheduler.address
# The below conditions should only happen in case a second
# cluster is alive, e.g. if a submitted task spawned its own
# LocalCluster, see gh4565
or (
isinstance(client._start_arg, str)
and client._start_arg == self.scheduler.address
or isinstance(client._start_arg, Cluster)
and client._start_arg.scheduler_address
== self.scheduler.address
)
):
self._client = client

if not self._client:
from .client import Client

asynchronous = self.loop is IOLoop.current()
self._client = Client(
self.scheduler,
loop=self.loop,
security=self.security,
set_as_default=True,
asynchronous=asynchronous,
direct_to_workers=True,
name="worker",
timeout=timeout,
)
):
self._client = client

if not self._client:
from .client import Client

asynchronous = self.loop is IOLoop.current()
self._client = Client(
self.scheduler,
loop=self.loop,
security=self.security,
set_as_default=True,
asynchronous=asynchronous,
direct_to_workers=True,
name="worker",
timeout=timeout,
)
Worker._initialized_clients.add(self._client)
if not asynchronous:
assert self._client.status == "running"
Worker._initialized_clients.add(self._client)
if not asynchronous:
assert self._client.status == "running"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need this? Isn't this already done by the sync Client constructor?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. This is all existing code I just indented under the with self._client_lock.


return self._client
return self._client

def get_current_task(self):
"""Get the key of the task we are currently running
Expand Down