Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3196,3 +3196,52 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
args, kwargs = mocked_gather.call_args
await Worker.gather_dep(b, *args, **kwargs)
await fut3


@gen_cluster(client=True)
async def test_get_worker_in_task(c: Client, s: Scheduler, *workers: Worker):
with pytest.raises(ValueError, match="Not running within a worker"):
get_worker()

def check_worker(expected: str):
w = get_worker()
assert w.address == expected

for w in workers:
await c.submit(check_worker, w.address, workers=[w.address]).result()


@gen_cluster(client=True)
async def test_get_worker_in_run(c: Client, s: Scheduler, *workers: Worker):
def check_worker(dask_worker: Worker):
w = get_worker()
assert w is dask_worker
return w.address

results = await c.run(check_worker)
assert results == {w.address: w.address for w in workers}


@gen_cluster(client=True)
async def test_get_worker_serialize_deserialize(
c: Client, s: Scheduler, a: Worker, b: Worker
):
class Checker:
def __init__(self) -> None:
self.sender = None
self.receiver = None

def result(self):
return self.sender, self.receiver

def __getstate__(self):
return (get_worker().address,)

def __setstate__(self, state):
self.sender = state[0]
self.receiver = get_worker().address

f = c.submit(Checker, workers=[a.address])
moved = c.submit(Checker.result, f, workers=[b.address])
result = await moved.result()
assert result == (a.address, b.address)
67 changes: 58 additions & 9 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import bisect
import builtins
import contextvars
import errno
import heapq
import logging
Expand All @@ -19,10 +20,11 @@
from datetime import timedelta
from inspect import isawaitable
from pickle import PicklingError
from typing import TYPE_CHECKING, Any, ClassVar, Container
from typing import TYPE_CHECKING, Any, Awaitable, ClassVar, Container, TypeVar

from typing_extensions import Concatenate, Literal, ParamSpec, TypeGuard

if TYPE_CHECKING:
from typing_extensions import Literal
from .diagnostics.plugin import WorkerPlugin
from .actor import Actor
from .client import Client
Expand Down Expand Up @@ -125,6 +127,51 @@

SerializedTask = namedtuple("SerializedTask", ["function", "args", "kwargs", "task"])

_current_worker: contextvars.ContextVar[Worker] = contextvars.ContextVar(
"current_worker"
)

P = ParamSpec("P")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was recently burned by ParamSpec in that not all ecosystem libraries supported typing_extensions~=3.10 (in my case it was tensorflow). That may have improved in the last few months, but something to consider

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, trying to get ParamSpec working probably took 4x longer than actually implementing the content of this PR. I'm just really excited about it.

Since mypy doesn't even support it properly, there's not much value in it here yet. I had thought we could leave it around and remove the # type: ignores in the future, but I probably should just take it out before merging this.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also excited about ParamSpec, and would like to use it for every decorator I see

T = TypeVar("T")

# All type-ignores here and next function are because mypy doesn't
# properly support ParamSpec yet: https://github.com/python/mypy/issues/8645


def is_coroutine_function(
user_callback: Callable[P, T | Awaitable[T]], # type: ignore
) -> TypeGuard[Callable[P, Awaitable[T]]]: # type: ignore
# mypy/pyright doesn't support narrowing on `iscoroutinefunction`, so we define it ourselves.
# https://github.com/microsoft/pyright/issues/2142#issuecomment-891985575
return asyncio.iscoroutinefunction(user_callback)


def as_current_worker(
f: Callable[Concatenate[Worker, P], T] # type: ignore
) -> Callable[Concatenate[Worker, P], T]: # type: ignore
"Decorator that sets `_current_worker` to `self` while the function is running"

if is_coroutine_function(f):

async def inner_async(self: Worker, *args: P.args, **kwargs: P.kwargs) -> T: # type: ignore
token = _current_worker.set(self)
try:
return await f(self, *args, **kwargs)
finally:
_current_worker.reset(token)

return inner_async
else:

def inner(self: Worker, *args: P.args, **kwargs: P.kwargs) -> T: # type: ignore
token = _current_worker.set(self)
try:
return f(self, *args, **kwargs)
finally:
_current_worker.reset(token)

return inner


class InvalidTransition(Exception):
pass
Expand Down Expand Up @@ -536,6 +583,7 @@ class Worker(ServerNode):
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]

@as_current_worker
def __init__(
self,
scheduler_ip: str | None = None,
Expand Down Expand Up @@ -1399,6 +1447,7 @@ def get_monitor_info(self, comm=None, recent=False, start=0):
# Lifecycle #
#############

@as_current_worker
async def start(self):
if self.status and self.status in (
Status.closed,
Expand Down Expand Up @@ -3288,6 +3337,7 @@ async def execute(self, key, *, stimulus_id):
args2,
kwargs2,
self.execution_state,
contextvars.copy_context(),
ts.key,
self.active_threads,
self.active_threads_lock,
Expand Down Expand Up @@ -3913,12 +3963,9 @@ def get_worker() -> Worker:
worker_client
"""
try:
return thread_state.execution_state["worker"]
except AttributeError:
try:
return first(w for w in Worker._instances if w.status in RUNNING)
except StopIteration:
raise ValueError("No workers found")
return _current_worker.get()
except LookupError:
raise ValueError("Not running within a worker") from None


def get_client(address=None, timeout=None, resolve_address=True) -> Client:
Expand Down Expand Up @@ -4242,6 +4289,7 @@ def apply_function(
args,
kwargs,
execution_state,
context: contextvars.Context,
key,
active_threads,
active_threads_lock,
Expand All @@ -4260,7 +4308,8 @@ def apply_function(
thread_state.execution_state = execution_state
thread_state.key = key

msg = apply_function_simple(function, args, kwargs, time_delay)
msg = context.run(apply_function_simple, function, args, kwargs, time_delay)
# NOTE: context passed manually because of https://bugs.python.org/issue34014

with active_threads_lock:
del active_threads[ident]
Expand Down