diff --git a/distributed/scheduler.py b/distributed/scheduler.py index cf240240cfb..43f68b69d47 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2167,10 +2167,7 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None: _task_slots_available(ws, self.WORKER_SATURATION), ) assert ws in self.running, (ws, self.running) - - if self.validate and ws is not None: assert self.workers.get(ws.address) is ws - assert ws in self.running, (ws, self.running) return ws @@ -7878,7 +7875,11 @@ def _exit_processing_common( state.release_resources(ts, ws) # If a slot has opened up for a queued task, schedule it. - if state.queued and not _worker_full(ws, state.WORKER_SATURATION): + if ( + state.queued + and ws.status == Status.running + and not _worker_full(ws, state.WORKER_SATURATION) + ): qts = state.queued.peek() if state.validate: assert qts.state == "queued", qts.state diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 7df6796066c..de2e512d2c3 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -408,6 +408,42 @@ async def test_queued_remove_add_worker(c, s, a, b): await wait(fs) +@gen_cluster( + client=True, + nthreads=[("", 1)], + config={ + "distributed.scheduler.worker-saturation": 1.0, + "distributed.worker.memory.pause": False, + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + }, +) +async def test_queued_dont_try_non_running_worker(c, s, a): + "When a slot opens on a non-running worker, don't consider scheduling a queued task" + events = [Event() for _ in range(5)] + fs = c.map(lambda ev: ev.wait(), events, key=[f"w-{i}" for i in range(len(events))]) + + await async_wait_for(lambda: s.queued, timeout=5) + + a.status = Status.paused + + await async_wait_for(lambda: not s.running, timeout=5) + + assert len(a.state.executing) == 1 + a_key: str = next(iter(a.state.executing)).key + a_task = s.tasks[a_key] + a_event = events[int(a_key[2])] + + front_of_queue = s.queued.peek() + + assert a_task.state == "processing" + await a_event.set() + await wait_for_state(a_key, "memory", s) + + story = s.story(front_of_queue) + assert story[-1][1:2] != ["queued", "queued"], story + + @pytest.mark.parametrize( "saturation_config, expected_task_counts", [