diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index e163cee626e..9aa8c970461 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -34,6 +34,8 @@ GatherDep, GatherDepSuccessEvent, Instruction, + InvalidTaskState, + InvalidTransition, PauseEvent, RecommendationsConflict, RefreshWhoHasEvent, @@ -44,6 +46,7 @@ SerializedTask, StateMachineEvent, TaskState, + TransitionCounterMaxExceeded, UnpauseEvent, UpdateDataEvent, merge_recs_instructions, @@ -194,6 +197,32 @@ def test_WorkerState_pickle(ws): assert ws2.data == {"y": 123} +@pytest.mark.parametrize( + "cls,kwargs", + [ + ( + InvalidTransition, + dict(key="x", start="released", finish="waiting", story=[]), + ), + ( + TransitionCounterMaxExceeded, + dict(key="x", start="released", finish="waiting", story=[]), + ), + (InvalidTaskState, dict(key="x", state="released", story=[])), + ], +) +@pytest.mark.parametrize("positional", [False, True]) +def test_pickle_exceptions(cls, kwargs, positional): + if positional: + e = cls(*kwargs.values()) + else: + e = cls(**kwargs) + e2 = pickle.loads(pickle.dumps(e)) + assert type(e2) is type(e) + for k, v in kwargs.items(): + assert getattr(e2, k) == v + + def traverse_subclasses(cls: type) -> Iterator[type]: yield cls for subcls in cls.__subclasses__(): diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 1e14fc8f7d7..56ae11ab033 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -130,6 +130,9 @@ def __init__( self.finish = finish self.story = story + def __reduce__(self) -> tuple[Callable, tuple]: + return type(self), (self.key, self.start, self.finish, self.story) + def __repr__(self) -> str: return ( f"{self.__class__.__name__}: {self.key} :: {self.start}->{self.finish}" @@ -169,6 +172,9 @@ def __init__( self.state = state self.story = story + def __reduce__(self) -> tuple[Callable, tuple]: + return type(self), (self.key, self.state, self.story) + def __repr__(self) -> str: return ( f"{self.__class__.__name__}: {self.key} :: {self.state}" @@ -2415,7 +2421,10 @@ def _transition( # final ts.state, # new recommendations - {ts.key: new for ts, new in recs.items()}, + { + ts.key: new[0] if isinstance(new, tuple) else new + for ts, new in recs.items() + }, stimulus_id, time(), )