From 2b8ebf1a384cf9e5bd978602bf72b0ee19e732e2 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Mon, 31 Oct 2022 14:11:29 -0600 Subject: [PATCH 1/3] `Scheduler.idle` `SortedDict` -> plain `set` --- distributed/scheduler.py | 70 ++++++++++------------------------------ distributed/stealing.py | 2 +- 2 files changed, 18 insertions(+), 54 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5342d693e91..4925b1b8abd 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -26,7 +26,6 @@ Iterable, Iterator, Mapping, - Sequence, Set, ) from contextlib import suppress @@ -1493,8 +1492,7 @@ class SchedulerState: #: Workers that are currently in running state running: set[WorkerState] #: Workers that are currently in running state and not fully utilized - #: (actually a SortedDict, but the sortedcontainers package isn't annotated) - idle: dict[str, WorkerState] + idle: set[WorkerState] #: Workers that are fully utilized. May include non-running workers. saturated: set[WorkerState] total_nthreads: int @@ -1595,7 +1593,7 @@ def __init__( self.clients["fire-and-forget"] = ClientState("fire-and-forget") self.extensions = {} self.host_info = host_info - self.idle = SortedDict() + self.idle = set() self.n_tasks = 0 self.resources = resources self.saturated = set() @@ -2055,7 +2053,7 @@ def decide_worker_rootish_queuing_disabled( # See root-ish-ness note below in `decide_worker_rootish_queuing_enabled` assert math.isinf(self.WORKER_SATURATION) - pool = self.idle.values() if self.idle else self.running + pool = self.idle or self.running if not pool: return None @@ -2126,7 +2124,7 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None: # Just pick the least busy worker. # NOTE: this will lead to worst-case scheduling with regards to co-assignment. - ws = min(self.idle.values(), key=lambda ws: len(ws.processing) / ws.nthreads) + ws = min(self.idle, key=lambda ws: len(ws.processing) / ws.nthreads) if self.validate: assert not _worker_full(ws, self.WORKER_SATURATION), ( ws, @@ -2165,43 +2163,12 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: # If there were no restrictions, `valid_workers()` didn't subset by `running`. valid_workers = self.running - if ts.dependencies or valid_workers is not None: - ws = decide_worker( - ts, - self.running, - valid_workers, - partial(self.worker_objective, ts), - ) - else: - # TODO if `is_rootish` would always return True for tasks without dependencies, - # we could remove all this logic. The rootish assignment logic would behave - # more or less the same as this, maybe without gauranteed round-robin though? - # This path is only reachable when `ts` doesn't have dependencies, but its - # group is also smaller than the cluster. - - # Fastpath when there are no related tasks or restrictions - worker_pool = self.idle or self.workers - # FIXME idle and workers are SortedDict's declared as dicts - # because sortedcontainers is not annotated - wp_vals = cast("Sequence[WorkerState]", worker_pool.values()) - n_workers: int = len(wp_vals) - if n_workers < 20: # smart but linear in small case - ws = min(wp_vals, key=operator.attrgetter("occupancy")) - assert ws - if ws.occupancy == 0: - # special case to use round-robin; linear search - # for next worker with zero occupancy (or just - # land back where we started). - wp_i: WorkerState - start: int = self.n_tasks % n_workers - i: int - for i in range(n_workers): - wp_i = wp_vals[(i + start) % n_workers] - if wp_i.occupancy == 0: - ws = wp_i - break - else: # dumb but fast in large case - ws = wp_vals[self.n_tasks % n_workers] + ws = decide_worker( + ts, + self.running, + valid_workers, + partial(self.worker_objective, ts), + ) if self.validate and ws is not None: assert self.workers.get(ws.address) is ws @@ -3038,10 +3005,10 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0): else not _worker_full(ws, self.WORKER_SATURATION) ): if ws.status == Status.running: - idle[ws.address] = ws + idle.add(ws) saturated.discard(ws) else: - idle.pop(ws.address, None) + idle.discard(ws) if p > nc: pending: float = occ * (p - nc) / (p * nc) @@ -4719,7 +4686,7 @@ async def remove_worker( self.rpc.remove(address) del self.stream_comms[address] del self.aliases[ws.name] - self.idle.pop(ws.address, None) + self.idle.discard(ws) self.saturated.discard(ws) del self.workers[address] ws.status = Status.closed @@ -4994,22 +4961,19 @@ def validate_state(self, allow_overlap: bool = False) -> None: if not (set(self.workers) == set(self.stream_comms)): raise ValueError("Workers not the same in all collections") - assert self.running.issuperset(self.idle.values()), ( - self.running, - list(self.idle.values()), - ) + assert self.running.issuperset(self.idle), (self.running, self.idle) task_group_counts: defaultdict[str, int] = defaultdict(int) for w, ws in self.workers.items(): assert isinstance(w, str), (type(w), w) assert isinstance(ws, WorkerState), (type(ws), ws) assert ws.address == w if ws.status != Status.running: - assert ws.address not in self.idle + assert ws not in self.idle assert ws.long_running.issubset(ws.processing) if not ws.processing: assert not ws.occupancy if ws.status == Status.running: - assert ws.address in self.idle + assert ws in self.idle assert not ws.needs_what.keys() & ws.has_what actual_needs_what: defaultdict[TaskState, int] = defaultdict(int) for ts in ws.processing: @@ -5307,7 +5271,7 @@ def handle_worker_status_change( self.send_all(client_msgs, worker_msgs) else: self.running.discard(ws) - self.idle.pop(ws.address, None) + self.idle.discard(ws) async def handle_request_refresh_who_has( self, keys: Iterable[str], worker: str, stimulus_id: str diff --git a/distributed/stealing.py b/distributed/stealing.py index b3a36c40f2b..dddf8c5ac94 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -400,7 +400,7 @@ def balance(self) -> None: with log_errors(): i = 0 # Paused and closing workers must never become thieves - potential_thieves = set(s.idle.values()) + potential_thieves = s.idle.copy() if not potential_thieves or len(potential_thieves) == len(s.workers): return victim: WorkerState | None From f8e0ec57e38a24897371427623041de69cf7e4fa Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 2 Nov 2022 21:22:39 -0600 Subject: [PATCH 2/3] `Scheduler.workers` `SortedDict` -> plain `dict` Not sure if it's necessary to actually sort anywhere --- distributed/scheduler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4925b1b8abd..c4de9d98591 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -34,7 +34,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload import psutil -from sortedcontainers import SortedDict, SortedSet +from sortedcontainers import SortedSet from tlz import ( first, groupby, @@ -1485,7 +1485,6 @@ class SchedulerState: ####################### #: Workers currently connected to the scheduler - #: (actually a SortedDict, but the sortedcontainers package isn't annotated) workers: dict[str, WorkerState] #: Worker {name: address} aliases: dict[Hashable, str] @@ -1575,7 +1574,7 @@ def __init__( self, aliases: dict[Hashable, str], clients: dict[str, ClientState], - workers: SortedDict[str, WorkerState], + workers: dict[str, WorkerState], host_info: dict[str, dict[str, Any]], resources: dict[str, dict[str, float]], tasks: dict[str, TaskState], @@ -3458,7 +3457,7 @@ def __init__( clients = {} # Worker state - workers = SortedDict() + workers = {} host_info = {} resources = {} From aa0cde06c0a4e17a4581603287ca220366af1e59 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 2 Nov 2022 21:34:27 -0600 Subject: [PATCH 3/3] Remove `sortedcontainers` entirely `Computation.code` was the only other place it was used. Doesn't seem worth the dependency for that. --- distributed/scheduler.py | 9 ++++----- distributed/tests/test_client.py | 16 ++++++++-------- requirements.txt | 1 - 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c4de9d98591..e275cf05fd1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -34,7 +34,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload import psutil -from sortedcontainers import SortedSet from tlz import ( first, groupby, @@ -854,7 +853,7 @@ class Computation: start: float groups: set[TaskGroup] - code: SortedSet + code: dict[str, None] # a sorted set id: uuid.UUID __slots__ = tuple(__annotations__) @@ -862,7 +861,7 @@ class Computation: def __init__(self): self.start = time() self.groups = set() - self.code = SortedSet() + self.code = {} self.id = uuid.uuid4() @property @@ -4268,8 +4267,8 @@ def update_graph( computation = Computation() self.computations.append(computation) - if code and code not in computation.code: # add new code blocks - computation.code.add(code) + if code: # add new code blocks + computation.code.setdefault(code, None) n = 0 while len(tasks) != n: # walk through new tasks, cancel any bad deps diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c0d5b716220..9e88b376432 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6985,7 +6985,7 @@ def fetch_comp_code(dask_scheduler): assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 - return comp.code[0] + return first(comp.code) code = client.run_on_scheduler(fetch_comp_code) @@ -7005,7 +7005,7 @@ def fetch_comp_code(dask_scheduler): assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 - return comp.code[0] + return first(comp.code) code = client.run_on_scheduler(fetch_comp_code) assert code == "" @@ -7026,7 +7026,7 @@ async def test_computation_object_code_dask_persist(c, s, a, b): comp = computations[0] assert len(comp.code) == 1 - assert comp.code[0] == test_function_code + assert first(comp.code) == test_function_code @gen_cluster(client=True) @@ -7047,7 +7047,7 @@ def func(x): assert len(comp.code) == 1 - assert comp.code[0] == test_function_code + assert first(comp.code) == test_function_code @gen_cluster(client=True) @@ -7069,7 +7069,7 @@ def func(x): # Code is deduplicated assert len(comp.code) == 1 - assert comp.code[0] == test_function_code + assert first(comp.code) == test_function_code @gen_cluster(client=True) @@ -7091,7 +7091,7 @@ def func(x): # Code is deduplicated assert len(comp.code) == 1 - assert comp.code[0] == test_function_code + assert first(comp.code) == test_function_code @gen_cluster(client=True) @@ -7109,7 +7109,7 @@ async def test_computation_object_code_client_map(c, s, a, b): comp = computations[0] assert len(comp.code) == 1 - assert comp.code[0] == test_function_code + assert first(comp.code) == test_function_code @gen_cluster(client=True) @@ -7127,7 +7127,7 @@ async def test_computation_object_code_client_compute(c, s, a, b): comp = computations[0] assert len(comp.code) == 1 - assert comp.code[0] == test_function_code + assert first(comp.code) == test_function_code @gen_cluster(client=True, Worker=Nanny) diff --git a/requirements.txt b/requirements.txt index f6b6c8f9caf..cb2a2ffc743 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ locket >= 1.0.0 msgpack >= 0.6.0 packaging >= 20.0 psutil >= 5.0 -sortedcontainers !=2.0.0, !=2.0.1 tblib >= 1.6.0 toolz >= 0.8.2 tornado >= 6.0.3, <6.2