From d3806043e7332963ea06ee8fc5407e2d5c135e31 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 27 Jan 2022 15:14:39 +0100 Subject: [PATCH 1/7] Ensure missing transitions are safe --- distributed/tests/test_worker.py | 263 +++++++++++++++++++++++++------ distributed/worker.py | 194 ++++++++++++++--------- 2 files changed, 334 insertions(+), 123 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index b6c4961ff6d..55aa5d8f9db 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -57,7 +57,14 @@ slowinc, slowsum, ) -from distributed.worker import Worker, error_message, logger, parse_memory_limit +from distributed.worker import ( + TaskState, + Worker, + _UniqueTaskHeap, + error_message, + logger, + parse_memory_limit, +) pytestmark = pytest.mark.ci1 @@ -1390,7 +1397,7 @@ async def test_prefer_gather_from_local_address(c, s, w1, w2, w3): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)] * 20, - timeout=30, + timeout=500000, config={"distributed.worker.connections.incoming": 1}, ) async def test_avoid_oversubscription(c, s, *workers): @@ -1400,7 +1407,10 @@ async def test_avoid_oversubscription(c, s, *workers): futures = [c.submit(len, x, pure=False, workers=[w.address]) for w in workers[1:]] - await wait(futures) + try: + await asyncio.wait_for(wait(futures), 10) + except asyncio.TimeoutError: + breakpoint() # Original worker not responsible for all transfers assert len(workers[0].outgoing_transfer_log) < len(workers) - 2 @@ -2186,12 +2196,26 @@ async def test_gpu_executor(c, s, w): assert "gpu" not in w.executors -def assert_task_states_on_worker(expected, worker): - for dep_key, expected_state in expected.items(): - assert dep_key in worker.tasks, (worker.name, dep_key, worker.tasks) - dep_ts = worker.tasks[dep_key] - assert dep_ts.state == expected_state, (worker.name, dep_ts, expected_state) - assert set(expected) == set(worker.tasks) +async def assert_task_states_on_worker(expected, worker): + active_exc = None + for _ in range(10): + try: + for dep_key, expected_state in expected.items(): + assert dep_key in worker.tasks, (worker.name, dep_key, worker.tasks) + dep_ts = worker.tasks[dep_key] + assert dep_ts.state == expected_state, ( + worker.name, + dep_ts, + expected_state, + ) + assert set(expected) == set(worker.tasks) + return + except AssertionError as exc: + active_exc = exc + await asyncio.sleep(0.1) + # If after a second the workers are not in equilibrium, they are broken + assert active_exc + raise active_exc @gen_cluster(client=True) @@ -2235,7 +2259,7 @@ def raise_exc(*args): g.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) # Expected states after we release references to the futures f.release() g.release() @@ -2251,7 +2275,7 @@ def raise_exc(*args): res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) res.release() @@ -2304,7 +2328,7 @@ def raise_exc(*args): g.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) # Expected states after we release references to the futures res.release() @@ -2318,7 +2342,7 @@ def raise_exc(*args): g.key: "memory", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) f.release() g.release() @@ -2369,7 +2393,7 @@ def raise_exc(*args): g.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, a) # Expected states after we release references to the futures f.release() @@ -2383,8 +2407,8 @@ def raise_exc(*args): g.key: "memory", } - assert_task_states_on_worker(expected_states, a) - assert_task_states_on_worker(expected_states, b) + await assert_task_states_on_worker(expected_states, a) + await assert_task_states_on_worker(expected_states, b) g.release() @@ -2418,8 +2442,7 @@ def raise_exc(*args): g.key: "memory", h.key: "memory", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) expected_states_B = { f.key: "memory", @@ -2427,8 +2450,7 @@ def raise_exc(*args): h.key: "memory", res.key: "error", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) f.release() @@ -2436,8 +2458,7 @@ def raise_exc(*args): g.key: "memory", h.key: "memory", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) expected_states_B = { f.key: "released", @@ -2445,8 +2466,7 @@ def raise_exc(*args): h.key: "memory", res.key: "error", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) g.release() @@ -2454,8 +2474,7 @@ def raise_exc(*args): g.key: "released", h.key: "memory", } - await asyncio.sleep(0.05) - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) # B must not forget a task since all have a still valid dependent expected_states_B = { @@ -2463,19 +2482,18 @@ def raise_exc(*args): h.key: "memory", res.key: "error", } - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) h.release() - await asyncio.sleep(0.05) expected_states_A = {} - assert_task_states_on_worker(expected_states_A, a) + await assert_task_states_on_worker(expected_states_A, a) expected_states_B = { f.key: "released", h.key: "released", res.key: "error", } - assert_task_states_on_worker(expected_states_B, b) + await assert_task_states_on_worker(expected_states_B, b) res.release() # We no longer hold any refs. Cluster should reset completely @@ -3130,33 +3148,38 @@ async def test_missing_released_zombie_tasks(c, s, a, b): @gen_cluster(client=True) async def test_missing_released_zombie_tasks_2(c, s, a, b): - a.total_in_connections = 0 - f1 = c.submit(inc, 1, key="f1", workers=[a.address]) - f2 = c.submit(inc, f1, key="f2", workers=[b.address]) + # If get_data_from_worker raises this will suggest a dead worker to B and it + # will transition the task to missing. We want to make sure that a missing + # task is properly released and not left as a zombie + with mock.patch.object( + distributed.worker, + "get_data_from_worker", + side_effect=CommClosedError, + ): + f1 = c.submit(inc, 1, key="f1", workers=[a.address]) + f2 = c.submit(inc, f1, key="f2", workers=[b.address]) - while f1.key not in b.tasks: - await asyncio.sleep(0) + while f1.key not in b.tasks: + await asyncio.sleep(0) - ts = b.tasks[f1.key] - assert ts.state == "fetch" + ts = b.tasks[f1.key] + assert ts.state == "fetch" - # A few things can happen to clear who_has. The dominant process is upon - # connection failure to a worker. Regardless of how the set was cleared, the - # task will be transitioned to missing where the worker is trying to - # reaquire this information from the scheduler. While this is happening on - # worker side, the tasks are released and we want to ensure that no dangling - # zombie tasks are left on the worker - ts.who_has.clear() + while not ts.state == "missing": + # If we sleep for a longer time, the worker will spin into an + # endless loop of asking the scheduler who_has and trying to connect + # to A + await asyncio.sleep(0) - del f1, f2 + del f1, f2 - while b.tasks: - await asyncio.sleep(0.01) + while b.tasks: + await asyncio.sleep(0.01) - assert_worker_story( - b.story(ts), - [("f1", "missing", "released", "released", {"f1": "forgotten"})], - ) + assert_worker_story( + b.story(ts), + [("f1", "missing", "released", "released", {"f1": "forgotten"})], + ) @pytest.mark.slow @@ -3441,6 +3464,8 @@ async def test_Worker__to_dict(c, s, a): "config", "incoming_transfer_log", "outgoing_transfer_log", + "data_needed", + "pending_data_per_worker", } assert d["tasks"]["x"]["key"] == "x" @@ -3462,3 +3487,139 @@ async def test_TaskState__to_dict(c, s, a): assert isinstance(tasks["z"], dict) assert tasks["x"]["dependents"] == [""] assert tasks["y"]["dependencies"] == [""] + + +@gen_cluster(client=True) +async def test_dups_in_pending_data_per_worker(c, s, a, b): + # There has been a condition leading to a deadlock (caught by AssertionError + # if validate is enabled) that was caused by not identifying a missing key + # properly + + # We need to fetch a key which is repeatedly selected as part of the Worker.select_from_gather optimization + # since if it goes through the ordinary channels of ensure_communicating it + # is flagged immediately as missing + + # this is a batch of futures we will fetch from A. We will use these as + # seeds, i.e. primary keys to fetch for the batched fetch optimization + futs = c.map(inc, range(100), workers=[a.address]) + # This will be the culprit/missing key we selectively insert into the fetch + # queue. We will manipulate the state machine such that this would raise the + # AssertionError + missing_fut = c.submit(inc, -1, workers=[a.address], key="culprit") + + # Ensure the data is available, scheduler is aware + await c.gather(futs) + await missing_fut + + # MOCKs: + # We will mock ensure_communicating and ensure_computing to disable the + # every_cycle callback of our handle_scheduler + # Effectively this allows us to intercept the moment directly after a + # handle_compute_task handler was executed + with mock.patch.object( + Worker, "ensure_communicating", return_value=None + ) as comm_mock: + with mock.patch.object( + Worker, "ensure_computing", return_value=None + ) as comp_mock: + # This new worker will be the one where the exception is provoked + x = await Worker(s.address, name=2, validate=True) + # fill up the data needed heap with tasks that are fine to be scheduled + + f1 = c.submit(sum, [*futs[:20]], workers=[x.address], key="f1", priority=10) + # Put the bad one in between + f2 = c.submit(inc, missing_fut, workers=[x.address], key="f2", priority=20) + # Put the bad one in between. Ensure the heap is full with all the + # tasks such that we have many batched fetches + f3 = c.submit( + sum, [*futs[20:40]], workers=[x.address], key="f3", priority=30 + ) + + # wait for all the tasks to be registered. Without the mocks we + # could not cleanly assert this but this is an important test + # assumption + while len(x.data_needed) != 41: + await asyncio.sleep(0.01) + assert missing_fut.key in x.tasks + ts = x.tasks[missing_fut.key] + assert ts in x.pending_data_per_worker[a.address] + # not at the top of the heap + key = x.pending_data_per_worker[a.address].peak() + assert key != missing_fut.key + + # This has been introducing duplicates to pending_data_per_worker + for _ in range(3): + await x.query_who_has(missing_fut.key) + + # We will now remove culprit from A such that X will handle the missing + # response. + # To not have the scheduler reschedule this task immediately, we will create + # another replica on another worker. We don't want X to be made aware of + # this which is why we're disabling / mocking the query_who_has to ensure + # that there is no background update. In a more realistic environment, this + # can be caused by certain delays in communication, particulary if AMM is + # runing + # We could also use AMM to create a replica if the API supports this. + with mock.patch.object(Worker, "query_who_has", return_value=None) as who_has_mock: + f_copy = c.submit( + inc, missing_fut, key="copy-intention-culprit", workers=[b.address] + ) + await f_copy + + # Now remove the replica from A *before* X requests the data + a.handle_remove_replicas([missing_fut.key], stimulus_id="test") + + # We want to ensure that the first batch includes all data for + x.target_message_size = sum(x.tasks[f.key].get_nbytes() for f in futs[:22]) + x.ensure_communicating() + + with mock.patch.object( + Worker, "ensure_communicating", return_value=None + ) as comm_mock: + with mock.patch.object( + Worker, "ensure_computing", return_value=None + ) as comp_mock: + while not x.data: + await asyncio.sleep(0.01) + + x.target_message_size = 1000000000 + x.ensure_communicating() + + await f1 + await f2 + await f3 + + +def test_unique_task_heap(): + heap = _UniqueTaskHeap() + + for x in range(10): + ts = TaskState(f"f{x}") + ts.priority = (0, 0, 1, x % 3) + heap.push(ts) + del ts + + heap_list = list(heap) + # iteration does not empty heap + assert heap + assert heap_list == sorted(heap_list, key=lambda ts: ts.priority) + + seen = set() + last_prio = (0, 0, 0, 0) + while heap: + peaked = heap.peak() + ts = heap.pop() + assert peaked == ts + seen.add(ts.key) + assert ts.priority + assert last_prio <= ts.priority + last_prio = last_prio + + ts = TaskState("foo") + heap.push(ts) + heap.push(ts) + assert len(heap) == 1 + assert heap.pop() == ts + assert not heap + + assert isinstance(repr(heap), str) diff --git a/distributed/worker.py b/distributed/worker.py index 72d3c93f5e3..b5424716d6d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -13,7 +13,14 @@ import warnings import weakref from collections import defaultdict, deque, namedtuple -from collections.abc import Callable, Collection, Iterable, Mapping, MutableMapping +from collections.abc import ( + Callable, + Collection, + Iterable, + Iterator, + Mapping, + MutableMapping, +) from concurrent.futures import Executor from contextlib import suppress from datetime import timedelta @@ -200,7 +207,7 @@ def __init__(self, key, runspec=None): self.dependencies = set() self.dependents = set() self.duration = None - self.priority = None + self.priority: tuple[int, ...] | None = None self.state = "released" self.who_has = set() self.coming_from = None @@ -267,6 +274,56 @@ def is_protected(self) -> bool: ) +class _UniqueTaskHeap: + def __init__(self, collection: Collection[TaskState] | None = None) -> None: + """A heap of TaskState objects ordered by TaskState.priority + Ties are broken by string comparison of the key. + Keys are guaranteed to be unique. + Iterating over this object returns the elements in priority order. + """ + if collection is None: + collection = [] + self._known = {ts.key for ts in collection} + self._heap = [(ts.priority, ts.key, ts) for ts in collection] + heapq.heapify(self._heap) + + def push(self, ts: TaskState) -> None: + """Add a new TaskState instance to the heap. If the key is already + known, no object is added. + + Note: This does not update the priority / heap order in case priority + changes. + """ + assert isinstance(ts, TaskState) + if ts.key not in self._known: + heapq.heappush(self._heap, (ts.priority, ts.key, ts)) + self._known.add(ts.key) + + def pop(self) -> TaskState: + """Pop the task with highest priority from the heap.""" + _, _, ts = heapq.heappop(self._heap) + self._known.remove(ts.key) + return ts + + def peak(self) -> TaskState: + """Get the highest priority TaskState without removing it from the heap""" + return self._heap[0][2] + + def __contains__(self, x: object) -> bool: + if isinstance(x, TaskState): + x = x.key + return x in self._known + + def __iter__(self) -> Iterator[TaskState]: + return iter([ts for _, _, ts in sorted(self._heap)]) + + def __len__(self) -> int: + return len(self._known) + + def __repr__(self) -> str: + return f"<{type(self).__name__}: {len(self)} items>" + + class Worker(ServerNode): """Worker node in a Dask distributed cluster @@ -342,7 +399,7 @@ class Worker(ServerNode): * **data.disk:** ``{key: object}``: Dictionary mapping keys to actual values stored on disk. Only available if condition for **data** being a zict.Buffer is met. - * **data_needed**: deque(keys) + * **data_needed**: heap(TaskState) The keys which still require data in order to execute, arranged in a deque * **ready**: [keys] Keys that are ready to run. Stored in a LIFO stack @@ -358,8 +415,8 @@ class Worker(ServerNode): long-running clients. * **has_what**: ``{worker: {deps}}`` The data that we care about that we think a worker has - * **pending_data_per_worker**: ``{worker: [dep]}`` - The data on each worker that we still want, prioritized as a deque + * **pending_data_per_worker**: ``{worker: heap(TaskState)}`` + The data on each worker that we still want, prioritized as a heap * **in_flight_tasks**: ``int`` A count of the number of tasks that are coming to us in current peer-to-peer connections @@ -457,10 +514,10 @@ class Worker(ServerNode): tasks: dict[str, TaskState] waiting_for_data_count: int has_what: defaultdict[str, set[str]] # {worker address: {ts.key, ...} - pending_data_per_worker: defaultdict[str, deque[str]] + pending_data_per_worker: defaultdict[str, _UniqueTaskHeap] nanny: Nanny | None _lock: threading.Lock - data_needed: list[tuple[int, str]] # heap[(ts.priority, ts.key)] + data_needed: _UniqueTaskHeap in_flight_workers: dict[str, set[str]] # {worker address: {ts.key, ...}} total_out_connections: int total_in_connections: int @@ -609,11 +666,11 @@ def __init__( self.tasks = {} self.waiting_for_data_count = 0 self.has_what = defaultdict(set) - self.pending_data_per_worker = defaultdict(deque) + self.pending_data_per_worker = defaultdict(_UniqueTaskHeap) self.nanny = nanny self._lock = threading.Lock() - self.data_needed = [] + self.data_needed = _UniqueTaskHeap() self.in_flight_workers = {} self.total_out_connections = dask.config.get( @@ -624,7 +681,7 @@ def __init__( ) self.comm_threshold_bytes = int(10e6) self.comm_nbytes = 0 - self._missing_dep_flight = set() + self._missing_dep_flight: set[TaskState] = set() self.threads = {} @@ -673,11 +730,11 @@ def __init__( ("executing", "released"): self.transition_executing_released, ("executing", "rescheduled"): self.transition_executing_rescheduled, ("fetch", "flight"): self.transition_fetch_flight, - ("fetch", "missing"): self.transition_fetch_missing, ("fetch", "released"): self.transition_generic_released, ("flight", "error"): self.transition_flight_error, ("flight", "fetch"): self.transition_flight_fetch, ("flight", "memory"): self.transition_flight_memory, + ("flight", "missing"): self.transition_flight_missing, ("flight", "released"): self.transition_flight_released, ("long-running", "error"): self.transition_generic_error, ("long-running", "memory"): self.transition_long_running_memory, @@ -1156,6 +1213,10 @@ def _to_dict( "status": self.status, "ready": self.ready, "constrained": self.constrained, + "data_needed": list(self.data_needed), + "pending_data_per_worker": { + w: list(v) for w, v in self.pending_data_per_worker.items() + }, "long_running": self.long_running, "executing_count": self.executing_count, "in_flight_tasks": self.in_flight_tasks, @@ -1913,7 +1974,6 @@ def handle_cancel_compute(self, key, reason): ts = self.tasks.get(key) if ts and ts.state in READY | {"waiting"}: self.log.append((key, "cancel-compute", reason, time())) - ts.scheduler_holds_ref = False # All possible dependents of TS should not be in state Processing on # scheduler side and therefore should not be assigned to a worker, # yet. @@ -1942,7 +2002,7 @@ def handle_acquire_replicas( if ts.state != "memory": recommendations[ts] = "fetch" - self.update_who_has(who_has, stimulus_id=stimulus_id) + self.update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) def ensure_task_exists( @@ -2038,11 +2098,10 @@ def handle_compute_task( for msg in scheduler_msgs: self.batched_stream.send(msg) + + self.update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) - # We received new info, that's great but not related to the compute-task - # instruction - self.update_who_has(who_has, stimulus_id=stimulus_id) if nbytes is not None: for key, value in nbytes.items(): self.tasks[key].nbytes = value @@ -2055,7 +2114,7 @@ def transition_missing_fetch(self, ts, *, stimulus_id): self._missing_dep_flight.discard(ts) ts.state = "fetch" ts.done = False - heapq.heappush(self.data_needed, (ts.priority, ts.key)) + self.data_needed.push(ts) return {}, [] def transition_missing_released(self, ts, *, stimulus_id): @@ -2066,10 +2125,11 @@ def transition_missing_released(self, ts, *, stimulus_id): assert ts.key in self.tasks return recommendations, smsgs - def transition_fetch_missing(self, ts, *, stimulus_id): - # handle_missing will append to self.data_needed if new workers are found + def transition_flight_missing(self, ts, *, stimulus_id): + assert ts.done ts.state = "missing" self._missing_dep_flight.add(ts) + ts.done = False return {}, [] def transition_released_fetch(self, ts, *, stimulus_id): @@ -2077,10 +2137,10 @@ def transition_released_fetch(self, ts, *, stimulus_id): assert ts.state == "released" assert ts.priority is not None for w in ts.who_has: - self.pending_data_per_worker[w].append(ts.key) + self.pending_data_per_worker[w].push(ts) ts.state = "fetch" ts.done = False - heapq.heappush(self.data_needed, (ts.priority, ts.key)) + self.data_needed.push(ts) return {}, [] def transition_generic_released(self, ts, *, stimulus_id): @@ -2126,7 +2186,6 @@ def transition_fetch_flight(self, ts, worker, *, stimulus_id): if self.validate: assert ts.state == "fetch" assert ts.who_has - assert ts.key not in self.data_needed ts.done = False ts.state = "flight" @@ -2425,11 +2484,17 @@ def transition_flight_fetch(self, ts, *, stimulus_id): # we can reset the task and transition to fetch again. If it is not yet # finished, this should be a no-op if ts.done: - recommendations, smsgs = self.transition_generic_released( - ts, stimulus_id=stimulus_id - ) - recommendations[ts] = "fetch" - return recommendations, smsgs + recommendations = {} + ts.state = "fetch" + ts.coming_from = None + ts.done = False + if not ts.who_has: + recommendations[ts] = "missing" + else: + self.data_needed.push(ts) + for w in ts.who_has: + self.pending_data_per_worker[w].push(ts) + return recommendations, [] else: return {}, [] @@ -2692,24 +2757,15 @@ def ensure_communicating(self): self.total_out_connections, ) - _, key = heapq.heappop(self.data_needed) - - try: - ts = self.tasks[key] - except KeyError: - continue + ts = self.data_needed.pop() if ts.state != "fetch": continue - if not ts.who_has: - self.transition(ts, "missing", stimulus_id=stimulus_id) - continue - workers = [w for w in ts.who_has if w not in self.in_flight_workers] if not workers: assert ts.priority is not None - skipped_worker_in_flight.append((ts.priority, ts.key)) + skipped_worker_in_flight.append(ts) continue host = get_address_host(self.address) @@ -2740,7 +2796,7 @@ def ensure_communicating(self): ) for el in skipped_worker_in_flight: - heapq.heappush(self.data_needed, el) + self.data_needed.push(el) def _get_task_finished_msg(self, ts): if ts.key not in self.data and ts.key not in self.actors: @@ -2834,13 +2890,12 @@ def select_keys_for_gather(self, worker, dep): L = self.pending_data_per_worker[worker] while L: - d = L.popleft() - ts = self.tasks.get(d) - if ts is None or ts.state != "fetch": + ts = L.pop() + if ts.state != "fetch": continue if total_bytes + ts.get_nbytes() > self.target_message_size: break - deps.add(d) + deps.add(ts.key) total_bytes += ts.get_nbytes() return deps, total_bytes @@ -3077,7 +3132,10 @@ async def gather_dep( self.batched_stream.send( {"op": "missing-data", "errant_worker": worker, "key": d} ) - recommendations[ts] = "fetch" + if not ts.who_has: + recommendations[ts] = "missing" + else: + recommendations[ts] = "fetch" del data, response self.transitions(recommendations, stimulus_id=stimulus_id) self.ensure_computing() @@ -3089,7 +3147,7 @@ async def gather_dep( self.repetitively_busy += 1 await asyncio.sleep(0.100 * 1.5 ** self.repetitively_busy) - await self.query_who_has(*to_gather_keys, stimulus_id=stimulus_id) + await self.query_who_has(*to_gather_keys) self.ensure_communicating() @@ -3108,7 +3166,12 @@ async def find_missing(self): keys=[ts.key for ts in self._missing_dep_flight], ) who_has = {k: v for k, v in who_has.items() if v} - self.update_who_has(who_has, stimulus_id=stimulus_id) + self.update_who_has(who_has) + recommendations = {} + for ts in self._missing_dep_flight: + if ts.who_has: + recommendations[ts] = "fetch" + self.transitions(recommendations, stimulus_id=stimulus_id) finally: # This is quite arbitrary but the heartbeat has scaling implemented @@ -3118,24 +3181,20 @@ async def find_missing(self): self.ensure_communicating() self.ensure_computing() - async def query_who_has( - self, *deps: str, stimulus_id: str - ) -> dict[str, Collection[str]]: + async def query_who_has(self, *deps: str) -> dict[str, Collection[str]]: with log_errors(): who_has = await retry_operation(self.scheduler.who_has, keys=deps) - self.update_who_has(who_has, stimulus_id=stimulus_id) + self.update_who_has(who_has) return who_has - def update_who_has( - self, who_has: dict[str, Collection[str]], *, stimulus_id: str - ) -> None: + def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: try: - recommendations = {} for dep, workers in who_has.items(): if not workers: continue if dep in self.tasks: + dep_ts = self.tasks[dep] if self.address in workers and self.tasks[dep].state != "memory": logger.debug( "Scheduler claims worker %s holds data for task %s which is not true.", @@ -3144,18 +3203,11 @@ def update_who_has( ) # Do not mutate the input dict. That's rude workers = set(workers) - {self.address} - dep_ts = self.tasks[dep] - if dep_ts.state in FETCH_INTENDED: - dep_ts.who_has.update(workers) - - if dep_ts.state == "missing": - recommendations[dep_ts] = "fetch" - - for worker in workers: - self.has_what[worker].add(dep) - self.pending_data_per_worker[worker].append(dep_ts.key) + dep_ts.who_has.update(workers) - self.transitions(recommendations, stimulus_id=stimulus_id) + for worker in workers: + self.has_what[worker].add(dep) + self.pending_data_per_worker[worker].push(dep_ts) except Exception as e: logger.exception(e) if LOG_PDB: @@ -3874,16 +3926,19 @@ def validate_task_fetch(self, ts): assert ts.key not in self.data assert self.address not in ts.who_has assert not ts.done + assert ts in self.data_needed + assert ts.who_has for w in ts.who_has: assert ts.key in self.has_what[w] + assert ts in self.pending_data_per_worker[w] def validate_task_missing(self, ts): assert ts.key not in self.data assert not ts.who_has assert not ts.done assert not any(ts.key in has_what for has_what in self.has_what.values()) - assert ts.key in self._missing_dep_flight + assert ts in self._missing_dep_flight def validate_task_cancelled(self, ts): assert ts.key not in self.data @@ -3903,7 +3958,6 @@ def validate_task_released(self, ts): assert ts not in self._in_flight_tasks assert ts not in self._missing_dep_flight assert ts not in self._missing_dep_flight - assert not ts.who_has assert not any(ts.key in has_what for has_what in self.has_what.values()) assert not ts.waiting_for_data assert not ts.done @@ -3973,13 +4027,9 @@ def validate_state(self): assert ( ts_wait.state in READY | {"executing", "flight", "fetch", "missing"} - or ts_wait.key in self._missing_dep_flight + or ts_wait in self._missing_dep_flight or ts_wait.who_has.issubset(self.in_flight_workers) ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) - if ts.state == "memory": - assert isinstance(ts.nbytes, int) - assert not ts.waiting_for_data - assert ts.key in self.data or ts.key in self.actors assert self.waiting_for_data_count == waiting_for_data_count for worker, keys in self.has_what.items(): for k in keys: From 417bf947bba0f2a9a607d259dd504129383abf75 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 27 Jan 2022 15:26:30 +0100 Subject: [PATCH 2/7] remove debug statements --- distributed/tests/test_worker.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 55aa5d8f9db..b170510c335 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1397,7 +1397,6 @@ async def test_prefer_gather_from_local_address(c, s, w1, w2, w3): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)] * 20, - timeout=500000, config={"distributed.worker.connections.incoming": 1}, ) async def test_avoid_oversubscription(c, s, *workers): @@ -1407,10 +1406,7 @@ async def test_avoid_oversubscription(c, s, *workers): futures = [c.submit(len, x, pure=False, workers=[w.address]) for w in workers[1:]] - try: - await asyncio.wait_for(wait(futures), 10) - except asyncio.TimeoutError: - breakpoint() + wait(futures) # Original worker not responsible for all transfers assert len(workers[0].outgoing_transfer_log) < len(workers) - 2 From 05122e81c3be04a481e7b265343464b420d0aab3 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 27 Jan 2022 17:12:04 +0100 Subject: [PATCH 3/7] revert test_avoid_oversubscription test --- distributed/tests/test_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index b170510c335..54bae7b7c16 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1406,7 +1406,7 @@ async def test_avoid_oversubscription(c, s, *workers): futures = [c.submit(len, x, pure=False, workers=[w.address]) for w in workers[1:]] - wait(futures) + await wait(futures) # Original worker not responsible for all transfers assert len(workers[0].outgoing_transfer_log) < len(workers) - 2 From dfd4d5eecaf4d9127012a85fa30246b6ece7ca7d Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 28 Jan 2022 12:38:19 +0100 Subject: [PATCH 4/7] Review comments --- distributed/tests/test_worker.py | 117 +++---------------------------- distributed/worker.py | 24 +++---- 2 files changed, 23 insertions(+), 118 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 54bae7b7c16..1e3c8732b41 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3485,107 +3485,6 @@ async def test_TaskState__to_dict(c, s, a): assert tasks["y"]["dependencies"] == [""] -@gen_cluster(client=True) -async def test_dups_in_pending_data_per_worker(c, s, a, b): - # There has been a condition leading to a deadlock (caught by AssertionError - # if validate is enabled) that was caused by not identifying a missing key - # properly - - # We need to fetch a key which is repeatedly selected as part of the Worker.select_from_gather optimization - # since if it goes through the ordinary channels of ensure_communicating it - # is flagged immediately as missing - - # this is a batch of futures we will fetch from A. We will use these as - # seeds, i.e. primary keys to fetch for the batched fetch optimization - futs = c.map(inc, range(100), workers=[a.address]) - # This will be the culprit/missing key we selectively insert into the fetch - # queue. We will manipulate the state machine such that this would raise the - # AssertionError - missing_fut = c.submit(inc, -1, workers=[a.address], key="culprit") - - # Ensure the data is available, scheduler is aware - await c.gather(futs) - await missing_fut - - # MOCKs: - # We will mock ensure_communicating and ensure_computing to disable the - # every_cycle callback of our handle_scheduler - # Effectively this allows us to intercept the moment directly after a - # handle_compute_task handler was executed - with mock.patch.object( - Worker, "ensure_communicating", return_value=None - ) as comm_mock: - with mock.patch.object( - Worker, "ensure_computing", return_value=None - ) as comp_mock: - # This new worker will be the one where the exception is provoked - x = await Worker(s.address, name=2, validate=True) - # fill up the data needed heap with tasks that are fine to be scheduled - - f1 = c.submit(sum, [*futs[:20]], workers=[x.address], key="f1", priority=10) - # Put the bad one in between - f2 = c.submit(inc, missing_fut, workers=[x.address], key="f2", priority=20) - # Put the bad one in between. Ensure the heap is full with all the - # tasks such that we have many batched fetches - f3 = c.submit( - sum, [*futs[20:40]], workers=[x.address], key="f3", priority=30 - ) - - # wait for all the tasks to be registered. Without the mocks we - # could not cleanly assert this but this is an important test - # assumption - while len(x.data_needed) != 41: - await asyncio.sleep(0.01) - assert missing_fut.key in x.tasks - ts = x.tasks[missing_fut.key] - assert ts in x.pending_data_per_worker[a.address] - # not at the top of the heap - key = x.pending_data_per_worker[a.address].peak() - assert key != missing_fut.key - - # This has been introducing duplicates to pending_data_per_worker - for _ in range(3): - await x.query_who_has(missing_fut.key) - - # We will now remove culprit from A such that X will handle the missing - # response. - # To not have the scheduler reschedule this task immediately, we will create - # another replica on another worker. We don't want X to be made aware of - # this which is why we're disabling / mocking the query_who_has to ensure - # that there is no background update. In a more realistic environment, this - # can be caused by certain delays in communication, particulary if AMM is - # runing - # We could also use AMM to create a replica if the API supports this. - with mock.patch.object(Worker, "query_who_has", return_value=None) as who_has_mock: - f_copy = c.submit( - inc, missing_fut, key="copy-intention-culprit", workers=[b.address] - ) - await f_copy - - # Now remove the replica from A *before* X requests the data - a.handle_remove_replicas([missing_fut.key], stimulus_id="test") - - # We want to ensure that the first batch includes all data for - x.target_message_size = sum(x.tasks[f.key].get_nbytes() for f in futs[:22]) - x.ensure_communicating() - - with mock.patch.object( - Worker, "ensure_communicating", return_value=None - ) as comm_mock: - with mock.patch.object( - Worker, "ensure_computing", return_value=None - ) as comp_mock: - while not x.data: - await asyncio.sleep(0.01) - - x.target_message_size = 1000000000 - x.ensure_communicating() - - await f1 - await f2 - await f3 - - def test_unique_task_heap(): heap = _UniqueTaskHeap() @@ -3593,19 +3492,18 @@ def test_unique_task_heap(): ts = TaskState(f"f{x}") ts.priority = (0, 0, 1, x % 3) heap.push(ts) - del ts heap_list = list(heap) # iteration does not empty heap - assert heap + assert len(heap) == 10 assert heap_list == sorted(heap_list, key=lambda ts: ts.priority) seen = set() last_prio = (0, 0, 0, 0) while heap: - peaked = heap.peak() + peeked = heap.peek() ts = heap.pop() - assert peaked == ts + assert peeked == ts seen.add(ts.key) assert ts.priority assert last_prio <= ts.priority @@ -3618,4 +3516,11 @@ def test_unique_task_heap(): assert heap.pop() == ts assert not heap - assert isinstance(repr(heap), str) + assert repr(heap) == "" + + # Test that we're cleaning the seen set on pop + heap.push(ts) + assert len(heap) == 1 + assert heap.pop() == ts + + assert repr(heap) == "" diff --git a/distributed/worker.py b/distributed/worker.py index b5424716d6d..9d90448e8bb 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -200,6 +200,8 @@ class TaskState: """ + priority: tuple[int, ...] | None + def __init__(self, key, runspec=None): assert key is not None self.key = key @@ -207,7 +209,7 @@ def __init__(self, key, runspec=None): self.dependencies = set() self.dependents = set() self.duration = None - self.priority: tuple[int, ...] | None = None + self.priority = None self.state = "released" self.who_has = set() self.coming_from = None @@ -275,14 +277,12 @@ def is_protected(self) -> bool: class _UniqueTaskHeap: - def __init__(self, collection: Collection[TaskState] | None = None) -> None: - """A heap of TaskState objects ordered by TaskState.priority - Ties are broken by string comparison of the key. - Keys are guaranteed to be unique. - Iterating over this object returns the elements in priority order. - """ - if collection is None: - collection = [] + """A heap of TaskState objects ordered by TaskState.priority + Ties are broken by string comparison of the key. Keys are guaranteed to be + unique. Iterating over this object returns the elements in priority order. + """ + + def __init__(self, collection: Collection[TaskState] = ()): self._known = {ts.key for ts in collection} self._heap = [(ts.priority, ts.key, ts) for ts in collection] heapq.heapify(self._heap) @@ -301,11 +301,11 @@ def push(self, ts: TaskState) -> None: def pop(self) -> TaskState: """Pop the task with highest priority from the heap.""" - _, _, ts = heapq.heappop(self._heap) - self._known.remove(ts.key) + _, key, ts = heapq.heappop(self._heap) + self._known.remove(key) return ts - def peak(self) -> TaskState: + def peek(self) -> TaskState: """Get the highest priority TaskState without removing it from the heap""" return self._heap[0][2] From 2d0abe1e1924ce7aea757b6c981d3b05bbb57beb Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 28 Jan 2022 17:58:04 +0100 Subject: [PATCH 5/7] fix assert in unique heap --- distributed/tests/test_worker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 1e3c8732b41..4ba5a887172 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3513,11 +3513,12 @@ def test_unique_task_heap(): heap.push(ts) heap.push(ts) assert len(heap) == 1 - assert heap.pop() == ts - assert not heap assert repr(heap) == "" + assert heap.pop() == ts + assert not heap + # Test that we're cleaning the seen set on pop heap.push(ts) assert len(heap) == 1 From d153017063f5d54da2addb46a4c10557213a1b92 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 28 Jan 2022 17:24:37 +0000 Subject: [PATCH 6/7] Update distributed/tests/test_worker.py --- distributed/tests/test_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 4ba5a887172..60429b94729 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3161,7 +3161,7 @@ async def test_missing_released_zombie_tasks_2(c, s, a, b): ts = b.tasks[f1.key] assert ts.state == "fetch" - while not ts.state == "missing": + while ts.state != "missing": # If we sleep for a longer time, the worker will spin into an # endless loop of asking the scheduler who_has and trying to connect # to A From 4712d1a9067fc74315eb1346dd892ca161833de8 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 1 Feb 2022 11:56:38 +0100 Subject: [PATCH 7/7] Review comments --- distributed/tests/test_worker.py | 4 ++-- distributed/worker.py | 26 +++++++++++--------------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 60429b94729..64f4eab12fa 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -59,8 +59,8 @@ ) from distributed.worker import ( TaskState, + UniqueTaskHeap, Worker, - _UniqueTaskHeap, error_message, logger, parse_memory_limit, @@ -3486,7 +3486,7 @@ async def test_TaskState__to_dict(c, s, a): def test_unique_task_heap(): - heap = _UniqueTaskHeap() + heap = UniqueTaskHeap() for x in range(10): ts = TaskState(f"f{x}") diff --git a/distributed/worker.py b/distributed/worker.py index 9d90448e8bb..02528770d28 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -117,7 +117,6 @@ "resumed", } READY = {"ready", "constrained"} -FETCH_INTENDED = {"missing", "fetch", "flight", "cancelled", "resumed"} # Worker.status subsets RUNNING = {Status.running, Status.paused, Status.closing_gracefully} @@ -276,7 +275,7 @@ def is_protected(self) -> bool: ) -class _UniqueTaskHeap: +class UniqueTaskHeap(Collection): """A heap of TaskState objects ordered by TaskState.priority Ties are broken by string comparison of the key. Keys are guaranteed to be unique. Iterating over this object returns the elements in priority order. @@ -315,7 +314,7 @@ def __contains__(self, x: object) -> bool: return x in self._known def __iter__(self) -> Iterator[TaskState]: - return iter([ts for _, _, ts in sorted(self._heap)]) + return (ts for _, _, ts in sorted(self._heap)) def __len__(self) -> int: return len(self._known) @@ -399,8 +398,8 @@ class Worker(ServerNode): * **data.disk:** ``{key: object}``: Dictionary mapping keys to actual values stored on disk. Only available if condition for **data** being a zict.Buffer is met. - * **data_needed**: heap(TaskState) - The keys which still require data in order to execute, arranged in a deque + * **data_needed**: UniqueTaskHeap + The tasks which still require data in order to execute, prioritized as a heap * **ready**: [keys] Keys that are ready to run. Stored in a LIFO stack * **constrained**: [keys] @@ -415,7 +414,7 @@ class Worker(ServerNode): long-running clients. * **has_what**: ``{worker: {deps}}`` The data that we care about that we think a worker has - * **pending_data_per_worker**: ``{worker: heap(TaskState)}`` + * **pending_data_per_worker**: ``{worker: UniqueTaskHeap}`` The data on each worker that we still want, prioritized as a heap * **in_flight_tasks**: ``int`` A count of the number of tasks that are coming to us in current @@ -514,10 +513,10 @@ class Worker(ServerNode): tasks: dict[str, TaskState] waiting_for_data_count: int has_what: defaultdict[str, set[str]] # {worker address: {ts.key, ...} - pending_data_per_worker: defaultdict[str, _UniqueTaskHeap] + pending_data_per_worker: defaultdict[str, UniqueTaskHeap] nanny: Nanny | None _lock: threading.Lock - data_needed: _UniqueTaskHeap + data_needed: UniqueTaskHeap in_flight_workers: dict[str, set[str]] # {worker address: {ts.key, ...}} total_out_connections: int total_in_connections: int @@ -666,11 +665,11 @@ def __init__( self.tasks = {} self.waiting_for_data_count = 0 self.has_what = defaultdict(set) - self.pending_data_per_worker = defaultdict(_UniqueTaskHeap) + self.pending_data_per_worker = defaultdict(UniqueTaskHeap) self.nanny = nanny self._lock = threading.Lock() - self.data_needed = _UniqueTaskHeap() + self.data_needed = UniqueTaskHeap() self.in_flight_workers = {} self.total_out_connections = dask.config.get( @@ -681,7 +680,7 @@ def __init__( ) self.comm_threshold_bytes = int(10e6) self.comm_nbytes = 0 - self._missing_dep_flight: set[TaskState] = set() + self._missing_dep_flight = set() self.threads = {} @@ -3132,10 +3131,7 @@ async def gather_dep( self.batched_stream.send( {"op": "missing-data", "errant_worker": worker, "key": d} ) - if not ts.who_has: - recommendations[ts] = "missing" - else: - recommendations[ts] = "fetch" + recommendations[ts] = "fetch" if ts.who_has else "missing" del data, response self.transitions(recommendations, stimulus_id=stimulus_id) self.ensure_computing()