-
-
Notifications
You must be signed in to change notification settings - Fork 763
Fix decide_worker picking a closing worker #8032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
882a546
c4e5143
d884e06
5661ac1
d5f1a83
1affb3f
80b36c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2207,7 +2207,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: | |
| restrictions. | ||
|
|
||
| Out of eligible workers holding dependencies of ``ts``, selects the worker | ||
| where, considering worker backlong and data-transfer costs, the task is | ||
| where, considering worker backlog and data-transfer costs, the task is | ||
| estimated to start running the soonest. | ||
|
|
||
| Returns | ||
|
|
@@ -2222,9 +2222,6 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: | |
|
|
||
| valid_workers = self.valid_workers(ts) | ||
| if valid_workers is None and len(self.running) < len(self.workers): | ||
| if not self.running: | ||
| return None | ||
|
|
||
| # If there were no restrictions, `valid_workers()` didn't subset by | ||
| # `running`. | ||
| valid_workers = self.running | ||
|
|
@@ -8194,7 +8191,7 @@ def _task_to_client_msgs(ts: TaskState) -> dict[str, list[dict[str, Any]]]: | |
|
|
||
| def decide_worker( | ||
| ts: TaskState, | ||
| all_workers: Iterable[WorkerState], | ||
| all_workers: set[WorkerState], | ||
| valid_workers: set[WorkerState] | None, | ||
| objective: Callable[[WorkerState], Any], | ||
| ) -> WorkerState | None: | ||
|
|
@@ -8215,12 +8212,13 @@ def decide_worker( | |
| """ | ||
| assert all(dts.who_has for dts in ts.dependencies) | ||
| if ts.actor: | ||
| candidates = set(all_workers) | ||
| candidates = all_workers.copy() | ||
| else: | ||
| candidates = {wws for dts in ts.dependencies for wws in dts.who_has} | ||
| candidates &= all_workers | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fixes #8019
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW I think this is a situation where an actual |
||
| if valid_workers is None: | ||
| if not candidates: | ||
| candidates = set(all_workers) | ||
| candidates = all_workers.copy() | ||
| else: | ||
| candidates &= valid_workers | ||
| if not candidates: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -48,21 +48,59 @@ def test_submit_after_failed_worker_sync(loop): | |||||||||||||
| assert total.result() == sum(map(inc, range(10))) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @pytest.mark.slow() | ||||||||||||||
| @pytest.mark.parametrize("compute_on_failed", [False, True]) | ||||||||||||||
| @gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "500ms"}) | ||||||||||||||
| async def test_submit_after_failed_worker_async(c, s, a, b, compute_on_failed): | ||||||||||||||
| async with Nanny(s.address, nthreads=2) as n: | ||||||||||||||
| await c.wait_for_workers(3) | ||||||||||||||
|
|
||||||||||||||
| L = c.map(inc, range(10)) | ||||||||||||||
| await wait(L) | ||||||||||||||
|
|
||||||||||||||
| kill_task = asyncio.create_task(n.kill()) | ||||||||||||||
| compute_addr = n.worker_address if compute_on_failed else a.address | ||||||||||||||
| total = c.submit(sum, L, workers=[compute_addr], allow_other_workers=True) | ||||||||||||||
| assert await total == sum(range(1, 11)) | ||||||||||||||
| await kill_task | ||||||||||||||
| @pytest.mark.parametrize("when", ["closing", "closed"]) | ||||||||||||||
| @pytest.mark.parametrize("y_on_failed", [False, True]) | ||||||||||||||
| @pytest.mark.parametrize("x_on_failed", [False, True]) | ||||||||||||||
| @gen_cluster( | ||||||||||||||
| client=True, | ||||||||||||||
| nthreads=[("", 1)] * 2, | ||||||||||||||
| config={"distributed.comm.timeouts.connect": "1s"}, | ||||||||||||||
| ) | ||||||||||||||
| async def test_submit_after_failed_worker_async( | ||||||||||||||
| c, s, a, b, x_on_failed, y_on_failed, when, monkeypatch | ||||||||||||||
| ): | ||||||||||||||
| a_ws = s.workers[a.address] | ||||||||||||||
|
|
||||||||||||||
| x = c.submit( | ||||||||||||||
| inc, | ||||||||||||||
| 1, | ||||||||||||||
| key="x", | ||||||||||||||
| workers=[b.address if x_on_failed else a.address], | ||||||||||||||
| allow_other_workers=True, | ||||||||||||||
| ) | ||||||||||||||
| await wait(x) | ||||||||||||||
|
|
||||||||||||||
| if when == "closed": | ||||||||||||||
| await b.close() | ||||||||||||||
| await async_poll_for(lambda: b.address not in s.workers, timeout=5) | ||||||||||||||
| elif when == "closing": | ||||||||||||||
| orig_remove_worker = s.remove_worker | ||||||||||||||
| in_remove_worker = asyncio.Event() | ||||||||||||||
| wait_remove_worker = asyncio.Event() | ||||||||||||||
|
|
||||||||||||||
| async def remove_worker(*args, **kwargs): | ||||||||||||||
| in_remove_worker.set() | ||||||||||||||
| await wait_remove_worker.wait() | ||||||||||||||
| return await orig_remove_worker(*args, **kwargs) | ||||||||||||||
|
|
||||||||||||||
| monkeypatch.setattr(s, "remove_worker", remove_worker) | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm OK-ish with using monkeypatch here. However, just for the sake of prosperity, there is also a way to use our RPC mechanism more naturally. Essentially you want to intercept the point in time just when a request handler is called. You can make this very explitc async def new_remove_worker_handler_with_events(self, *args, **kwargs):
in_remove_worker.set()
await wait_remove_worker.wait()
return await self.remove_worker(*args, **kwargs)
s.handlers['unregister'] = new_remove_worker_handler_with_events`Semantically, this overrides the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're not arriving here from the distributed/distributed/scheduler.py Lines 5700 to 5705 in f0303aa
|
||||||||||||||
| await b.close() | ||||||||||||||
| await in_remove_worker.wait() | ||||||||||||||
| assert s.workers[b.address].status.name == "closing" | ||||||||||||||
|
|
||||||||||||||
| y = c.submit( | ||||||||||||||
| inc, | ||||||||||||||
| x, | ||||||||||||||
| key="y", | ||||||||||||||
| workers=[b.address if y_on_failed else a.address], | ||||||||||||||
| allow_other_workers=True, | ||||||||||||||
| ) | ||||||||||||||
| await async_poll_for(lambda: "y" in s.tasks, timeout=5) | ||||||||||||||
|
|
||||||||||||||
| if when == "closing": | ||||||||||||||
| wait_remove_worker.set() | ||||||||||||||
| assert await y == 3 | ||||||||||||||
| assert s.tasks["y"].who_has == {a_ws} | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @gen_cluster(client=True, timeout=60) | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this related? At least the new test doesn't seem to care about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unreachable because the same condition is already tested on line 2218