diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 6bea4a3fdf5..9346da09e95 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3550,3 +3550,37 @@ async def test_broken_comm(c, s, a, b): ) s = df.shuffle("id", shuffle="tasks") await c.compute(s.size) + + +@gen_cluster(nthreads=[]) +async def test_do_not_block_event_loop_during_shutdown(s): + loop = asyncio.get_running_loop() + called_handler = threading.Event() + block_handler = threading.Event() + + w = await Worker(s.address) + executor = w.executors["default"] + + # The block wait must be smaller than the test timeout and smaller than the + # default value for timeout in `Worker.close`` + async def block(): + def fn(): + called_handler.set() + assert block_handler.wait(20) + + await loop.run_in_executor(executor, fn) + + async def set_future(): + while True: + try: + await loop.run_in_executor(executor, sleep, 0.1) + except RuntimeError: # executor has started shutting down + block_handler.set() + return + + async def close(): + called_handler.wait() + # executor_wait is True by default but we want to be explicit here + await w.close(executor_wait=True) + + await asyncio.gather(block(), close(), set_future()) diff --git a/distributed/worker.py b/distributed/worker.py index e5563ed09d7..6a18bd1cd75 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -53,7 +53,7 @@ from distributed.comm import connect, get_address_host from distributed.comm.addressing import address_from_user_args, parse_address from distributed.comm.utils import OFFLOAD_THRESHOLD -from distributed.compatibility import randbytes +from distributed.compatibility import randbytes, to_thread from distributed.core import ( CommClosedError, ConnectionPool, @@ -1525,11 +1525,19 @@ async def close( for executor in self.executors.values(): if executor is utils._offload_executor: continue # Never shutdown the offload executor - if isinstance(executor, ThreadPoolExecutor): - executor._work_queue.queue.clear() - executor.shutdown(wait=executor_wait, timeout=timeout) - else: - executor.shutdown(wait=executor_wait) + + def _close(): + if isinstance(executor, ThreadPoolExecutor): + executor._work_queue.queue.clear() + executor.shutdown(wait=executor_wait, timeout=timeout) + else: + executor.shutdown(wait=executor_wait) + + # Waiting for the shutdown can block the event loop causing + # weird deadlocks particularly if the task that is executing in + # the thread is waiting for a server reply, e.g. when using + # worker clients, semaphores, etc. + await to_thread(_close) self.stop() await self.rpc.close()