diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 449b406c928..04e31847609 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -1,7 +1,4 @@ import asyncio -import contextvars -import functools -import sys import pytest from click.testing import CliRunner @@ -18,33 +15,12 @@ import distributed.cli.dask_worker from distributed import Client -from distributed.compatibility import LINUX +from distributed.compatibility import LINUX, to_thread from distributed.deploy.utils import nprocesses_nthreads from distributed.metrics import time from distributed.utils import parse_ports, sync from distributed.utils_test import gen_cluster, popen, requires_ipv6 -if sys.version_info >= (3, 9): - from asyncio import to_thread -else: - - async def to_thread(*func_args, **kwargs): - """Asynchronously run function *func* in a separate thread. - Any *args and **kwargs supplied for this function are directly passed - to *func*. Also, the current :class:`contextvars.Context` is propagated, - allowing context variables from the main thread to be accessed in the - separate thread. - Return a coroutine that can be awaited to get the eventual result of *func*. - - backport from - https://github.com/python/cpython/blob/3f1ea163ea54513e00e0e9d5442fee1b639825cc/Lib/asyncio/threads.py#L12-L25 - """ - func, *args = func_args - loop = asyncio.get_running_loop() - ctx = contextvars.copy_context() - func_call = functools.partial(ctx.run, func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) - def test_nanny_worker_ports(loop): with popen(["dask-scheduler", "--port", "9359", "--no-dashboard"]): diff --git a/distributed/compatibility.py b/distributed/compatibility.py index 352667ec2bd..32c94151d55 100644 --- a/distributed/compatibility.py +++ b/distributed/compatibility.py @@ -12,3 +12,28 @@ LINUX = sys.platform == "linux" MACOS = sys.platform == "darwin" WINDOWS = sys.platform.startswith("win") + + +if sys.version_info >= (3, 9): + from asyncio import to_thread +else: + import contextvars + import functools + from asyncio import events + + async def to_thread(func, /, *args, **kwargs): + """Asynchronously run function *func* in a separate thread. + Any *args and **kwargs supplied for this function are directly passed + to *func*. Also, the current :class:`contextvars.Context` is propagated, + allowing context variables from the main thread to be accessed in the + separate thread. + + Return a coroutine that can be awaited to get the eventual result of *func*. + + backport from + https://github.com/python/cpython/blob/3f1ea163ea54513e00e0e9d5442fee1b639825cc/Lib/asyncio/threads.py#L12-L25 + """ + loop = events.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 2cb36819713..c75d8abd7ea 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3805,3 +3805,37 @@ def test_unique_task_heap(): assert heap.pop() == ts assert repr(heap) == "" + + +@gen_cluster(nthreads=[]) +async def test_do_not_block_event_loop_during_shutdown(s): + loop = asyncio.get_running_loop() + called_handler = threading.Event() + block_handler = threading.Event() + + w = await Worker(s.address) + executor = w.executors["default"] + + # The block wait must be smaller than the test timeout and smaller than the + # default value for timeout in `Worker.close`` + async def block(): + def fn(): + called_handler.set() + assert block_handler.wait(20) + + await loop.run_in_executor(executor, fn) + + async def set_future(): + while True: + try: + await loop.run_in_executor(executor, sleep, 0.1) + except RuntimeError: # executor has started shutting down + block_handler.set() + return + + async def close(): + called_handler.wait() + # executor_wait is True by default but we want to be explicit here + await w.close(executor_wait=True) + + await asyncio.gather(block(), close(), set_future()) diff --git a/distributed/worker.py b/distributed/worker.py index eb0d07c6505..589ab3aee54 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -50,6 +50,7 @@ from .comm import Comm, connect, get_address_host from .comm.addressing import address_from_user_args, parse_address from .comm.utils import OFFLOAD_THRESHOLD +from .compatibility import to_thread from .core import ( CommClosedError, Status, @@ -1735,11 +1736,19 @@ async def close( for executor in self.executors.values(): if executor is utils._offload_executor: continue # Never shutdown the offload executor - if isinstance(executor, ThreadPoolExecutor): - executor._work_queue.queue.clear() - executor.shutdown(wait=executor_wait, timeout=timeout) - else: - executor.shutdown(wait=executor_wait) + + def _close(): + if isinstance(executor, ThreadPoolExecutor): + executor._work_queue.queue.clear() + executor.shutdown(wait=executor_wait, timeout=timeout) + else: + executor.shutdown(wait=executor_wait) + + # Waiting for the shutdown can block the event loop causing + # weird deadlocks particularly if the task that is executing in + # the thread is waiting for a server reply, e.g. when using + # worker clients, semaphores, etc. + await to_thread(_close) self.stop() await self.rpc.close()