diff --git a/distributed/_async_taskgroup.py b/distributed/_async_taskgroup.py new file mode 100644 index 00000000000..a048491d302 --- /dev/null +++ b/distributed/_async_taskgroup.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import asyncio +import threading +from collections.abc import Callable, Coroutine +from typing import TYPE_CHECKING, Any, TypeVar + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + P = ParamSpec("P") + R = TypeVar("R") + T = TypeVar("T") + Coro = Coroutine[Any, Any, T] + + +class _LoopBoundMixin: + """Backport of the private asyncio.mixins._LoopBoundMixin from 3.11""" + + _global_lock = threading.Lock() + + _loop = None + + def _get_loop(self): + loop = asyncio.get_running_loop() + + if self._loop is None: + with self._global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + + +class AsyncTaskGroupClosedError(RuntimeError): + pass + + +def _delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]: + """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" + + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + await asyncio.sleep(delay) + return await corofunc(*args, **kwargs) + + return wrapper + + +class AsyncTaskGroup(_LoopBoundMixin): + """Collection tracking all currently running asynchronous tasks within a group""" + + #: If True, the group is closed and does not allow adding new tasks. + closed: bool + + def __init__(self) -> None: + self.closed = False + self._ongoing_tasks: set[asyncio.Task[None]] = set() + + def call_soon( + self, afunc: Callable[P, Coro[None]], /, *args: P.args, **kwargs: P.kwargs + ) -> None: + """Schedule a coroutine function to be executed as an `asyncio.Task`. + + The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments + as an `asyncio.Task`. + + Parameters + ---------- + afunc + Coroutine function to schedule. + *args + Arguments to be passed to `afunc`. + **kwargs + Keyword arguments to be passed to `afunc` + + Returns + ------- + None + + Raises + ------ + AsyncTaskGroupClosedError + If the task group is closed. + """ + if self.closed: # Avoid creating a coroutine + raise AsyncTaskGroupClosedError( + "Cannot schedule a new coroutine function as the group is already closed." + ) + task = self._get_loop().create_task(afunc(*args, **kwargs)) + task.add_done_callback(self._ongoing_tasks.remove) + self._ongoing_tasks.add(task) + return None + + def call_later( + self, + delay: float, + afunc: Callable[P, Coro[None]], + /, + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. + + The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments + as an `asyncio.Task` that is executed after `delay` seconds. + + Parameters + ---------- + delay + Delay in seconds. + afunc + Coroutine function to schedule. + *args + Arguments to be passed to `afunc`. + **kwargs + Keyword arguments to be passed to `afunc` + + Returns + ------- + The None + + Raises + ------ + AsyncTaskGroupClosedError + If the task group is closed. + """ + self.call_soon(_delayed(afunc, delay), *args, **kwargs) + + def close(self) -> None: + """Closes the task group so that no new tasks can be scheduled. + + Existing tasks continue to run. + """ + self.closed = True + + async def stop(self) -> None: + """Close the group and stop all currently running tasks. + + Closes the task group and cancels all tasks. All tasks are cancelled + an additional time for each time this task is cancelled. + """ + self.close() + + current_task = asyncio.current_task(self._get_loop()) + err = None + while tasks_to_stop := (self._ongoing_tasks - {current_task}): + for task in tasks_to_stop: + task.cancel() + try: + await asyncio.wait(tasks_to_stop) + except asyncio.CancelledError as e: + err = e + + if err is not None: + raise err + + def __len__(self): + return len(self._ongoing_tasks) diff --git a/distributed/core.py b/distributed/core.py index 90705e80515..ce94e6fb983 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -35,6 +35,7 @@ from dask.utils import parse_timedelta from distributed import profile, protocol +from distributed._async_taskgroup import AsyncTaskGroup, AsyncTaskGroupClosedError from distributed.comm import ( Comm, CommClosedError, @@ -138,151 +139,6 @@ def _expects_comm(func: Callable) -> bool: return False -class _LoopBoundMixin: - """Backport of the private asyncio.mixins._LoopBoundMixin from 3.11""" - - _global_lock = threading.Lock() - - _loop = None - - def _get_loop(self): - loop = asyncio.get_running_loop() - - if self._loop is None: - with self._global_lock: - if self._loop is None: - self._loop = loop - if loop is not self._loop: - raise RuntimeError(f"{self!r} is bound to a different event loop") - return loop - - -class AsyncTaskGroupClosedError(RuntimeError): - pass - - -def _delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]: - """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" - - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - await asyncio.sleep(delay) - return await corofunc(*args, **kwargs) - - return wrapper - - -class AsyncTaskGroup(_LoopBoundMixin): - """Collection tracking all currently running asynchronous tasks within a group""" - - #: If True, the group is closed and does not allow adding new tasks. - closed: bool - - def __init__(self) -> None: - self.closed = False - self._ongoing_tasks: set[asyncio.Task[None]] = set() - - def call_soon( - self, afunc: Callable[P, Coro[None]], /, *args: P.args, **kwargs: P.kwargs - ) -> None: - """Schedule a coroutine function to be executed as an `asyncio.Task`. - - The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments - as an `asyncio.Task`. - - Parameters - ---------- - afunc - Coroutine function to schedule. - *args - Arguments to be passed to `afunc`. - **kwargs - Keyword arguments to be passed to `afunc` - - Returns - ------- - None - - Raises - ------ - AsyncTaskGroupClosedError - If the task group is closed. - """ - if self.closed: # Avoid creating a coroutine - raise AsyncTaskGroupClosedError( - "Cannot schedule a new coroutine function as the group is already closed." - ) - task = self._get_loop().create_task(afunc(*args, **kwargs)) - task.add_done_callback(self._ongoing_tasks.remove) - self._ongoing_tasks.add(task) - return None - - def call_later( - self, - delay: float, - afunc: Callable[P, Coro[None]], - /, - *args: P.args, - **kwargs: P.kwargs, - ) -> None: - """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. - - The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments - as an `asyncio.Task` that is executed after `delay` seconds. - - Parameters - ---------- - delay - Delay in seconds. - afunc - Coroutine function to schedule. - *args - Arguments to be passed to `afunc`. - **kwargs - Keyword arguments to be passed to `afunc` - - Returns - ------- - The None - - Raises - ------ - AsyncTaskGroupClosedError - If the task group is closed. - """ - self.call_soon(_delayed(afunc, delay), *args, **kwargs) - - def close(self) -> None: - """Closes the task group so that no new tasks can be scheduled. - - Existing tasks continue to run. - """ - self.closed = True - - async def stop(self) -> None: - """Close the group and stop all currently running tasks. - - Closes the task group and cancels all tasks. All tasks are cancelled - an additional time for each time this task is cancelled. - """ - self.close() - - current_task = asyncio.current_task(self._get_loop()) - err = None - while tasks_to_stop := (self._ongoing_tasks - {current_task}): - for task in tasks_to_stop: - task.cancel() - try: - await asyncio.wait(tasks_to_stop) - except asyncio.CancelledError as e: - err = e - - if err is not None: - raise err - - def __len__(self): - return len(self._ongoing_tasks) - - class Server: """Dask Distributed Server diff --git a/distributed/nanny.py b/distributed/nanny.py index 52e4ad5b360..af0d9a62ad5 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -25,12 +25,12 @@ from dask.utils import parse_timedelta from distributed import preloading +from distributed._async_taskgroup import AsyncTaskGroupClosedError from distributed.comm import get_address_host from distributed.comm.addressing import address_from_user_args from distributed.compatibility import asyncio_run from distributed.config import get_loop_factory from distributed.core import ( - AsyncTaskGroupClosedError, CommClosedError, ErrorMessage, OKMessage, diff --git a/distributed/tests/test_async_task_group.py b/distributed/tests/test_async_task_group.py new file mode 100644 index 00000000000..12a019521c4 --- /dev/null +++ b/distributed/tests/test_async_task_group.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import asyncio +import time as timemod + +import pytest + +from distributed._async_taskgroup import AsyncTaskGroup, AsyncTaskGroupClosedError +from distributed.utils_test import gen_test + + +async def _wait_for_n_loop_cycles(n): + for _ in range(n): + await asyncio.sleep(0) + + +def test_async_task_group_initialization(): + group = AsyncTaskGroup() + assert not group.closed + assert len(group) == 0 + + +@gen_test() +async def test_async_task_group_call_soon_executes_task_in_background(): + group = AsyncTaskGroup() + ev = asyncio.Event() + flag = False + + async def set_flag(): + nonlocal flag + await ev.wait() + flag = True + + assert group.call_soon(set_flag) is None + assert len(group) == 1 + ev.set() + await _wait_for_n_loop_cycles(2) + assert len(group) == 0 + assert flag + + +@gen_test() +async def test_async_task_group_call_later_executes_delayed_task_in_background(): + group = AsyncTaskGroup() + ev = asyncio.Event() + + start = timemod.monotonic() + assert group.call_later(1, ev.set) is None + assert len(group) == 1 + await ev.wait() + end = timemod.monotonic() + # the task must be removed in exactly 1 event loop cycle + await _wait_for_n_loop_cycles(2) + assert len(group) == 0 + assert end - start > 1 - timemod.get_clock_info("monotonic").resolution + + +def test_async_task_group_close_closes(): + group = AsyncTaskGroup() + group.close() + assert group.closed + + # Test idempotency + group.close() + assert group.closed + + +@gen_test() +async def test_async_task_group_close_does_not_cancel_existing_tasks(): + group = AsyncTaskGroup() + + ev = asyncio.Event() + flag = False + + async def set_flag(): + nonlocal flag + await ev.wait() + flag = True + return None + + assert group.call_soon(set_flag) is None + + group.close() + + assert len(group) == 1 + + ev.set() + await _wait_for_n_loop_cycles(2) + assert len(group) == 0 + + +@gen_test() +async def test_async_task_group_close_prohibits_new_tasks(): + group = AsyncTaskGroup() + group.close() + + ev = asyncio.Event() + flag = False + + async def set_flag(): + nonlocal flag + await ev.wait() + flag = True + return True + + with pytest.raises(AsyncTaskGroupClosedError): + group.call_soon(set_flag) + assert len(group) == 0 + + with pytest.raises(AsyncTaskGroupClosedError): + group.call_later(1, set_flag) + assert len(group) == 0 + + await asyncio.sleep(0.01) + assert not flag + + +@gen_test() +async def test_async_task_group_stop_disallows_shutdown(): + group = AsyncTaskGroup() + + task = None + + async def set_flag(): + nonlocal task + task = asyncio.current_task() + + assert group.call_soon(set_flag) is None + assert len(group) == 1 + # tasks are not given a grace period, and are not even allowed to start + # if the group is closed immediately + await group.stop() + assert task is None + + +@gen_test() +async def test_async_task_group_stop_cancels_long_running(): + group = AsyncTaskGroup() + + task = None + flag = False + started = asyncio.Event() + + async def set_flag(): + nonlocal task + task = asyncio.current_task() + started.set() + await asyncio.sleep(10) + nonlocal flag + flag = True + return True + + assert group.call_soon(set_flag) is None + assert len(group) == 1 + await started.wait() + await group.stop() + assert task + assert task.cancelled() + assert not flag diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 93bb18f16d3..af25353bf20 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -8,7 +8,6 @@ import socket import sys import threading -import time as timemod import weakref from unittest import mock @@ -22,8 +21,6 @@ from distributed.comm.registry import backends from distributed.comm.tcp import TCPBackend, TCPListener from distributed.core import ( - AsyncTaskGroup, - AsyncTaskGroupClosedError, ConnectionPool, Server, Status, @@ -84,156 +81,6 @@ def echo_no_serialize(comm, x): return {"result": x} -def test_async_task_group_initialization(): - group = AsyncTaskGroup() - assert not group.closed - assert len(group) == 0 - - -async def _wait_for_n_loop_cycles(n): - for _ in range(n): - await asyncio.sleep(0) - - -@gen_test() -async def test_async_task_group_call_soon_executes_task_in_background(): - group = AsyncTaskGroup() - ev = asyncio.Event() - flag = False - - async def set_flag(): - nonlocal flag - await ev.wait() - flag = True - - assert group.call_soon(set_flag) is None - assert len(group) == 1 - ev.set() - await _wait_for_n_loop_cycles(2) - assert len(group) == 0 - assert flag - - -@gen_test() -async def test_async_task_group_call_later_executes_delayed_task_in_background(): - group = AsyncTaskGroup() - ev = asyncio.Event() - - start = timemod.monotonic() - assert group.call_later(1, ev.set) is None - assert len(group) == 1 - await ev.wait() - end = timemod.monotonic() - # the task must be removed in exactly 1 event loop cycle - await _wait_for_n_loop_cycles(2) - assert len(group) == 0 - assert end - start > 1 - timemod.get_clock_info("monotonic").resolution - - -def test_async_task_group_close_closes(): - group = AsyncTaskGroup() - group.close() - assert group.closed - - # Test idempotency - group.close() - assert group.closed - - -@gen_test() -async def test_async_task_group_close_does_not_cancel_existing_tasks(): - group = AsyncTaskGroup() - - ev = asyncio.Event() - flag = False - - async def set_flag(): - nonlocal flag - await ev.wait() - flag = True - return None - - assert group.call_soon(set_flag) is None - - group.close() - - assert len(group) == 1 - - ev.set() - await _wait_for_n_loop_cycles(2) - assert len(group) == 0 - - -@gen_test() -async def test_async_task_group_close_prohibits_new_tasks(): - group = AsyncTaskGroup() - group.close() - - ev = asyncio.Event() - flag = False - - async def set_flag(): - nonlocal flag - await ev.wait() - flag = True - return True - - with pytest.raises(AsyncTaskGroupClosedError): - group.call_soon(set_flag) - assert len(group) == 0 - - with pytest.raises(AsyncTaskGroupClosedError): - group.call_later(1, set_flag) - assert len(group) == 0 - - await asyncio.sleep(0.01) - assert not flag - - -@gen_test() -async def test_async_task_group_stop_disallows_shutdown(): - group = AsyncTaskGroup() - - task = None - - async def set_flag(): - nonlocal task - task = asyncio.current_task() - - assert group.call_soon(set_flag) is None - assert len(group) == 1 - # tasks are not given a grace period, and are not even allowed to start - # if the group is closed immediately - await group.stop() - assert task is None - - -@gen_test() -async def test_async_task_group_stop_cancels_long_running(): - group = AsyncTaskGroup() - - task = None - flag = False - started = asyncio.Event() - - async def set_flag(): - nonlocal task - task = asyncio.current_task() - started.set() - await asyncio.sleep(10) - nonlocal flag - flag = True - return True - - assert group.call_soon(set_flag) is None - assert len(group) == 1 - await started.wait() - await group.stop() - assert task - assert task.cancelled() - assert not flag - - @gen_test() async def test_server_status_is_always_enum(): """Assignments with strings is forbidden"""