From b834918a8502f531f44cc6de6d23659f1ea8b85c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 24 Jun 2022 00:46:29 +0100 Subject: [PATCH 1/4] Pickle WorkerState --- distributed/collections.py | 16 ++++++ distributed/tests/test_collections.py | 51 +++++++++++++++---- .../tests/test_worker_state_machine.py | 22 ++++++++ distributed/worker_state_machine.py | 7 ++- 4 files changed, 82 insertions(+), 14 deletions(-) diff --git a/distributed/collections.py b/distributed/collections.py index c25001047f6..b074c353ef1 100644 --- a/distributed/collections.py +++ b/distributed/collections.py @@ -53,6 +53,22 @@ def __init__(self, *, key: Callable[[T], Any]): def __repr__(self) -> str: return f"<{type(self).__name__}: {len(self)} items>" + def __reduce__(self) -> tuple[Callable, tuple]: + heap = [(k, i, v) for k, i, vref in self._heap if (v := vref()) in self._data] + return HeapSet._unpickle, (self.key, self._inc, heap) + + @staticmethod + def _unpickle( + key: Callable[[T], Any], inc: int, heap: list[tuple[Any, int, T]] + ) -> HeapSet[T]: + self = object.__new__(HeapSet) + self.key = key # type: ignore + self._data = {v for _, _, v in heap} + self._inc = inc + self._heap = [(k, i, weakref.ref(v)) for k, i, v in heap] + heapq.heapify(self._heap) + return self + def __contains__(self, value: object) -> bool: return value in self._data diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 25853495ebf..9664e84a66e 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -1,3 +1,7 @@ +import operator +import pickle +import random + import pytest from distributed.collections import LRU, HeapSet @@ -20,19 +24,20 @@ def test_lru(): assert list(l.keys()) == ["c", "a", "d"] -def test_heapset(): - class C: - def __init__(self, k, i): - self.k = k - self.i = i +class C: + def __init__(self, k, i): + self.k = k + self.i = i - def __hash__(self): - return hash(self.k) + def __hash__(self): + return hash(self.k) - def __eq__(self, other): - return isinstance(other, C) and other.k == self.k + def __eq__(self, other): + return isinstance(other, C) and other.k == self.k - heap = HeapSet(key=lambda c: c.i) + +def test_heapset(): + heap = HeapSet(key=operator.attrgetter("i")) cx = C("x", 2) cy = C("y", 1) @@ -148,3 +153,29 @@ def __init__(self, i): heap.add(C("unsortable_key", None)) assert len(heap) == 1 assert set(heap) == {cx} + + +def test_heapset_pickle(): + """Test pickle roundtrip for a HeapSet. + + Note + ---- + To make this test work with plain pickle and not need cloudpickle, we had to avoid + lambdas and local classes in our test. Here we're testing that HeapSet doesn't add + lambdas etc. of its own. + """ + heap = HeapSet(key=operator.attrgetter("i")) + + # Test edge case with broken weakrefs + for i in range(200): + c = C(f"y{i}", random.random()) + heap.add(c) + if random.random() > 0.7: + heap.remove(c) + + heap2 = pickle.loads(pickle.dumps(heap)) + assert len(heap) == len(heap2) + # Test that the heap has been re-heapified upon unpickle + assert len(heap2._heap) < len(heap._heap) + while heap: + assert heap.pop() == heap2.pop() diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index eac95602f9f..3fb15871365 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -2,6 +2,7 @@ import asyncio import gc +import pickle from collections.abc import Iterator import pytest @@ -152,6 +153,27 @@ def test_WorkerState__to_dict(ws): assert actual == expect +def test_WorkerState_pickle(ws): + """Test pickle round-trip. + + Big caveat + ---------- + WorkerState, on its own, can be serialized with pickle; it doesn't need cloudpickle. + A WorkerState extracted from a Worker might, as data contents may only be + serializable with cloudpickle. Some objects created externally - namely, the + SpillBuffer - may not be serializable at all. + """ + ws.handle_stimulus( + AcquireReplicasEvent( + who_has={"x": ["127.0.0.1:1235"]}, nbytes={"x": 123}, stimulus_id="s1" + ) + ) + ws.handle_stimulus(UpdateDataEvent(data={"y": 123}, report=False, stimulus_id="s")) + ws2 = pickle.loads(pickle.dumps(ws)) + assert ws2.tasks.keys() == {"x", "y"} + assert ws2.data == {"y": 123} + + def traverse_subclasses(cls: type) -> Iterator[type]: yield cls for subcls in cls.__subclasses__(): diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index f55b468c9f8..90eacc51fc9 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -2,7 +2,6 @@ import abc import asyncio -import functools import heapq import logging import operator @@ -21,7 +20,7 @@ ) from copy import copy from dataclasses import dataclass, field -from functools import lru_cache +from functools import lru_cache, partial, singledispatchmethod from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast from tlz import peekn, pluck @@ -1062,7 +1061,7 @@ def __init__( self.has_what = defaultdict(set) self.data_needed = HeapSet(key=operator.attrgetter("priority")) self.data_needed_per_worker = defaultdict( - lambda: HeapSet(key=operator.attrgetter("priority")) + cast(Callable, partial(HeapSet, key=operator.attrgetter("priority"))) ) self.in_flight_workers = {} self.busy_workers = set() @@ -2300,7 +2299,7 @@ def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> Instructio # Events # ########## - @functools.singledispatchmethod + @singledispatchmethod def _handle_event(self, ev: StateMachineEvent) -> RecsInstrs: raise TypeError(ev) # pragma: nocover From 3644a225078e643499ca96151be3c24bea9e247f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 24 Jun 2022 01:57:07 +0100 Subject: [PATCH 2/4] Update distributed/tests/test_collections.py --- distributed/tests/test_collections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 9664e84a66e..d48a2dd37fc 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -166,7 +166,7 @@ def test_heapset_pickle(): """ heap = HeapSet(key=operator.attrgetter("i")) - # Test edge case with broken weakrefs + # The heap contains broken weakrefs for i in range(200): c = C(f"y{i}", random.random()) heap.add(c) From eb308ba5da5ce208dad504320886a696d8b15e95 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 24 Jun 2022 01:58:37 +0100 Subject: [PATCH 3/4] Update distributed/tests/test_worker_state_machine.py --- distributed/tests/test_worker_state_machine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 3fb15871365..81f19963d68 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -160,8 +160,8 @@ def test_WorkerState_pickle(ws): ---------- WorkerState, on its own, can be serialized with pickle; it doesn't need cloudpickle. A WorkerState extracted from a Worker might, as data contents may only be - serializable with cloudpickle. Some objects created externally - namely, the - SpillBuffer - may not be serializable at all. + serializable with cloudpickle. Some objects created externally and not designed + for network transfer - namely, the SpillBuffer - may not be serializable at all. """ ws.handle_stimulus( AcquireReplicasEvent( From f15252a68c50e7e0aa4539d906792f0ae715b08a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 24 Jun 2022 15:53:22 +0100 Subject: [PATCH 4/4] Update distributed/worker_state_machine.py Co-authored-by: Thomas Grainger --- distributed/worker_state_machine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 90eacc51fc9..15e3bbaf421 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1061,7 +1061,7 @@ def __init__( self.has_what = defaultdict(set) self.data_needed = HeapSet(key=operator.attrgetter("priority")) self.data_needed_per_worker = defaultdict( - cast(Callable, partial(HeapSet, key=operator.attrgetter("priority"))) + partial(HeapSet[TaskState], key=operator.attrgetter("priority")) ) self.in_flight_workers = {} self.busy_workers = set()