From 882a546fd1f8b0a0b96b334e2790a24fc25ad32c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 25 Jul 2023 00:14:35 +0100 Subject: [PATCH 1/6] Fix decide_worker picking a closing worker --- .github/workflows/tests.yaml | 5 +- distributed/scheduler.py | 8 ++- distributed/tests/test_failed_workers.py | 63 +++++++++++++++++++----- 3 files changed, 56 insertions(+), 20 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 29c9d3f14c4..9c23a9976a1 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -197,8 +197,9 @@ jobs: set -o pipefail mkdir reports - pytest distributed \ - -m "not avoid_ci and ${{ matrix.partition }}" --runslow \ + # TEMP DO NOT MERGE + pytest distributed/tests/test_failed_workers.py \ + -m "not avoid_ci" --runslow \ --leaks=fds,processes,threads \ --junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \ --cov=distributed --cov-report=xml \ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ca612cd9ba6..1a2e5304b22 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2205,7 +2205,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: restrictions. Out of eligible workers holding dependencies of ``ts``, selects the worker - where, considering worker backlong and data-transfer costs, the task is + where, considering worker backlog and data-transfer costs, the task is estimated to start running the soonest. Returns @@ -2220,9 +2220,6 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: valid_workers = self.valid_workers(ts) if valid_workers is None and len(self.running) < len(self.workers): - if not self.running: - return None - # If there were no restrictions, `valid_workers()` didn't subset by # `running`. valid_workers = self.running @@ -8181,7 +8178,7 @@ def _task_to_client_msgs(ts: TaskState) -> dict[str, list[dict[str, Any]]]: def decide_worker( ts: TaskState, - all_workers: Iterable[WorkerState], + all_workers: set[WorkerState], valid_workers: set[WorkerState] | None, objective: Callable[[WorkerState], Any], ) -> WorkerState | None: @@ -8205,6 +8202,7 @@ def decide_worker( candidates = set(all_workers) else: candidates = {wws for dts in ts.dependencies for wws in dts.who_has} + candidates &= all_workers if valid_workers is None: if not candidates: candidates = set(all_workers) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index f553d164601..2a5c8e383f5 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -17,6 +17,7 @@ from distributed import Client, Nanny, profile, wait from distributed.comm import CommClosedError from distributed.compatibility import MACOS +from distributed.core import Status from distributed.metrics import time from distributed.utils import CancelledError, sync from distributed.utils_test import ( @@ -48,21 +49,57 @@ def test_submit_after_failed_worker_sync(loop): assert total.result() == sum(map(inc, range(10))) -@pytest.mark.slow() -@pytest.mark.parametrize("compute_on_failed", [False, True]) -@gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "500ms"}) -async def test_submit_after_failed_worker_async(c, s, a, b, compute_on_failed): - async with Nanny(s.address, nthreads=2) as n: - await c.wait_for_workers(3) +@pytest.mark.repeat(20) # TEMP DO NOT MERGE +@pytest.mark.slow +@pytest.mark.parametrize("wait_closing", [False, True]) +@pytest.mark.parametrize("y_on_failed", [False, True]) +@pytest.mark.parametrize("x_on_failed", [False, True]) +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_submit_after_failed_worker_async( + c, s, a, b, x_on_failed, y_on_failed, wait_closing +): + a_ws = s.workers[a.address] + b_ws = s.workers[b.address] + + L = c.map( + inc, + range(10), + workers=[b.address if x_on_failed else a.address], + allow_other_workers=True, + ) - L = c.map(inc, range(10)) - await wait(L) + await wait(L) + total = c.submit( + sum, + L, + key="y", + workers=[b.address if y_on_failed else a.address], + allow_other_workers=True, + ) + + done_update_graph = False + if wait_closing: + in_update_graph = asyncio.Event() + + async def update_graph(*args, **kwargs): + in_update_graph.set() + await async_poll_for( + lambda: b_ws.status == Status.closing, timeout=5, period=0 + ) + s.update_graph(*args, **kwargs) + nonlocal done_update_graph + done_update_graph = True + + s.stream_handlers["update-graph"] = update_graph + await in_update_graph.wait() - kill_task = asyncio.create_task(n.kill()) - compute_addr = n.worker_address if compute_on_failed else a.address - total = c.submit(sum, L, workers=[compute_addr], allow_other_workers=True) - assert await total == sum(range(1, 11)) - await kill_task + await b.close() + assert await total == sum(range(1, 11)) + if wait_closing: + # Without this we may never know if s.update_graph raised, + # or if the monkey-patch worked to begin with + assert done_update_graph + assert s.tasks["y"].who_has == {a_ws} @gen_cluster(client=True, timeout=60) From c4e51434a1309efda6e0301b5f898983ae4c9e0e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 25 Jul 2023 17:03:01 +0100 Subject: [PATCH 2/6] Redesign test --- distributed/tests/test_failed_workers.py | 59 +++++++++++++----------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 2a5c8e383f5..acfd8cc294b 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -17,7 +17,6 @@ from distributed import Client, Nanny, profile, wait from distributed.comm import CommClosedError from distributed.compatibility import MACOS -from distributed.core import Status from distributed.metrics import time from distributed.utils import CancelledError, sync from distributed.utils_test import ( @@ -32,6 +31,7 @@ inc, slowadd, slowinc, + wait_for_state, ) from distributed.worker_state_machine import FreeKeysEvent @@ -50,16 +50,18 @@ def test_submit_after_failed_worker_sync(loop): @pytest.mark.repeat(20) # TEMP DO NOT MERGE -@pytest.mark.slow -@pytest.mark.parametrize("wait_closing", [False, True]) +@pytest.mark.parametrize("when", ["closing", "closed"]) @pytest.mark.parametrize("y_on_failed", [False, True]) @pytest.mark.parametrize("x_on_failed", [False, True]) -@gen_cluster(client=True, nthreads=[("", 1)] * 2) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.comm.timeouts.connect": "1s"}, +) async def test_submit_after_failed_worker_async( - c, s, a, b, x_on_failed, y_on_failed, wait_closing + c, s, a, b, x_on_failed, y_on_failed, when, monkeypatch ): a_ws = s.workers[a.address] - b_ws = s.workers[b.address] L = c.map( inc, @@ -67,8 +69,26 @@ async def test_submit_after_failed_worker_async( workers=[b.address if x_on_failed else a.address], allow_other_workers=True, ) - await wait(L) + + if when == "closed": + await b.close() + await async_poll_for(lambda: b.address not in s.workers, timeout=5) + elif when == "closing": + orig_remove_worker = s.remove_worker + in_remove_worker = asyncio.Event() + wait_remove_worker = asyncio.Event() + + async def remove_worker(*args, **kwargs): + in_remove_worker.set() + await wait_remove_worker.wait() + return await orig_remove_worker(*args, **kwargs) + + monkeypatch.setattr(s, "remove_worker", remove_worker) + await b.close() + await in_remove_worker.wait() + assert s.workers[b.address].status.name == "closing" + total = c.submit( sum, L, @@ -77,28 +97,11 @@ async def test_submit_after_failed_worker_async( allow_other_workers=True, ) - done_update_graph = False - if wait_closing: - in_update_graph = asyncio.Event() - - async def update_graph(*args, **kwargs): - in_update_graph.set() - await async_poll_for( - lambda: b_ws.status == Status.closing, timeout=5, period=0 - ) - s.update_graph(*args, **kwargs) - nonlocal done_update_graph - done_update_graph = True - - s.stream_handlers["update-graph"] = update_graph - await in_update_graph.wait() - - await b.close() + await wait_for_state("y", "processing", s, interval=0) + assert s.tasks["y"].processing_on is a_ws + if when == "closing": + wait_remove_worker.set() assert await total == sum(range(1, 11)) - if wait_closing: - # Without this we may never know if s.update_graph raised, - # or if the monkey-patch worked to begin with - assert done_update_graph assert s.tasks["y"].who_has == {a_ws} From d884e06b3a98a3d493c07d208ba15b2122fba863 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 25 Jul 2023 17:11:23 +0100 Subject: [PATCH 3/6] read-only annotations --- distributed/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 1a2e5304b22..2cdb577b0c4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8178,8 +8178,8 @@ def _task_to_client_msgs(ts: TaskState) -> dict[str, list[dict[str, Any]]]: def decide_worker( ts: TaskState, - all_workers: set[WorkerState], - valid_workers: set[WorkerState] | None, + all_workers: Set[WorkerState], + valid_workers: Set[WorkerState] | None, objective: Callable[[WorkerState], Any], ) -> WorkerState | None: """ From 5661ac17e7c2f016bd6b3a4c8c14b03ec131f13a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 26 Jul 2023 14:19:33 +0100 Subject: [PATCH 4/6] Revert temp changes --- .github/workflows/tests.yaml | 5 ++--- distributed/tests/test_failed_workers.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 9c23a9976a1..29c9d3f14c4 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -197,9 +197,8 @@ jobs: set -o pipefail mkdir reports - # TEMP DO NOT MERGE - pytest distributed/tests/test_failed_workers.py \ - -m "not avoid_ci" --runslow \ + pytest distributed \ + -m "not avoid_ci and ${{ matrix.partition }}" --runslow \ --leaks=fds,processes,threads \ --junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \ --cov=distributed --cov-report=xml \ diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index acfd8cc294b..5c539899597 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -49,7 +49,6 @@ def test_submit_after_failed_worker_sync(loop): assert total.result() == sum(map(inc, range(10))) -@pytest.mark.repeat(20) # TEMP DO NOT MERGE @pytest.mark.parametrize("when", ["closing", "closed"]) @pytest.mark.parametrize("y_on_failed", [False, True]) @pytest.mark.parametrize("x_on_failed", [False, True]) From d5f1a83288e232dca65f5118fd61a33c5dae0921 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 26 Jul 2023 15:18:09 +0100 Subject: [PATCH 5/6] lint --- distributed/scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 2cdb577b0c4..7601b418365 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8178,8 +8178,8 @@ def _task_to_client_msgs(ts: TaskState) -> dict[str, list[dict[str, Any]]]: def decide_worker( ts: TaskState, - all_workers: Set[WorkerState], - valid_workers: Set[WorkerState] | None, + all_workers: set[WorkerState], + valid_workers: set[WorkerState] | None, objective: Callable[[WorkerState], Any], ) -> WorkerState | None: """ @@ -8199,13 +8199,13 @@ def decide_worker( """ assert all(dts.who_has for dts in ts.dependencies) if ts.actor: - candidates = set(all_workers) + candidates = all_workers.copy() else: candidates = {wws for dts in ts.dependencies for wws in dts.who_has} candidates &= all_workers if valid_workers is None: if not candidates: - candidates = set(all_workers) + candidates = all_workers.copy() else: candidates &= valid_workers if not candidates: From 80b36c986dc7f32b2ef6cf1411792be110cb0b88 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 31 Jul 2023 12:48:22 +0100 Subject: [PATCH 6/6] Code review --- distributed/tests/test_failed_workers.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 5c539899597..8d7eec43606 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -31,7 +31,6 @@ inc, slowadd, slowinc, - wait_for_state, ) from distributed.worker_state_machine import FreeKeysEvent @@ -62,13 +61,14 @@ async def test_submit_after_failed_worker_async( ): a_ws = s.workers[a.address] - L = c.map( + x = c.submit( inc, - range(10), + 1, + key="x", workers=[b.address if x_on_failed else a.address], allow_other_workers=True, ) - await wait(L) + await wait(x) if when == "closed": await b.close() @@ -88,19 +88,18 @@ async def remove_worker(*args, **kwargs): await in_remove_worker.wait() assert s.workers[b.address].status.name == "closing" - total = c.submit( - sum, - L, + y = c.submit( + inc, + x, key="y", workers=[b.address if y_on_failed else a.address], allow_other_workers=True, ) + await async_poll_for(lambda: "y" in s.tasks, timeout=5) - await wait_for_state("y", "processing", s, interval=0) - assert s.tasks["y"].processing_on is a_ws if when == "closing": wait_remove_worker.set() - assert await total == sum(range(1, 11)) + assert await y == 3 assert s.tasks["y"].who_has == {a_ws}