From f2e2af0b62bf04692fbd4ed21248035f807ce7af Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 28 Jun 2022 17:35:34 +0100 Subject: [PATCH 1/6] Rewrite test using WSM --- distributed/tests/test_cancelled_state.py | 84 +++++++------------ .../tests/test_worker_state_machine.py | 34 +++++--- distributed/worker_state_machine.py | 32 +++++++ 3 files changed, 86 insertions(+), 64 deletions(-) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index f4c806e6af4..f3ba4c532e1 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -16,6 +16,13 @@ slowinc, wait_for_state, ) +from distributed.worker_state_machine import ( + ComputeTaskEvent, + Execute, + FreeKeysEvent, + GatherDep, + GatherDepNetworkFailureEvent, +) async def wait_for_cancelled(key, dask_worker): @@ -375,60 +382,29 @@ def block_execution(event, lock): assert await fut2 == 2 -@gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_cancelled_resumed_after_flight_with_dependencies(c, s, w2, w3): - # See https://github.com/dask/distributed/pull/6327#discussion_r872231090 - block_get_data_1 = asyncio.Lock() - enter_get_data_1 = asyncio.Event() - await block_get_data_1.acquire() - - class BlockGetDataWorker(Worker): - def __init__(self, *args, get_data_event, get_data_lock, **kwargs): - self._get_data_event = get_data_event - self._get_data_lock = get_data_lock - super().__init__(*args, **kwargs) - - async def get_data(self, comm, *args, **kwargs): - self._get_data_event.set() - async with self._get_data_lock: - return await super().get_data(comm, *args, **kwargs) - - async with await BlockGetDataWorker( - s.address, - get_data_event=enter_get_data_1, - get_data_lock=block_get_data_1, - name="w1", - ) as w1: - - f1 = c.submit(inc, 1, key="f1", workers=[w1.address]) - f2 = c.submit(inc, 2, key="f2", workers=[w1.address]) - f3 = c.submit(sum, [f1, f2], key="f3", workers=[w1.address]) - - await wait(f3) - f4 = c.submit(inc, f3, key="f4", workers=[w2.address]) - - await enter_get_data_1.wait() - s.set_restrictions( - { - f1.key: {w3.address}, - f2.key: {w3.address}, - f3.key: {w2.address}, - } - ) - await s.remove_worker(w1.address, stimulus_id="stim-id") - - await wait_for_state(f3.key, "resumed", w2) - assert_story( - w2.state.log, - [ - (f3.key, "flight", "released", "cancelled", {}), - # ... - (f3.key, "cancelled", "waiting", "resumed", {}), - ], - ) - # w1 closed - - assert await f4 == 6 +def test_cancelled_resumed_after_flight_with_dependencies(ws): + """See https://github.com/dask/distributed/pull/6327#discussion_r872231090""" + ws2 = "127.0.0.1:2" + instructions = ws.handle_stimulus( + # Create task x and put it in flight from ws2 + ComputeTaskEvent.dummy(key="y", who_has={"x": [ws2]}, stimulus_id="s1"), + # The scheduler realises that ws2 is unresponsive, although ws doesn't know yet. + # Having lost the last surviving replica of x, the scheduler cancels all of its + # dependents. This also cancels x. + FreeKeysEvent(keys=["y"], stimulus_id="s2"), + # The scheduler reschedules x on another worker, which just happens to be one + # that was previously fetching it. This does not generate an Execute + # instruction, because the GatherDep instruction isn't complete yet. + ComputeTaskEvent.dummy(key="x", stimulus_id="s3"), + # After ~30s, the TCP socket with ws2 finally times out and collapses. + # This triggers the Execute instruction. + GatherDepNetworkFailureEvent(worker=ws2, total_nbytes=1, stimulus_id="s4"), + ) + assert instructions == [ + GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s1"), + Execute(key="x", stimulus_id="s4"), # Note the stimulus_id! + ] + assert ws.tasks["x"].state == "executing" @pytest.mark.parametrize("wait_for_processing", [True, False]) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 4e532116f83..ef5b670e4b1 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -307,6 +307,29 @@ def test_computetask_to_dict(): assert ev3.priority == (0,) # List is automatically converted back to tuple +def test_computetask_dummy(): + ev = ComputeTaskEvent.dummy(key="x", stimulus_id="s") + assert ev == ComputeTaskEvent( + key="x", + who_has={}, + nbytes={}, + priority=(0,), + duration=1.0, + run_spec=None, + resource_restrictions={}, + actor=False, + annotations={}, + stimulus_id="s", + function=None, + args=None, + kwargs=None, + ) + + # nbytes is generated from who_has if omitted + ev2 = ComputeTaskEvent.dummy(key="x", who_has={"y": "127.0.0.1:2"}, stimulus_id="s") + assert ev2.nbytes == {"y": 1} + + def test_updatedata_to_dict(): """The potentially very large UpdateDataEvent.data is not stored in the log""" ev = UpdateDataEvent( @@ -933,19 +956,10 @@ def test_gather_priority(ws): stimulus_id="compute1", ), # A higher-priority task, even if scheduled later, is fetched first - ComputeTaskEvent( + ComputeTaskEvent.dummy( key="z", who_has={"y": ["127.0.0.7:1"]}, - nbytes={"y": 1}, priority=(0,), - duration=1.0, - run_spec=None, - function=None, - args=None, - kwargs=None, - resource_restrictions={}, - actor=False, - annotations={}, stimulus_id="compute2", ), UnpauseEvent(stimulus_id="unpause"), diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 33d6ce63093..cec51d8c542 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -694,6 +694,38 @@ def to_loggable(self, *, handled: float) -> StateMachineEvent: def _after_from_dict(self) -> None: self.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None) + @staticmethod + def dummy( + *, + key: str, + who_has: dict[str, Collection[str]] | None = None, + nbytes: dict[str, int] | None = None, + priority: tuple[int, ...] = (0,), + duration: float = 1.0, + resource_restrictions: dict[str, float] | None = None, + actor: bool = False, + annotations: dict | None = None, + stimulus_id: str, + ) -> ComputeTaskEvent: + """Build a dummy event, with most attributes set to a reasonable default. + This is a convenience method to be used in unit testing only. + """ + return ComputeTaskEvent( + key=key, + who_has=who_has or {}, + nbytes=nbytes or {k: 1 for k in who_has or ()}, + priority=priority, + duration=duration, + run_spec=None, + function=None, + args=None, + kwargs=None, + resource_restrictions=resource_restrictions or {}, + actor=actor, + annotations=annotations or {}, + stimulus_id=stimulus_id, + ) + @dataclass class ExecuteSuccessEvent(StateMachineEvent): From 687dba3ae0c340874b4349da5113c8c9c4427e05 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jun 2022 13:59:24 +0100 Subject: [PATCH 2/6] Reintroduce simplified integration test --- distributed/tests/test_cancelled_state.py | 35 +++++++++++++++++++++-- distributed/tests/test_utils_test.py | 20 ++++++++++++- distributed/utils_test.py | 22 +++++++++++++- 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index f3ba4c532e1..de1db27bd31 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -15,6 +15,7 @@ inc, slowinc, wait_for_state, + wait_for_stimulus, ) from distributed.worker_state_machine import ( ComputeTaskEvent, @@ -382,8 +383,38 @@ def block_execution(event, lock): assert await fut2 == 2 -def test_cancelled_resumed_after_flight_with_dependencies(ws): - """See https://github.com/dask/distributed/pull/6327#discussion_r872231090""" +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_cancelled_resumed_after_flight_with_dependencies(c, s, a): + """A task is in flight from b to a. + While a is waiting, b dies. The scheduler notices before a and reschedules the + task on a itself (as the only surviving replica was just lost). + Test that the worker eventually computes the task. + + See https://github.com/dask/distributed/pull/6327#discussion_r872231090 + See test_cancelled_resumed_after_flight_with_dependencies_workerstate below. + """ + async with await BlockedGetData(s.address) as b: + x = c.submit(inc, 1, key="x", workers=[b.address], allow_other_workers=True) + y = c.submit(inc, x, key="y", workers=[a.address]) + await b.in_get_data.wait() + + # Make b dead to s, but not to a + await s.remove_worker(b.address, stimulus_id="stim-id") + + # Wait for the scheduler to reschedule x on a. + # We want the comms from the scheduler to reach a before b closes the RPC + # channel, causing a.gather_dep() to raise OSError. + await wait_for_stimulus(a.state, ComputeTaskEvent, key="x") + + # b closed; a.gather_dep() fails. Note that, in the current implementation, x won't + # be recomputed on a until this happens. + assert await y == 3 + + +def test_cancelled_resumed_after_flight_with_dependencies_workerstate(ws): + """Same as test_cancelled_resumed_after_flight_with_dependencies, but testing the + WorkerState in isolation + """ ws2 = "127.0.0.1:2" instructions = ws.handle_stimulus( # Create task x and put it in flight from ws2 diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index d0cf73a432b..f63ead6541b 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -45,9 +45,11 @@ raises_with_cause, tls_only_security, wait_for_state, + wait_for_stimulus, ) from distributed.worker import fail_hard from distributed.worker_state_machine import ( + ComputeTaskEvent, InvalidTaskState, InvalidTransition, PauseEvent, @@ -920,7 +922,7 @@ async def test_freeze_batched_send(): assert e.count == 3 -@gen_cluster(client=True, nthreads=[("", 1)], timeout=2) +@gen_cluster(client=True, nthreads=[("", 1)]) async def test_wait_for_state(c, s, a, capsys): ev = Event() x = c.submit(lambda ev: ev.wait(), ev, key="x") @@ -947,3 +949,19 @@ async def test_wait_for_state(c, s, a, capsys): f"tasks[x].state='memory' on {s.address}; expected state='bad_state'\n" f"tasks[y] not found on {s.address}\n" ) + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_wait_for_stimulus(c, s, a): + t1 = asyncio.create_task(wait_for_stimulus(a.state, ComputeTaskEvent)) + t2 = asyncio.create_task(wait_for_stimulus(a.state, ComputeTaskEvent, key="y")) + await asyncio.sleep(0.05) + assert not t1.done() + assert not t2.done() + + x = c.submit(inc, 1, key="x") + await t1 + await wait_for_stimulus(a.state, ComputeTaskEvent, key="x") + assert not t2.done() + y = c.submit(inc, 1, key="y") + await t2 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index cd0145c82c0..2ebcb0685c4 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -73,7 +73,7 @@ sync, ) from distributed.worker import WORKER_ANY_RUNNING, Worker -from distributed.worker_state_machine import InvalidTransition +from distributed.worker_state_machine import InvalidTransition, StateMachineEvent from distributed.worker_state_machine import TaskState as WorkerTaskState from distributed.worker_state_machine import WorkerState @@ -2400,6 +2400,9 @@ def freeze_batched_send(bcomm: BatchedSend) -> Iterator[LockedComm]: async def wait_for_state( key: str, state: str, dask_worker: Worker | Scheduler, *, interval: float = 0.01 ) -> None: + """Wait for a task to appear on a Worker or on the Scheduler and to be in a specific + state. + """ if isinstance(dask_worker, Worker): tasks = dask_worker.state.tasks elif isinstance(dask_worker, Scheduler): @@ -2424,6 +2427,23 @@ async def wait_for_state( raise +async def wait_for_stimulus( + ws: WorkerState, + type_: type[StateMachineEvent] | tuple[type[StateMachineEvent], ...], + *, + interval: float = 0.01, + **matches: Any, +) -> StateMachineEvent: + """Wait for a specific stimulus to appear in the log of the WorkerState.""" + while True: + for ev in ws.stimulus_log: + if not isinstance(ev, type_): + continue + if not matches or all(getattr(ev, k) == v for k, v in matches.items()): + return ev + await asyncio.sleep(interval) + + @pytest.fixture def ws(): state = WorkerState(address="127.0.0.1:1", transition_counter_max=50_000) From e03f0fdc5426bb80b4710ba113a6a3c757dde399 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jun 2022 14:08:30 +0100 Subject: [PATCH 3/6] speed optimization --- distributed/utils_test.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 2ebcb0685c4..5088c74e6d7 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2435,12 +2435,15 @@ async def wait_for_stimulus( **matches: Any, ) -> StateMachineEvent: """Wait for a specific stimulus to appear in the log of the WorkerState.""" + last_ev = None while True: - for ev in ws.stimulus_log: - if not isinstance(ev, type_): - continue - if not matches or all(getattr(ev, k) == v for k, v in matches.items()): - return ev + if ws.stimulus_log and ws.stimulus_log[-1] is not last_ev: + last_ev = ws.stimulus_log[-1] + for ev in ws.stimulus_log: + if not isinstance(ev, type_): + continue + if not matches or all(getattr(ev, k) == v for k, v in matches.items()): + return ev await asyncio.sleep(interval) From 97759661d6572f7c94cefcb30b96f2dbf6b930d2 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jun 2022 14:15:36 +0100 Subject: [PATCH 4/6] Support for remote calls --- distributed/tests/test_cancelled_state.py | 2 +- distributed/tests/test_utils_test.py | 8 +++++--- distributed/utils_test.py | 9 +++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index de1db27bd31..0ae1bec2479 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -404,7 +404,7 @@ async def test_cancelled_resumed_after_flight_with_dependencies(c, s, a): # Wait for the scheduler to reschedule x on a. # We want the comms from the scheduler to reach a before b closes the RPC # channel, causing a.gather_dep() to raise OSError. - await wait_for_stimulus(a.state, ComputeTaskEvent, key="x") + await wait_for_stimulus(ComputeTaskEvent, a, key="x") # b closed; a.gather_dep() fails. Note that, in the current implementation, x won't # be recomputed on a until this happens. diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index f63ead6541b..9a10a13d391 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -953,15 +953,17 @@ async def test_wait_for_state(c, s, a, capsys): @gen_cluster(client=True, nthreads=[("", 1)]) async def test_wait_for_stimulus(c, s, a): - t1 = asyncio.create_task(wait_for_stimulus(a.state, ComputeTaskEvent)) - t2 = asyncio.create_task(wait_for_stimulus(a.state, ComputeTaskEvent, key="y")) + t1 = asyncio.create_task(wait_for_stimulus(ComputeTaskEvent, a)) + t2 = asyncio.create_task(wait_for_stimulus(ComputeTaskEvent, a, key="y")) await asyncio.sleep(0.05) assert not t1.done() assert not t2.done() x = c.submit(inc, 1, key="x") await t1 - await wait_for_stimulus(a.state, ComputeTaskEvent, key="x") + await wait_for_stimulus(ComputeTaskEvent, a, key="x") + await c.run(wait_for_stimulus, ComputeTaskEvent, key="x") assert not t2.done() + y = c.submit(inc, 1, key="y") await t2 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 5088c74e6d7..9f9ea59e316 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2428,18 +2428,19 @@ async def wait_for_state( async def wait_for_stimulus( - ws: WorkerState, type_: type[StateMachineEvent] | tuple[type[StateMachineEvent], ...], + dask_worker: Worker, *, interval: float = 0.01, **matches: Any, ) -> StateMachineEvent: """Wait for a specific stimulus to appear in the log of the WorkerState.""" + log = dask_worker.state.stimulus_log last_ev = None while True: - if ws.stimulus_log and ws.stimulus_log[-1] is not last_ev: - last_ev = ws.stimulus_log[-1] - for ev in ws.stimulus_log: + if log and log[-1] is not last_ev: + last_ev = log[-1] + for ev in log: if not isinstance(ev, type_): continue if not matches or all(getattr(ev, k) == v for k, v in matches.items()): From dc7ebad1389c69e4b3c6d0d28b35a55d3b14361e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jun 2022 14:17:34 +0100 Subject: [PATCH 5/6] Test return type --- distributed/tests/test_utils_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 9a10a13d391..8e39f4bc3af 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -960,7 +960,8 @@ async def test_wait_for_stimulus(c, s, a): assert not t2.done() x = c.submit(inc, 1, key="x") - await t1 + ev = await t1 + assert isinstance(ev, ComputeTaskEvent) await wait_for_stimulus(ComputeTaskEvent, a, key="x") await c.run(wait_for_stimulus, ComputeTaskEvent, key="x") assert not t2.done() From bfcf5eb0c14043a66fd3b7082edf307f0b52207f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jun 2022 14:18:34 +0100 Subject: [PATCH 6/6] simplify --- distributed/utils_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 9f9ea59e316..b7dececfadd 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2443,7 +2443,7 @@ async def wait_for_stimulus( for ev in log: if not isinstance(ev, type_): continue - if not matches or all(getattr(ev, k) == v for k, v in matches.items()): + if all(getattr(ev, k) == v for k, v in matches.items()): return ev await asyncio.sleep(interval)