diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index d7bcf1746a4..3f65cfb3fc8 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -606,11 +606,10 @@ def test_workerstate_executing_skips_fetch_on_success(ws_with_running_task): ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"), ExecuteSuccessEvent.dummy("x", 123, stimulus_id="s3"), ) - assert len(instructions) == 2 - assert isinstance(instructions[0], TaskFinishedMsg) - assert instructions[0].key == "x" - assert instructions[0].stimulus_id == "s3" - assert instructions[1] == Execute(key="y", stimulus_id="s3") + assert instructions == [ + TaskFinishedMsg.match(key="x", stimulus_id="s3"), + Execute(key="y", stimulus_id="s3"), + ] assert ws.tasks["x"].state == "memory" assert ws.data["x"] == 123 @@ -662,12 +661,10 @@ def test_workerstate_flight_skips_executing_on_success(ws): worker=ws2, total_nbytes=1, data={"x": 123}, stimulus_id="s4" ), ) - assert len(instructions) == 2 - assert instructions[0] == GatherDep( - worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s1" - ) - assert isinstance(instructions[1], TaskFinishedMsg) - assert instructions[1].stimulus_id == "s4" + assert instructions == [ + GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s1"), + TaskFinishedMsg.match(key="x", stimulus_id="s4"), + ] assert ws.tasks["x"].state == "memory" assert ws.data["x"] == 123 diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index e163cee626e..4581505be3d 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -50,6 +50,20 @@ ) +def test_instruction_match(): + i = ReleaseWorkerDataMsg(key="x", stimulus_id="s1") + assert i == ReleaseWorkerDataMsg(key="x", stimulus_id="s1") + assert i != ReleaseWorkerDataMsg(key="y", stimulus_id="s1") + assert i != ReleaseWorkerDataMsg(key="x", stimulus_id="s2") + assert i != RescheduleMsg(key="x", stimulus_id="s1") + + assert i == ReleaseWorkerDataMsg.match(key="x") + assert i == ReleaseWorkerDataMsg.match(stimulus_id="s1") + assert i != ReleaseWorkerDataMsg.match(key="y") + assert i != ReleaseWorkerDataMsg.match(stimulus_id="s2") + assert i != RescheduleMsg.match(key="x") + + def test_TaskState_tracking(cleanup): gc.collect() x = TaskState("x") @@ -961,19 +975,16 @@ async def test_deprecated_worker_attributes(s, a, b): ], ) def test_aggregate_gather_deps(ws, nbytes, n_in_flight): + ws2 = "127.0.0.1:2" instructions = ws.handle_stimulus( AcquireReplicasEvent( - who_has={ - "x1": ["127.0.0.1:1235"], - "x2": ["127.0.0.1:1235"], - "x3": ["127.0.0.1:1235"], - }, + who_has={"x1": [ws2], "x2": [ws2], "x3": [ws2]}, nbytes={"x1": nbytes, "x2": nbytes, "x3": nbytes}, - stimulus_id="test", + stimulus_id="s1", ) ) - assert len(instructions) == 1 - assert isinstance(instructions[0], GatherDep) + assert instructions == [GatherDep.match(worker=ws2, stimulus_id="s1")] + assert len(instructions[0].to_gather) == n_in_flight assert len(ws.in_flight_tasks) == n_in_flight diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 1e14fc8f7d7..a2f6e431dd3 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -340,6 +340,56 @@ class Instruction: __slots__ = ("stimulus_id",) stimulus_id: str + @classmethod + def match(cls, **kwargs: Any) -> _InstructionMatch: + """Generate a partial match to compare against an Instruction instance. + The typical usage is to compare a list of instructions returned by + :meth:`WorkerState.handle_stimulus` or in :attr:`WorkerState.stimulus_log` vs. + an expected list of matches. + + Example + ------- + + .. code-block:: python + + instructions = ws.handle_stimulus(...) + assert instructions == [ + TaskFinishedMsg.match(key="x"), + ... + ] + """ + return _InstructionMatch(cls, **kwargs) + + def __eq__(self, other: object) -> bool: + if isinstance(other, _InstructionMatch): + return other == self + else: + # Revert to default dataclass behaviour + return super().__eq__(other) + + +class _InstructionMatch: + """Utility class, to be used to test an instructions list. + See :meth:`Instruction.match`. + """ + + cls: type[Instruction] + kwargs: dict[str, Any] + + def __init__(self, cls: type[Instruction], **kwargs: Any): + self.cls = cls + self.kwargs = kwargs + + def __repr__(self) -> str: + cls_str = self.cls.__name__ + kwargs_str = ", ".join(f"{k}={v}" for k, v in self.kwargs.items()) + return f"{cls_str}({kwargs_str}) (partial match)" + + def __eq__(self, other: object) -> bool: + if type(other) is not self.cls: + return False + return all(getattr(other, k) == v for k, v in self.kwargs.items()) + @dataclass class GatherDep(Instruction): @@ -1861,9 +1911,7 @@ def _transition_cancelled_error( # message to the scheduler since from the schedulers POV it already # released this task if self.validate: - assert len(instructions) == 1 - assert isinstance(instructions[0], TaskErredMsg) - assert instructions[0].key == ts.key + assert instructions == [TaskErredMsg.match(key=ts.key)] instructions.clear() # Workers should never "retry" tasks. A transition to error should, by # default, be the end. Since cancelled indicates that the scheduler lost