Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,7 +2207,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
restrictions.

Out of eligible workers holding dependencies of ``ts``, selects the worker
where, considering worker backlong and data-transfer costs, the task is
where, considering worker backlog and data-transfer costs, the task is
estimated to start running the soonest.

Returns
Expand All @@ -2222,9 +2222,6 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:

valid_workers = self.valid_workers(ts)
if valid_workers is None and len(self.running) < len(self.workers):
if not self.running:
return None
Comment on lines -2225 to -2226

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this related? At least the new test doesn't seem to care about this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unreachable because the same condition is already tested on line 2218


# If there were no restrictions, `valid_workers()` didn't subset by
# `running`.
valid_workers = self.running
Expand Down Expand Up @@ -8194,7 +8191,7 @@ def _task_to_client_msgs(ts: TaskState) -> dict[str, list[dict[str, Any]]]:

def decide_worker(
ts: TaskState,
all_workers: Iterable[WorkerState],
all_workers: set[WorkerState],
valid_workers: set[WorkerState] | None,
objective: Callable[[WorkerState], Any],
) -> WorkerState | None:
Expand All @@ -8215,12 +8212,13 @@ def decide_worker(
"""
assert all(dts.who_has for dts in ts.dependencies)
if ts.actor:
candidates = set(all_workers)
candidates = all_workers.copy()
else:
candidates = {wws for dts in ts.dependencies for wws in dts.who_has}
candidates &= all_workers

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes #8019

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I think this is a situation where an actual decide_worker unit test would be appropriate

if valid_workers is None:
if not candidates:
candidates = set(all_workers)
candidates = all_workers.copy()
else:
candidates &= valid_workers
if not candidates:
Expand Down
68 changes: 53 additions & 15 deletions distributed/tests/test_failed_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,59 @@ def test_submit_after_failed_worker_sync(loop):
assert total.result() == sum(map(inc, range(10)))


@pytest.mark.slow()
@pytest.mark.parametrize("compute_on_failed", [False, True])
@gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "500ms"})
async def test_submit_after_failed_worker_async(c, s, a, b, compute_on_failed):
async with Nanny(s.address, nthreads=2) as n:
await c.wait_for_workers(3)

L = c.map(inc, range(10))
await wait(L)

kill_task = asyncio.create_task(n.kill())
compute_addr = n.worker_address if compute_on_failed else a.address
total = c.submit(sum, L, workers=[compute_addr], allow_other_workers=True)
assert await total == sum(range(1, 11))
await kill_task
@pytest.mark.parametrize("when", ["closing", "closed"])
@pytest.mark.parametrize("y_on_failed", [False, True])
@pytest.mark.parametrize("x_on_failed", [False, True])
@gen_cluster(
client=True,
nthreads=[("", 1)] * 2,
config={"distributed.comm.timeouts.connect": "1s"},
)
async def test_submit_after_failed_worker_async(
c, s, a, b, x_on_failed, y_on_failed, when, monkeypatch
):
a_ws = s.workers[a.address]

x = c.submit(
inc,
1,
key="x",
workers=[b.address if x_on_failed else a.address],
allow_other_workers=True,
)
await wait(x)

if when == "closed":
await b.close()
await async_poll_for(lambda: b.address not in s.workers, timeout=5)
elif when == "closing":
orig_remove_worker = s.remove_worker
in_remove_worker = asyncio.Event()
wait_remove_worker = asyncio.Event()

async def remove_worker(*args, **kwargs):
in_remove_worker.set()
await wait_remove_worker.wait()
return await orig_remove_worker(*args, **kwargs)

monkeypatch.setattr(s, "remove_worker", remove_worker)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK-ish with using monkeypatch here. However, just for the sake of prosperity, there is also a way to use our RPC mechanism more naturally. Essentially you want to intercept the point in time just when a request handler is called. You can make this very explitc

async def new_remove_worker_handler_with_events(self, *args, **kwargs):
    in_remove_worker.set()
    await wait_remove_worker.wait()
    return await self.remove_worker(*args, **kwargs)
s.handlers['unregister'] = new_remove_worker_handler_with_events`

Semantically, this overrides the unregister handler and replaces it with a new handler.
However, in the end, it's the same thing just the way the patch is installed is different.

@crusaderky crusaderky Jul 31, 2023

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not arriving here from the unregister handler. We're arriving from

finally:
if worker in self.stream_comms:
worker_comm.abort()
await self.remove_worker(
worker, stimulus_id=f"handle-worker-cleanup-{time()}"
)

await b.close()
await in_remove_worker.wait()
assert s.workers[b.address].status.name == "closing"

y = c.submit(
inc,
x,
key="y",
workers=[b.address if y_on_failed else a.address],
allow_other_workers=True,
)
await async_poll_for(lambda: "y" in s.tasks, timeout=5)

if when == "closing":
wait_remove_worker.set()
assert await y == 3
assert s.tasks["y"].who_has == {a_ws}


@gen_cluster(client=True, timeout=60)
Expand Down