From f3881d22b4c2355eaba01b0cffa8855a96f89b87 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 4 Sep 2023 15:04:07 +0100 Subject: [PATCH] Auto-fail tasks with deps larger than the worker memory --- distributed/scheduler.py | 59 ++++++++++++++++++++----- distributed/tests/test_scheduler.py | 39 +++++++++++++++- distributed/tests/test_worker_memory.py | 1 + 3 files changed, 86 insertions(+), 13 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 02425902c43..4d905ce7256 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2078,7 +2078,6 @@ def transition_released_waiting(self, key: str, stimulus_id: str) -> RecsMsgs: def transition_no_worker_processing(self, key: str, stimulus_id: str) -> RecsMsgs: ts = self.tasks[key] - worker_msgs: Msgs = {} if self.validate: assert not ts.actor, f"Actors can't be in `no-worker`: {ts}" @@ -2086,10 +2085,10 @@ def transition_no_worker_processing(self, key: str, stimulus_id: str) -> RecsMsg if ws := self.decide_worker_non_rootish(ts): self.unrunnable.discard(ts) - worker_msgs = self._add_to_processing(ts, ws) + return self._add_to_processing(ts, ws, stimulus_id=stimulus_id) # If no worker, task just stays in `no-worker` - return {}, {}, worker_msgs + return {}, {}, {} def decide_worker_rootish_queuing_disabled( self, ts: TaskState @@ -2295,8 +2294,7 @@ def transition_waiting_processing(self, key: str, stimulus_id: str) -> RecsMsgs: if not (ws := self.decide_worker_non_rootish(ts)): return {ts.key: "no-worker"}, {}, {} - worker_msgs = self._add_to_processing(ts, ws) - return {}, {}, worker_msgs + return self._add_to_processing(ts, ws, stimulus_id=stimulus_id) def transition_waiting_memory( self, @@ -2751,8 +2749,6 @@ def transition_queued_released(self, key: str, stimulus_id: str) -> RecsMsgs: def transition_queued_processing(self, key: str, stimulus_id: str) -> RecsMsgs: ts = self.tasks[key] - recommendations: Recs = {} - worker_msgs: Msgs = {} if self.validate: assert not ts.actor, f"Actors can't be queued: {ts}" @@ -2760,10 +2756,9 @@ def transition_queued_processing(self, key: str, stimulus_id: str) -> RecsMsgs: if ws := self.decide_worker_rootish_queuing_enabled(): self.queued.discard(ts) - worker_msgs = self._add_to_processing(ts, ws) + return self._add_to_processing(ts, ws, stimulus_id=stimulus_id) # If no worker, task just stays `queued` - - return recommendations, {}, worker_msgs + return {}, {}, {} def _remove_key(self, key: str) -> None: ts = self.tasks.pop(key) @@ -3144,7 +3139,9 @@ def _validate_ready(self, ts: TaskState) -> None: assert ts not in self.queued assert all(dts.who_has for dts in ts.dependencies) - def _add_to_processing(self, ts: TaskState, ws: WorkerState) -> Msgs: + def _add_to_processing( + self, ts: TaskState, ws: WorkerState, stimulus_id: str + ) -> RecsMsgs: """Set a task as processing on a worker and return the worker messages to send""" if self.validate: self._validate_ready(ts) @@ -3161,7 +3158,45 @@ def _add_to_processing(self, ts: TaskState, ws: WorkerState) -> Msgs: if ts.actor: ws.actors.add(ts) - return {ws.address: [self._task_to_msg(ts)]} + ndep_bytes = sum(dts.nbytes for dts in ts.dependencies) + if ( + ws.memory_limit + and ndep_bytes > ws.memory_limit + and dask.config.get("distributed.worker.memory.terminate") + ): + # Note + # ---- + # This is a crude safety system, only meant to prevent order-of-magnitude + # fat-finger errors. + # + # For collection finalizers and in general most concat operations, it takes + # a lot less to kill off the worker; you'll just need + # ndep_bytes * 2 > ws.memory_limit * terminate threshold. + # + # In heterogeneous clusters with workers mounting different amounts of + # memory, the user is expected to manually set host/worker/resource + # restrictions. + msg = ( + f"Task {ts.key} has {format_bytes(ndep_bytes)} worth of input " + f"dependencies, but worker {ws.address} has memory_limit set to " + f"{format_bytes(ws.memory_limit)}." + ) + if ts.prefix.name == "finalize": + msg += ( + " It seems like you called client.compute() on a huge collection. " + "Consider writing to distributed storage or slicing/reducing first." + ) + logger.error(msg) + return self._transition( + ts.key, + "erred", + exception=pickle.dumps(MemoryError(msg)), + cause=ts.key, + stimulus_id=stimulus_id, + worker=ws.address, + ) + + return {}, {}, {ws.address: [self._task_to_msg(ts)]} def _exit_processing_common(self, ts: TaskState) -> WorkerState | None: """Remove *ts* from the set of processing tasks. diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index ec3bcac49e4..cb16ab39f3a 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -7,6 +7,7 @@ import math import operator import pickle +import random import re import sys from collections.abc import Collection @@ -22,7 +23,7 @@ from tornado.ioloop import IOLoop import dask -from dask import delayed +from dask import bag, delayed from dask.core import flatten from dask.highlevelgraph import HighLevelGraph, MaterializedLayer from dask.utils import parse_timedelta, tmpfile, typename @@ -4472,3 +4473,39 @@ async def test_scatter_creates_ts(c, s, a, b): await a.close() assert await x2 == 2 assert s.tasks["x"].run_spec is not None + + +@pytest.mark.parametrize("finalize", [False, True]) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 4, + worker_kwargs={"memory_limit": "100 kB"}, + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": False, + }, +) +async def test_refuse_to_schedule_huge_task(c, s, *workers, finalize): + """If the total size of a task's input grossly exceed the memory available on the + worker, the scheduler must refuse to compute it + """ + bg = bag.from_sequence( + [random.randbytes(30_000) for _ in range(4)], + npartitions=4, + ) + match = r"worth of input dependencies, but worker .* has memory_limit set to" + if finalize: + fut = c.compute(bg) + match += r".* you called client.compute()" + else: + bg = bg.repartition(npartitions=1).persist() + fut = list(c.futures_of(bg))[0] + + with pytest.raises(MemoryError, match=match): + await fut + + # The task never reached the workers + for w in workers: + for ev in w.state.log: + assert fut.key not in ev diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index e9b2a779a83..2e02c98a53a 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -637,6 +637,7 @@ def f(ev): "distributed.worker.memory.target": False, "distributed.worker.memory.spill": False, "distributed.worker.memory.pause": False, + "distributed.worker.memory.terminate": False, }, ), )