Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, cls, address, key, worker=None):
self._worker = None
try:
self._client = get_client()
self._future = Future(key, inform=self._worker is None)
self._future = Future(key, self._client, inform=self._worker is None)
# ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable.
except ValueError:
self._client = None
Expand Down
137 changes: 72 additions & 65 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
from dask.delayed import single_key
except ImportError:
single_key = first
from typing import Generator

from tornado import gen
from tornado.ioloop import IOLoop

Expand Down Expand Up @@ -110,42 +112,71 @@

logger = logging.getLogger(__name__)

_global_clients: weakref.WeakValueDictionary[
int, Client
] = weakref.WeakValueDictionary()
_global_client_index = [0]

_current_client = ContextVar("_current_client", default=None)

DEFAULT_EXTENSIONS = {
"pubsub": PubSubClientExtension,
}


def _get_global_client() -> Client | None:
L = sorted(list(_global_clients), reverse=True)
for k in L:
c = _global_clients[k]
if c.status != "closed":
return c
else:
del _global_clients[k]
return None
class _GlobalClientManager:
def __init__(self):
self._lock = threading.RLock()
self._global_clients: weakref.WeakValueDictionary[
int, Client
] = weakref.WeakValueDictionary()
self._global_client_index = 0
self._set = None

def _get_global_client(self) -> Client | None:
with self._lock:
L = sorted(self._global_clients, reverse=True)
for k in L:
return self._global_clients[k]
return None

def _set_global_client(self, c: Client) -> None:
with self._lock:
if not self._set:
self._set = dask.config.set(
scheduler="dask.distributed", shuffle="tasks"
)
self._global_clients[self._global_client_index] = c
self._global_client_index += 1

def _set_global_client(c: Client | None) -> None:
if c is not None:
_global_clients[_global_client_index[0]] = c
_global_client_index[0] += 1
def _del_global_client(self, c: Client) -> None:
with self._lock:
for k in list(self._global_clients):
try:
if self._global_clients[k] is c:
del self._global_clients[k]
except KeyError: # pragma: no cover
pass

if not self._global_clients and self._set:
self._set.__exit__(None, None, None)
self._set = None

def clear(self):
with self._lock:
self._global_clients.clear()
self._global_client_index = 0

def close(self, timeout=3):
while c := self._get_global_client():
if c.asynchronous:
c.loop.add_callback(c.close, timeout=timeout)
else:
c.close(timeout=timeout)

def _del_global_client(c: Client) -> None:
for k in list(_global_clients):
try:
if _global_clients[k] is c:
del _global_clients[k]
except KeyError: # pragma: no cover
pass

_global_client_manager = _GlobalClientManager()
_set_global_client = _global_client_manager._set_global_client
_get_global_client = _global_client_manager._get_global_client
_del_global_client = _global_client_manager._del_global_client


atexit.register(_global_client_manager.close)


class Future(WrappedKey):
Expand Down Expand Up @@ -189,11 +220,17 @@ class Future(WrappedKey):
_cb_executor = None
_cb_executor_pid = None

def __init__(self, key, client=None, inform=True, state=None):
def __init__(
self,
key: str | tuple,
client: Client,
inform: bool = True,
state: FutureState | None = None,
):
self.key = key
self._cleared = False
tkey = stringify(key)
self.client = client or Client.current()
self.client = client
self.client._inc_ref(tkey)
self._generation = self.client.generation

Expand Down Expand Up @@ -944,10 +981,6 @@ def __init__(

self._start_arg = address
self._set_as_default = set_as_default
if set_as_default:
self._set_config = dask.config.set(
scheduler="dask.distributed", shuffle="tasks"
)
self._event_handlers = {}

self._stream_handlers = {
Expand Down Expand Up @@ -1176,8 +1209,6 @@ def start(self, **kwargs):
return

self._loop_runner.start()
if self._set_as_default:
_set_global_client(self)

if self.asynchronous:
self._started = asyncio.ensure_future(self._start(**kwargs))
Expand Down Expand Up @@ -1215,6 +1246,8 @@ def _send_to_scheduler(self, msg):

async def _start(self, timeout=no_default, **kwargs):
self.status = "connecting"
if self._set_as_default:
_set_global_client(self)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think moving _set_global_client into async def _start is worse - as this is likely to be called via Client(asynchronous=False) and so run in an off-main-thread eventloop thread


await self.rpc.start()

Expand Down Expand Up @@ -1599,6 +1632,8 @@ async def _close(self, fast=False):
return

self.status = "closing"
if self._set_as_default:
_del_global_client(self)

for preload in self.preloads:
await preload.teardown()
Expand All @@ -1608,14 +1643,7 @@ async def _close(self, fast=False):
pc.stop()

with log_errors():
_del_global_client(self)
self._scheduler_identity = {}
with suppress(AttributeError):
# clear the dask.config set keys
with self._set_config:
pass
if self.get == dask.config.get("get", None):
del dask.config.config["get"]

if (
self.scheduler_comm
Expand Down Expand Up @@ -1654,9 +1682,6 @@ async def _close(self, fast=False):

self.status = "closed"

if _get_global_client() is self:
_set_global_client(None)

if (
handle_report_task is not None
and handle_report_task is not current_task
Expand Down Expand Up @@ -5182,7 +5207,7 @@ def AsCompleted(*args, **kwargs):
raise Exception("This has moved to as_completed")


def default_client(c=None):
def default_client(c: Client | None = None) -> Client:
"""Return a client if one has started

Parameters
Expand All @@ -5207,15 +5232,15 @@ def default_client(c=None):
)


def ensure_default_client(client):
def ensure_default_client(client: Client) -> None:
"""Ensures the client passed as argument is set as the default

Parameters
----------
client : Client
The client
"""
dask.config.set(scheduler="dask.distributed")
client._set_as_default = True
_set_global_client(client)


Expand Down Expand Up @@ -5518,7 +5543,7 @@ def __exit__(self, exc_type, exc_value, traceback):


@contextmanager
def temp_default_client(c):
def temp_default_client(c: Client) -> Generator[None, None, None]:
"""Set the default client for the duration of the context

.. note::
Expand All @@ -5541,21 +5566,3 @@ def temp_default_client(c):
yield
finally:
_set_global_client(old_exec)


def _close_global_client():
"""
Force close of global client. This cleans up when a client
wasn't close explicitly, e.g. interactive sessions.
"""
c = _get_global_client()
if c is not None:
c._should_close_loop = False
with suppress(TimeoutError, RuntimeError):
if c.asynchronous:
c.loop.add_callback(c.close, timeout=3)
else:
c.close(timeout=3)


atexit.register(_close_global_client)
106 changes: 96 additions & 10 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,37 @@
pytestmark = pytest.mark.ci1


@contextmanager
def _ensure_no_global_client_state_left():
assert _get_global_client() is None
assert dask.base.get_scheduler() is None
assert dask.config.get("shuffle", None) is None
assert dask.config.get("scheduler", None) is None

yield
assert _get_global_client() is None
try:
assert dask.config.get("shuffle", None) is None
assert dask.config.get("scheduler", None) is None
except AssertionError:
dask.config.config.pop("shuffle", None)
dask.config.config.pop("scheduler", None)
raise
assert dask.base.get_scheduler() is None


@pytest.fixture(autouse=True)
def ensure_no_global_clients_are_leaking():
with _ensure_no_global_client_state_left():
yield


def test_modify_global_state_raises():
with pytest.raises(AssertionError):
with _ensure_no_global_client_state_left():
dask.config.config["scheduler"] = "test-value"


@gen_cluster(client=True)
async def test_submit(c, s, a, b):
x = c.submit(inc, 10, key="x")
Expand Down Expand Up @@ -3365,11 +3396,75 @@ def test_default_get(loop_in_thread):
assert dask.base.get_scheduler() == pre_get


@gen_cluster(nthreads=[])
async def test_default_global_client_multi_clients(s):
pre_get = dask.base.get_scheduler()
assert _get_global_client() is None
pytest.raises(KeyError, dask.config.get, "shuffle")

async with Client(s.address, asynchronous=True) as c:
assert dask.base.get_scheduler() == c.get
assert dask.config.get("shuffle") == "tasks"
assert dask.config.get("scheduler") == "dask.distributed"

async with Client(s.address, asynchronous=True) as d:
assert dask.base.get_scheduler() == d.get

assert dask.base.get_scheduler() == c.get

assert dask.base.get_scheduler() is None

assert _get_global_client() is None
c = await Client(s.address, asynchronous=True)

assert dask.base.get_scheduler() == c.get
assert dask.config.get("shuffle") == "tasks"
assert dask.config.get("scheduler") == "dask.distributed"

d = await Client(s.address, asynchronous=True)

assert dask.base.get_scheduler() == d.get
assert dask.config.get("shuffle") == "tasks"
assert dask.config.get("scheduler") == "dask.distributed"

await c.close()

assert dask.base.get_scheduler() == d.get
assert dask.config.get("shuffle") == "tasks"
assert dask.config.get("scheduler") == "dask.distributed"
await d.close()

assert dask.config.get("shuffle", None) is None
assert dask.config.get("scheduler", None) is None
assert dask.base.get_scheduler() is None

c = await Client(s.address, asynchronous=True)

assert dask.base.get_scheduler() == c.get
assert dask.config.get("shuffle") == "tasks"
assert dask.config.get("scheduler") == "dask.distributed"

non_default = await Client(s.address, asynchronous=True, set_as_default=False)

assert dask.base.get_scheduler() == c.get
assert dask.config.get("shuffle") == "tasks"
assert dask.config.get("scheduler") == "dask.distributed"

await c.close()

assert dask.config.get("shuffle", None) is None
assert dask.config.get("scheduler", None) is None
assert dask.base.get_scheduler() is None
await d.close()


@gen_cluster(client=True)
async def test_ensure_default_client(c, s, a, b):
assert c is default_client()

async with Client(s.address, set_as_default=False, asynchronous=True) as c2:
async with Client(
s.address, set_as_default=False, asynchronous=True, name="c2"
) as c2:
assert c is default_client()
assert c2 is not default_client()
ensure_default_client(c2)
Expand Down Expand Up @@ -5592,15 +5687,6 @@ async def test_client_with_name(s, a, b):
assert "foo" in text


@gen_cluster(client=True)
async def test_future_defaults_to_default_client(c, s, a, b):
x = c.submit(inc, 1)
await wait(x)

future = Future(x.key)
assert future.client is c


@gen_cluster(client=True)
async def test_future_auto_inform(c, s, a, b):
x = c.submit(inc, 1)
Expand Down
Loading