diff --git a/distributed/collections.py b/distributed/collections.py index c25001047f6..b074c353ef1 100644 --- a/distributed/collections.py +++ b/distributed/collections.py @@ -53,6 +53,22 @@ def __init__(self, *, key: Callable[[T], Any]): def __repr__(self) -> str: return f"<{type(self).__name__}: {len(self)} items>" + def __reduce__(self) -> tuple[Callable, tuple]: + heap = [(k, i, v) for k, i, vref in self._heap if (v := vref()) in self._data] + return HeapSet._unpickle, (self.key, self._inc, heap) + + @staticmethod + def _unpickle( + key: Callable[[T], Any], inc: int, heap: list[tuple[Any, int, T]] + ) -> HeapSet[T]: + self = object.__new__(HeapSet) + self.key = key # type: ignore + self._data = {v for _, _, v in heap} + self._inc = inc + self._heap = [(k, i, weakref.ref(v)) for k, i, v in heap] + heapq.heapify(self._heap) + return self + def __contains__(self, value: object) -> bool: return value in self._data diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 4fd5e7b1151..9b752b37285 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -1,5 +1,9 @@ from __future__ import annotations +import operator +import pickle +import random + import pytest from distributed.collections import LRU, HeapSet @@ -22,19 +26,20 @@ def test_lru(): assert list(l.keys()) == ["c", "a", "d"] -def test_heapset(): - class C: - def __init__(self, k, i): - self.k = k - self.i = i +class C: + def __init__(self, k, i): + self.k = k + self.i = i - def __hash__(self): - return hash(self.k) + def __hash__(self): + return hash(self.k) - def __eq__(self, other): - return isinstance(other, C) and other.k == self.k + def __eq__(self, other): + return isinstance(other, C) and other.k == self.k - heap = HeapSet(key=lambda c: c.i) + +def test_heapset(): + heap = HeapSet(key=operator.attrgetter("i")) cx = C("x", 2) cy = C("y", 1) @@ -150,3 +155,29 @@ def __init__(self, i): heap.add(C("unsortable_key", None)) assert len(heap) == 1 assert set(heap) == {cx} + + +def test_heapset_pickle(): + """Test pickle roundtrip for a HeapSet. + + Note + ---- + To make this test work with plain pickle and not need cloudpickle, we had to avoid + lambdas and local classes in our test. Here we're testing that HeapSet doesn't add + lambdas etc. of its own. + """ + heap = HeapSet(key=operator.attrgetter("i")) + + # The heap contains broken weakrefs + for i in range(200): + c = C(f"y{i}", random.random()) + heap.add(c) + if random.random() > 0.7: + heap.remove(c) + + heap2 = pickle.loads(pickle.dumps(heap)) + assert len(heap) == len(heap2) + # Test that the heap has been re-heapified upon unpickle + assert len(heap2._heap) < len(heap._heap) + while heap: + assert heap.pop() == heap2.pop() diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 667032f6e50..aa67892fa15 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -2,6 +2,7 @@ import asyncio import gc +import pickle from collections.abc import Iterator import pytest @@ -163,6 +164,27 @@ def test_WorkerState__to_dict(ws): assert actual == expect +def test_WorkerState_pickle(ws): + """Test pickle round-trip. + + Big caveat + ---------- + WorkerState, on its own, can be serialized with pickle; it doesn't need cloudpickle. + A WorkerState extracted from a Worker might, as data contents may only be + serializable with cloudpickle. Some objects created externally and not designed + for network transfer - namely, the SpillBuffer - may not be serializable at all. + """ + ws.handle_stimulus( + AcquireReplicasEvent( + who_has={"x": ["127.0.0.1:1235"]}, nbytes={"x": 123}, stimulus_id="s1" + ) + ) + ws.handle_stimulus(UpdateDataEvent(data={"y": 123}, report=False, stimulus_id="s")) + ws2 = pickle.loads(pickle.dumps(ws)) + assert ws2.tasks.keys() == {"x", "y"} + assert ws2.data == {"y": 123} + + def traverse_subclasses(cls: type) -> Iterator[type]: yield cls for subcls in cls.__subclasses__(): diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 7dd3431ef14..40dab843982 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -2,7 +2,6 @@ import abc import asyncio -import functools import heapq import logging import operator @@ -21,7 +20,7 @@ ) from copy import copy from dataclasses import dataclass, field -from functools import lru_cache +from functools import lru_cache, partial, singledispatchmethod from itertools import chain from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast @@ -1086,7 +1085,7 @@ def __init__( self.has_what = defaultdict(set) self.data_needed = HeapSet(key=operator.attrgetter("priority")) self.data_needed_per_worker = defaultdict( - lambda: HeapSet(key=operator.attrgetter("priority")) + partial(HeapSet[TaskState], key=operator.attrgetter("priority")) ) self.in_flight_workers = {} self.busy_workers = set() @@ -2324,7 +2323,7 @@ def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> Instructio # Events # ########## - @functools.singledispatchmethod + @singledispatchmethod def _handle_event(self, ev: StateMachineEvent) -> RecsInstrs: raise TypeError(ev) # pragma: nocover