diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 63d34015914..b5a9c436782 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3196,3 +3196,52 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker( args, kwargs = mocked_gather.call_args await Worker.gather_dep(b, *args, **kwargs) await fut3 + + +@gen_cluster(client=True) +async def test_get_worker_in_task(c: Client, s: Scheduler, *workers: Worker): + with pytest.raises(ValueError, match="Not running within a worker"): + get_worker() + + def check_worker(expected: str): + w = get_worker() + assert w.address == expected + + for w in workers: + await c.submit(check_worker, w.address, workers=[w.address]).result() + + +@gen_cluster(client=True) +async def test_get_worker_in_run(c: Client, s: Scheduler, *workers: Worker): + def check_worker(dask_worker: Worker): + w = get_worker() + assert w is dask_worker + return w.address + + results = await c.run(check_worker) + assert results == {w.address: w.address for w in workers} + + +@gen_cluster(client=True) +async def test_get_worker_serialize_deserialize( + c: Client, s: Scheduler, a: Worker, b: Worker +): + class Checker: + def __init__(self) -> None: + self.sender = None + self.receiver = None + + def result(self): + return self.sender, self.receiver + + def __getstate__(self): + return (get_worker().address,) + + def __setstate__(self, state): + self.sender = state[0] + self.receiver = get_worker().address + + f = c.submit(Checker, workers=[a.address]) + moved = c.submit(Checker.result, f, workers=[b.address]) + result = await moved.result() + assert result == (a.address, b.address) diff --git a/distributed/worker.py b/distributed/worker.py index bbd707ae7bb..b72c9e18ad5 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3,6 +3,7 @@ import asyncio import bisect import builtins +import contextvars import errno import heapq import logging @@ -19,10 +20,11 @@ from datetime import timedelta from inspect import isawaitable from pickle import PicklingError -from typing import TYPE_CHECKING, Any, ClassVar, Container +from typing import TYPE_CHECKING, Any, Awaitable, ClassVar, Container, TypeVar + +from typing_extensions import Concatenate, Literal, ParamSpec, TypeGuard if TYPE_CHECKING: - from typing_extensions import Literal from .diagnostics.plugin import WorkerPlugin from .actor import Actor from .client import Client @@ -125,6 +127,51 @@ SerializedTask = namedtuple("SerializedTask", ["function", "args", "kwargs", "task"]) +_current_worker: contextvars.ContextVar[Worker] = contextvars.ContextVar( + "current_worker" +) + +P = ParamSpec("P") +T = TypeVar("T") + +# All type-ignores here and next function are because mypy doesn't +# properly support ParamSpec yet: https://github.com/python/mypy/issues/8645 + + +def is_coroutine_function( + user_callback: Callable[P, T | Awaitable[T]], # type: ignore +) -> TypeGuard[Callable[P, Awaitable[T]]]: # type: ignore + # mypy/pyright doesn't support narrowing on `iscoroutinefunction`, so we define it ourselves. + # https://github.com/microsoft/pyright/issues/2142#issuecomment-891985575 + return asyncio.iscoroutinefunction(user_callback) + + +def as_current_worker( + f: Callable[Concatenate[Worker, P], T] # type: ignore +) -> Callable[Concatenate[Worker, P], T]: # type: ignore + "Decorator that sets `_current_worker` to `self` while the function is running" + + if is_coroutine_function(f): + + async def inner_async(self: Worker, *args: P.args, **kwargs: P.kwargs) -> T: # type: ignore + token = _current_worker.set(self) + try: + return await f(self, *args, **kwargs) + finally: + _current_worker.reset(token) + + return inner_async + else: + + def inner(self: Worker, *args: P.args, **kwargs: P.kwargs) -> T: # type: ignore + token = _current_worker.set(self) + try: + return f(self, *args, **kwargs) + finally: + _current_worker.reset(token) + + return inner + class InvalidTransition(Exception): pass @@ -536,6 +583,7 @@ class Worker(ServerNode): plugins: dict[str, WorkerPlugin] _pending_plugins: tuple[WorkerPlugin, ...] + @as_current_worker def __init__( self, scheduler_ip: str | None = None, @@ -1399,6 +1447,7 @@ def get_monitor_info(self, comm=None, recent=False, start=0): # Lifecycle # ############# + @as_current_worker async def start(self): if self.status and self.status in ( Status.closed, @@ -3288,6 +3337,7 @@ async def execute(self, key, *, stimulus_id): args2, kwargs2, self.execution_state, + contextvars.copy_context(), ts.key, self.active_threads, self.active_threads_lock, @@ -3913,12 +3963,9 @@ def get_worker() -> Worker: worker_client """ try: - return thread_state.execution_state["worker"] - except AttributeError: - try: - return first(w for w in Worker._instances if w.status in RUNNING) - except StopIteration: - raise ValueError("No workers found") + return _current_worker.get() + except LookupError: + raise ValueError("Not running within a worker") from None def get_client(address=None, timeout=None, resolve_address=True) -> Client: @@ -4242,6 +4289,7 @@ def apply_function( args, kwargs, execution_state, + context: contextvars.Context, key, active_threads, active_threads_lock, @@ -4260,7 +4308,8 @@ def apply_function( thread_state.execution_state = execution_state thread_state.key = key - msg = apply_function_simple(function, args, kwargs, time_delay) + msg = context.run(apply_function_simple, function, args, kwargs, time_delay) + # NOTE: context passed manually because of https://bugs.python.org/issue34014 with active_threads_lock: del active_threads[ident]