Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,18 +2078,17 @@ def transition_released_waiting(self, key: str, stimulus_id: str) -> RecsMsgs:

def transition_no_worker_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
ts = self.tasks[key]
worker_msgs: Msgs = {}

if self.validate:
assert not ts.actor, f"Actors can't be in `no-worker`: {ts}"
assert ts in self.unrunnable

if ws := self.decide_worker_non_rootish(ts):
self.unrunnable.discard(ts)
worker_msgs = self._add_to_processing(ts, ws)
return self._add_to_processing(ts, ws, stimulus_id=stimulus_id)
# If no worker, task just stays in `no-worker`

return {}, {}, worker_msgs
return {}, {}, {}

def decide_worker_rootish_queuing_disabled(
self, ts: TaskState
Expand Down Expand Up @@ -2295,8 +2294,7 @@ def transition_waiting_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
if not (ws := self.decide_worker_non_rootish(ts)):
return {ts.key: "no-worker"}, {}, {}

worker_msgs = self._add_to_processing(ts, ws)
return {}, {}, worker_msgs
return self._add_to_processing(ts, ws, stimulus_id=stimulus_id)

def transition_waiting_memory(
self,
Expand Down Expand Up @@ -2751,19 +2749,16 @@ def transition_queued_released(self, key: str, stimulus_id: str) -> RecsMsgs:

def transition_queued_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
ts = self.tasks[key]
recommendations: Recs = {}
worker_msgs: Msgs = {}

if self.validate:
assert not ts.actor, f"Actors can't be queued: {ts}"
assert ts in self.queued

if ws := self.decide_worker_rootish_queuing_enabled():
self.queued.discard(ts)
worker_msgs = self._add_to_processing(ts, ws)
return self._add_to_processing(ts, ws, stimulus_id=stimulus_id)
# If no worker, task just stays `queued`

return recommendations, {}, worker_msgs
return {}, {}, {}

def _remove_key(self, key: str) -> None:
ts = self.tasks.pop(key)
Expand Down Expand Up @@ -3144,7 +3139,9 @@ def _validate_ready(self, ts: TaskState) -> None:
assert ts not in self.queued
assert all(dts.who_has for dts in ts.dependencies)

def _add_to_processing(self, ts: TaskState, ws: WorkerState) -> Msgs:
def _add_to_processing(
self, ts: TaskState, ws: WorkerState, stimulus_id: str
) -> RecsMsgs:
"""Set a task as processing on a worker and return the worker messages to send"""
if self.validate:
self._validate_ready(ts)
Expand All @@ -3161,7 +3158,45 @@ def _add_to_processing(self, ts: TaskState, ws: WorkerState) -> Msgs:
if ts.actor:
ws.actors.add(ts)

return {ws.address: [self._task_to_msg(ts)]}
ndep_bytes = sum(dts.nbytes for dts in ts.dependencies)
if (
ws.memory_limit
and ndep_bytes > ws.memory_limit
and dask.config.get("distributed.worker.memory.terminate")
):
# Note
# ----
# This is a crude safety system, only meant to prevent order-of-magnitude
# fat-finger errors.
#
# For collection finalizers and in general most concat operations, it takes
# a lot less to kill off the worker; you'll just need
# ndep_bytes * 2 > ws.memory_limit * terminate threshold.
#
# In heterogeneous clusters with workers mounting different amounts of
# memory, the user is expected to manually set host/worker/resource
# restrictions.
msg = (
f"Task {ts.key} has {format_bytes(ndep_bytes)} worth of input "
f"dependencies, but worker {ws.address} has memory_limit set to "
f"{format_bytes(ws.memory_limit)}."
)
if ts.prefix.name == "finalize":
msg += (
" It seems like you called client.compute() on a huge collection. "
"Consider writing to distributed storage or slicing/reducing first."
)
logger.error(msg)
return self._transition(
ts.key,
"erred",
exception=pickle.dumps(MemoryError(msg)),
cause=ts.key,
stimulus_id=stimulus_id,
worker=ws.address,
)

return {}, {}, {ws.address: [self._task_to_msg(ts)]}

def _exit_processing_common(self, ts: TaskState) -> WorkerState | None:
"""Remove *ts* from the set of processing tasks.
Expand Down
39 changes: 38 additions & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import operator
import pickle
import random
import re
import sys
from collections.abc import Collection
Expand All @@ -22,7 +23,7 @@
from tornado.ioloop import IOLoop

import dask
from dask import delayed
from dask import bag, delayed
from dask.core import flatten
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
from dask.utils import parse_timedelta, tmpfile, typename
Expand Down Expand Up @@ -4472,3 +4473,39 @@ async def test_scatter_creates_ts(c, s, a, b):
await a.close()
assert await x2 == 2
assert s.tasks["x"].run_spec is not None


@pytest.mark.parametrize("finalize", [False, True])
@gen_cluster(
client=True,
nthreads=[("", 1)] * 4,
worker_kwargs={"memory_limit": "100 kB"},
config={
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": False,
"distributed.worker.memory.pause": False,
},
)
async def test_refuse_to_schedule_huge_task(c, s, *workers, finalize):
"""If the total size of a task's input grossly exceed the memory available on the
worker, the scheduler must refuse to compute it
"""
bg = bag.from_sequence(
[random.randbytes(30_000) for _ in range(4)],
npartitions=4,
)
match = r"worth of input dependencies, but worker .* has memory_limit set to"
if finalize:
fut = c.compute(bg)
match += r".* you called client.compute()"
else:
bg = bg.repartition(npartitions=1).persist()
fut = list(c.futures_of(bg))[0]

with pytest.raises(MemoryError, match=match):
await fut

# The task never reached the workers
for w in workers:
for ev in w.state.log:
assert fut.key not in ev
1 change: 1 addition & 0 deletions distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ def f(ev):
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": False,
"distributed.worker.memory.pause": False,
"distributed.worker.memory.terminate": False,
},
),
)
Expand Down