diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b01d3f06ad0..3a72976c47b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4876,6 +4876,7 @@ def stimulus_task_erred( exception=None, stimulus_id=None, traceback=None, + run_id=None, **kwargs, ): """Mark that a task has erred on a particular worker""" @@ -4885,6 +4886,11 @@ def stimulus_task_erred( if ts is None or ts.state != "processing": return {}, {}, {} + if ts.run_id != run_id: + if ts.processing_on and ts.processing_on.address == worker: + return self._transition(key, "released", stimulus_id) + return {}, {}, {} + if ts.retries > 0: ts.retries -= 1 return self._transition(key, "waiting", stimulus_id) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 7f95ab36ddc..1609972fa8b 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -14,6 +14,7 @@ _LockedCommPool, assert_story, async_poll_for, + freeze_batched_send, gen_cluster, inc, lock_inc, @@ -488,6 +489,8 @@ async def release_all_futures(): await lock_compute.release() await exit_compute.wait() + await async_poll_for(lambda: f3.key not in b.state.tasks, timeout=5) + f1 = c.submit(inc, 1, key="f1", workers=[a.address]) f2 = c.submit(inc, f1, key="f2", workers=[a.address]) f3 = c.submit(inc, f2, key="f3", workers=[b.address]) @@ -569,8 +572,7 @@ async def release_all_futures(): ) elif wait_for_processing and raise_error: - with pytest.raises(RuntimeError, match="test error"): - await f3 + assert await f4 == 4 + 2 assert_story( b.state.story(f3.key), @@ -581,20 +583,31 @@ async def release_all_futures(): (f3.key, "resumed", "released", "cancelled", {}), (f3.key, "cancelled", "waiting", "executing", {}), (f3.key, "executing", "error", "error", {}), - # FIXME: (distributed#7489) + ( + f3.key, + "error", + "released", + "released", + {f2.key: "released", f3.key: "forgotten"}, + ), + (f3.key, "released", "forgotten", "forgotten", {f2.key: "forgotten"}), + (f3.key, "ready", "executing", "executing", {}), + (f3.key, "executing", "memory", "memory", {}), ], ) else: assert False, "unreachable" +@pytest.mark.parametrize("raise_error", [True, False]) @gen_cluster(client=True) -async def test_cancelled_handle_compute(c, s, a, b): +async def test_cancelled_handle_compute(c, s, a, b, raise_error): """ Given the history of a task executing -> cancelled - A handle_compute should properly restore executing. + A handle_compute should cause the result of the cancelled task to be rejected + by the scheduler and the task to be re-run. See Also -------- @@ -611,6 +624,8 @@ async def test_cancelled_handle_compute(c, s, a, b): def block(x, lock, enter_event, exit_event): enter_event.set() with lock: + if raise_error: + raise RuntimeError("test error") return x + 1 f1 = c.submit(inc, 1, key="f1", workers=[a.address]) @@ -650,22 +665,151 @@ async def release_all_futures(): assert await f4 == 4 + 2 - story = b.state.story(f3.key) + if raise_error: + assert_story( + b.state.story(f3.key), + expect=[ + (f3.key, "ready", "executing", "executing", {}), + (f3.key, "executing", "released", "cancelled", {}), + (f3.key, "cancelled", "waiting", "executing", {}), + (f3.key, "executing", "error", "error", {}), + ( + f3.key, + "error", + "released", + "released", + {f2.key: "released", f3.key: "forgotten"}, + ), + (f3.key, "released", "forgotten", "forgotten", {f2.key: "forgotten"}), + (f3.key, "ready", "executing", "executing", {}), + (f3.key, "executing", "memory", "memory", {}), + ], + ) + else: + assert_story( + b.state.story(f3.key), + expect=[ + (f3.key, "ready", "executing", "executing", {}), + (f3.key, "executing", "released", "cancelled", {}), + (f3.key, "cancelled", "waiting", "executing", {}), + (f3.key, "executing", "memory", "memory", {}), + ( + f3.key, + "memory", + "released", + "released", + {f2.key: "released", f3.key: "forgotten"}, + ), + (f3.key, "released", "forgotten", "forgotten", {f2.key: "forgotten"}), + (f3.key, "ready", "executing", "executing", {}), + (f3.key, "executing", "memory", "memory", {}), + ], + ) + + +@gen_cluster(client=True) +async def test_cancelled_task_error_rejected(c, s, a, b): + """ + Given the history of a task + executing -> cancelled + + An error in the cancelled task is rejected by the scheduler and superseded + by a more recent run on another worker. + + """ + # This test is heavily using set_restrictions to simulate certain scheduler + # decisions of placing keys + + lock_erring = Lock() + enter_compute_erring = Event() + exit_compute_erring = Event() + lock_successful = Lock() + enter_compute_successful = Event() + exit_compute_successful = Event() + + await lock_erring.acquire() + await lock_successful.acquire() + + def block(x, lock, enter_event, exit_event, raise_error): + enter_event.set() + try: + with lock: + if raise_error: + raise RuntimeError("test_error") + return x + 1 + finally: + exit_event.set() + + f1 = c.submit(inc, 1, key="f1", workers=[a.address]) + f2 = c.submit(inc, f1, key="f2", workers=[a.address]) + f3 = c.submit( + block, + f2, + lock=lock_erring, + enter_event=enter_compute_erring, + exit_event=exit_compute_erring, + raise_error=True, + key="f3", + workers=[b.address], + ) + + f4 = c.submit(sum, [f1, f3], key="f4", workers=[b.address]) + + await enter_compute_erring.wait() + + async def release_all_futures(): + futs = [f1, f2, f3, f4] + for fut in futs: + fut.release() + + while any(fut.key in s.tasks for fut in futs): + await asyncio.sleep(0.05) + + with freeze_batched_send(s.stream_comms[b.address]): + await release_all_futures() + + f1 = c.submit(inc, 1, key="f1", workers=[a.address]) + f2 = c.submit(inc, f1, key="f2", workers=[a.address]) + f3 = c.submit( + block, + f2, + lock=lock_successful, + enter_event=enter_compute_successful, + exit_event=exit_compute_successful, + raise_error=False, + key="f3", + workers=[a.address], + ) + f4 = c.submit(sum, [f1, f3], key="f4", workers=[b.address]) + + await wait_for_state(f3.key, "processing", s) + await enter_compute_successful.wait() + + await lock_erring.release() + await wait_for_state(f3.key, "error", b) + + await lock_successful.release() + assert await f4 == 4 + 2 + assert_story( b.state.story(f3.key), expect=[ (f3.key, "ready", "executing", "executing", {}), - (f3.key, "executing", "released", "cancelled", {}), - (f3.key, "cancelled", "waiting", "executing", {}), - (f3.key, "executing", "memory", "memory", {}), + (f3.key, "executing", "error", "error", {}), ( f3.key, - "memory", + "error", "released", "released", - {f2.key: "released", f3.key: "forgotten"}, + {f3.key: "forgotten"}, ), - (f3.key, "released", "forgotten", "forgotten", {f2.key: "forgotten"}), + (f3.key, "released", "forgotten", "forgotten", {}), + ], + ) + + assert_story( + a.state.story(f3.key), + expect=[ (f3.key, "ready", "executing", "executing", {}), (f3.key, "executing", "memory", "memory", {}), ], @@ -787,7 +931,7 @@ def test_workerstate_executing_failure_to_fetch(ws_with_running_task): - executing -> long-running -> cancelled -> resumed(fetch) The task execution later terminates with a failure. - This is an edge case interaction between work stealing and a task that does not + This is an edge case interaction involving task cancellation and a task that does not deterministically succeed or fail when run multiple times or on different workers. Test that the task is fetched from the other worker. This is to avoid having to deal diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index b2817363001..1c87d1d39f4 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -531,6 +531,7 @@ def test_executefailure_to_dict(): ev = ExecuteFailureEvent( stimulus_id="test", key="x", + run_id=1, start=123.4, stop=456.7, exception=Serialize(ValueError("foo")), @@ -546,6 +547,7 @@ def test_executefailure_to_dict(): "stimulus_id": "test", "handled": 11.22, "key": "x", + "run_id": 1, "start": 123.4, "stop": 456.7, "exception": "", @@ -571,6 +573,7 @@ def test_executefailure_dummy(): ev = ExecuteFailureEvent.dummy("x", stimulus_id="s") assert ev == ExecuteFailureEvent( key="x", + run_id=1, start=None, stop=None, exception=Serialize(None), diff --git a/distributed/worker.py b/distributed/worker.py index 9e2e94b2b63..b7841ba2f04 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2241,6 +2241,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: return ExecuteFailureEvent.from_exception( exc, key=key, + run_id=run_id, stimulus_id=f"run-spec-deserialize-failed-{time()}", ) @@ -2365,6 +2366,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: return ExecuteFailureEvent.from_exception( result, key=key, + run_id=run_id, start=result["start"], stop=result["stop"], stimulus_id=f"task-erred-{time()}", @@ -2375,6 +2377,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: return ExecuteFailureEvent.from_exception( exc, key=key, + run_id=run_id, stimulus_id=f"execute-unknown-error-{time()}", ) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 504eddfe40e..d71a8e0258f 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -482,6 +482,7 @@ class TaskErredMsg(SendMessageToScheduler): op = "task-erred" key: str + run_id: int exception: Serialize traceback: Serialize | None exception_text: str @@ -497,11 +498,12 @@ def to_dict(self) -> dict[str, Any]: @staticmethod def from_task( - ts: TaskState, stimulus_id: str, thread: int | None = None + ts: TaskState, run_id: int, stimulus_id: str, thread: int | None = None ) -> TaskErredMsg: assert ts.exception return TaskErredMsg( key=ts.key, + run_id=run_id, exception=ts.exception, traceback=ts.traceback, exception_text=ts.exception_text, @@ -903,6 +905,7 @@ def dummy( @dataclass class ExecuteFailureEvent(ExecuteDoneEvent): + run_id: int # FIXME: Utilize the run ID in all ExecuteDoneEvents start: float | None stop: float | None exception: Serialize @@ -921,6 +924,7 @@ def from_exception( err_or_msg: BaseException | ErrorMessage, *, key: str, + run_id: int, start: float | None = None, stop: float | None = None, stimulus_id: str, @@ -932,6 +936,7 @@ def from_exception( return cls( key=key, + run_id=run_id, start=start, stop=stop, exception=msg["exception"], @@ -945,6 +950,7 @@ def from_exception( def dummy( key: str, *, + run_id: int = 1, stimulus_id: str, ) -> ExecuteFailureEvent: """Build a dummy event, with most attributes set to a reasonable default. @@ -952,6 +958,7 @@ def dummy( """ return ExecuteFailureEvent( key=key, + run_id=run_id, start=None, stop=None, exception=Serialize(None), @@ -2025,6 +2032,7 @@ def _transition_generic_error( traceback: Serialize | None, exception_text: str, traceback_text: str, + run_id: int, *, stimulus_id: str, ) -> RecsInstrs: @@ -2035,6 +2043,7 @@ def _transition_generic_error( ts.state = "error" smsg = TaskErredMsg.from_task( ts, + run_id=run_id, stimulus_id=stimulus_id, thread=self.threads.get(ts.key), ) @@ -2048,6 +2057,7 @@ def _transition_resumed_error( traceback: Serialize | None, exception_text: str, traceback_text: str, + run_id: int, *, stimulus_id: str, ) -> RecsInstrs: @@ -2438,7 +2448,7 @@ def _transition_to_memory( # Third-party MutableMappings (dask-cuda etc.) may have other use cases # for this. msg = error_message(e) - return {ts: tuple(msg.values())}, [] + return {ts: tuple(msg.values()) + (run_id,)}, [] stop = time() if stop - start > 0.005: @@ -2859,7 +2869,9 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs: ) ) elif ts.state == "error": - instructions.append(TaskErredMsg.from_task(ts, stimulus_id=ev.stimulus_id)) + instructions.append( + TaskErredMsg.from_task(ts, run_id=ev.run_id, stimulus_id=ev.stimulus_id) + ) elif ts.state in { "released", "fetch", @@ -3042,6 +3054,7 @@ def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs: ev.traceback, ev.exception_text, ev.traceback_text, + ts.run_id, ) for ts in self._gather_dep_done_common(ev) } @@ -3179,6 +3192,7 @@ def _handle_execute_failure(self, ev: ExecuteFailureEvent) -> RecsInstrs: ev.traceback, ev.exception_text, ev.traceback_text, + ev.run_id, ) return recs, instr