Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 23 additions & 61 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
Iterable,
Iterator,
Mapping,
Sequence,
Set,
)
from contextlib import suppress
Expand All @@ -35,7 +34,6 @@
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload

import psutil
from sortedcontainers import SortedDict, SortedSet
from tlz import (
first,
groupby,
Expand Down Expand Up @@ -855,15 +853,15 @@ class Computation:

start: float
groups: set[TaskGroup]
code: SortedSet
code: dict[str, None] # a sorted set
id: uuid.UUID

__slots__ = tuple(__annotations__)

def __init__(self):
self.start = time()
self.groups = set()
self.code = SortedSet()
self.code = {}
self.id = uuid.uuid4()

@property
Expand Down Expand Up @@ -1486,15 +1484,13 @@ class SchedulerState:
#######################

#: Workers currently connected to the scheduler
#: (actually a SortedDict, but the sortedcontainers package isn't annotated)
workers: dict[str, WorkerState]
#: Worker {name: address}
aliases: dict[Hashable, str]
#: Workers that are currently in running state
running: set[WorkerState]
#: Workers that are currently in running state and not fully utilized
#: (actually a SortedDict, but the sortedcontainers package isn't annotated)
idle: dict[str, WorkerState]
idle: set[WorkerState]
#: Workers that are fully utilized. May include non-running workers.
saturated: set[WorkerState]
total_nthreads: int
Expand Down Expand Up @@ -1577,7 +1573,7 @@ def __init__(
self,
aliases: dict[Hashable, str],
clients: dict[str, ClientState],
workers: SortedDict[str, WorkerState],
workers: dict[str, WorkerState],
host_info: dict[str, dict[str, Any]],
resources: dict[str, dict[str, float]],
tasks: dict[str, TaskState],
Expand All @@ -1595,7 +1591,7 @@ def __init__(
self.clients["fire-and-forget"] = ClientState("fire-and-forget")
self.extensions = {}
self.host_info = host_info
self.idle = SortedDict()
self.idle = set()
self.n_tasks = 0
self.resources = resources
self.saturated = set()
Expand Down Expand Up @@ -2055,7 +2051,7 @@ def decide_worker_rootish_queuing_disabled(
# See root-ish-ness note below in `decide_worker_rootish_queuing_enabled`
assert math.isinf(self.WORKER_SATURATION)

pool = self.idle.values() if self.idle else self.running
pool = self.idle or self.running
if not pool:
return None

Expand Down Expand Up @@ -2126,7 +2122,7 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None:

# Just pick the least busy worker.
# NOTE: this will lead to worst-case scheduling with regards to co-assignment.
ws = min(self.idle.values(), key=lambda ws: len(ws.processing) / ws.nthreads)
ws = min(self.idle, key=lambda ws: len(ws.processing) / ws.nthreads)
if self.validate:
assert not _worker_full(ws, self.WORKER_SATURATION), (
ws,
Expand Down Expand Up @@ -2165,43 +2161,12 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
# If there were no restrictions, `valid_workers()` didn't subset by `running`.
valid_workers = self.running

if ts.dependencies or valid_workers is not None:
ws = decide_worker(
ts,
self.running,
valid_workers,
partial(self.worker_objective, ts),
)
else:
# TODO if `is_rootish` would always return True for tasks without dependencies,
# we could remove all this logic. The rootish assignment logic would behave
# more or less the same as this, maybe without gauranteed round-robin though?
# This path is only reachable when `ts` doesn't have dependencies, but its
# group is also smaller than the cluster.

# Fastpath when there are no related tasks or restrictions
worker_pool = self.idle or self.workers
# FIXME idle and workers are SortedDict's declared as dicts
# because sortedcontainers is not annotated
wp_vals = cast("Sequence[WorkerState]", worker_pool.values())
n_workers: int = len(wp_vals)
if n_workers < 20: # smart but linear in small case
ws = min(wp_vals, key=operator.attrgetter("occupancy"))
assert ws
if ws.occupancy == 0:
# special case to use round-robin; linear search
# for next worker with zero occupancy (or just
# land back where we started).
wp_i: WorkerState
start: int = self.n_tasks % n_workers
i: int
for i in range(n_workers):
wp_i = wp_vals[(i + start) % n_workers]
if wp_i.occupancy == 0:
ws = wp_i
break
else: # dumb but fast in large case
ws = wp_vals[self.n_tasks % n_workers]
Comment on lines -2183 to -2204

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW if it's just about this code path, we can drop the sortedset entirely by putting workers in a list as well as a set. that'd be a LogN operation whenever a worker leaves the cluster to remove that worker from the list. That's much better than a NlogN every time we iterate over the workers

ws = decide_worker(
ts,
self.running,
valid_workers,
partial(self.worker_objective, ts),
)

if self.validate and ws is not None:
assert self.workers.get(ws.address) is ws
Expand Down Expand Up @@ -3038,10 +3003,10 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):
else not _worker_full(ws, self.WORKER_SATURATION)
):
if ws.status == Status.running:
idle[ws.address] = ws
idle.add(ws)
saturated.discard(ws)
else:
idle.pop(ws.address, None)
idle.discard(ws)

if p > nc:
pending: float = occ * (p - nc) / (p * nc)
Expand Down Expand Up @@ -3491,7 +3456,7 @@ def __init__(
clients = {}

# Worker state
workers = SortedDict()
workers = {}

host_info = {}
resources = {}
Expand Down Expand Up @@ -4302,8 +4267,8 @@ def update_graph(
computation = Computation()
self.computations.append(computation)

if code and code not in computation.code: # add new code blocks
computation.code.add(code)
if code: # add new code blocks
computation.code.setdefault(code, None)

n = 0
while len(tasks) != n: # walk through new tasks, cancel any bad deps
Expand Down Expand Up @@ -4719,7 +4684,7 @@ async def remove_worker(
self.rpc.remove(address)
del self.stream_comms[address]
del self.aliases[ws.name]
self.idle.pop(ws.address, None)
self.idle.discard(ws)
self.saturated.discard(ws)
del self.workers[address]
ws.status = Status.closed
Expand Down Expand Up @@ -4994,22 +4959,19 @@ def validate_state(self, allow_overlap: bool = False) -> None:
if not (set(self.workers) == set(self.stream_comms)):
raise ValueError("Workers not the same in all collections")

assert self.running.issuperset(self.idle.values()), (
self.running,
list(self.idle.values()),
)
assert self.running.issuperset(self.idle), (self.running, self.idle)
task_group_counts: defaultdict[str, int] = defaultdict(int)
for w, ws in self.workers.items():
assert isinstance(w, str), (type(w), w)
assert isinstance(ws, WorkerState), (type(ws), ws)
assert ws.address == w
if ws.status != Status.running:
assert ws.address not in self.idle
assert ws not in self.idle
assert ws.long_running.issubset(ws.processing)
if not ws.processing:
assert not ws.occupancy
if ws.status == Status.running:
assert ws.address in self.idle
assert ws in self.idle
assert not ws.needs_what.keys() & ws.has_what
actual_needs_what: defaultdict[TaskState, int] = defaultdict(int)
for ts in ws.processing:
Expand Down Expand Up @@ -5307,7 +5269,7 @@ def handle_worker_status_change(
self.send_all(client_msgs, worker_msgs)
else:
self.running.discard(ws)
self.idle.pop(ws.address, None)
self.idle.discard(ws)

async def handle_request_refresh_who_has(
self, keys: Iterable[str], worker: str, stimulus_id: str
Expand Down
2 changes: 1 addition & 1 deletion distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def balance(self) -> None:
with log_errors():
i = 0
# Paused and closing workers must never become thieves
potential_thieves = set(s.idle.values())
potential_thieves = s.idle.copy()
if not potential_thieves or len(potential_thieves) == len(s.workers):
return
victim: WorkerState | None
Expand Down
16 changes: 8 additions & 8 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6985,7 +6985,7 @@ def fetch_comp_code(dask_scheduler):
assert len(computations) == 1
comp = computations[0]
assert len(comp.code) == 1
return comp.code[0]
return first(comp.code)

code = client.run_on_scheduler(fetch_comp_code)

Expand All @@ -7005,7 +7005,7 @@ def fetch_comp_code(dask_scheduler):
assert len(computations) == 1
comp = computations[0]
assert len(comp.code) == 1
return comp.code[0]
return first(comp.code)

code = client.run_on_scheduler(fetch_comp_code)
assert code == "<Code not available>"
Expand All @@ -7026,7 +7026,7 @@ async def test_computation_object_code_dask_persist(c, s, a, b):
comp = computations[0]
assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert first(comp.code) == test_function_code


@gen_cluster(client=True)
Expand All @@ -7047,7 +7047,7 @@ def func(x):

assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert first(comp.code) == test_function_code


@gen_cluster(client=True)
Expand All @@ -7069,7 +7069,7 @@ def func(x):
# Code is deduplicated
assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert first(comp.code) == test_function_code


@gen_cluster(client=True)
Expand All @@ -7091,7 +7091,7 @@ def func(x):
# Code is deduplicated
assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert first(comp.code) == test_function_code


@gen_cluster(client=True)
Expand All @@ -7109,7 +7109,7 @@ async def test_computation_object_code_client_map(c, s, a, b):
comp = computations[0]
assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert first(comp.code) == test_function_code


@gen_cluster(client=True)
Expand All @@ -7127,7 +7127,7 @@ async def test_computation_object_code_client_compute(c, s, a, b):
comp = computations[0]
assert len(comp.code) == 1

assert comp.code[0] == test_function_code
assert first(comp.code) == test_function_code


@gen_cluster(client=True, Worker=Nanny)
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ locket >= 1.0.0
msgpack >= 0.6.0
packaging >= 20.0
psutil >= 5.0
sortedcontainers !=2.0.0, !=2.0.1
tblib >= 1.6.0
toolz >= 0.8.2
tornado >= 6.0.3, <6.2
Expand Down