Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ComputeTaskEvent,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FreeKeysEvent,
GatherDep,
Instruction,
PauseEvent,
Expand Down Expand Up @@ -1037,3 +1038,58 @@ async def test_clean_log(s, a, b):
"""Test that brand new workers start with a clean log"""
assert not a.state.log
assert not a.state.stimulus_log


def test_running_task_in_all_running_tasks(ws_with_running_task):
ws = ws_with_running_task

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ws = ws_with_running_task
ws = ws_with_running_task
ts = ws.tasks["x"]
assert ts in ws.all_running_tasks

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's our philosophy here? Do we care about testing the precondition/setup? This is already being tested in a different unit test.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing the same thing in multiple places is not a good idea, unless it helps improving clarity.
In this case, it makes it absolutely clear to the reader that the task was there to begin with and it left, instead of forcing them to figure out if another test is checking for it.

ws2 = "127.0.0.1:2"
ts = ws.tasks["x"]
assert ts in ws.all_running_tasks

ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1"))
assert ts.state == "cancelled"
assert ts in ws.all_running_tasks

ws.handle_stimulus(
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2")
)
assert ts.state == "resumed"
assert ts in ws.all_running_tasks


@pytest.mark.xfail(reason="distributed#6565, distributed#6692")
@pytest.mark.parametrize(
"done_ev_cls,done_status",
[(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")],
)
def test_done_task_not_in_all_running_tasks(
ws_with_running_task, done_ev_cls, done_status
):
ws = ws_with_running_task
ts = ws.tasks["x"]
assert ts in ws.all_running_tasks

ws.handle_stimulus(done_ev_cls.dummy("x", stimulus_id="s1"))
assert ts.state == done_status
assert ts not in ws.all_running_tasks


@pytest.mark.xfail(reason="distributed#6565, distributed#6689, distributed#6692")
@pytest.mark.parametrize(
"done_ev_cls,done_status",
[(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")],
)
def test_done_resumed_task_not_in_all_running_tasks(
ws_with_running_task, done_ev_cls, done_status
):
ws = ws_with_running_task
ws2 = "127.0.0.1:2"

ws.handle_stimulus(
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"),
done_ev_cls.dummy("x", stimulus_id="s3"),
)
ts = ws.tasks["x"]
assert ts.state == done_status
assert ts not in ws.all_running_tasks
27 changes: 21 additions & 6 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,7 @@ def handle_stimulus(self, *stims: StateMachineEvent) -> Instructions:
@property
def executing_count(self) -> int:
"""Count of tasks currently executing on this worker.
Does not include long running (a.k.a. seceded) and cancelled tasks.

See also
--------
Expand All @@ -1212,6 +1213,17 @@ def executing_count(self) -> int:
"""
return len(self.executing)

@property
def all_running_tasks(self) -> set[TaskState]:
"""All tasks that are currently occupying a thread.
These are:

- ``ts.status in ("executing", "long-running", "cancelled")``
- ``ts.status == "resumed" and ts._previous in ("executing", "long-running")``
"""
# Note: cancelled and resumed tasks are still in either of these sets
return self.executing | {self.tasks[key] for key in self.long_running}

@property
def in_flight_tasks_count(self) -> int:
"""Count of tasks currently being replicated from other workers to this one.
Expand Down Expand Up @@ -1981,7 +1993,7 @@ def _transition_cancelled_fetch(
ts.state = ts._previous
return {}, []
else:
assert ts._previous == "executing"
assert ts._previous in {"executing", "long-running"}
ts.state = "resumed"
ts._next = "fetch"
return {}, []
Expand Down Expand Up @@ -3119,11 +3131,14 @@ def validate_state(self) -> None:
waiting_for_data_count += 1
for ts_wait in ts.waiting_for_data:
assert ts_wait.key in self.tasks
assert (
ts_wait.state in READY | {"executing", "flight", "fetch", "missing"}
or ts_wait in self.missing_dep_flight
or ts_wait.who_has.issubset(self.in_flight_workers)
), (ts, ts_wait, self.story(ts), self.story(ts_wait))
assert ts_wait.state in READY | {
"executing",
"long-running",
"resumed",
"flight",
"fetch",
"missing",
}, (ts, ts_wait, self.story(ts), self.story(ts_wait))
# FIXME https://github.com/dask/distributed/issues/6319
# assert self.waiting_for_data_count == waiting_for_data_count
for worker, keys in self.has_what.items():
Expand Down