Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,11 +606,10 @@ def test_workerstate_executing_skips_fetch_on_success(ws_with_running_task):
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"),
ExecuteSuccessEvent.dummy("x", 123, stimulus_id="s3"),
)
assert len(instructions) == 2
assert isinstance(instructions[0], TaskFinishedMsg)
assert instructions[0].key == "x"
assert instructions[0].stimulus_id == "s3"
assert instructions[1] == Execute(key="y", stimulus_id="s3")
assert instructions == [
TaskFinishedMsg.match(key="x", stimulus_id="s3"),
Execute(key="y", stimulus_id="s3"),
]
assert ws.tasks["x"].state == "memory"
assert ws.data["x"] == 123

Expand Down Expand Up @@ -662,12 +661,10 @@ def test_workerstate_flight_skips_executing_on_success(ws):
worker=ws2, total_nbytes=1, data={"x": 123}, stimulus_id="s4"
),
)
assert len(instructions) == 2
assert instructions[0] == GatherDep(
worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s1"
)
assert isinstance(instructions[1], TaskFinishedMsg)
assert instructions[1].stimulus_id == "s4"
assert instructions == [
GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s1"),
TaskFinishedMsg.match(key="x", stimulus_id="s4"),
]
assert ws.tasks["x"].state == "memory"
assert ws.data["x"] == 123

Expand Down
27 changes: 19 additions & 8 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@
)


def test_instruction_match():
i = ReleaseWorkerDataMsg(key="x", stimulus_id="s1")
assert i == ReleaseWorkerDataMsg(key="x", stimulus_id="s1")
assert i != ReleaseWorkerDataMsg(key="y", stimulus_id="s1")
assert i != ReleaseWorkerDataMsg(key="x", stimulus_id="s2")
assert i != RescheduleMsg(key="x", stimulus_id="s1")

assert i == ReleaseWorkerDataMsg.match(key="x")
assert i == ReleaseWorkerDataMsg.match(stimulus_id="s1")
assert i != ReleaseWorkerDataMsg.match(key="y")
assert i != ReleaseWorkerDataMsg.match(stimulus_id="s2")
assert i != RescheduleMsg.match(key="x")


def test_TaskState_tracking(cleanup):
gc.collect()
x = TaskState("x")
Expand Down Expand Up @@ -961,19 +975,16 @@ async def test_deprecated_worker_attributes(s, a, b):
],
)
def test_aggregate_gather_deps(ws, nbytes, n_in_flight):
ws2 = "127.0.0.1:2"
instructions = ws.handle_stimulus(
AcquireReplicasEvent(
who_has={
"x1": ["127.0.0.1:1235"],
"x2": ["127.0.0.1:1235"],
"x3": ["127.0.0.1:1235"],
},
who_has={"x1": [ws2], "x2": [ws2], "x3": [ws2]},
nbytes={"x1": nbytes, "x2": nbytes, "x3": nbytes},
stimulus_id="test",
stimulus_id="s1",
)
)
assert len(instructions) == 1
assert isinstance(instructions[0], GatherDep)
assert instructions == [GatherDep.match(worker=ws2, stimulus_id="s1")]
assert len(instructions[0].to_gather) == n_in_flight
assert len(ws.in_flight_tasks) == n_in_flight


Expand Down
54 changes: 51 additions & 3 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,56 @@ class Instruction:
__slots__ = ("stimulus_id",)
stimulus_id: str

@classmethod
def match(cls, **kwargs: Any) -> _InstructionMatch:
"""Generate a partial match to compare against an Instruction instance.
The typical usage is to compare a list of instructions returned by
:meth:`WorkerState.handle_stimulus` or in :attr:`WorkerState.stimulus_log` vs.
an expected list of matches.

Example
-------

.. code-block:: python

instructions = ws.handle_stimulus(...)
assert instructions == [
TaskFinishedMsg.match(key="x"),
...
]
"""
return _InstructionMatch(cls, **kwargs)

def __eq__(self, other: object) -> bool:
if isinstance(other, _InstructionMatch):
return other == self
else:
# Revert to default dataclass behaviour
return super().__eq__(other)


class _InstructionMatch:
"""Utility class, to be used to test an instructions list.
See :meth:`Instruction.match`.
"""

cls: type[Instruction]
kwargs: dict[str, Any]

def __init__(self, cls: type[Instruction], **kwargs: Any):
self.cls = cls
self.kwargs = kwargs

def __repr__(self) -> str:
cls_str = self.cls.__name__
kwargs_str = ", ".join(f"{k}={v}" for k, v in self.kwargs.items())
return f"{cls_str}({kwargs_str}) (partial match)"

def __eq__(self, other: object) -> bool:
if type(other) is not self.cls:
return False
return all(getattr(other, k) == v for k, v in self.kwargs.items())


@dataclass
class GatherDep(Instruction):
Expand Down Expand Up @@ -1861,9 +1911,7 @@ def _transition_cancelled_error(
# message to the scheduler since from the schedulers POV it already
# released this task
if self.validate:
assert len(instructions) == 1
assert isinstance(instructions[0], TaskErredMsg)
assert instructions[0].key == ts.key
assert instructions == [TaskErredMsg.match(key=ts.key)]
instructions.clear()
# Workers should never "retry" tasks. A transition to error should, by
# default, be the end. Since cancelled indicates that the scheduler lost
Expand Down