diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index f08af1af8ca..fab4a4cf3fd 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -28,6 +28,7 @@ ComputeTaskEvent, ExecuteFailureEvent, ExecuteSuccessEvent, + FreeKeysEvent, GatherDep, Instruction, PauseEvent, @@ -1037,3 +1038,58 @@ async def test_clean_log(s, a, b): """Test that brand new workers start with a clean log""" assert not a.state.log assert not a.state.stimulus_log + + +def test_running_task_in_all_running_tasks(ws_with_running_task): + ws = ws_with_running_task + ws2 = "127.0.0.1:2" + ts = ws.tasks["x"] + assert ts in ws.all_running_tasks + + ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1")) + assert ts.state == "cancelled" + assert ts in ws.all_running_tasks + + ws.handle_stimulus( + ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2") + ) + assert ts.state == "resumed" + assert ts in ws.all_running_tasks + + +@pytest.mark.xfail(reason="distributed#6565, distributed#6692") +@pytest.mark.parametrize( + "done_ev_cls,done_status", + [(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")], +) +def test_done_task_not_in_all_running_tasks( + ws_with_running_task, done_ev_cls, done_status +): + ws = ws_with_running_task + ts = ws.tasks["x"] + assert ts in ws.all_running_tasks + + ws.handle_stimulus(done_ev_cls.dummy("x", stimulus_id="s1")) + assert ts.state == done_status + assert ts not in ws.all_running_tasks + + +@pytest.mark.xfail(reason="distributed#6565, distributed#6689, distributed#6692") +@pytest.mark.parametrize( + "done_ev_cls,done_status", + [(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")], +) +def test_done_resumed_task_not_in_all_running_tasks( + ws_with_running_task, done_ev_cls, done_status +): + ws = ws_with_running_task + ws2 = "127.0.0.1:2" + + ws.handle_stimulus( + FreeKeysEvent(keys=["x"], stimulus_id="s1"), + ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"), + done_ev_cls.dummy("x", stimulus_id="s3"), + ) + ts = ws.tasks["x"] + assert ts.state == done_status + assert ts not in ws.all_running_tasks diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index e5b04da75c2..b31db943c7a 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1203,6 +1203,7 @@ def handle_stimulus(self, *stims: StateMachineEvent) -> Instructions: @property def executing_count(self) -> int: """Count of tasks currently executing on this worker. + Does not include long running (a.k.a. seceded) and cancelled tasks. See also -------- @@ -1212,6 +1213,17 @@ def executing_count(self) -> int: """ return len(self.executing) + @property + def all_running_tasks(self) -> set[TaskState]: + """All tasks that are currently occupying a thread. + These are: + + - ``ts.status in ("executing", "long-running", "cancelled")`` + - ``ts.status == "resumed" and ts._previous in ("executing", "long-running")`` + """ + # Note: cancelled and resumed tasks are still in either of these sets + return self.executing | {self.tasks[key] for key in self.long_running} + @property def in_flight_tasks_count(self) -> int: """Count of tasks currently being replicated from other workers to this one. @@ -1981,7 +1993,7 @@ def _transition_cancelled_fetch( ts.state = ts._previous return {}, [] else: - assert ts._previous == "executing" + assert ts._previous in {"executing", "long-running"} ts.state = "resumed" ts._next = "fetch" return {}, [] @@ -3119,11 +3131,14 @@ def validate_state(self) -> None: waiting_for_data_count += 1 for ts_wait in ts.waiting_for_data: assert ts_wait.key in self.tasks - assert ( - ts_wait.state in READY | {"executing", "flight", "fetch", "missing"} - or ts_wait in self.missing_dep_flight - or ts_wait.who_has.issubset(self.in_flight_workers) - ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) + assert ts_wait.state in READY | { + "executing", + "long-running", + "resumed", + "flight", + "fetch", + "missing", + }, (ts, ts_wait, self.story(ts), self.story(ts_wait)) # FIXME https://github.com/dask/distributed/issues/6319 # assert self.waiting_for_data_count == waiting_for_data_count for worker, keys in self.has_what.items():