Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def test_workerstate_executing_skips_fetch_on_success(ws_with_running_task):
assert ws.data["x"] == 123


@pytest.mark.xfail(reason="distributed#6565, distributed#6689")
@pytest.mark.xfail(reason="distributed#6689")
def test_workerstate_executing_failure_to_fetch(ws_with_running_task):
"""Test state loops:

Expand Down
11 changes: 8 additions & 3 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,6 @@ def test_running_task_in_all_running_tasks(ws_with_running_task):
assert ts in ws.all_running_tasks


@pytest.mark.xfail(reason="distributed#6565, distributed#6692")
@pytest.mark.parametrize(
"done_ev_cls,done_status",
[(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")],
Expand All @@ -1074,10 +1073,16 @@ def test_done_task_not_in_all_running_tasks(
assert ts not in ws.all_running_tasks


@pytest.mark.xfail(reason="distributed#6565, distributed#6689, distributed#6692")
@pytest.mark.parametrize(
"done_ev_cls,done_status",
[(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")],
[
(ExecuteSuccessEvent, "memory"),
pytest.param(
ExecuteFailureEvent,
"error",
marks=pytest.mark.xfail(reason="distributed#6689"),
),
],
)
def test_done_resumed_task_not_in_all_running_tasks(
ws_with_running_task, done_ev_cls, done_status
Expand Down
58 changes: 38 additions & 20 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,10 +1092,10 @@ class WorkerState:
#: See also :meth:`executing_count` and :attr:`long_runing`.
executing: set[TaskState]

#: Set of keys of tasks that are currently running and have called
#: Set of tasks that are currently running and have called
#: :func:`~distributed.secede`.
#: These tasks do not appear in the :attr:`executing` set.
long_running: set[str]
long_running: set[TaskState]

#: A number of tasks that this worker has run in its lifetime.
#: See also :meth:`executing_count`.
Expand Down Expand Up @@ -1234,7 +1234,7 @@ def all_running_tasks(self) -> set[TaskState]:
- ``ts.status == "resumed" and ts._previous in ("executing", "long-running")``
"""
# Note: cancelled and resumed tasks are still in either of these sets
return self.executing | {self.tasks[key] for key in self.long_running}
return self.executing | self.long_running

@property
def in_flight_tasks_count(self) -> int:
Expand Down Expand Up @@ -1335,6 +1335,7 @@ def _purge_state(self, ts: TaskState) -> None:
ts.done = False

self.executing.discard(ts)
self.long_running.discard(ts)
self.in_flight_tasks.discard(ts)

def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
Expand Down Expand Up @@ -1801,6 +1802,7 @@ def _transition_executing_rescheduled(

self._release_resources(ts)
self.executing.discard(ts)
self.long_running.discard(ts)

return merge_recs_instructions(
({}, [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)]),
Expand Down Expand Up @@ -1838,7 +1840,7 @@ def _transition_cancelled_error(
*,
stimulus_id: str,
) -> RecsInstrs:
assert ts._previous == "executing" or ts.key in self.long_running
assert ts._previous in ("executing", "long-running")
recs, instructions = self._transition_executing_error(
ts,
exception,
Expand Down Expand Up @@ -1897,6 +1899,7 @@ def _transition_executing_error(
) -> RecsInstrs:
self._release_resources(ts)
self.executing.discard(ts)
self.long_running.discard(ts)

return merge_recs_instructions(
self._transition_generic_error(
Expand Down Expand Up @@ -1949,7 +1952,7 @@ def _transition_from_resumed(
assert finish != "memory"
next_state = ts._next
assert next_state in {"waiting", "fetch"}, next_state
assert ts._previous in {"executing", "flight"}, ts._previous
assert ts._previous in {"executing", "long-running", "flight"}, ts._previous

if next_state != finish:
recs, instructions = self._transition_generic_released(
Expand Down Expand Up @@ -2005,7 +2008,7 @@ def _transition_cancelled_fetch(
ts.state = ts._previous
return {}, []
else:
assert ts._previous in {"executing", "long-running"}
assert ts._previous in ("executing", "long-running")
ts.state = "resumed"
ts._next = "fetch"
return {}, []
Expand Down Expand Up @@ -2038,6 +2041,7 @@ def _transition_cancelled_released(
if not ts.done:
return {}, []
self.executing.discard(ts)
self.long_running.discard(ts)
self.in_flight_tasks.discard(ts)

self._release_resources(ts)
Expand All @@ -2053,12 +2057,6 @@ def _transition_executing_released(
ts.done = False
return {}, []

def _transition_long_running_memory(
self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str
) -> RecsInstrs:
self.executed_count += 1
return self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id)

def _transition_generic_memory(
self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str
) -> RecsInstrs:
Expand All @@ -2069,6 +2067,7 @@ def _transition_generic_memory(

self._release_resources(ts)
self.executing.discard(ts)
self.long_running.discard(ts)
self.in_flight_tasks.discard(ts)
ts.coming_from = None

Expand All @@ -2091,11 +2090,12 @@ def _transition_executing_memory(
self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str
) -> RecsInstrs:
if self.validate:
assert ts.state == "executing" or ts.key in self.long_running
assert ts.state in ("executing", "long-running")
assert not ts.waiting_for_data
assert ts.key not in self.ready

self.executing.discard(ts)
self.long_running.discard(ts)
self.executed_count += 1
return merge_recs_instructions(
self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id),
Expand Down Expand Up @@ -2195,7 +2195,7 @@ def _transition_executing_long_running(
) -> RecsInstrs:
ts.state = "long-running"
self.executing.discard(ts)
self.long_running.add(ts.key)
self.long_running.add(ts)

smsg = LongRunningMsg(
key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id
Expand Down Expand Up @@ -2287,8 +2287,8 @@ def _transition_released_forgotten(
("flight", "memory"): _transition_flight_memory,
("flight", "missing"): _transition_flight_missing,
("flight", "released"): _transition_flight_released,
("long-running", "error"): _transition_generic_error,
("long-running", "memory"): _transition_long_running_memory,
("long-running", "error"): _transition_executing_error,
("long-running", "memory"): _transition_executing_memory,
("long-running", "rescheduled"): _transition_executing_rescheduled,
("long-running", "released"): _transition_executing_released,
("memory", "released"): _transition_memory_released,
Expand Down Expand Up @@ -2984,7 +2984,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
for w, tss in self.data_needed.items()
},
"executing": {ts.key for ts in self.executing},
"long_running": self.long_running,
"long_running": {ts.key for ts in self.long_running},
"in_flight_tasks": {ts.key for ts in self.in_flight_tasks},
"in_flight_workers": self.in_flight_workers,
"busy_workers": self.busy_workers,
Expand All @@ -3008,7 +3008,14 @@ def _validate_task_memory(self, ts: TaskState) -> None:
assert ts.state == "memory"

def _validate_task_executing(self, ts: TaskState) -> None:
assert ts.state == "executing"
if ts.state == "executing":
assert ts in self.executing
Comment thread
crusaderky marked this conversation as resolved.
assert ts not in self.long_running
else:
assert ts.state == "long-running"
assert ts not in self.executing
assert ts in self.long_running
Comment thread
crusaderky marked this conversation as resolved.

assert ts.run_spec is not None
assert ts.key not in self.data
assert not ts.waiting_for_data
Expand Down Expand Up @@ -3065,7 +3072,7 @@ def _validate_task_cancelled(self, ts: TaskState) -> None:

def _validate_task_resumed(self, ts: TaskState) -> None:
assert ts.key not in self.data
assert ts._next
assert ts._next in {"fetch", "waiting"}
assert ts._previous in {"long-running", "executing", "flight"}

def _validate_task_released(self, ts: TaskState) -> None:
Expand Down Expand Up @@ -3106,7 +3113,7 @@ def validate_task(self, ts: TaskState) -> None:
self._validate_task_resumed(ts)
elif ts.state == "ready":
self._validate_task_ready(ts)
elif ts.state == "executing":
elif ts.state in ("executing", "long-running"):
self._validate_task_executing(ts)
elif ts.state == "flight":
self._validate_task_flight(ts)
Expand Down Expand Up @@ -3157,13 +3164,24 @@ def validate_state(self) -> None:
assert ts.state == "fetch"
assert worker in ts.who_has

# FIXME https://github.com/dask/distributed/issues/6689
# for ts in self.executing:
# assert ts.state == "executing" or (
# ts.state in ("cancelled", "resumed") and ts._previous == "executing"
# ), self.story(ts)
# for ts in self.long_running:
# assert ts.state == "long-running" or (
# ts.state in ("cancelled", "resumed") and ts._previous == "long-running"
# ), self.story(ts)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as #6689 : there are tasks with "error" state in self.executing.


# Test that there aren't multiple TaskState objects with the same key in any
# Set[TaskState]. See note in TaskState.__hash__.
for ts in chain(
*self.data_needed.values(),
self.missing_dep_flight,
self.in_flight_tasks,
self.executing,
self.long_running,
):
assert self.tasks[ts.key] is ts

Expand Down