diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c5f4c0fba6b..c66f82d898d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -26,7 +26,6 @@ Iterable, Iterator, Mapping, - Sequence, Set, ) from contextlib import suppress @@ -1307,8 +1306,7 @@ class SchedulerState: #: Workers that are currently in running state running: set[WorkerState] #: Workers that are currently in running state and not fully utilized - #: (actually a SortedDict, but the sortedcontainers package isn't annotated) - idle: dict[str, WorkerState] + idle: set[WorkerState] #: Workers that are fully utilized. May include non-running workers. saturated: set[WorkerState] total_nthreads: int @@ -1408,7 +1406,7 @@ def __init__( self.clients["fire-and-forget"] = ClientState("fire-and-forget") self.extensions = {} self.host_info = host_info - self.idle = SortedDict() + self.idle = set() self.n_tasks = 0 self.resources = resources self.saturated = set() @@ -1600,32 +1598,12 @@ def _transition( b_recs, b_cmsgs, b_wmsgs = func(self, key, stimulus_id) recommendations.update(a_recs) - for c, new_msgs in a_cmsgs.items(): - msgs = client_msgs.get(c) - if msgs is not None: - msgs.extend(new_msgs) - else: - client_msgs[c] = new_msgs - for w, new_msgs in a_wmsgs.items(): - msgs = worker_msgs.get(w) - if msgs is not None: - msgs.extend(new_msgs) - else: - worker_msgs[w] = new_msgs + update_msgs(client_msgs, a_cmsgs) + update_msgs(worker_msgs, a_wmsgs) recommendations.update(b_recs) - for c, new_msgs in b_cmsgs.items(): - msgs = client_msgs.get(c) - if msgs is not None: - msgs.extend(new_msgs) - else: - client_msgs[c] = new_msgs - for w, new_msgs in b_wmsgs.items(): - msgs = worker_msgs.get(w) - if msgs is not None: - msgs.extend(new_msgs) - else: - worker_msgs[w] = new_msgs + update_msgs(client_msgs, b_cmsgs) + update_msgs(worker_msgs, b_wmsgs) start = "released" else: @@ -1705,19 +1683,13 @@ def _transitions( new = self._transition(key, finish, stimulus_id) new_recs, new_cmsgs, new_wmsgs = new + # Put recommendations at end of dict, so they're processed in the next cycle + for k in new_recs: + if k != key: + recommendations.pop(k, None) recommendations.update(new_recs) - for c, new_msgs in new_cmsgs.items(): - msgs = client_msgs.get(c) - if msgs is not None: - msgs.extend(new_msgs) - else: - client_msgs[c] = new_msgs - for w, new_msgs in new_wmsgs.items(): - msgs = worker_msgs.get(w) - if msgs is not None: - msgs.extend(new_msgs) - else: - worker_msgs[w] = new_msgs + update_msgs(client_msgs, new_cmsgs) + update_msgs(worker_msgs, new_wmsgs) if self.validate: # FIXME downcast antipattern @@ -1873,7 +1845,7 @@ def decide_worker_rootish_queuing_disabled( # See root-ish-ness note below in `decide_worker_rootish_queuing_enabled` assert math.isinf(self.WORKER_SATURATION) - pool = self.idle.values() if self.idle else self.running + pool = self.idle or self.running if not pool: return None @@ -1902,30 +1874,9 @@ def decide_worker_rootish_queuing_disabled( return ws - def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None: - """Pick a worker for a runnable root-ish task, if not all are busy. - - Picks the least-busy worker out of the ``idle`` workers (idle workers have fewer - tasks running than threads, as set by ``distributed.scheduler.worker-saturation``). - It does not consider the location of dependencies, since they'll end up on every - worker anyway. - - If all workers are full, returns None, meaning the task should transition to - ``queued``. The scheduler will wait to send it to a worker until a thread opens - up. This ensures that downstream tasks always run before new root tasks are - started. - - This does not try to schedule sibling tasks on the same worker; in fact, it - usually does the opposite. Even though this increases subsequent data transfer, - it typically reduces overall memory use by eliminating root task overproduction. - - Returns - ------- - ws: WorkerState | None - The worker to assign the task to. If there are no idle workers, returns - None, in which case the task should be transitioned to ``queued``. - - """ + def decide_worker_from_family( + self, family: tuple[set[TaskState], set[TaskState]] | None + ) -> WorkerState: if self.validate: # We don't `assert self.is_rootish(ts)` here, because that check is dependent on # cluster size. It's possible a task looked root-ish when it was queued, but the @@ -1933,25 +1884,51 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None: # If `is_rootish` changes to a static definition, then add that assertion here # (and actually pass in the task). assert not math.isinf(self.WORKER_SATURATION) + assert self.idle + + if family: + siblings, downstream = family + # If any tasks are in memory or processing, use the non-saturated worker that holds the most data already. + # Ignoring saturated workers avoids a 'dogpile' in the case of unusual graph structures. + # ^ TODO test this + candidates: defaultdict[WorkerState, int] = defaultdict(lambda: 0) + sws: WorkerState | None + for ts in siblings: + for sws in ts.who_has: + if sws.status == Status.running and not _worker_full( + sws, self.WORKER_SATURATION + ): + candidates[sws] += ts.get_nbytes() + if ( + (sws := ts.processing_on) # NOTE: exclusive with `ts.who_has` + and sws.status == Status.running + and not _worker_full(sws, self.WORKER_SATURATION) + ): + # NOTE: siblings processing on different workers is a rare case + tg = ts.group + nbytes_estimate = ( + round(tg.nbytes_total / nmem) + if (nmem := tg.states["memory"]) + else DEFAULT_DATA_SIZE + ) + candidates[sws] += nbytes_estimate + if candidates: + ws, _ = max(candidates.items(), key=operator.itemgetter(1)) + logger.info( + f"Scheduling family on sibling worker {ws}, {candidates=}, {family=}" + ) + return ws - if not self.idle: - # All workers busy? Task gets/stays queued. - return None - - # Just pick the least busy worker. - # NOTE: this will lead to worst-case scheduling with regards to co-assignment. - ws = min(self.idle.values(), key=lambda ws: len(ws.processing) / ws.nthreads) + # No siblings are anywhere else (or no family at all). Pick the least busy worker. + ws = min(self.idle, key=lambda ws: len(ws.processing) / ws.nthreads) if self.validate: + assert self.workers.get(ws.address) is ws assert not _worker_full(ws, self.WORKER_SATURATION), ( ws, _task_slots_available(ws, self.WORKER_SATURATION), ) assert ws in self.running, (ws, self.running) - if self.validate and ws is not None: - assert self.workers.get(ws.address) is ws - assert ws in self.running, (ws, self.running) - return ws def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: @@ -1994,13 +1971,11 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: # group is also smaller than the cluster. # Fastpath when there are no related tasks or restrictions - worker_pool = self.idle or self.workers - # FIXME idle and workers are SortedDict's declared as dicts - # because sortedcontainers is not annotated - wp_vals = cast("Sequence[WorkerState]", worker_pool.values()) - n_workers: int = len(wp_vals) + # FIXME making a list here is silly, but so is this whole code path + worker_pool = list(self.idle or self.workers.values()) + n_workers: int = len(worker_pool) if n_workers < 20: # smart but linear in small case - ws = min(wp_vals, key=operator.attrgetter("occupancy")) + ws = min(worker_pool, key=operator.attrgetter("occupancy")) assert ws if ws.occupancy == 0: # special case to use round-robin; linear search @@ -2010,12 +1985,12 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: start: int = self.n_tasks % n_workers i: int for i in range(n_workers): - wp_i = wp_vals[(i + start) % n_workers] + wp_i = worker_pool[(i + start) % n_workers] if wp_i.occupancy == 0: ws = wp_i break else: # dumb but fast in large case - ws = wp_vals[self.n_tasks % n_workers] + ws = worker_pool[self.n_tasks % n_workers] if self.validate and ws is not None: assert self.workers.get(ws.address) is ws @@ -2040,8 +2015,7 @@ def transition_waiting_processing(self, key, stimulus_id): if not (ws := self.decide_worker_rootish_queuing_disabled(ts)): return {ts.key: "no-worker"}, {}, {} else: - if not (ws := self.decide_worker_rootish_queuing_enabled()): - return {ts.key: "queued"}, {}, {} + return _queueable_to_processing(self, ts) else: if not (ws := self.decide_worker_non_rootish(ts)): return {ts.key: "no-worker"}, {}, {} @@ -2660,20 +2634,12 @@ def transition_queued_released(self, key, stimulus_id): def transition_queued_processing(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] - recommendations: Recs = {} - client_msgs: dict = {} - worker_msgs: dict = {} if self.validate: assert not ts.actor, f"Actors can't be queued: {ts}" assert ts in self.queued - if ws := self.decide_worker_rootish_queuing_enabled(): - self.queued.discard(ts) - worker_msgs = _add_to_processing(self, ts, ws) - # If no worker, task just stays `queued` - - return recommendations, client_msgs, worker_msgs + return _queueable_to_processing(self, ts) except Exception as e: logger.exception(e) if LOG_PDB: @@ -2897,10 +2863,10 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0): else not _worker_full(ws, self.WORKER_SATURATION) ): if ws.status == Status.running: - idle[ws.address] = ws + idle.add(ws) saturated.discard(ws) else: - idle.pop(ws.address, None) + idle.discard(ws) if p > nc: pending: float = occ * (p - nc) / (p * nc) @@ -4614,7 +4580,7 @@ async def remove_worker( self.rpc.remove(address) del self.stream_comms[address] del self.aliases[ws.name] - self.idle.pop(ws.address, None) + self.idle.discard(ws) self.saturated.discard(ws) del self.workers[address] ws.status = Status.closed @@ -4890,21 +4856,21 @@ def validate_state(self, allow_overlap: bool = False) -> None: if not (set(self.workers) == set(self.stream_comms)): raise ValueError("Workers not the same in all collections") - assert self.running.issuperset(self.idle.values()), ( + assert self.running.issuperset(self.idle), ( self.running, - list(self.idle.values()), + self.idle, ) for w, ws in self.workers.items(): assert isinstance(w, str), (type(w), w) assert isinstance(ws, WorkerState), (type(ws), ws) assert ws.address == w if ws.status != Status.running: - assert ws.address not in self.idle + assert ws not in self.idle assert ws.long_running.issubset(ws.processing) if not ws.processing: assert not ws.occupancy if ws.status == Status.running: - assert ws.address in self.idle + assert ws in self.idle assert (ws.status == Status.running) == (ws in self.running) for ws in self.running: @@ -5203,7 +5169,7 @@ def handle_worker_status_change( self.send_all(client_msgs, worker_msgs) else: self.running.discard(ws) - self.idle.pop(ws.address, None) + self.idle.discard(ws) async def handle_request_refresh_who_has( self, keys: Iterable[str], worker: str, stimulus_id: str @@ -7708,6 +7674,115 @@ def _validate_ready(state: SchedulerState, ts: TaskState) -> None: assert all(dts.who_has for dts in ts.dependencies) +def _queueable_to_processing( + state: SchedulerState, ts: TaskState +) -> tuple[Recs, dict[str, list[Any]], dict[str, list[Any]]]: + "Common logic for transitioning a queueable (root-ish) task to processing, along with the rest of its family" + # Fastpath: skip family search and everything else if no workers are free. + # Since all siblings are assigned at once, this means `ts` belongs to a family we haven't processed yet. + # This means that we've already filled the cluster with root task families, and shouldn't schedule any more. + if not state.idle: + return {ts.key: "queued"}, {}, {} + + # TODO maxsize as config/what?! + fam = family(ts, maxsize=20, widely_shared_cutoff=len(state.workers)) + ws = state.decide_worker_from_family(fam) + # ^ NOTE: This is all we need to for good (re)scheduling when the cluster changes size. + + if ts.state == "queued": + state.queued.remove(ts) + worker_msgs: dict[str, list[Any]] = _add_to_processing(state, ts, ws) + + recommendations: Recs = {} + if fam: + # Schedule other tasks that will all need to be in memory + # with this task on the same worker at once. + siblings, downstream = fam + sorted_siblings = sorted(siblings, key=operator.attrgetter("priority")) + + # Check that we're parallelism-constrained before saturating a worker + if len(siblings) > ws.nthreads: + # In the somewhat rare case that we have more workers than families, + # parallelism is abundant. We should give up some co-assignment so that we + # don't leave workers idle. This is quite hard to determine statically, + # because we have no idea a) how many total root tasks there are, and b) how + # many total root families there are. + + # We can _very brittlely_ guess via `TaskGroup`s. _We generally assume the + # graph structure of root tasks is homogeneous_, which is typically true + # with Dask collections, but certainly not true in general. + # The downstreams' TaskGroups give a guess as to the number of families (a + # downstream is effectively the output of a family). + # If we assume every family is the same size as this one (also typically the + # case with collections, but not true in general), then each family will + # under-schedule by this factor, fully filling all workers. + est_total_families = ( + (sum(len(dts.group) for dts in downstream) // len(downstream)) + if downstream + else 0 + ) + family_saturation = est_total_families / len(state.workers) + if family_saturation < 1: + max_siblings = round(len(siblings) * family_saturation) + sorted_siblings = sorted_siblings[:max_siblings] + + # TODO what about when the first task in a family that's already on a worker completes, + # but not the whole family? We want to wait for the whole family to be done before assigning + # another one. We don't want to take a slot that could be used for the downstream task in the future. + + # If workers could be responsible with memory, this would be okay. Because if everything in the previous + # family is done execpt one input to the downstream, then yes, we might as well get started on a new family + # while we're waiting, as long as we have the memory capacity to do so. + for fts in sorted_siblings: + assert fts is not ts + + # Rare: siblings already running, or ran, somewhere else. + # Since all siblings are scheduled onto the same worker at the same time, they'll also + # usually share the same fate if that worker dies, and all be re-scheduled at once too. + # The exceptions are: + # - Some in-memory tasks could have been replicated to other workers, but not all. + # - Non-commutative families (siblings set is different depending on which root you start from). + # - Scale up/down could cross the `widely_shared_cutoff`, leading to different assessment of a family. + # TODO tests for these cases + if fts.state == "processing": + logger.info(f"Skipping processing {fts}, {fts.processing_on=}, {ws=}") + continue + + if fts.state == "memory": + # only can happen in the case of rescheduling, and replicas already exist + logger.info(f"Skipping in memory {fts}, {ws in fts.who_has=}, {ws=}") + continue + + if fts.state == "released": + # FIXME if `fts` just went `memory->released`, `waiting_on` will inaccurately be empty. + # 1) this manual transition feels like bad practice + # 2) we can't be certain it shouldn't actually to to `forgotten` (without duplicating logic from `memory->released`) + # 3) it's just kinda weird that tasks can be in this broken `released` state at all. + # it feels degenerate to me. i kinda don't think `released->waiting` should be a transition, but rather + # a shared helper function like `handle_released_task` or something. + # TODO add a test that triggers the need for this + state._transition(fts.key, "waiting", "qtp") + elif fts.state == "queued": + if state.validate: + assert fts in state.queued + state.queued.discard(fts) + + assert fts.state in ("waiting", "queued"), (fts, ts) + + # When `fts` is not runnable yet, that means it's waiting for deps. So it + # should just schedule near those deps using `decide_worker_non_rootish`. + # Unless they're widely-shared, in which case it should schedule near its + # family. In which case it should look root-ish, so it should come back here. + if not fts.waiting_on: + update_msgs(worker_msgs, _add_to_processing(state, fts, ws)) + + # This recommendation will be a no-op. It's just to remove any existing + # recommendation for the key from the recommendations queue. + recommendations[fts.key] = "processing" + + return recommendations, {}, worker_msgs + + def _add_to_processing( state: SchedulerState, ts: TaskState, ws: WorkerState ) -> dict[str, list]: @@ -8205,6 +8280,127 @@ def _worker_full(ws: WorkerState, saturation_factor: float) -> bool: return _task_slots_available(ws, saturation_factor) <= 0 +def _previous_in_linear_chain(ts: TaskState, cutoff: int) -> TaskState | None: + if len(ts.dependents) != 1 or len(ts.dependencies) > cutoff: + return None + + prev: TaskState | None = None + for dts in ts.dependencies: + if len(dts.dependents) > cutoff: # widely-shared; ignore it + continue + if prev: + return None + prev = dts + + return prev + + +def _next_in_linear_chain(ts: TaskState, cutoff: int) -> TaskState | None: + if len(ts.dependents) != 1 or len(ts.dependencies) > cutoff: + return None + + # Check if this is part of a linear chain: + # exactly 1 dependency, excluding widely-shared tasks. + non_widely_shared = False + for dts in ts.dependencies: + if len(dts.dependents) > cutoff: # widely-shared; ignore it + continue + if non_widely_shared: + return None + non_widely_shared = True + + return next(iter(ts.dependents)) + + +def family( + ts: TaskState, maxsize: int, widely_shared_cutoff: int +) -> tuple[set[TaskState], set[TaskState]] | None: + """ + All tasks in a family must be in memory at once to compute at least one common dependency. + + Returns these ``sibling`` tasks, and the set of ``downstream`` (dependent) tasks that + the siblings will used to compute. + + All ``siblings`` and ``downstream`` should be scheduled onto the same worker. + All ``siblings`` should be be scheduled at once---there's no benefit to queuing them. + + For the purposes of identifying families: + + * Linear chains are collapsed (traversed up and down) + * Widely-shared tasks (tasks with > ``widely_shared_cutoff`` dependents) are ignored + + If the task's family is too large, or empty, returns None. + That is, if ``siblings`` would be larger than ``maxsize``, or ``downstream`` would be + larger than ``widely_shared_cutoff``, or ``ts.dependents`` is empty, returns None. + """ + # TODO potentially could be useful to distinguish between 'too big' + # and empty family---we might want to schedule them differently. + # That is, maybe `None` and `(set(), set())` might not be synonymous. + if not ts.dependents or len(ts.dependents) > min(widely_shared_cutoff, maxsize): + return None + + siblings: set[TaskState] = set() + downstream: set[TaskState] = set() + # TODO maintain `seen` set to avoid repeated traversal? Should we add on the way down, or just back up? + for dts in ts.dependents: # TODO even support multiple dependents? + # Traverse down linear chains + while ndts := _next_in_linear_chain(dts, widely_shared_cutoff): + # TODO check seen + dts = ndts + + if dts in downstream: + # No need to traverse from a task we've already seen + continue + + if dts in siblings: + # `siblings` and `downstream` are exclusive, and `siblings` takes priority + siblings.remove(dts) + continue + + sibs = dts.dependencies + if len(sibs) == 1: + # Faster path: this is just a linear chain, so no siblings besides `ts` + downstream.add(dts) + continue + + if len(sibs) < maxsize: + downstream.add(dts) + for sts in sibs: + # Traverse linear chains _back_ to root tasks + while ( + sts is not ts + and sts not in downstream + and (psts := _previous_in_linear_chain(sts, widely_shared_cutoff)) + ): + # TODO check seen + sts = psts + + if ( + sts is ts + or len(sts.dependents) + >= widely_shared_cutoff # ignore widely-shared siblings + or sts in downstream # a->b, b->c, a->c. downstream takes priority. + ): + continue + + siblings.add(sts) + if len(siblings) > maxsize: + return None + + # NOTE: `downstream` isn't used for scheduling yet, since family scheduling + # only applies to root tasks. But it would be used for STA. + return siblings, downstream + + +def update_msgs(msgs: dict[str, list[Any]], new: dict[str, list[Any]]) -> None: + for k, new_msgs in new.items(): + m = msgs.get(k) + if m is not None: + m.extend(new_msgs) + else: + msgs[k] = new_msgs + + class KilledWorker(Exception): def __init__(self, task: str, last_worker: WorkerState, allowed_failures: int): super().__init__(task, last_worker, allowed_failures) diff --git a/distributed/stealing.py b/distributed/stealing.py index cdbcce30c4e..49b30405696 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -400,7 +400,7 @@ def balance(self) -> None: with log_errors(): i = 0 # Paused and closing workers must never become thieves - potential_thieves = set(s.idle.values()) + potential_thieves = s.idle.copy() if not potential_thieves or len(potential_thieves) == len(s.workers): return victim: WorkerState | None diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 1df0ceb1619..775784bca11 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -17,7 +17,7 @@ import cloudpickle import psutil import pytest -from tlz import concat, first, merge, valmap +from tlz import concat, first, merge, partition, valmap from tornado.ioloop import IOLoop, PeriodicCallback import dask @@ -136,6 +136,7 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c): assert x.key in a.data or x.key in b.data +@pytest.mark.parametrize("sat", [float("inf"), 1.0]) @pytest.mark.parametrize("ndeps", [0, 1, 4]) @pytest.mark.parametrize( "nthreads", @@ -144,13 +145,13 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c): [("127.0.0.1", 3), ("127.0.0.1", 2), ("127.0.0.1", 1)], ], ) -def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads): +def test_decide_worker_coschedule_order_neighbors(sat, ndeps, nthreads): @gen_cluster( client=True, nthreads=nthreads, config={ "distributed.scheduler.work-stealing": False, - "distributed.scheduler.worker-saturation": float("inf"), + "distributed.scheduler.worker-saturation": sat, }, ) async def test_decide_worker_coschedule_order_neighbors_(c, s, *workers): @@ -200,6 +201,7 @@ def random(**kwargs): **trivial_deps, ) + # dask.visualize(x, x.sum(axis=1, split_every=20), optimize_graph=True) xx, xsum = dask.persist(x, x.sum(axis=1, split_every=20)) await xsum @@ -251,7 +253,6 @@ def random(**kwargs): test_decide_worker_coschedule_order_neighbors_() -@pytest.mark.slow @gen_cluster( nthreads=[("", 2)] * 4, client=True, @@ -360,7 +361,7 @@ async def test_queued_paused_unpaused(c, s, a, b, queue): f1s = c.map(slowinc, range(16)) f2s = c.map(slowinc, f1s) - final = c.submit(sum, *f2s) + final = c.submit(sum, f2s) del f1s, f2s while not a.data or not b.data: @@ -409,6 +410,100 @@ async def test_queued_remove_add_worker(c, s, a, b): await wait(fs) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 6, + config={"distributed.scheduler.worker-saturation": 1.0}, +) +async def test_utilization_over_co_assignment(c, s, *workers): + event = Event() + roots = [delayed(event.wait)(5, dask_key_name=f"r-{i}") for i in range(6)] + aggs = [ + delayed(list)(rs, dask_key_name=f"a-{i}") + for i, rs in enumerate(partition(2, roots)) + ] + fs = c.compute(aggs) + + await async_wait_for(lambda: any(w.state.tasks for w in workers), timeout=5) + + # All workers should be used, even though it breaks up co-assignment + assert not s.idle + rts = [s.tasks[r.key] for r in roots] + assert {ts.processing_on for ts in rts} == set(s.workers.values()) + + await event.set() + await wait(fs) + + +@gen_cluster( + client=True, + nthreads=[("", 2)] * 2, + config={"distributed.scheduler.worker-saturation": 1.0}, +) +async def test_co_assign_scale_up(c, s, a, b): + event = Event() + devent = delayed(event) + roots = [devent.wait(5, dask_key_name=f"r-{i}") for i in range(16)] + aggs = [ + delayed(list)(rs, dask_key_name=f"a-{i}") + for i, rs in enumerate(partition(4, roots)) + ] + fs = c.compute(aggs) + + await async_wait_for(lambda: s.queued, timeout=5) + + # Each family of roots should be processing on the same worker, or not at all + for agg in aggs: + tss = s.tasks[agg.key].dependencies + proc = [ts.processing_on for ts in tss] + assert proc == proc[:1] * len(proc) + + async with Worker(s.address, nthreads=2) as w: + await async_wait_for(lambda: w.state.tasks, timeout=5) + assert len(w.state.tasks) == 5 # 4 `r` + the Event + + # Each family of roots should be processing on the same worker, or not at all + for agg in aggs: + tss = s.tasks[agg.key].dependencies + proc = [ts.processing_on for ts in tss] + assert len(set(proc)) == 1, proc + + await event.set() + await wait(fs) + + +@gen_cluster( + client=True, + nthreads=[("", 2)] * 3, + config={"distributed.scheduler.worker-saturation": 1.0}, +) +async def test_co_assign_scale_down(c, s, *workers): + event = Event() + roots = [delayed(event.wait)(5, dask_key_name=f"r-{i}") for i in range(16)] + aggs = [ + delayed(list)(rs, dask_key_name=f"a-{i}") + for i, rs in enumerate(partition(4, roots)) + ] + # pin roots so we can check where they are at the end + fs = c.compute(aggs + roots) + + await async_wait_for(lambda: s.queued, timeout=5) + + await workers[0].close() + await event.set() + await wait(fs) + + for r in roots: + ts = s.tasks[r.key] + assert len(ts.who_has) == 1, ts.who_has + + for w in workers: + assert not w.transfer_incoming_log + + +# TODO test _where_ tasks get assigned on scale-down. They should prefer to go near their siblings. + + @pytest.mark.parametrize( "saturation, expected_task_counts", [ diff --git a/distributed/tests/test_scheduler_family.py b/distributed/tests/test_scheduler_family.py new file mode 100644 index 00000000000..4204cc93c3e --- /dev/null +++ b/distributed/tests/test_scheduler_family.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import operator + +from tlz import partition, partition_all + +import dask + +from distributed.scheduler import family +from distributed.utils_test import async_wait_for, gen_cluster, slowidentity, slowinc + +ident = dask.delayed(slowidentity, pure=True) +inc = dask.delayed(slowinc, pure=True) +add = dask.delayed(operator.add, pure=True) +dsum = dask.delayed(sum, pure=True) + + +async def submit_delayed(client, scheduler, x): + "Submit a delayed object or list of them; wait until tasks are processed on scheduler" + # dask.visualize(x, optimize_graph=True, collapse_outputs=True) + fs = client.compute(x) + await async_wait_for(lambda: scheduler.tasks, 5) + try: + key = fs.key + except AttributeError: + key = fs[0].key + await async_wait_for(lambda: scheduler.tasks[key].state != "released", 5) + return fs + + +@gen_cluster(nthreads=[], client=True) +async def test_family(c, s): + r""" + z z z + / | / | / | + x | x | x | + / \ | / \ | / \ | + a b c a b c a b c + """ + ax = [dask.delayed(i, name=f"a-{i}") for i in range(3)] + bx = [dask.delayed(i, name=f"b-{i}") for i in range(3)] + cx = [dask.delayed(i, name=f"c-{i}") for i in range(3)] + + zs = [(a + b) + c for a, b, c in zip(ax, bx, cx)] + + _ = await submit_delayed(c, s, zs) + + a1 = s.tasks["a-1"] + b1 = s.tasks["b-1"] + c1 = s.tasks["c-1"] + + fam = family(a1, 1000, 1000) + assert fam + sibs, downstream = fam + assert sibs == {b1} + assert len(downstream) == 1 + add_1_1 = next(iter(downstream)) + + fam = family(c1, 1000, 1000) + assert fam + sibs, downstream = fam + assert sibs == {add_1_1} + assert len(downstream) == 1 # don't know keys + add_1_2 = next(iter(downstream)) + + fam = family(add_1_1, 1000, 1000) + assert fam + sibs, downstream = fam + assert sibs == {c1} + assert downstream == {add_1_2} + + assert family(add_1_2, 1000, 1000) is None + + +@gen_cluster(nthreads=[], client=True) +async def test_family_linear_chains(c, s): + r""" + final + / \ \ + / \ \ + -------------- + | s2 | s2 s2 + | / \ | / \ / \ + |/______ \ | | | | | + ------------- | | | | | | | + | s1 | | | | s1 | s1 | + | / | \ | | | | / | \ | / | \ | + | x | | | | | | x | | | x | | | + | | | | | | | | | | | | | | | | + | x | x | | x | x | x x x | x x + | | | | | | | | | | | | | | | | + | a b c | | d | a b c d a b c d + ------------- ------ / / / / / / / / + \ \ \ \ \ \ / / / / / / / / + r + """ + root = dask.delayed(0, name="root") + ax = [ + ident(ident(inc(root, dask_key_name=f"a-{i}"))) for i in range(3) # 2-chain(z) + ] + bx = [inc(root, dask_key_name=f"b-{i}") for i in range(3)] # 0-chain + cx = [ident(inc(root, dask_key_name=f"c-{i}")) for i in range(3)] # 1-chain + s1x = [ + dsum([a, b, c], dask_key_name=f"s1-{i}") + for i, (a, b, c) in enumerate(zip(ax, bx, cx)) + ] + + dx = [ident(inc(root, dask_key_name=f"d-{i}")) for i in range(3)] # 1-chain + s2x = [ + add(s1, d, dask_key_name=f"s2-{i}") for i, (s1, d) in enumerate(zip(s1x, dx)) + ] + + final = dsum(s2x, dask_key_name="final") + + _ = await submit_delayed(c, s, final) + + root = s.tasks["root"] + a1 = s.tasks["a-1"] + b1 = s.tasks["b-1"] + c1 = s.tasks["c-1"] + d1 = s.tasks["d-1"] + s1_1 = s.tasks["s1-1"] + s2_1 = s.tasks["s2-1"] + final = s.tasks["final"] + + await async_wait_for(lambda: final.state == "waiting", 5) + + # `a` traverses chains up and down to find `b` and `c` + # Does *not* include `d`: `d` is not required to compute `s1` + fam = family(a1, 1000, 4) + assert fam + sibs, downstream = fam + assert sibs == {b1, c1} + assert downstream == {s1_1} + + # `d` traverses chains up to find `s2` + # does not traverse down past `s2` + fam = family(d1, 1000, 4) + assert fam + sibs, downstream = fam + assert sibs == {s1_1} + assert downstream == {s2_1} + + # Don't traverse a linear chain past self + mid_chain = next(iter(a1.dependents)) + fam = family(mid_chain, 1000, 4) + assert fam + sibs, downstream = fam + assert sibs == {b1, c1} + assert downstream == {s1_1} + + # `root` has no family with small widely-shared cutoff + assert family(root, 1000, 4) is None + + # With large cutoff, `root` has no siblings. + # But the `s1` and `s2` tasks are all considered downstream, if you + # collapse the linear chains (which include `a`, `b`, `c`, `d`). + + # Note that `s1`s could be considered both siblings and downstreams + # (siblings, since they need to be in memory along with root to compute `s2`). + # But tasks that meet this criteria are explicitly labeled as only downstream. + fam = family(root, 1000, 1000) + assert fam + sibs, downstream = fam + assert sibs == set() + assert {ts.key for ts in downstream} == { + "s1-0", + "s1-1", + "s1-2", + "s2-0", + "s2-1", + "s2-2", + } + + +@gen_cluster(nthreads=[], client=True) +async def test_family_linear_chains_plus_widely_shared(c, s): + r""" + s s s + /|\ /|\ /\ + a a a a a a a a + |\|\|\|\|/|/|/| + | | | | s | | | + r r r r r r r r + """ + shared = dask.delayed(0, name="shared") + roots = [dask.delayed(i, name=f"r-{i}") for i in range(8)] + ax = [add(r, shared, dask_key_name=f"a-{i}") for i, r in enumerate(roots)] + sx = [ + dsum(axs, dask_key_name=f"s-{i}") for i, axs in enumerate(partition_all(3, ax)) + ] + + _ = await submit_delayed(c, s, sx) + + r0 = s.tasks["r-0"] + r1 = s.tasks["r-1"] + r2 = s.tasks["r-2"] + s0 = s.tasks["s-0"] + + fam = family(r0, 1000, 4) + assert fam + sibs, downstream = fam + assert sibs == {r1, r2} + assert downstream == {s0} + + +@gen_cluster(nthreads=[], client=True) +async def test_family_triangle(c, s): + r""" + z + /| + y | + \ | + x + """ + x = dask.delayed(0, name="x") + y = inc(x, dask_key_name="y") + z = add(x, y, dask_key_name="z") + + _ = await submit_delayed(c, s, z) + + x = s.tasks["x"] + y = s.tasks["y"] + z = s.tasks["z"] + + fam = family(x, 1000, 1000) + assert fam + sibs, downstream = fam + assert sibs == set() + assert downstream == {z} # `y` is just a linear chain, not downstream + + fam = family(y, 1000, 1000) + assert fam + sibs, downstream = fam + assert sibs == {x} + assert downstream == {z} + + +@gen_cluster(nthreads=[], client=True) +async def test_family_wide_gather_downstream(c, s): + r""" + s + / / / /|\ \ \ + i i i i i i i i + | | | | | | | | + r r r r r r r r + """ + roots = [dask.delayed(i, name=f"r-{i}") for i in range(8)] + incs = [inc(r, dask_key_name=f"i-{i}") for i, r in enumerate(roots)] + sum = dsum(incs, dask_key_name="sum") + + _ = await submit_delayed(c, s, sum) + + rts = [s.tasks[r.key] for r in roots] + sts = s.tasks["sum"] + + fam = family(rts[0], 4, 1000) + assert fam + sibs, downstream = fam + assert sibs == set() + assert downstream == set() # `sum` not downstream because it's too large + + fam = family(rts[0], 1000, 1000) + assert fam + sibs, downstream = fam + assert sibs == set(rts[1:]) + assert downstream == {sts} + + +# TODO test family commutativity. Given any node X in any graph, calculate `family(X)`. +# For each sibling S, `family(S)` should give the same family, regardless of the +# starting node. +# EXECPT THIS ISN'T TRUE + + +@gen_cluster(nthreads=[], client=True) +async def test_family_non_commutative(c, s): + roots = [dask.delayed(i, name=f"r-{i}") for i in range(16)] + aggs = [dsum(rs) for rs in partition(4, roots)] + extra = dsum([roots[::4]], dask_key_name="extra") + + _ = await submit_delayed(c, s, aggs + [extra]) + + rts = [s.tasks[r.key] for r in roots] + ats = [s.tasks[a.key] for a in aggs] + ets = s.tasks["extra"] + + fam = family(rts[0], 1000, 1000) + assert fam + sibs, downstream = fam + assert sibs == set(rts[1:4]) | {rts[4], rts[8], rts[12]} + assert downstream == {ats[0], ets} + + fam = family(rts[1], 1000, 1000) + assert fam + sibs, downstream = fam + assert sibs == {rts[0], rts[2], rts[3]} + assert downstream == {ats[0]}