diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index f4c806e6af4..0ae1bec2479 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -15,6 +15,14 @@ inc, slowinc, wait_for_state, + wait_for_stimulus, +) +from distributed.worker_state_machine import ( + ComputeTaskEvent, + Execute, + FreeKeysEvent, + GatherDep, + GatherDepNetworkFailureEvent, ) @@ -375,60 +383,59 @@ def block_execution(event, lock): assert await fut2 == 2 -@gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_cancelled_resumed_after_flight_with_dependencies(c, s, w2, w3): - # See https://github.com/dask/distributed/pull/6327#discussion_r872231090 - block_get_data_1 = asyncio.Lock() - enter_get_data_1 = asyncio.Event() - await block_get_data_1.acquire() +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_cancelled_resumed_after_flight_with_dependencies(c, s, a): + """A task is in flight from b to a. + While a is waiting, b dies. The scheduler notices before a and reschedules the + task on a itself (as the only surviving replica was just lost). + Test that the worker eventually computes the task. + + See https://github.com/dask/distributed/pull/6327#discussion_r872231090 + See test_cancelled_resumed_after_flight_with_dependencies_workerstate below. + """ + async with await BlockedGetData(s.address) as b: + x = c.submit(inc, 1, key="x", workers=[b.address], allow_other_workers=True) + y = c.submit(inc, x, key="y", workers=[a.address]) + await b.in_get_data.wait() - class BlockGetDataWorker(Worker): - def __init__(self, *args, get_data_event, get_data_lock, **kwargs): - self._get_data_event = get_data_event - self._get_data_lock = get_data_lock - super().__init__(*args, **kwargs) + # Make b dead to s, but not to a + await s.remove_worker(b.address, stimulus_id="stim-id") - async def get_data(self, comm, *args, **kwargs): - self._get_data_event.set() - async with self._get_data_lock: - return await super().get_data(comm, *args, **kwargs) - - async with await BlockGetDataWorker( - s.address, - get_data_event=enter_get_data_1, - get_data_lock=block_get_data_1, - name="w1", - ) as w1: - - f1 = c.submit(inc, 1, key="f1", workers=[w1.address]) - f2 = c.submit(inc, 2, key="f2", workers=[w1.address]) - f3 = c.submit(sum, [f1, f2], key="f3", workers=[w1.address]) - - await wait(f3) - f4 = c.submit(inc, f3, key="f4", workers=[w2.address]) - - await enter_get_data_1.wait() - s.set_restrictions( - { - f1.key: {w3.address}, - f2.key: {w3.address}, - f3.key: {w2.address}, - } - ) - await s.remove_worker(w1.address, stimulus_id="stim-id") + # Wait for the scheduler to reschedule x on a. + # We want the comms from the scheduler to reach a before b closes the RPC + # channel, causing a.gather_dep() to raise OSError. + await wait_for_stimulus(ComputeTaskEvent, a, key="x") + + # b closed; a.gather_dep() fails. Note that, in the current implementation, x won't + # be recomputed on a until this happens. + assert await y == 3 - await wait_for_state(f3.key, "resumed", w2) - assert_story( - w2.state.log, - [ - (f3.key, "flight", "released", "cancelled", {}), - # ... - (f3.key, "cancelled", "waiting", "resumed", {}), - ], - ) - # w1 closed - assert await f4 == 6 +def test_cancelled_resumed_after_flight_with_dependencies_workerstate(ws): + """Same as test_cancelled_resumed_after_flight_with_dependencies, but testing the + WorkerState in isolation + """ + ws2 = "127.0.0.1:2" + instructions = ws.handle_stimulus( + # Create task x and put it in flight from ws2 + ComputeTaskEvent.dummy(key="y", who_has={"x": [ws2]}, stimulus_id="s1"), + # The scheduler realises that ws2 is unresponsive, although ws doesn't know yet. + # Having lost the last surviving replica of x, the scheduler cancels all of its + # dependents. This also cancels x. + FreeKeysEvent(keys=["y"], stimulus_id="s2"), + # The scheduler reschedules x on another worker, which just happens to be one + # that was previously fetching it. This does not generate an Execute + # instruction, because the GatherDep instruction isn't complete yet. + ComputeTaskEvent.dummy(key="x", stimulus_id="s3"), + # After ~30s, the TCP socket with ws2 finally times out and collapses. + # This triggers the Execute instruction. + GatherDepNetworkFailureEvent(worker=ws2, total_nbytes=1, stimulus_id="s4"), + ) + assert instructions == [ + GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s1"), + Execute(key="x", stimulus_id="s4"), # Note the stimulus_id! + ] + assert ws.tasks["x"].state == "executing" @pytest.mark.parametrize("wait_for_processing", [True, False]) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index d0cf73a432b..8e39f4bc3af 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -45,9 +45,11 @@ raises_with_cause, tls_only_security, wait_for_state, + wait_for_stimulus, ) from distributed.worker import fail_hard from distributed.worker_state_machine import ( + ComputeTaskEvent, InvalidTaskState, InvalidTransition, PauseEvent, @@ -920,7 +922,7 @@ async def test_freeze_batched_send(): assert e.count == 3 -@gen_cluster(client=True, nthreads=[("", 1)], timeout=2) +@gen_cluster(client=True, nthreads=[("", 1)]) async def test_wait_for_state(c, s, a, capsys): ev = Event() x = c.submit(lambda ev: ev.wait(), ev, key="x") @@ -947,3 +949,22 @@ async def test_wait_for_state(c, s, a, capsys): f"tasks[x].state='memory' on {s.address}; expected state='bad_state'\n" f"tasks[y] not found on {s.address}\n" ) + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_wait_for_stimulus(c, s, a): + t1 = asyncio.create_task(wait_for_stimulus(ComputeTaskEvent, a)) + t2 = asyncio.create_task(wait_for_stimulus(ComputeTaskEvent, a, key="y")) + await asyncio.sleep(0.05) + assert not t1.done() + assert not t2.done() + + x = c.submit(inc, 1, key="x") + ev = await t1 + assert isinstance(ev, ComputeTaskEvent) + await wait_for_stimulus(ComputeTaskEvent, a, key="x") + await c.run(wait_for_stimulus, ComputeTaskEvent, key="x") + assert not t2.done() + + y = c.submit(inc, 1, key="y") + await t2 diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 4e532116f83..ef5b670e4b1 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -307,6 +307,29 @@ def test_computetask_to_dict(): assert ev3.priority == (0,) # List is automatically converted back to tuple +def test_computetask_dummy(): + ev = ComputeTaskEvent.dummy(key="x", stimulus_id="s") + assert ev == ComputeTaskEvent( + key="x", + who_has={}, + nbytes={}, + priority=(0,), + duration=1.0, + run_spec=None, + resource_restrictions={}, + actor=False, + annotations={}, + stimulus_id="s", + function=None, + args=None, + kwargs=None, + ) + + # nbytes is generated from who_has if omitted + ev2 = ComputeTaskEvent.dummy(key="x", who_has={"y": "127.0.0.1:2"}, stimulus_id="s") + assert ev2.nbytes == {"y": 1} + + def test_updatedata_to_dict(): """The potentially very large UpdateDataEvent.data is not stored in the log""" ev = UpdateDataEvent( @@ -933,19 +956,10 @@ def test_gather_priority(ws): stimulus_id="compute1", ), # A higher-priority task, even if scheduled later, is fetched first - ComputeTaskEvent( + ComputeTaskEvent.dummy( key="z", who_has={"y": ["127.0.0.7:1"]}, - nbytes={"y": 1}, priority=(0,), - duration=1.0, - run_spec=None, - function=None, - args=None, - kwargs=None, - resource_restrictions={}, - actor=False, - annotations={}, stimulus_id="compute2", ), UnpauseEvent(stimulus_id="unpause"), diff --git a/distributed/utils_test.py b/distributed/utils_test.py index cd0145c82c0..b7dececfadd 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -73,7 +73,7 @@ sync, ) from distributed.worker import WORKER_ANY_RUNNING, Worker -from distributed.worker_state_machine import InvalidTransition +from distributed.worker_state_machine import InvalidTransition, StateMachineEvent from distributed.worker_state_machine import TaskState as WorkerTaskState from distributed.worker_state_machine import WorkerState @@ -2400,6 +2400,9 @@ def freeze_batched_send(bcomm: BatchedSend) -> Iterator[LockedComm]: async def wait_for_state( key: str, state: str, dask_worker: Worker | Scheduler, *, interval: float = 0.01 ) -> None: + """Wait for a task to appear on a Worker or on the Scheduler and to be in a specific + state. + """ if isinstance(dask_worker, Worker): tasks = dask_worker.state.tasks elif isinstance(dask_worker, Scheduler): @@ -2424,6 +2427,27 @@ async def wait_for_state( raise +async def wait_for_stimulus( + type_: type[StateMachineEvent] | tuple[type[StateMachineEvent], ...], + dask_worker: Worker, + *, + interval: float = 0.01, + **matches: Any, +) -> StateMachineEvent: + """Wait for a specific stimulus to appear in the log of the WorkerState.""" + log = dask_worker.state.stimulus_log + last_ev = None + while True: + if log and log[-1] is not last_ev: + last_ev = log[-1] + for ev in log: + if not isinstance(ev, type_): + continue + if all(getattr(ev, k) == v for k, v in matches.items()): + return ev + await asyncio.sleep(interval) + + @pytest.fixture def ws(): state = WorkerState(address="127.0.0.1:1", transition_counter_max=50_000) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 33d6ce63093..cec51d8c542 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -694,6 +694,38 @@ def to_loggable(self, *, handled: float) -> StateMachineEvent: def _after_from_dict(self) -> None: self.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None) + @staticmethod + def dummy( + *, + key: str, + who_has: dict[str, Collection[str]] | None = None, + nbytes: dict[str, int] | None = None, + priority: tuple[int, ...] = (0,), + duration: float = 1.0, + resource_restrictions: dict[str, float] | None = None, + actor: bool = False, + annotations: dict | None = None, + stimulus_id: str, + ) -> ComputeTaskEvent: + """Build a dummy event, with most attributes set to a reasonable default. + This is a convenience method to be used in unit testing only. + """ + return ComputeTaskEvent( + key=key, + who_has=who_has or {}, + nbytes=nbytes or {k: 1 for k in who_has or ()}, + priority=priority, + duration=duration, + run_spec=None, + function=None, + args=None, + kwargs=None, + resource_restrictions=resource_restrictions or {}, + actor=actor, + annotations=annotations or {}, + stimulus_id=stimulus_id, + ) + @dataclass class ExecuteSuccessEvent(StateMachineEvent):