diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9bdd2337b0e..9416144503f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2207,7 +2207,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: restrictions. Out of eligible workers holding dependencies of ``ts``, selects the worker - where, considering worker backlong and data-transfer costs, the task is + where, considering worker backlog and data-transfer costs, the task is estimated to start running the soonest. Returns @@ -2222,9 +2222,6 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: valid_workers = self.valid_workers(ts) if valid_workers is None and len(self.running) < len(self.workers): - if not self.running: - return None - # If there were no restrictions, `valid_workers()` didn't subset by # `running`. valid_workers = self.running @@ -8194,7 +8191,7 @@ def _task_to_client_msgs(ts: TaskState) -> dict[str, list[dict[str, Any]]]: def decide_worker( ts: TaskState, - all_workers: Iterable[WorkerState], + all_workers: set[WorkerState], valid_workers: set[WorkerState] | None, objective: Callable[[WorkerState], Any], ) -> WorkerState | None: @@ -8215,12 +8212,13 @@ def decide_worker( """ assert all(dts.who_has for dts in ts.dependencies) if ts.actor: - candidates = set(all_workers) + candidates = all_workers.copy() else: candidates = {wws for dts in ts.dependencies for wws in dts.who_has} + candidates &= all_workers if valid_workers is None: if not candidates: - candidates = set(all_workers) + candidates = all_workers.copy() else: candidates &= valid_workers if not candidates: diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index f553d164601..8d7eec43606 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -48,21 +48,59 @@ def test_submit_after_failed_worker_sync(loop): assert total.result() == sum(map(inc, range(10))) -@pytest.mark.slow() -@pytest.mark.parametrize("compute_on_failed", [False, True]) -@gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "500ms"}) -async def test_submit_after_failed_worker_async(c, s, a, b, compute_on_failed): - async with Nanny(s.address, nthreads=2) as n: - await c.wait_for_workers(3) - - L = c.map(inc, range(10)) - await wait(L) - - kill_task = asyncio.create_task(n.kill()) - compute_addr = n.worker_address if compute_on_failed else a.address - total = c.submit(sum, L, workers=[compute_addr], allow_other_workers=True) - assert await total == sum(range(1, 11)) - await kill_task +@pytest.mark.parametrize("when", ["closing", "closed"]) +@pytest.mark.parametrize("y_on_failed", [False, True]) +@pytest.mark.parametrize("x_on_failed", [False, True]) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.comm.timeouts.connect": "1s"}, +) +async def test_submit_after_failed_worker_async( + c, s, a, b, x_on_failed, y_on_failed, when, monkeypatch +): + a_ws = s.workers[a.address] + + x = c.submit( + inc, + 1, + key="x", + workers=[b.address if x_on_failed else a.address], + allow_other_workers=True, + ) + await wait(x) + + if when == "closed": + await b.close() + await async_poll_for(lambda: b.address not in s.workers, timeout=5) + elif when == "closing": + orig_remove_worker = s.remove_worker + in_remove_worker = asyncio.Event() + wait_remove_worker = asyncio.Event() + + async def remove_worker(*args, **kwargs): + in_remove_worker.set() + await wait_remove_worker.wait() + return await orig_remove_worker(*args, **kwargs) + + monkeypatch.setattr(s, "remove_worker", remove_worker) + await b.close() + await in_remove_worker.wait() + assert s.workers[b.address].status.name == "closing" + + y = c.submit( + inc, + x, + key="y", + workers=[b.address if y_on_failed else a.address], + allow_other_workers=True, + ) + await async_poll_for(lambda: "y" in s.tasks, timeout=5) + + if when == "closing": + wait_remove_worker.set() + assert await y == 3 + assert s.tasks["y"].who_has == {a_ws} @gen_cluster(client=True, timeout=60)