diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 9080c1d77f7..de1e4ecb880 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -16,20 +16,22 @@ gen_cluster, inc, lock_inc, - slowinc, wait_for_state, wait_for_stimulus, ) from distributed.worker_state_machine import ( + AddKeysMsg, ComputeTaskEvent, Execute, ExecuteFailureEvent, ExecuteSuccessEvent, FreeKeysEvent, GatherDep, + GatherDepFailureEvent, GatherDepNetworkFailureEvent, GatherDepSuccessEvent, TaskFinishedMsg, + UpdateDataEvent, ) @@ -231,53 +233,30 @@ async def wait_and_raise(*args, **kwargs): w.state.story(f1.key), [ (f1.key, "executing", "released", "cancelled", {}), - ( - f1.key, - "cancelled", - "error", - "error", - {f2.key: "executing", f1.key: "released"}, - ), - (f1.key, "error", "released", "released", {f1.key: "forgotten"}), + (f1.key, "cancelled", "error", "released", {f1.key: "forgotten"}), (f1.key, "released", "forgotten", "forgotten", {}), ], ) -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_flight_cancelled_error(c, s, b): - """One worker with one thread. We provoke an flight->cancelled transition - and let the task err.""" - lock = asyncio.Lock() - await lock.acquire() +def test_flight_cancelled_error(ws): + """Test flight -> cancelled -> error transition loop. + This can be caused by an issue while (un)pickling or a bug in the network stack. - class BrokenWorker(Worker): - block_get_data = True - - async def get_data(self, comm, *args, **kwargs): - if self.block_get_data: - async with lock: - comm.abort() - return await super().get_data(comm, *args, **kwargs) - - async with BrokenWorker(s.address) as a: - await c.wait_for_workers(2) - fut1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True) - fut2 = c.submit(inc, fut1, workers=[b.address]) - await wait_for_state(fut1.key, "flight", b) - fut2.release() - fut1.release() - await wait_for_state(fut1.key, "cancelled", b) - lock.release() - # At this point we do not fetch the result of the future since the - # future itself would raise a cancelled exception. At this point we're - # concerned about the worker. The task should transition over error to - # be eventually forgotten since we no longer hold a ref. - while fut1.key in b.state.tasks: - await asyncio.sleep(0.01) - a.block_get_data = False - # Everything should still be executing as usual after this - assert await c.submit(sum, c.map(inc, range(10))) == sum(map(inc, range(10))) + See https://github.com/dask/distributed/issues/6877 + """ + ws2 = "127.0.0.1:2" + instructions = ws.handle_stimulus( + ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s1"), + FreeKeysEvent(keys=["y", "x"], stimulus_id="s2"), + GatherDepFailureEvent.from_exception( + Exception(), worker=ws2, total_nbytes=1, stimulus_id="s3" + ), + ) + assert instructions == [ + GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s1") + ] + assert not ws.tasks @gen_cluster(client=True, nthreads=[("", 1)]) @@ -332,6 +311,7 @@ def block_execution(lock): (fut1.key, "resumed", "released", "cancelled", {}), # After gather_dep receives the data, the task is forgotten (fut1.key, "cancelled", "memory", "released", {fut1.key: "forgotten"}), + (fut1.key, "released", "forgotten", "forgotten", {}), ], ) @@ -369,7 +349,8 @@ def block_execution(event, lock): b.state.story(fut1.key), [ (fut1.key, "executing", "released", "cancelled", {}), - (fut1.key, "cancelled", "error", "error", {fut1.key: "released"}), + (fut1.key, "cancelled", "error", "released", {fut1.key: "forgotten"}), + (fut1.key, "released", "forgotten", "forgotten", {}), ], ) @@ -480,14 +461,18 @@ async def test_resumed_cancelled_handle_compute( lock_compute = Lock() await lock_compute.acquire() enter_compute = Event() + exit_compute = Event() - def block(x, lock, enter_event): + def block(x, lock, enter_event, exit_event): enter_event.set() - with lock: - if raise_error: - raise RuntimeError("test error") - else: - return x + 1 + try: + with lock: + if raise_error: + raise RuntimeError("test error") + else: + return x + 1 + finally: + exit_event.set() f1 = c.submit(inc, 1, key="f1", workers=[a.address]) f2 = c.submit(inc, f1, key="f2", workers=[a.address]) @@ -496,6 +481,7 @@ def block(x, lock, enter_event): f2, lock=lock_compute, enter_event=enter_compute, + exit_event=exit_compute, key="f3", workers=[b.address], ) @@ -523,6 +509,10 @@ async def release_all_futures(): await wait_for_state(f3.key, "resumed", b) await release_all_futures() + if not wait_for_processing: + await lock_compute.release() + await exit_compute.wait() + f1 = c.submit(inc, 1, key="f1", workers=[a.address]) f2 = c.submit(inc, f1, key="f2", workers=[a.address]) f3 = c.submit(inc, f2, key="f3", workers=[b.address]) @@ -530,10 +520,9 @@ async def release_all_futures(): if wait_for_processing: await wait_for_state(f3.key, "processing", s) + await lock_compute.release() - await lock_compute.release() - - if not raise_error: + if not wait_for_processing and not raise_error: assert await f4 == 4 + 2 assert_story( @@ -546,19 +535,55 @@ async def release_all_futures(): ], ) - else: + elif not wait_for_processing and raise_error: + assert await f4 == 4 + 2 + + assert_story( + b.state.story(f3.key), + expect=[ + (f3.key, "ready", "executing", "executing", {}), + (f3.key, "executing", "released", "cancelled", {}), + (f3.key, "cancelled", "fetch", "resumed", {}), + (f3.key, "resumed", "error", "released", {f3.key: "fetch"}), + (f3.key, "fetch", "flight", "flight", {}), + (f3.key, "flight", "missing", "missing", {}), + (f3.key, "missing", "waiting", "waiting", {f2.key: "fetch"}), + (f3.key, "waiting", "ready", "ready", {f3.key: "executing"}), + (f3.key, "ready", "executing", "executing", {}), + (f3.key, "executing", "memory", "memory", {}), + ], + ) + + elif wait_for_processing and not raise_error: + assert await f4 == 4 + 2 + + assert_story( + b.state.story(f3.key), + expect=[ + (f3.key, "ready", "executing", "executing", {}), + (f3.key, "executing", "released", "cancelled", {}), + (f3.key, "cancelled", "fetch", "resumed", {}), + (f3.key, "resumed", "waiting", "executing", {}), + (f3.key, "executing", "memory", "memory", {}), + ], + ) + + elif wait_for_processing and raise_error: with pytest.raises(RuntimeError, match="test error"): await f3 assert_story( b.state.story(f3.key), - expect=[ + [ (f3.key, "ready", "executing", "executing", {}), (f3.key, "executing", "released", "cancelled", {}), (f3.key, "cancelled", "fetch", "resumed", {}), - (f3.key, "resumed", "error", "error", {}), + (f3.key, "resumed", "waiting", "executing", {}), + (f3.key, "executing", "error", "error", {}), ], ) + else: + assert False, "unreachable" @pytest.mark.parametrize("intermediate_state", ["resumed", "cancelled"]) @@ -570,13 +595,9 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker( """If a task was transitioned to in-flight, the gather_dep coroutine was scheduled but a cancel request came in before gather_data_from_worker was issued. This might corrupt the state machine if the cancelled key is not properly handled. - - See also - -------- - test_workerstate_deadlock_cancelled_after_inflight_before_gather_from_worker """ - fut1 = c.submit(slowinc, 1, workers=[a.address], key="f1") - fut1B = c.submit(slowinc, 2, workers=[x.address], key="f1B") + fut1 = c.submit(inc, 1, workers=[a.address], key="f1") + fut1B = c.submit(inc, 2, workers=[x.address], key="f1B") fut2 = c.submit(sum, [fut1, fut1B], workers=[x.address], key="f2") await fut2 @@ -661,14 +682,13 @@ def test_workerstate_executing_skips_fetch_on_success(ws_with_running_task): ExecuteSuccessEvent.dummy("x", 123, stimulus_id="s3"), ) assert instructions == [ - TaskFinishedMsg.match(key="x", stimulus_id="s3"), + AddKeysMsg(keys=["x"], stimulus_id="s3"), Execute(key="y", stimulus_id="s3"), ] assert ws.tasks["x"].state == "memory" assert ws.data["x"] == 123 -@pytest.mark.xfail(reason="distributed#6689") def test_workerstate_executing_failure_to_fetch(ws_with_running_task): """Test state loops: @@ -887,3 +907,39 @@ async def resume(): # Test that x does not get stuck. assert await fut == expect + + +@pytest.mark.parametrize("release_dep", [False, True]) +@pytest.mark.parametrize("done_ev_cls", [ExecuteSuccessEvent, ExecuteFailureEvent]) +def test_cancel_with_dependencies_in_memory(ws, release_dep, done_ev_cls): + """Cancel an executing task y with an in-memory dependency x; then simulate that x + did not have any further dependents, so cancel x as well. + + Test that x immediately transitions to released state and is forgotten as soon as + y finishes computing. + + Read: https://github.com/dask/distributed/issues/6893""" + ws.handle_stimulus( + UpdateDataEvent(data={"x": 1}, report=False, stimulus_id="s1"), + ComputeTaskEvent.dummy("y", who_has={"x": [ws.address]}, stimulus_id="s2"), + ) + assert ws.tasks["x"].state == "memory" + assert ws.tasks["y"].state == "executing" + + ws.handle_stimulus(FreeKeysEvent(keys=["y"], stimulus_id="s3")) + assert ws.tasks["x"].state == "memory" + assert ws.tasks["y"].state == "cancelled" + + if release_dep: + # This will happen iff x has no dependents or waiters on the scheduler + ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s4")) + assert ws.tasks["x"].state == "released" + assert ws.tasks["y"].state == "cancelled" + + ws.handle_stimulus(done_ev_cls.dummy("y", stimulus_id="s5")) + assert "y" not in ws.tasks + assert "x" not in ws.tasks + else: + ws.handle_stimulus(done_ev_cls.dummy("y", stimulus_id="s5")) + assert "y" not in ws.tasks + assert ws.tasks["x"].state == "memory" diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index ae96835605c..91810357857 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1171,16 +1171,7 @@ def test_task_with_dependencies_acquires_resources(ws): @pytest.mark.parametrize( "done_ev_cls,done_status", - [ - (ExecuteSuccessEvent, "memory"), - pytest.param( - ExecuteFailureEvent, - "flight", - marks=pytest.mark.xfail( - reason="distributed#6682,distributed#6689,distributed#6693" - ), - ), - ], + [(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "flight")], ) def test_resumed_task_releases_resources( ws_with_running_task, done_ev_cls, done_status @@ -1247,14 +1238,7 @@ def test_done_task_not_in_all_running_tasks( @pytest.mark.parametrize( "done_ev_cls,done_status", - [ - (ExecuteSuccessEvent, "memory"), - pytest.param( - ExecuteFailureEvent, - "flight", - marks=pytest.mark.xfail(reason="distributed#6689"), - ), - ], + [(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "flight")], ) def test_done_resumed_task_not_in_all_running_tasks( ws_with_running_task, done_ev_cls, done_status diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 094b90a5498..6f6c6978a26 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -238,9 +238,9 @@ class TaskState: #: The current state of the task state: TaskStateState = "released" #: The previous state of the task. It is not None iff state in (cancelled, resumed). - previous: TaskStateState | None = None + previous: Literal["executing", "long-running", "flight", None] = None #: The next state of the task. It is not None iff state == resumed. - next: TaskStateState | None = None + next: Literal["fetch", "waiting", None] = None #: Expected duration of the task duration: float | None = None @@ -592,20 +592,13 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: -------- distributed.utils.recursive_to_dict """ - info = { - "cls": type(self).__name__, - "stimulus_id": self.stimulus_id, - "handled": self.handled, - } - info.update( - { - k: getattr(self, k) - for k in self.__annotations__ - # Necessary for subclasses that don't define their own annotations - if k != "_classes" - } - ) - info = {k: v for k, v in info.items() if k not in exclude} + info = {"cls": type(self).__name__} + for k in dir(self): + if k in exclude or k.startswith("_"): + continue + v = getattr(self, k) + if not callable(v): + info[k] = v return recursive_to_dict(info, exclude=exclude) @staticmethod @@ -808,8 +801,17 @@ def dummy( @dataclass -class ExecuteSuccessEvent(StateMachineEvent): +class ExecuteDoneEvent(StateMachineEvent): + """Abstract base event for all the possible outcomes of a :class:`Compute` + instruction + """ + key: str + __slots__ = ("key",) + + +@dataclass +class ExecuteSuccessEvent(ExecuteDoneEvent): value: object start: float stop: float @@ -823,6 +825,13 @@ def to_loggable(self, *, handled: float) -> StateMachineEvent: out.value = None return out + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: + d = super()._to_dict(exclude=exclude) + # This is excluded by the parent class as it is a callable + if "type" not in exclude: + d["type"] = str(self.type) + return d + def _after_from_dict(self) -> None: self.value = None self.type = None @@ -850,8 +859,7 @@ def dummy( @dataclass -class ExecuteFailureEvent(StateMachineEvent): - key: str +class ExecuteFailureEvent(ExecuteDoneEvent): start: float | None stop: float | None exception: Serialize @@ -911,15 +919,14 @@ def dummy( ) +# Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception @dataclass -class CancelComputeEvent(StateMachineEvent): - __slots__ = ("key",) - key: str +class RescheduleEvent(ExecuteDoneEvent): + __slots__ = () -# Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception @dataclass -class RescheduleEvent(StateMachineEvent): +class CancelComputeEvent(StateMachineEvent): __slots__ = ("key",) key: str @@ -1112,7 +1119,11 @@ class WorkerState: missing_dep_flight: set[TaskState] #: Which tasks that are coming to us in current peer-to-peer connections. - #: All and only tasks with TaskState.state == 'flight'. + #: This set includes exclusively: + #: - tasks with :attr:`state` == 'flight' + #: - tasks with :attr:`state` in ('cancelled', 'resumed') and + #: :attr:`previous` == 'flight` + #: #: See also :meth:`in_flight_tasks_count`. in_flight_tasks: set[TaskState] @@ -1153,6 +1164,11 @@ class WorkerState: available_resources: dict[str, float] #: Set of tasks that are currently running. + #: This set includes exclusively: + #: - tasks with :attr:`state` == 'executing' + #: - tasks with :attr:`state` in ('cancelled', 'resumed') and + #: :attr:`previous` == 'executing` + #: #: See also :meth:`executing_count` and :attr:`long_running`. executing: set[TaskState] @@ -1160,6 +1176,11 @@ class WorkerState: #: :func:`~distributed.secede`, so they no longer count towards the maximum number #: of concurrent tasks (nthreads). #: These tasks do not appear in the :attr:`executing` set. + #: This set includes exclusively: + #: - tasks with :attr:`state` == 'long-running' + #: - tasks with :attr:`state` in ('cancelled', 'resumed') and + #: :attr:`previous` == 'long-running` + #: long_running: set[TaskState] #: A number of tasks that this worker has run in its lifetime; this includes failed @@ -1408,6 +1429,7 @@ def _purge_state(self, ts: TaskState) -> None: ts.previous = None ts.next = None ts.done = False + ts.coming_from = None self.missing_dep_flight.discard(ts) self.ready.discard(ts) @@ -1787,10 +1809,7 @@ def _transition_generic_released( if not ts.dependents: recs[ts] = "forgotten" - return merge_recs_instructions( - (recs, []), - self._ensure_computing(), - ) + return recs, [] def _transition_released_waiting( self, ts: TaskState, *, stimulus_id: str @@ -1861,24 +1880,14 @@ def _transition_executing_rescheduled( ) -> RecsInstrs: """Note: this transition is triggered exclusively by a task raising the Reschedule() Exception; it is not involved in work stealing. - The task is always done. """ - if self.validate: - # Notably, we're missing the third state in which a task can raise - # Reschedule(), which is "cancelled" - assert ts.state in ("executing", "long-running"), ts - - self._release_resources(ts) - self.executing.discard(ts) - self.long_running.discard(ts) - + assert ts.done return merge_recs_instructions( ({}, [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)]), # Note: this is not the same as recommending {ts: "released"} on the - # previous line, as it would instead transition the task to cancelled - but - # a task that raised the Reschedule() exception is finished! + # previous line, as it would instead run the ("executing", "released") + # transition, which would need special code for ts.done=True. self._transition_generic_released(ts, stimulus_id=stimulus_id), - self._ensure_computing(), ) def _transition_waiting_ready( @@ -1902,41 +1911,6 @@ def _transition_waiting_ready( return self._ensure_computing() - def _transition_cancelled_error( - self, - ts: TaskState, - exception: Serialize, - traceback: Serialize | None, - exception_text: str, - traceback_text: str, - *, - stimulus_id: str, - ) -> RecsInstrs: - assert ts.previous in ( - "executing", - "long-running", - ), f"Expected 'executing' or 'long-running'; got '{ts.previous}'" - recs, instructions = self._transition_executing_error( - ts, - exception, - traceback, - exception_text, - traceback_text, - stimulus_id=stimulus_id, - ) - # We'll ignore instructions, i.e. we choose to not submit the failure - # message to the scheduler since from the schedulers POV it already - # released this task - if self.validate: - assert instructions == [TaskErredMsg.match(key=ts.key)] - instructions.clear() - # Workers should never "retry" tasks. A transition to error should, by - # default, be the end. Since cancelled indicates that the scheduler lost - # interest, we can transition straight to released - assert ts not in recs - recs[ts] = "released" - return recs, instructions - def _transition_generic_error( self, ts: TaskState, @@ -1960,7 +1934,7 @@ def _transition_generic_error( return {}, [smsg] - def _transition_executing_error( + def _transition_resumed_error( self, ts: TaskState, exception: Serialize, @@ -1970,123 +1944,129 @@ def _transition_executing_error( *, stimulus_id: str, ) -> RecsInstrs: - self._release_resources(ts) - self.executing.discard(ts) - self.long_running.discard(ts) - - return merge_recs_instructions( - self._transition_generic_error( - ts, - exception, - traceback, - exception_text, - traceback_text, - stimulus_id=stimulus_id, - ), - self._ensure_computing(), - ) - - def _transition_from_resumed( - self, ts: TaskState, finish: TaskStateState, stimulus_id: str - ) -> RecsInstrs: - """`resumed` is an intermediate degenerate state which splits further up - into two states depending on what the last signal / next state is - intended to be. There are only two viable choices depending on whether - the task is required to be fetched from another worker `resumed(fetch)` - or the task shall be computed on this worker `resumed(waiting)`. - - The only viable state transitions ending up here are - - flight -> cancelled -> resumed(waiting) - - or - - executing -> cancelled -> resumed(fetch) - - depending on the origin. Equally, only `fetch`, `waiting`, or `released` - are allowed output states. - - See also `_transition_resumed_waiting` + """In case of failure of the previous state, discard the error and kick off the + next state without informing the scheduler """ - recs: Recs = {} - instructions: Instructions = [] - - if ts.previous == finish: - # We're back where we started. We should forget about the entire - # cancellation attempt - ts.state = finish - ts.next = None - ts.previous = None - elif not ts.done: - # If we're not done, yet, just remember where we want to be next - ts.next = finish + assert ts.done + if ts.previous in ("executing", "long-running"): + assert ts.next == "fetch" + recs: Recs = {ts: "fetch"} else: - # Flight/executing finished unsuccessfully, i.e. not in memory - assert finish != "memory" - next_state = ts.next - assert next_state in {"waiting", "fetch"}, next_state - assert ts.previous in {"executing", "long-running", "flight"}, ts.previous - - if ts.previous in ("executing", "long-running"): - self._release_resources(ts) - self.executing.discard(ts) - self.long_running.discard(ts) - - if next_state != finish: - recs, instructions = self._transition_generic_released( - ts, stimulus_id=stimulus_id - ) - recs[ts] = next_state + assert ts.previous == "flight" + assert ts.next == "waiting" + recs = {ts: "waiting"} - return recs, instructions + ts.state = "released" + ts.done = False + ts.previous = None + ts.next = None + return recs, [] def _transition_resumed_fetch( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - """See Worker._transition_from_resumed""" - recs, instructions = self._transition_from_resumed( - ts, "fetch", stimulus_id=stimulus_id - ) - if self.validate: - # This would only be possible in a fetch->cancelled->resumed->fetch loop, - # but there are no transitions from fetch which set the state to cancelled. - # If this assertion failed, we' need to call _ensure_communicating like in - # the other transitions that set ts.status = "fetch". - assert ts.state != "fetch" - return recs, instructions + """ + See also + -------- + _transition_cancelled_fetch + _transition_cancelled_waiting + _transition_resumed_waiting + _transition_flight_fetch + """ + if ts.previous == "flight": + if self.validate: + assert ts.next == "waiting" + if ts.done: + # We arrived here either from GatherDepNetworkFailureEvent or from + # GatherDepSuccessEvent but without the key in the data attribute. + # We would now normally try to fetch the task from another peer worker + # or transition it to missing if none are left; here instead we're going + # to compute the task as we had been asked by the scheduler. + ts.state = "released" + ts.done = False + ts.previous = None + ts.next = None + return {ts: "waiting"}, [] + else: + # We're back where we started. We should forget about the entire + # cancellation attempt + ts.state = "flight" + ts.previous = None + ts.next = None + + elif self.validate: + assert ts.previous in ("executing", "long-running") + assert ts.next == "fetch" + # None of the exit events of execute recommend a transition to fetch + assert not ts.done + + return {}, [] def _transition_resumed_missing( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - """See Worker._transition_from_resumed""" - return self._transition_from_resumed(ts, "missing", stimulus_id=stimulus_id) + return {ts: "fetch"}, [] def _transition_resumed_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - if not ts.done: - ts.state = "cancelled" - ts.next = None - return {}, [] - else: - return self._transition_generic_released(ts, stimulus_id=stimulus_id) + # None of the exit events of execute or gather_dep recommend a transition to + # released + assert not ts.done + ts.state = "cancelled" + ts.next = None + return {}, [] def _transition_resumed_waiting( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - """See Worker._transition_from_resumed""" - return self._transition_from_resumed(ts, "waiting", stimulus_id=stimulus_id) + """ + See also + -------- + _transition_cancelled_fetch + _transition_cancelled_waiting + _transition_resumed_fetch + """ + # None of the exit events of execute or gather_dep recommend a transition to + # waiting + assert not ts.done + if ts.previous in ("executing", "long-running"): + assert ts.next == "fetch" + # We're back where we started. We should forget about the entire + # cancellation attempt + ts.state = ts.previous + ts.next = None + ts.previous = None + elif self.validate: + assert ts.previous == "flight" + assert ts.next == "waiting" + + return {}, [] def _transition_cancelled_fetch( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - if ts.done: - return {ts: "released"}, [] - elif ts.previous == "flight": - ts.state = ts.previous - return {}, [] + """ + See also + -------- + _transition_cancelled_waiting + _transition_resumed_fetch + _transition_resumed_waiting + """ + if ts.previous == "flight": + if ts.done: + # gather_dep just completed for a cancelled task. + # Discard output and possibly forget + return {ts: "released"}, [] + else: + # Forget the task was cancelled to begin with + ts.state = "flight" + ts.previous = None + return {}, [] else: assert ts.previous in ("executing", "long-running") + # None of the exit events of execute recommend a transition to fetch + assert not ts.done ts.state = "resumed" ts.next = "fetch" return {}, [] @@ -2094,10 +2074,20 @@ def _transition_cancelled_fetch( def _transition_cancelled_waiting( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - if ts.done: - return {ts: "released"}, [] - elif ts.previous in ("executing", "long-running"): + """ + See also + -------- + _transition_cancelled_fetch + _transition_resumed_fetch + _transition_resumed_waiting + """ + # None of the exit events of gather_dep or execute recommend a transition to + # waiting + assert not ts.done + if ts.previous in ("executing", "long-running"): + # Forget the task was cancelled to begin with ts.state = ts.previous + ts.previous = None return {}, [] else: assert ts.previous == "flight" @@ -2105,79 +2095,32 @@ def _transition_cancelled_waiting( ts.next = "waiting" return {}, [] - def _transition_cancelled_forgotten( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - ts.next = "forgotten" - if not ts.done: - return {}, [] - return {ts: "released"}, [] - def _transition_cancelled_released( - self, ts: TaskState, *, stimulus_id: str + self, + ts: TaskState, + *args: Any, # extra arguments of transitions to memory or error - ignored + stimulus_id: str, ) -> RecsInstrs: if not ts.done: return {}, [] - self.executing.discard(ts) - self.long_running.discard(ts) - self.in_flight_tasks.discard(ts) - self._release_resources(ts) + ts.previous = None + ts.done = False return self._transition_generic_released(ts, stimulus_id=stimulus_id) def _transition_executing_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - ts.previous = ts.state - ts.next = None - # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 - ts.state = "cancelled" - ts.done = False - return {}, [] - - def _transition_generic_memory( - self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str - ) -> RecsInstrs: - if value is NO_VALUE and ts.key not in self.data: - raise RuntimeError( - f"Tried to transition task {ts} to `memory` without data available" - ) - - instructions: Instructions = [] - try: - recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) - except Exception as e: - msg = error_message(e) - recs = {ts: tuple(msg.values())} - else: - self._release_resources(ts) - self.executing.discard(ts) - self.long_running.discard(ts) - self.in_flight_tasks.discard(ts) - ts.coming_from = None - - if self.validate: - assert ts.key in self.data or ts.key in self.actors - instructions.append( - self._get_task_finished_msg(ts, stimulus_id=stimulus_id) - ) - - return recs, instructions - - def _transition_executing_memory( - self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str - ) -> RecsInstrs: + """We can't stop executing a task just because the scheduler asked us to, + so we're entering cancelled state and waiting until it completes. + """ if self.validate: assert ts.state in ("executing", "long-running") - assert not ts.waiting_for_data - - self.executing.discard(ts) - self.long_running.discard(ts) - self.executed_count += 1 - return merge_recs_instructions( - self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id), - self._ensure_computing(), - ) + assert not ts.next + assert not ts.done + ts.previous = cast(Literal["executing", "long-running"], ts.state) + ts.state = "cancelled" + return {}, [] def _transition_constrained_executing( self, ts: TaskState, *, stimulus_id: str @@ -2222,54 +2165,18 @@ def _transition_flight_fetch( if not ts.done: return {}, [] - ts.coming_from = None return self._transition_generic_fetch(ts, stimulus_id=stimulus_id) - def _transition_flight_error( - self, - ts: TaskState, - exception: Serialize, - traceback: Serialize | None, - exception_text: str, - traceback_text: str, - *, - stimulus_id: str, - ) -> RecsInstrs: - self.in_flight_tasks.discard(ts) - ts.coming_from = None - return self._transition_generic_error( - ts, - exception, - traceback, - exception_text, - traceback_text, - stimulus_id=stimulus_id, - ) - def _transition_flight_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - if ts.done: - # FIXME: Is this even possible? Would an assert instead be more - # sensible? - return self._transition_generic_released(ts, stimulus_id=stimulus_id) - else: - ts.previous = "flight" - ts.next = None - # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 - ts.state = "cancelled" - return {}, [] - - def _transition_cancelled_memory( - self, ts: TaskState, value: object, *, stimulus_id: str - ) -> RecsInstrs: - """We only need this because the to-memory signatures require a value but - we do not want to store a cancelled result and want to release immediately. - - See also ``_transition_cancelled_error`` - """ - assert ts.done - return self._transition_cancelled_released(ts, stimulus_id=stimulus_id) + # None of the exit events of gather_dep recommend a transition to released + assert not ts.done + ts.previous = "flight" + ts.next = None + # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 + ts.state = "cancelled" + return {}, [] def _transition_executing_long_running( self, ts: TaskState, compute_duration: float, *, stimulus_id: str @@ -2286,30 +2193,82 @@ def _transition_executing_long_running( self._ensure_computing(), ) + def _transition_executing_memory( + self, ts: TaskState, value: object, *, stimulus_id: str + ) -> RecsInstrs: + """This transition is *normally* triggered by ExecuteSuccessEvent. + However, beware that it can also be triggered by scatter(). + """ + return self._transition_to_memory( + ts, value, "task-finished", stimulus_id=stimulus_id + ) + def _transition_released_memory( self, ts: TaskState, value: object, *, stimulus_id: str ) -> RecsInstrs: - try: - recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) - except Exception as e: - msg = error_message(e) - recs = {ts: tuple(msg.values())} - return recs, [] - smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) - return recs, [smsg] + """This transition is triggered by scatter()""" + return self._transition_to_memory( + ts, value, "add-keys", stimulus_id=stimulus_id + ) def _transition_flight_memory( self, ts: TaskState, value: object, *, stimulus_id: str ) -> RecsInstrs: - self.in_flight_tasks.discard(ts) - ts.coming_from = None + """This transition is *normally* triggered by GatherDepSuccessEvent. + However, beware that it can also be triggered by scatter(). + """ + return self._transition_to_memory( + ts, value, "add-keys", stimulus_id=stimulus_id + ) + + def _transition_resumed_memory( + self, ts: TaskState, value: object, *, stimulus_id: str + ) -> RecsInstrs: + """Normally, we send to the scheduler a 'task-finished' message for a completed + execution and 'add-data' for a completed replication from another worker. The + scheduler's reaction to the two messages is fundamentally different; namely, + add-data is only admissible for tasks that are already in memory on another + worker, and won't trigger transitions. + + In the case of resumed tasks, the scheduler's expectation is set by ts.next - + which means, the opposite of what the worker actually just completed. + """ + msg_type: Literal["add-keys", "task-finished"] + if ts.previous in ("executing", "long-running"): + assert ts.next == "fetch" + msg_type = "add-keys" + else: + assert ts.previous == "flight" + assert ts.next == "waiting" + msg_type = "task-finished" + + ts.previous = None + ts.next = None + return self._transition_to_memory(ts, value, msg_type, stimulus_id=stimulus_id) + + def _transition_to_memory( + self, + ts: TaskState, + value: object, + msg_type: Literal["add-keys", "task-finished"], + *, + stimulus_id: str, + ) -> RecsInstrs: try: recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) except Exception as e: msg = error_message(e) recs = {ts: tuple(msg.values())} return recs, [] - smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) + + # NOTE: The scheduler's reaction to these two messages is fundamentally + # different. Namely, add-keys is only admissible for tasks that are already in + # memory on another worker, and won't trigger transitions. + if msg_type == "add-keys": + smsg: Instruction = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) + else: + assert msg_type == "task-finished" + smsg = self._get_task_finished_msg(ts, stimulus_id=stimulus_id) return recs, [smsg] def _transition_released_forgotten( @@ -2339,15 +2298,14 @@ def _transition_released_forgotten( Mapping[tuple[TaskStateState, TaskStateState], Callable[..., RecsInstrs]] ] = { ("cancelled", "fetch"): _transition_cancelled_fetch, - ("cancelled", "released"): _transition_cancelled_released, + ("cancelled", "error"): _transition_cancelled_released, + ("cancelled", "memory"): _transition_cancelled_released, ("cancelled", "missing"): _transition_cancelled_released, - ("cancelled", "waiting"): _transition_cancelled_waiting, - ("cancelled", "forgotten"): _transition_cancelled_forgotten, + ("cancelled", "released"): _transition_cancelled_released, ("cancelled", "rescheduled"): _transition_cancelled_released, - ("cancelled", "memory"): _transition_cancelled_memory, - ("cancelled", "error"): _transition_cancelled_error, - ("resumed", "memory"): _transition_generic_memory, - ("resumed", "error"): _transition_generic_error, + ("cancelled", "waiting"): _transition_cancelled_waiting, + ("resumed", "memory"): _transition_resumed_memory, + ("resumed", "error"): _transition_resumed_error, ("resumed", "released"): _transition_resumed_released, ("resumed", "waiting"): _transition_resumed_waiting, ("resumed", "fetch"): _transition_resumed_fetch, @@ -2355,7 +2313,7 @@ def _transition_released_forgotten( ("constrained", "executing"): _transition_constrained_executing, ("constrained", "released"): _transition_generic_released, ("error", "released"): _transition_generic_released, - ("executing", "error"): _transition_executing_error, + ("executing", "error"): _transition_generic_error, ("executing", "long-running"): _transition_executing_long_running, ("executing", "memory"): _transition_executing_memory, ("executing", "released"): _transition_executing_released, @@ -2363,12 +2321,12 @@ def _transition_released_forgotten( ("fetch", "flight"): _transition_fetch_flight, ("fetch", "missing"): _transition_generic_missing, ("fetch", "released"): _transition_generic_released, - ("flight", "error"): _transition_flight_error, + ("flight", "error"): _transition_generic_error, ("flight", "fetch"): _transition_flight_fetch, ("flight", "memory"): _transition_flight_memory, ("flight", "missing"): _transition_flight_missing, ("flight", "released"): _transition_flight_released, - ("long-running", "error"): _transition_executing_error, + ("long-running", "error"): _transition_generic_error, ("long-running", "memory"): _transition_executing_memory, ("long-running", "rescheduled"): _transition_executing_rescheduled, ("long-running", "released"): _transition_executing_released, @@ -2755,12 +2713,20 @@ def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState] """Common code for the handlers of all subclasses of GatherDepDoneEvent. Yields the tasks that need to transition out of flight. + The task states can be flight, cancelled, or resumed, but in case of scatter() + they can also be in memory or error states. + + See also + -------- + _execute_done_common """ self.comm_nbytes -= ev.total_nbytes keys = self.in_flight_workers.pop(ev.worker) for key in keys: ts = self.tasks[key] ts.done = True + ts.coming_from = None + self.in_flight_tasks.remove(ts) yield ts @_handle_event.register @@ -2925,43 +2891,60 @@ def _handle_cancel_compute(self, ev: CancelComputeEvent) -> RecsInstrs: assert not ts.dependents return {ts: "released"}, [] - @_handle_event.register - def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs: - """Task completed successfully""" - # key *must* be still in tasks. Releasing it directly is forbidden - # without going through cancelled + def _execute_done_common( + self, ev: ExecuteDoneEvent + ) -> tuple[TaskState, Recs, Instructions]: + """Common code for the handlers of all subclasses of ExecuteDoneEvent. + + The task state can be executing, cancelled, or resumed, but in case of scatter() + it can also be in memory or error state. + + See also + -------- + _gather_dep_done_common + """ + # key *must* be still in tasks - see _transition_released_forgotten ts = self.tasks.get(ev.key) assert ts, self.story(ev.key) - + if self.validate: + assert (ts in self.executing) != (ts in self.long_running) # XOR ts.done = True + + self.executed_count += 1 + self._release_resources(ts) + self.executing.discard(ts) + self.long_running.discard(ts) + + recs, instr = self._ensure_computing() + assert ts not in recs + return ts, recs, instr + + @_handle_event.register + def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs: + """Task completed successfully""" + ts, recs, instr = self._execute_done_common(ev) ts.startstops.append({"action": "compute", "start": ev.start, "stop": ev.stop}) ts.nbytes = ev.nbytes ts.type = ev.type - return {ts: ("memory", ev.value)}, [] + recs[ts] = ("memory", ev.value) + return recs, instr @_handle_event.register def _handle_execute_failure(self, ev: ExecuteFailureEvent) -> RecsInstrs: """Task execution failed""" - # key *must* be still in tasks. Releasing it directly is forbidden - # without going through cancelled - ts = self.tasks.get(ev.key) - assert ts, self.story(ev.key) - - ts.done = True + ts, recs, instr = self._execute_done_common(ev) if ev.start is not None and ev.stop is not None: ts.startstops.append( {"action": "compute", "start": ev.start, "stop": ev.stop} ) - - return { - ts: ( - "error", - ev.exception, - ev.traceback, - ev.exception_text, - ev.traceback_text, - ) - }, [] + recs[ts] = ( + "error", + ev.exception, + ev.traceback, + ev.exception_text, + ev.traceback_text, + ) + return recs, instr @_handle_event.register def _handle_reschedule(self, ev: RescheduleEvent) -> RecsInstrs: @@ -2970,13 +2953,9 @@ def _handle_reschedule(self, ev: RescheduleEvent) -> RecsInstrs: Note: this has nothing to do with work stealing, which instead causes a FreeKeysEvent. """ - # key *must* be still in tasks. Releasing it directly is forbidden - # without going through cancelled - ts = self.tasks.get(ev.key) - assert ts, self.story(ev.key) - - ts.done = True - return {ts: "rescheduled"}, [] + ts, recs, instr = self._execute_done_common(ev) + recs[ts] = "rescheduled" + return recs, instr @_handle_event.register def _handle_find_missing(self, ev: FindMissingEvent) -> RecsInstrs: @@ -3081,22 +3060,41 @@ def _validate_task_memory(self, ts: TaskState) -> None: assert not ts.waiting_for_data def _validate_task_executing(self, ts: TaskState) -> None: - if ts.state == "executing": + """Validate tasks: + + - ts.state == executing + - ts.state == long-running + - ts.state == cancelled, ts.previous == executing + - ts.state == cancelled, ts.previous == long-running + - ts.state == resumed, ts.previous == executing, ts.next == fetch + - ts.state == resumed, ts.previous == long-running, ts.next == fetch + """ + if ts.state == "executing" or ts.previous == "executing": assert ts in self.executing assert ts not in self.long_running else: - assert ts.state == "long-running" + assert ts.state == "long-running" or ts.previous == "long-running" assert ts not in self.executing assert ts in self.long_running assert ts.run_spec is not None assert ts.key not in self.data assert not ts.waiting_for_data - for dep in ts.dependencies: - assert dep.state == "memory", self.story(dep) - assert dep.key in self.data or dep.key in self.actors + + # FIXME https://github.com/dask/distributed/issues/6893 + # This assertion can be false for + # - cancelled or resumed tasks + # - executing tasks which used to be cancelled in the past + # for dep in ts.dependencies: + # assert dep.state == "memory", self.story(dep) + # assert dep.key in self.data or dep.key in self.actors def _validate_task_ready(self, ts: TaskState) -> None: + """Validate tasks: + + - ts.state == ready + - ts.state == constrained + """ if ts.state == "ready": assert not ts.resource_restrictions assert ts in self.ready @@ -3121,6 +3119,12 @@ def _validate_task_waiting(self, ts: TaskState) -> None: assert not all(dep.key in self.data for dep in ts.dependencies) def _validate_task_flight(self, ts: TaskState) -> None: + """Validate tasks: + + - ts.state == flight + - ts.state == cancelled, ts.previous == flight + - ts.state == resumed, ts.previous == flight, ts.next == waiting + """ assert ts.key not in self.data assert ts in self.in_flight_tasks for dep in ts.dependents: @@ -3147,15 +3151,21 @@ def _validate_task_missing(self, ts: TaskState) -> None: assert ts in self.missing_dep_flight def _validate_task_cancelled(self, ts: TaskState) -> None: - assert ts.key not in self.data - assert ts.previous in {"long-running", "executing", "flight"} - # We'll always transition to released after it is done assert ts.next is None + if ts.previous in ("executing", "long-running"): + self._validate_task_executing(ts) + else: + assert ts.previous == "flight" + self._validate_task_flight(ts) def _validate_task_resumed(self, ts: TaskState) -> None: - assert ts.key not in self.data - assert ts.next in {"fetch", "waiting"} - assert ts.previous in {"long-running", "executing", "flight"} + if ts.previous in ("executing", "long-running"): + assert ts.next == "fetch" + self._validate_task_executing(ts) + else: + assert ts.previous == "flight" + assert ts.next == "waiting" + self._validate_task_flight(ts) def _validate_task_released(self, ts: TaskState) -> None: assert ts.key not in self.data @@ -3257,15 +3267,18 @@ def validate_state(self) -> None: # assert ts.state == "flight" or ( # ts.state in ("cancelled", "resumed") and ts.previous == "flight" # ) - # FIXME https://github.com/dask/distributed/issues/6689 - # for ts in self.executing: - # assert ts.state == "executing" or ( - # ts.state in ("cancelled", "resumed") and ts.previous == "executing" - # ) - # for ts in self.long_running: - # assert ts.state == "long-running" or ( - # ts.state in ("cancelled", "resumed") and ts.previous == "long-running" - # ) + for ts in self.executing: + assert ts.state == "executing" or ( + ts.state in ("cancelled", "resumed") and ts.previous == "executing" + ), ts + for ts in self.long_running: + assert ts.state == "long-running" or ( + ts.state in ("cancelled", "resumed") and ts.previous == "long-running" + ), ts + for ts in self.in_flight_tasks: + assert ts.state == "flight" or ( + ts.state in ("cancelled", "resumed") and ts.previous == "flight" + ), ts # Test that there aren't multiple TaskState objects with the same key in any # Set[TaskState]. See note in TaskState.__hash__.