From ac9530667e7baa27f67c5c811a7724abb7e7471d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 20 Jul 2023 16:32:57 +0100 Subject: [PATCH 01/10] Client.gather() overhaul --- distributed/client.py | 34 ++++--- distributed/scheduler.py | 58 +++-------- distributed/tests/test_client.py | 67 +++++++++++++ distributed/tests/test_scheduler.py | 65 ++++-------- distributed/tests/test_utils_comm.py | 63 ++++++++++-- distributed/tests/test_worker.py | 26 +++-- distributed/utils_comm.py | 145 ++++++++++++++++++--------- distributed/utils_test.py | 29 ++++++ distributed/worker.py | 67 ++++--------- 9 files changed, 339 insertions(+), 215 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index ec55b7f2c03..ccef901e08e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2299,7 +2299,7 @@ async def wait(k): result = pack_data(unpacked, merge(data, bad_data)) return result - async def _gather_remote(self, direct, local_worker): + async def _gather_remote(self, direct: bool, local_worker: bool) -> dict[str, Any]: """Perform gather with workers or scheduler This method exists to limit and batch many concurrent gathers into a @@ -2311,22 +2311,26 @@ async def _gather_remote(self, direct, local_worker): self._gather_keys = None # clear state, these keys are being sent off self._gather_future = None - if direct or local_worker: # gather directly from workers - who_has = await retry_operation(self.scheduler.who_has, keys=keys) - data2, missing_keys, missing_workers = await gather_from_workers( - who_has, rpc=self.rpc, close=False - ) - response = {"status": "OK", "data": data2} - if missing_keys: - keys2 = [key for key in keys if key not in data2] - response = await retry_operation(self.scheduler.gather, keys=keys2) - if response["status"] == "OK": - response["data"].update(data2) + if not direct and not local_worker: + # ask scheduler to gather data for us + return await retry_operation(self.scheduler.gather, keys=keys) + + # gather directly from workers + async def who_has(keys: list[str]) -> dict[str, Collection[str]]: + return await retry_operation(self.scheduler.who_has, keys=keys) - else: # ask scheduler to gather data for us - response = await retry_operation(self.scheduler.gather, keys=keys) + data, missing_keys = await gather_from_workers( + keys=keys, who_has=who_has, rpc=self.rpc + ) + response: dict[str, Any] = {"status": "OK", "data": data} + if missing_keys: + response = await retry_operation( + self.scheduler.gather, keys=missing_keys + ) + if response["status"] == "OK": + response["data"].update(data) - return response + return response def gather(self, futures, errors="raise", direct=None, asynchronous=None): """Gather futures from distributed memory diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 109a9e6729f..a3c5989e4cb 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5914,54 +5914,24 @@ async def scatter( ) return keys - async def gather(self, keys, serializers=None): + async def gather( + self, keys: Collection[str], serializers: list[str] | None = None + ) -> dict[str, Any]: """Collect data from workers to the scheduler""" - stimulus_id = f"gather-{time()}" - keys = list(keys) - who_has = {} - for key in keys: - ts: TaskState = self.tasks.get(key) - if ts is not None: - who_has[key] = [ws.address for ws in ts.who_has] - else: - who_has[key] = [] - - data, missing_keys, missing_workers = await gather_from_workers( - who_has, rpc=self.rpc, close=False, serializers=serializers + data, missing_keys = await gather_from_workers( + keys=keys, who_has=self.get_who_has, rpc=self.rpc, serializers=serializers ) + self.log_event("all", {"action": "gather", "count": len(keys)}) + if not missing_keys: - result = {"status": "OK", "data": data} - else: - missing_states = [ - (self.tasks[key].state if key in self.tasks else None) - for key in missing_keys - ] - logger.exception( - "Couldn't gather keys %s state: %s workers: %s", - missing_keys, - missing_states, - missing_workers, - ) - result = {"status": "error", "keys": missing_keys} - with log_errors(): - # Remove suspicious workers from the scheduler and shut them down. - await asyncio.gather( - *( - self.remove_worker( - address=worker, close=True, stimulus_id=stimulus_id - ) - for worker in missing_workers - ) - ) - for key, workers in missing_keys.items(): - logger.exception( - "Shut down workers that don't have promised key: %s, %s", - str(workers), - str(key), - ) + return {"status": "OK", "data": data} - self.log_event("all", {"action": "gather", "count": len(keys)}) - return result + missing_states = { + key: self.tasks[key].state if key in self.tasks else "forgotten" + for key in missing_keys + } + logger.error("Couldn't gather keys: %s", missing_states) + return {"status": "error", "keys": list(missing_keys)} @log_errors async def restart(self, client=None, timeout=30, wait_for_workers=True): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 2ed731a160e..8b7f68edf1e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -84,7 +84,9 @@ from distributed.utils import get_mp_context, is_valid_xml, open_port, sync, tmp_text from distributed.utils_test import ( NO_AMM, + BarrierGetData, BlockedGatherDep, + BlockedGetData, TaskStateMetadataPlugin, _UnhashableCallable, async_poll_for, @@ -8430,3 +8432,68 @@ def identity(x): outer_future = c.submit(identity, {"x": inner_future, "y": 2}) result = await outer_future assert result == {"x": 1, "y": 2} + + +@gen_cluster( + client=True, + config=merge(NO_AMM, {"distributed.worker.memory.pause": False}), +) +async def test_replicate_busy(c, s, a, b): + """Client.replicate() receives a 'busy' response from a worker""" + async with BarrierGetData(s.address, barrier_count=2) as w: + x = (await c.scatter({"x": 1}, workers=[w.address]))["x"] + # Throttle to 1 simultaneous connection + w.status = Status.paused + await c.replicate(x) + # either a or b will receive 'busy'. + # After 0.15s, it will fetch the key from either b or a or from w. + assert w.barrier_count in (0, -1) + assert dict(a.data) == dict(b.data) == {"x": 1} + + +@pytest.mark.parametrize("direct", [False, True]) +@gen_cluster( + client=True, + nthreads=[], + config=merge(NO_AMM, {"distributed.worker.memory.pause": False}), +) +async def test_gather_busy(c, s, a, b, direct): + """Client.gather() receives a 'busy' response from a worker""" + async with BarrierGetData(s.address, barrier_count=2) as w: + x = c.submit(inc, 1, key="x", workers=[w.address]) + await wait(x) + # Throttle to 1 simultaneous connection + w.status = Status.paused + + async with Client(s.address, asynchronous=True) as c2: + assert await asyncio.gather( + c.gather(x, direct=direct), + c2.gather(Future("x", client=c2), direct=direct), + ) == [2, 2] + + assert w.barrier_count == -1 + + +@pytest.mark.parametrize("direct", [False, True]) +@gen_cluster(client=True, nthreads=[("", 1)], config=NO_AMM) +async def test_gather_race_vs_AMM(c, s, a, direct): + """Test race condition: + Client.gather() tries to get a key from a worker, but in the meantime the + Active Memory Manager has moved it to another worker + """ + async with BlockedGetData(s.address) as b: + x = c.submit(inc, 1, key="x", workers=[b.address]) + fut = asyncio.create_task(c.gather(x, direct=direct)) + await b.in_get_data.wait() + + # Simulate AMM replicate from b to a, followed by AMM drop on b + # Can't use s.request_acquire_replicas as it would get stuck on b.block_get_data + a.update_data({"x": 3}) + a.batched_send({"op": "add-keys", "keys": ["x"]}) + await async_poll_for(lambda: len(s.tasks["x"].who_has) == 2, timeout=5) + s.request_remove_replicas(b.address, ["x"], stimulus_id="remove") + await async_poll_for(lambda: "x" not in b.data, timeout=5) + + b.block_get_data.set() + + assert await fut == 3 # It's from a; it would be 2 if it were from b diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 4c0728f68d0..83bb2e28883 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -74,7 +74,7 @@ varying, wait_for_state, ) -from distributed.worker import dumps_function, dumps_task, get_worker, secede +from distributed.worker import dumps_function, dumps_task, secede pytestmark = pytest.mark.ci1 @@ -2855,56 +2855,31 @@ async def test_gather_no_workers(c, s, a, b): assert list(res["keys"]) == ["x"] -@gen_cluster(client=True, client_kwargs={"direct_to_workers": False}) -async def test_gather_bad_worker_removed(c, s, a, b): - """ - Upon connection failure or missing expected keys during gather, a worker is - shut down. The tasks should be rescheduled onto different workers, transparently - to `client.gather`. +@gen_cluster( + client=True, + nthreads=[("", 1)], + # This behaviour is independent of retries. + # Disable it to reduce complexity of this test. + config={"distributed.comm.retry.count": 0}, +) +async def test_gather_bad_worker(c, s, a): + """Upon connection failure, the scheduler tries again indefinitely and + transparently, for as long as the batched comms channel is active, to fulfil + `client.gather`. """ - x = c.submit(slowinc, 1, workers=[a.address], allow_other_workers=True) - - def finalizer(*args): - return get_worker().address - - fin = c.submit( - finalizer, x, key="final", workers=[a.address], allow_other_workers=True - ) - - s.rpc = await FlakyConnectionPool(failing_connections=1) - - # This behaviour is independent of retries. Remove them to reduce complexity - # of this setup - with dask.config.set({"distributed.comm.retry.count": 0}): - with captured_logger( - logging.getLogger("distributed.scheduler") - ) as sched_logger, captured_logger( - logging.getLogger("distributed.client") - ) as client_logger: - # Gather using the client (as an ordinary user would) - # Upon a missing key, the client will remove the bad worker and - # reschedule the computations - - # Both tasks are rescheduled onto `b`, since `a` was removed. - assert await fin == b.address - - await a.finished() - assert list(s.workers) == [b.address] - - sched_logger = sched_logger.getvalue() - client_logger = client_logger.getvalue() - assert "Shut down workers that don't have promised key" in sched_logger + x = c.submit(inc, 1, key="x") + s.rpc = await FlakyConnectionPool(failing_connections=3) - assert "Couldn't gather 1 keys, rescheduling" in client_logger + with captured_logger("distributed.scheduler") as sched_logger: + with captured_logger("distributed.client") as client_logger: + assert await c.gather(x, direct=False) == 2 - assert s.tasks[fin.key].who_has == {s.workers[b.address]} - assert a.state.executed_count == 2 - assert b.state.executed_count >= 1 - # ^ leave room for a future switch from `remove_worker` to `retire_workers` + assert sched_logger.getvalue() == "Couldn't gather keys: {'x': 'memory'}\n" * 3 + assert "Couldn't gather 1 keys, rescheduling" in client_logger.getvalue() # Ensure that the communication was done via the scheduler, i.e. we actually hit a # bad connection - assert s.rpc.cnn_count > 0 + assert s.rpc.cnn_count == 4 @gen_cluster(client=True) diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 44b5e52e7ed..a66b632de08 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from unittest import mock import pytest @@ -17,7 +18,7 @@ subs_multiple, unpack_remotedata, ) -from distributed.utils_test import BrokenComm, gen_cluster +from distributed.utils_test import NO_AMM, BlockedGetData, BrokenComm, gen_cluster def test_pack_data(): @@ -46,12 +47,10 @@ async def test_gather_from_workers_permissive(c, s, a, b): rpc = await ConnectionPool() x = await c.scatter({"x": 1}, workers=a.address) - data, missing, bad_workers = await gather_from_workers( - {"x": [a.address], "y": [b.address]}, rpc=rpc - ) + data, missing = await gather_from_workers(["x", "y"], s.get_who_has, rpc=rpc) assert data == {"x": 1} - assert list(missing) == ["y"] + assert missing == {"y"} class BrokenConnectionPool(ConnectionPool): @@ -64,10 +63,58 @@ async def test_gather_from_workers_permissive_flaky(c, s, a, b): x = await c.scatter({"x": 1}, workers=a.address) rpc = await BrokenConnectionPool() - data, missing, bad_workers = await gather_from_workers({"x": [a.address]}, rpc=rpc) + data, missing = await gather_from_workers(["x"], s.get_who_has, rpc=rpc) - assert missing == {"x": [a.address]} - assert bad_workers == [a.address] + assert data == {} + assert missing == {"x"} + + +@gen_cluster(client=True, nthreads=[], config=NO_AMM) +async def test_gather_from_workers_cancelled_error(c, s): + """Something somewhere in the networking stack raises CancelledError while + gather_from_workers is running + + See Also + -------- + test_worker.py::test_gather_dep_cancelled_error + test_worker.py::test_get_data_cancelled_error + https://github.com/dask/distributed/issues/8006 + """ + rpc = await ConnectionPool() + async with BlockedGetData(s.address) as a, BlockedGetData(s.address) as b: + a.block_get_data.set() + b.block_get_data.set() + x = await c.scatter({"x": 1}, broadcast=True) + assert len(s.tasks["x"].who_has) == 2 + a.in_get_data.clear() + b.in_get_data.clear() + a.block_get_data.clear() + b.block_get_data.clear() + + fut = asyncio.create_task( + gather_from_workers(keys=["x"], who_has=s.get_who_has, rpc=rpc) + ) + await asyncio.wait( + [ + asyncio.create_task(a.in_get_data.wait()), + asyncio.create_task(b.in_get_data.wait()), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + + tasks = { + task for task in asyncio.all_tasks() if "get-data-from-" in task.get_name() + } + assert tasks + # There should be only one task but cope with finding more just in case a + # previous test didn't properly clean up + for task in tasks: + task.cancel() + + a.block_get_data.set() + b.block_get_data.set() + # gather_from_workers retries transparently from the other worker + assert await fut == ({"x": 1}, set()) def test_retry_no_exception(cleanup): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 29b789618c7..cca981b3d29 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -547,7 +547,7 @@ async def test_gather_missing_keys(c, s, a, b): async with rpc(a.address) as aa: resp = await aa.gather(who_has={x.key: [b.address], "y": [b.address]}) - assert resp == {"status": "partial-fail", "keys": {"y": (b.address,)}} + assert resp == {"status": "partial-fail", "keys": ("y",)} assert a.data[x.key] == b.data[x.key] == "x" @@ -563,21 +563,31 @@ async def test_gather_missing_workers(c, s, a, b): async with rpc(a.address) as aa: resp = await aa.gather(who_has={x.key: [b.address], "y": [bad_addr]}) - assert resp == {"status": "partial-fail", "keys": {"y": (bad_addr,)}} + assert resp == {"status": "partial-fail", "keys": ("y",)} assert a.data[x.key] == b.data[x.key] == "x" -@pytest.mark.parametrize("missing_first", [False, True]) -@gen_cluster(client=True, worker_kwargs={"timeout": "100ms"}) -async def test_gather_missing_workers_replicated(c, s, a, b, missing_first): +@pytest.mark.slow +@pytest.mark.parametrize("know_real", [False, True, True, True, True]) # Read below +@gen_cluster(client=True, worker_kwargs={"timeout": "1s"}, config=NO_AMM) +async def test_gather_missing_workers_replicated(c, s, a, b, know_real): """A worker owning a redundant copy of a key is missing. The key is successfully gathered from other workers. + + know_real=False + gather() will try to connect to the bad address, fail, and then query the + scheduler who will respond with the good address. Then gather will successfully + retrieve the key from the good address. + know_real=True + 50% of the times, gather() will try to connect to the bad address, fail, and + immediately connect to the good address. + The other 50% of the times it will directly connect to the good address, + hence why this test is repeated. """ assert b.address.startswith("tcp://127.0.0.1:") x = await c.scatter("x", workers=[b.address]) bad_addr = "tcp://127.0.0.1:12345" - # Order matters! Test both - addrs = [bad_addr, b.address] if missing_first else [b.address, bad_addr] + addrs = [bad_addr, b.address] if know_real else [bad_addr] async with rpc(a.address) as aa: resp = await aa.gather(who_has={x.key: addrs}) assert resp == {"status": "OK"} @@ -3359,6 +3369,7 @@ async def test_gather_dep_cancelled_error(c, s, a): See Also -------- test_get_data_cancelled_error + test_utils_comm.py::test_gather_from_workers_cancelled_error https://github.com/dask/distributed/issues/8006 """ async with BlockedGetData(s.address) as b: @@ -3397,6 +3408,7 @@ async def test_get_data_cancelled_error(c, s, a): See Also -------- test_gather_dep_cancelled_error + test_utils_comm.py::test_gather_from_workers_cancelled_error https://github.com/dask/distributed/issues/8006 """ diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index e78f303b2be..d94a8f2fecd 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -1,10 +1,18 @@ from __future__ import annotations import asyncio +import inspect import logging import random from collections import defaultdict -from collections.abc import Callable, Collection, Coroutine, Mapping +from collections.abc import ( + Awaitable, + Callable, + Collection, + Coroutine, + Iterable, + Mapping, +) from functools import partial from itertools import cycle from typing import Any, TypeVar @@ -22,52 +30,83 @@ async def gather_from_workers( - who_has: Mapping[str, Collection[str]], + keys: Iterable[str], + who_has: Callable[ + [list[str]], + Mapping[str, Collection[str]] | Awaitable[Mapping[str, Collection[str]]], + ], + *, rpc: ConnectionPool, - close: bool = True, serializers: list[str] | None = None, who: str | None = None, -) -> tuple[dict[str, object], dict[str, list[str]], list[str]]: +) -> tuple[dict[str, object], set[str]]: """Gather data directly from peers Parameters ---------- - who_has: dict - Dict mapping keys to sets of workers that may have that key - rpc: callable + keys: + keys of tasks to be gathered + who_has: + function used to refresh who_has from the scheduler. + Accepts a single 'keys' parameter and returns a mapping from keys to workers. + rpc: + RPC channel to use - Returns dict mapping key to value + Returns + ------- + Tuple: + + - Successfully retrieved: ``{task key: task value, ...}`` + - Failed to retrieve from all workers: ``{task key, ...}`` See Also -------- gather _gather + Scheduler.get_who_has """ from distributed.worker import get_data_from_worker - bad_addresses: set[str] = set() - missing_workers = set() - original_who_has = who_has - new_who_has = {k: set(v) for k, v in who_has.items()} + to_gather: dict[str, set[str]] = {k: set() for k in keys} results: dict[str, object] = {} - all_bad_keys: set[str] = set() + missing_keys: set[str] = set() + missing_workers: set[str] = set() + busy_workers: set[str] = set() - while len(results) + len(all_bad_keys) < len(who_has): + while to_gather: d = defaultdict(list) - rev = dict() - bad_keys = set() - for key, addresses in new_who_has.items(): - if key in results: - continue - try: - addr = random.choice(list(addresses - bad_addresses)) - d[addr].append(key) - rev[key] = addr - except IndexError: - bad_keys.add(key) - if bad_keys: - all_bad_keys |= bad_keys - coroutines = { + for key, addresses in to_gather.items(): + addresses -= missing_workers + addresses -= busy_workers + if addresses: + d[random.choice(list(addresses))].append(key) + + if not d: + if busy_workers: + await asyncio.sleep(0.15) + busy_workers.clear() + + new_who_has = who_has(list(to_gather)) + if inspect.isawaitable(new_who_has): + new_who_has = await new_who_has + for key, new_addresses in new_who_has.items(): # type: ignore + addresses = set(new_addresses) - missing_workers + if addresses: + to_gather[key].update(addresses) + d[random.choice(list(addresses))].append(key) + else: + # 1. We failed to connect to all workers reported by the scheduler + # in previous iterations, or + # 2. All workers holding the data have crashed and the task is not + # in memory on the scheduler anymore, or + # 3. The scheduler has forgotten the task + missing_keys.add(key) + del to_gather[key] + + if not d: + break + + tasks = { address: asyncio.create_task( retry_operation( partial( @@ -77,7 +116,6 @@ async def gather_from_workers( address, who=who, serializers=serializers, - max_connections=False, ), operation="get_data_from_worker", ), @@ -85,28 +123,39 @@ async def gather_from_workers( ) for address, keys in d.items() } - response: dict[str, object] = {} - for worker, c in coroutines.items(): + for address, task in tasks.items(): try: - r = await c - except OSError: - missing_workers.add(worker) - except ValueError as e: - logger.info( - "Got an unexpected error while collecting from workers: %s", e + r = await task + except (OSError, asyncio.CancelledError, asyncio.TimeoutError): + # Note: CancelledError and asyncio.TimeoutError are rare conditions + # that can be raised by the network stack. + # See https://github.com/dask/distributed/issues/8006 + missing_workers.add(address) + except Exception: + # For example, deserialization error + logger.exception( + "Unexpected error while collecting tasks %s from %s", + d[address], + address, ) - missing_workers.add(worker) + for key in d[address]: + missing_keys.add(key) + del to_gather[key] + missing_workers.add(address) else: - response.update(r["data"]) - - bad_addresses |= {v for k, v in rev.items() if k not in response} - results.update(response) - - return ( - results, - {k: list(original_who_has[k]) for k in all_bad_keys}, - list(missing_workers), - ) + if r["status"] == "busy": + busy_workers.add(address) + continue + + assert r["status"] == "OK" + for key in d[address]: + if key in r["data"]: + results[key] = r["data"][key] + del to_gather[key] + else: + to_gather[key].remove(address) + + return results, missing_keys class WrappedKey: diff --git a/distributed/utils_test.py b/distributed/utils_test.py index fbceba4d356..52626d487b5 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2211,6 +2211,7 @@ async def test1(s, a, b): See also -------- BlockedGetData + BarrierGetData BlockedExecute """ @@ -2232,6 +2233,7 @@ class BlockedGetData(Worker): See also -------- + BarrierGetData BlockedGatherDep BlockedExecute """ @@ -2281,6 +2283,7 @@ def f(in_task, block_task): -------- BlockedGatherDep BlockedGetData + BarrierGetData """ def __init__(self, *args, **kwargs): @@ -2310,6 +2313,32 @@ async def _maybe_deserialize_task( return await super()._maybe_deserialize_task(ts) +class BarrierGetData(Worker): + """Block get_data RPC call until at least barrier_count connections are going on + in parallel at the same time + + See also + -------- + BlockedGatherDep + BlockedGetData + BlockedExecute + """ + + def __init__(self, *args, barrier_count, **kwargs): + # TODO just use asyncio.Barrier (needs Python >=3.11) + self.barrier_count = barrier_count + self.wait_get_data = asyncio.Event() + super().__init__(*args, **kwargs) + + async def get_data(self, comm, *args, **kwargs): + self.barrier_count -= 1 + if self.barrier_count > 0: + await self.wait_get_data.wait() + else: + self.wait_get_data.set() + return await super().get_data(comm, *args, **kwargs) + + @contextmanager def freeze_data_fetching(w: Worker, *, jump_start: bool = False) -> Iterator[None]: """Prevent any task from transitioning from fetch to flight on the worker while diff --git a/distributed/worker.py b/distributed/worker.py index fe52f7960b2..26309d17081 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -39,7 +39,6 @@ TypedDict, TypeVar, cast, - overload, ) from tlz import keymap, pluck @@ -1307,23 +1306,31 @@ def keys(self) -> list[str]: return list(self.data) async def gather(self, who_has: dict[str, list[str]]) -> dict[str, Any]: + """Endpoint used by Scheduler.replicate()""" + first = True who_has = { k: [coerce_to_address(addr) for addr in v] for k, v in who_has.items() if k not in self.data } - result, missing_keys, missing_workers = await gather_from_workers( - who_has=who_has, rpc=self.rpc, who=self.address + + async def refresh_who_has(keys: list[str]) -> Mapping[str, Collection[str]]: + nonlocal first + if first: + first = False + return who_has + return await retry_operation(self.scheduler.who_has, keys=keys) + + result, missing_keys = await gather_from_workers( + keys=who_has, + who_has=refresh_who_has, + rpc=self.rpc, + who=self.address, ) self.update_data(data=result) if missing_keys: - logger.warning( - "Could not find data: %s on workers: %s (who_has: %s)", - missing_keys, - missing_workers, - who_has, - ) - return {"status": "partial-fail", "keys": missing_keys} + logger.error("Could not find data: %s", missing_keys) + return {"status": "partial-fail", "keys": list(missing_keys)} else: return {"status": "OK"} @@ -1707,17 +1714,11 @@ async def batched_send_connect(): async def get_data( self, comm: Comm, - keys: Collection[str] | None = None, + keys: Collection[str], who: str | None = None, serializers: list[str] | None = None, - max_connections: int | None = None, ) -> GetDataBusy | Literal[Status.dont_reply]: - if max_connections is None: - max_connections = self.transfer_outgoing_count_limit - - if keys is None: - keys = set() - + max_connections = self.transfer_outgoing_count_limit # Allow same-host connections more liberally if ( max_connections @@ -2845,41 +2846,12 @@ def secede(): ) -@overload -async def get_data_from_worker( - rpc: ConnectionPool, - keys: Collection[str], - worker: str, - *, - who: str | None = None, - max_connections: Literal[False], - serializers: list[str] | None = None, - deserializers: list[str] | None = None, -) -> GetDataSuccess: - ... - - -@overload -async def get_data_from_worker( - rpc: ConnectionPool, - keys: Collection[str], - worker: str, - *, - who: str | None = None, - max_connections: bool | int | None = None, - serializers: list[str] | None = None, - deserializers: list[str] | None = None, -) -> GetDataBusy | GetDataSuccess: - ... - - async def get_data_from_worker( rpc: ConnectionPool, keys: Collection[str], worker: str, *, who: str | None = None, - max_connections: bool | int | None = None, serializers: list[str] | None = None, deserializers: list[str] | None = None, ) -> GetDataBusy | GetDataSuccess: @@ -2909,7 +2881,6 @@ async def get_data_from_worker( op="get_data", keys=keys, who=who, - max_connections=max_connections, ) try: status = response["status"] From b1d782e49f09e16b922a417eafeeb419bd605a93 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 26 Jul 2023 16:23:44 +0100 Subject: [PATCH 02/10] Rename who_has callable to get_who_has --- distributed/client.py | 4 ++-- distributed/scheduler.py | 2 +- distributed/utils_comm.py | 14 +++++++------- distributed/worker.py | 10 +++------- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 87a34ad01d0..69518486c4e 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2316,11 +2316,11 @@ async def _gather_remote(self, direct: bool, local_worker: bool) -> dict[str, An return await retry_operation(self.scheduler.gather, keys=keys) # gather directly from workers - async def who_has(keys: list[str]) -> dict[str, Collection[str]]: + async def get_who_has(keys: list[str]) -> dict[str, Collection[str]]: return await retry_operation(self.scheduler.who_has, keys=keys) data, missing_keys = await gather_from_workers( - keys=keys, who_has=who_has, rpc=self.rpc + keys, get_who_has, rpc=self.rpc ) response: dict[str, Any] = {"status": "OK", "data": data} if missing_keys: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3a348e17d17..1161fa44d58 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5920,7 +5920,7 @@ async def gather( ) -> dict[str, Any]: """Collect data from workers to the scheduler""" data, missing_keys = await gather_from_workers( - keys=keys, who_has=self.get_who_has, rpc=self.rpc, serializers=serializers + keys, self.get_who_has, rpc=self.rpc, serializers=serializers ) self.log_event("all", {"action": "gather", "count": len(keys)}) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index d94a8f2fecd..379bfe9904b 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -31,7 +31,7 @@ async def gather_from_workers( keys: Iterable[str], - who_has: Callable[ + get_who_has: Callable[ [list[str]], Mapping[str, Collection[str]] | Awaitable[Mapping[str, Collection[str]]], ], @@ -46,8 +46,8 @@ async def gather_from_workers( ---------- keys: keys of tasks to be gathered - who_has: - function used to refresh who_has from the scheduler. + get_who_has: + function used to fetch the who_has mapping from the scheduler. Accepts a single 'keys' parameter and returns a mapping from keys to workers. rpc: RPC channel to use @@ -86,10 +86,10 @@ async def gather_from_workers( await asyncio.sleep(0.15) busy_workers.clear() - new_who_has = who_has(list(to_gather)) - if inspect.isawaitable(new_who_has): - new_who_has = await new_who_has - for key, new_addresses in new_who_has.items(): # type: ignore + who_has = get_who_has(list(to_gather)) + if inspect.isawaitable(who_has): + who_has = await who_has + for key, new_addresses in who_has.items(): # type: ignore addresses = set(new_addresses) - missing_workers if addresses: to_gather[key].update(addresses) diff --git a/distributed/worker.py b/distributed/worker.py index 26309d17081..626ee03dee8 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1314,7 +1314,7 @@ async def gather(self, who_has: dict[str, list[str]]) -> dict[str, Any]: if k not in self.data } - async def refresh_who_has(keys: list[str]) -> Mapping[str, Collection[str]]: + async def get_who_has(keys: list[str]) -> Mapping[str, Collection[str]]: nonlocal first if first: first = False @@ -1323,7 +1323,7 @@ async def refresh_who_has(keys: list[str]) -> Mapping[str, Collection[str]]: result, missing_keys = await gather_from_workers( keys=who_has, - who_has=refresh_who_has, + get_who_has=get_who_has, rpc=self.rpc, who=self.address, ) @@ -1720,11 +1720,7 @@ async def get_data( ) -> GetDataBusy | Literal[Status.dont_reply]: max_connections = self.transfer_outgoing_count_limit # Allow same-host connections more liberally - if ( - max_connections - and comm - and get_address_host(comm.peer_address) == get_address_host(self.address) - ): + if get_address_host(comm.peer_address) == get_address_host(self.address): max_connections = max_connections * 2 if self.status == Status.paused: From 47be3d2a85787a73988e180e4aca53aabedcfdce Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Aug 2023 12:27:25 +0100 Subject: [PATCH 03/10] Refactor without callback --- distributed/client.py | 13 +++---- distributed/scheduler.py | 39 ++++++++++++++++----- distributed/utils_comm.py | 72 ++++++++++++--------------------------- distributed/worker.py | 61 ++++++++++++++++++++------------- 4 files changed, 94 insertions(+), 91 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index dcfa054ed7a..6c245c0a4bc 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2336,20 +2336,17 @@ async def _gather_remote(self, direct: bool, local_worker: bool) -> dict[str, An return await retry_operation(self.scheduler.gather, keys=keys) # gather directly from workers - async def get_who_has(keys: list[str]) -> dict[str, Collection[str]]: - return await retry_operation(self.scheduler.who_has, keys=keys) - - data, missing_keys = await gather_from_workers( - keys, get_who_has, rpc=self.rpc + who_has = await retry_operation(self.scheduler.who_has, keys=keys) + data, missing_keys, failed_keys, _ = await gather_from_workers( + who_has, rpc=self.rpc ) response: dict[str, Any] = {"status": "OK", "data": data} - if missing_keys: + if missing_keys or failed_keys: response = await retry_operation( - self.scheduler.gather, keys=missing_keys + self.scheduler.gather, keys=missing_keys + failed_keys ) if response["status"] == "OK": response["data"].update(data) - return response def gather(self, futures, errors="raise", direct=None, asynchronous=None): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 57f90b43b22..b821376b642 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5922,20 +5922,43 @@ async def gather( self, keys: Collection[str], serializers: list[str] | None = None ) -> dict[str, Any]: """Collect data from workers to the scheduler""" - data, missing_keys = await gather_from_workers( - keys, self.get_who_has, rpc=self.rpc, serializers=serializers - ) + data = {} + missing_keys = list(keys) + failed_keys: list[str] = [] + missing_workers: set[str] = set() + + while missing_keys: + who_has = {} + for key, workers in self.get_who_has(missing_keys).items(): + valid_workers = set(workers) - missing_workers + if valid_workers: + who_has[key] = valid_workers + else: + failed_keys.append(key) + + ( + new_data, + missing_keys, + new_failed_keys, + new_missing_workers, + ) = await gather_from_workers( + who_has, rpc=self.rpc, serializers=serializers + ) + data.update(new_data) + failed_keys += new_failed_keys + missing_workers.update(new_missing_workers) + self.log_event("all", {"action": "gather", "count": len(keys)}) - if not missing_keys: + if not failed_keys: return {"status": "OK", "data": data} - missing_states = { + failed_states = { key: self.tasks[key].state if key in self.tasks else "forgotten" - for key in missing_keys + for key in failed_keys } - logger.error("Couldn't gather keys: %s", missing_states) - return {"status": "error", "keys": list(missing_keys)} + logger.error("Couldn't gather keys: %s", failed_states) + return {"status": "error", "keys": list(failed_keys)} @log_errors async def restart(self, client=None, timeout=30, wait_for_workers=True): diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 379bfe9904b..2a8347ead12 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -1,18 +1,10 @@ from __future__ import annotations import asyncio -import inspect import logging import random from collections import defaultdict -from collections.abc import ( - Awaitable, - Callable, - Collection, - Coroutine, - Iterable, - Mapping, -) +from collections.abc import Callable, Collection, Coroutine, Mapping from functools import partial from itertools import cycle from typing import Any, TypeVar @@ -30,25 +22,18 @@ async def gather_from_workers( - keys: Iterable[str], - get_who_has: Callable[ - [list[str]], - Mapping[str, Collection[str]] | Awaitable[Mapping[str, Collection[str]]], - ], - *, + who_has: Mapping[str, Collection[str]], rpc: ConnectionPool, + *, serializers: list[str] | None = None, who: str | None = None, -) -> tuple[dict[str, object], set[str]]: +) -> tuple[dict[str, object], list[str], list[str], list[str]]: """Gather data directly from peers Parameters ---------- - keys: - keys of tasks to be gathered - get_who_has: - function used to fetch the who_has mapping from the scheduler. - Accepts a single 'keys' parameter and returns a mapping from keys to workers. + who_has: + mapping from keys to worker addresses rpc: RPC channel to use @@ -56,8 +41,10 @@ async def gather_from_workers( ------- Tuple: - - Successfully retrieved: ``{task key: task value, ...}`` - - Failed to retrieve from all workers: ``{task key, ...}`` + - Successfully retrieved: ``{key: value, ...}`` + - Keys that were not available on any worker: ``[key, ...]`` + - Keys that raised exception; e.g. failed to deserialize: ``[key, ...]`` + - Workers that failed to respond: ``[address, ...]`` See Also -------- @@ -67,9 +54,9 @@ async def gather_from_workers( """ from distributed.worker import get_data_from_worker - to_gather: dict[str, set[str]] = {k: set() for k in keys} - results: dict[str, object] = {} - missing_keys: set[str] = set() + to_gather = {k: set(v) for k, v in who_has.items()} + data: dict[str, object] = {} + failed_keys: list[str] = [] missing_workers: set[str] = set() busy_workers: set[str] = set() @@ -77,34 +64,17 @@ async def gather_from_workers( d = defaultdict(list) for key, addresses in to_gather.items(): addresses -= missing_workers - addresses -= busy_workers - if addresses: - d[random.choice(list(addresses))].append(key) + ready_addresses = addresses - busy_workers + if ready_addresses: + d[random.choice(list(ready_addresses))].append(key) if not d: if busy_workers: await asyncio.sleep(0.15) busy_workers.clear() + continue - who_has = get_who_has(list(to_gather)) - if inspect.isawaitable(who_has): - who_has = await who_has - for key, new_addresses in who_has.items(): # type: ignore - addresses = set(new_addresses) - missing_workers - if addresses: - to_gather[key].update(addresses) - d[random.choice(list(addresses))].append(key) - else: - # 1. We failed to connect to all workers reported by the scheduler - # in previous iterations, or - # 2. All workers holding the data have crashed and the task is not - # in memory on the scheduler anymore, or - # 3. The scheduler has forgotten the task - missing_keys.add(key) - del to_gather[key] - - if not d: - break + return data, list(to_gather), failed_keys, list(missing_workers) tasks = { address: asyncio.create_task( @@ -139,7 +109,7 @@ async def gather_from_workers( address, ) for key in d[address]: - missing_keys.add(key) + failed_keys.append(key) del to_gather[key] missing_workers.add(address) else: @@ -150,12 +120,12 @@ async def gather_from_workers( assert r["status"] == "OK" for key in d[address]: if key in r["data"]: - results[key] = r["data"][key] + data[key] = r["data"][key] del to_gather[key] else: to_gather[key].remove(address) - return results, missing_keys + return data, [], failed_keys, list(missing_workers) class WrappedKey: diff --git a/distributed/worker.py b/distributed/worker.py index 951d6d94feb..1ce8b961ba5 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1306,31 +1306,44 @@ def keys(self) -> list[str]: return list(self.data) async def gather(self, who_has: dict[str, list[str]]) -> dict[str, Any]: - """Endpoint used by Scheduler.replicate()""" - first = True - who_has = { - k: [coerce_to_address(addr) for addr in v] - for k, v in who_has.items() - if k not in self.data - } + """Endpoint used by Scheduler.rebalance() and Scheduler.replicate()""" + missing_keys = [k for k in who_has if k not in self.data] + failed_keys = [] + missing_workers: set[str] = set() + stimulus_id = f"gather-{time()}" + + while missing_keys: + to_gather = {} + for k in missing_keys: + workers = set(who_has[k]) - missing_workers + if workers: + to_gather[k] = workers + else: + failed_keys.append(k) + if not to_gather: + break - async def get_who_has(keys: list[str]) -> Mapping[str, Collection[str]]: - nonlocal first - if first: - first = False - return who_has - return await retry_operation(self.scheduler.who_has, keys=keys) - - result, missing_keys = await gather_from_workers( - keys=who_has, - get_who_has=get_who_has, - rpc=self.rpc, - who=self.address, - ) - self.update_data(data=result) - if missing_keys: - logger.error("Could not find data: %s", missing_keys) - return {"status": "partial-fail", "keys": list(missing_keys)} + ( + data, + missing_keys, + new_failed_keys, + new_missing_workers, + ) = await gather_from_workers( + who_has=to_gather, rpc=self.rpc, who=self.address + ) + self.update_data(data, stimulus_id=stimulus_id) + del data + failed_keys += new_failed_keys + missing_workers.update(new_missing_workers) + + if missing_keys: + who_has = await retry_operation( + self.scheduler.who_has, keys=missing_keys + ) + + if failed_keys: + logger.error("Could not find data: %s", failed_keys) + return {"status": "partial-fail", "keys": list(failed_keys)} else: return {"status": "OK"} From 3e0fe9593157295e57eb64a7f8b5ec9c368c41ab Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Aug 2023 15:17:31 +0100 Subject: [PATCH 04/10] Revert handling of CancelledError --- distributed/tests/test_utils_comm.py | 48 ---------------------------- distributed/tests/test_worker.py | 2 -- distributed/utils_comm.py | 5 +-- 3 files changed, 1 insertion(+), 54 deletions(-) diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index a66b632de08..cc94e79a3f6 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -69,54 +69,6 @@ async def test_gather_from_workers_permissive_flaky(c, s, a, b): assert missing == {"x"} -@gen_cluster(client=True, nthreads=[], config=NO_AMM) -async def test_gather_from_workers_cancelled_error(c, s): - """Something somewhere in the networking stack raises CancelledError while - gather_from_workers is running - - See Also - -------- - test_worker.py::test_gather_dep_cancelled_error - test_worker.py::test_get_data_cancelled_error - https://github.com/dask/distributed/issues/8006 - """ - rpc = await ConnectionPool() - async with BlockedGetData(s.address) as a, BlockedGetData(s.address) as b: - a.block_get_data.set() - b.block_get_data.set() - x = await c.scatter({"x": 1}, broadcast=True) - assert len(s.tasks["x"].who_has) == 2 - a.in_get_data.clear() - b.in_get_data.clear() - a.block_get_data.clear() - b.block_get_data.clear() - - fut = asyncio.create_task( - gather_from_workers(keys=["x"], who_has=s.get_who_has, rpc=rpc) - ) - await asyncio.wait( - [ - asyncio.create_task(a.in_get_data.wait()), - asyncio.create_task(b.in_get_data.wait()), - ], - return_when=asyncio.FIRST_COMPLETED, - ) - - tasks = { - task for task in asyncio.all_tasks() if "get-data-from-" in task.get_name() - } - assert tasks - # There should be only one task but cope with finding more just in case a - # previous test didn't properly clean up - for task in tasks: - task.cancel() - - a.block_get_data.set() - b.block_get_data.set() - # gather_from_workers retries transparently from the other worker - assert await fut == ({"x": 1}, set()) - - def test_retry_no_exception(cleanup): n_calls = 0 retval = object() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index cca981b3d29..e99f2db1658 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3369,7 +3369,6 @@ async def test_gather_dep_cancelled_error(c, s, a): See Also -------- test_get_data_cancelled_error - test_utils_comm.py::test_gather_from_workers_cancelled_error https://github.com/dask/distributed/issues/8006 """ async with BlockedGetData(s.address) as b: @@ -3408,7 +3407,6 @@ async def test_get_data_cancelled_error(c, s, a): See Also -------- test_gather_dep_cancelled_error - test_utils_comm.py::test_gather_from_workers_cancelled_error https://github.com/dask/distributed/issues/8006 """ diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 2a8347ead12..f6f644d7000 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -96,10 +96,7 @@ async def gather_from_workers( for address, task in tasks.items(): try: r = await task - except (OSError, asyncio.CancelledError, asyncio.TimeoutError): - # Note: CancelledError and asyncio.TimeoutError are rare conditions - # that can be raised by the network stack. - # See https://github.com/dask/distributed/issues/8006 + except OSError: missing_workers.add(address) except Exception: # For example, deserialization error From 1e9d58476adf40f24270db4fb8fedbfa2dc37d9d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Aug 2023 15:28:35 +0100 Subject: [PATCH 05/10] revert --- distributed/tests/test_utils_comm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index cc94e79a3f6..56e72683bb9 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio from unittest import mock import pytest @@ -18,7 +17,7 @@ subs_multiple, unpack_remotedata, ) -from distributed.utils_test import NO_AMM, BlockedGetData, BrokenComm, gen_cluster +from distributed.utils_test import BrokenComm, gen_cluster def test_pack_data(): From 86c8fa250b05f913916d33e81c0433cf0b4c03c4 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Aug 2023 15:30:37 +0100 Subject: [PATCH 06/10] Revert no shutdown of workers --- distributed/scheduler.py | 15 +++++++ distributed/tests/test_scheduler.py | 65 ++++++++++++++++++++--------- 2 files changed, 60 insertions(+), 20 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b821376b642..12da0690082 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5922,6 +5922,7 @@ async def gather( self, keys: Collection[str], serializers: list[str] | None = None ) -> dict[str, Any]: """Collect data from workers to the scheduler""" + stimulus_id = f"gather-{time()}" data = {} missing_keys = list(keys) failed_keys: list[str] = [] @@ -5958,6 +5959,20 @@ async def gather( for key in failed_keys } logger.error("Couldn't gather keys: %s", failed_states) + + if missing_workers: + with log_errors(): + # Remove suspicious workers from the scheduler and shut them down. + await asyncio.gather( + *( + self.remove_worker( + address=worker, close=True, stimulus_id=stimulus_id + ) + for worker in missing_workers + ) + ) + logger.error("Shut down unresponsive workers:: %s", missing_workers) + return {"status": "error", "keys": list(failed_keys)} @log_errors diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 7870565ae18..08c44e5b1e3 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -74,7 +74,7 @@ varying, wait_for_state, ) -from distributed.worker import dumps_function, dumps_task, secede +from distributed.worker import dumps_function, dumps_task, get_worker, secede pytestmark = pytest.mark.ci1 @@ -2865,31 +2865,56 @@ async def test_gather_no_workers(c, s, a, b): assert list(res["keys"]) == ["x"] -@gen_cluster( - client=True, - nthreads=[("", 1)], - # This behaviour is independent of retries. - # Disable it to reduce complexity of this test. - config={"distributed.comm.retry.count": 0}, -) -async def test_gather_bad_worker(c, s, a): - """Upon connection failure, the scheduler tries again indefinitely and - transparently, for as long as the batched comms channel is active, to fulfil - `client.gather`. +@gen_cluster(client=True, client_kwargs={"direct_to_workers": False}) +async def test_gather_bad_worker_removed(c, s, a, b): """ - x = c.submit(inc, 1, key="x") - s.rpc = await FlakyConnectionPool(failing_connections=3) + Upon connection failure or missing expected keys during gather, a worker is + shut down. The tasks should be rescheduled onto different workers, transparently + to `client.gather`. + """ + x = c.submit(slowinc, 1, workers=[a.address], allow_other_workers=True) + + def finalizer(*args): + return get_worker().address + + fin = c.submit( + finalizer, x, key="final", workers=[a.address], allow_other_workers=True + ) + + s.rpc = await FlakyConnectionPool(failing_connections=1) + + # This behaviour is independent of retries. Remove them to reduce complexity + # of this setup + with dask.config.set({"distributed.comm.retry.count": 0}): + with captured_logger( + logging.getLogger("distributed.scheduler") + ) as sched_logger, captured_logger( + logging.getLogger("distributed.client") + ) as client_logger: + # Gather using the client (as an ordinary user would) + # Upon a missing key, the client will remove the bad worker and + # reschedule the computations + + # Both tasks are rescheduled onto `b`, since `a` was removed. + assert await fin == b.address + + await a.finished() + assert list(s.workers) == [b.address] + + sched_logger = sched_logger.getvalue() + client_logger = client_logger.getvalue() + assert "Shut down unresponsive workers" in sched_logger - with captured_logger("distributed.scheduler") as sched_logger: - with captured_logger("distributed.client") as client_logger: - assert await c.gather(x, direct=False) == 2 + assert "Couldn't gather 1 keys, rescheduling" in client_logger - assert sched_logger.getvalue() == "Couldn't gather keys: {'x': 'memory'}\n" * 3 - assert "Couldn't gather 1 keys, rescheduling" in client_logger.getvalue() + assert s.tasks[fin.key].who_has == {s.workers[b.address]} + assert a.state.executed_count == 2 + assert b.state.executed_count >= 1 + # ^ leave room for a future switch from `remove_worker` to `retire_workers` # Ensure that the communication was done via the scheduler, i.e. we actually hit a # bad connection - assert s.rpc.cnn_count == 4 + assert s.rpc.cnn_count > 0 @gen_cluster(client=True) From 4acc370638a5da81e212169c5925675faf34407d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Aug 2023 16:19:02 +0100 Subject: [PATCH 07/10] tweaks --- distributed/tests/test_client.py | 41 ----------- distributed/tests/test_utils_comm.py | 105 +++++++++++++++++++++++++-- distributed/utils_comm.py | 1 - 3 files changed, 98 insertions(+), 49 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index ca59eb81a19..80ef12a2269 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -84,7 +84,6 @@ from distributed.utils import get_mp_context, is_valid_xml, open_port, sync, tmp_text from distributed.utils_test import ( NO_AMM, - BarrierGetData, BlockedGatherDep, BlockedGetData, TaskStateMetadataPlugin, @@ -8445,46 +8444,6 @@ def identity(x): assert result == {"x": 1, "y": 2} -@gen_cluster( - client=True, - config=merge(NO_AMM, {"distributed.worker.memory.pause": False}), -) -async def test_replicate_busy(c, s, a, b): - """Client.replicate() receives a 'busy' response from a worker""" - async with BarrierGetData(s.address, barrier_count=2) as w: - x = (await c.scatter({"x": 1}, workers=[w.address]))["x"] - # Throttle to 1 simultaneous connection - w.status = Status.paused - await c.replicate(x) - # either a or b will receive 'busy'. - # After 0.15s, it will fetch the key from either b or a or from w. - assert w.barrier_count in (0, -1) - assert dict(a.data) == dict(b.data) == {"x": 1} - - -@pytest.mark.parametrize("direct", [False, True]) -@gen_cluster( - client=True, - nthreads=[], - config=merge(NO_AMM, {"distributed.worker.memory.pause": False}), -) -async def test_gather_busy(c, s, a, b, direct): - """Client.gather() receives a 'busy' response from a worker""" - async with BarrierGetData(s.address, barrier_count=2) as w: - x = c.submit(inc, 1, key="x", workers=[w.address]) - await wait(x) - # Throttle to 1 simultaneous connection - w.status = Status.paused - - async with Client(s.address, asynchronous=True) as c2: - assert await asyncio.gather( - c.gather(x, direct=direct), - c2.gather(Future("x", client=c2), direct=direct), - ) == [2, 2] - - assert w.barrier_count == -1 - - @pytest.mark.parametrize("direct", [False, True]) @gen_cluster(client=True, nthreads=[("", 1)], config=NO_AMM) async def test_gather_race_vs_AMM(c, s, a, direct): diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 56e72683bb9..ee0eaff0899 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -1,14 +1,17 @@ from __future__ import annotations +import asyncio +import random from unittest import mock import pytest from dask.optimization import SubgraphCallable +from distributed import wait from distributed.compatibility import asyncio_run from distributed.config import get_loop_factory -from distributed.core import ConnectionPool +from distributed.core import ConnectionPool, Status from distributed.utils_comm import ( WrappedKey, gather_from_workers, @@ -17,7 +20,7 @@ subs_multiple, unpack_remotedata, ) -from distributed.utils_test import BrokenComm, gen_cluster +from distributed.utils_test import BarrierGetData, BrokenComm, gen_cluster, inc def test_pack_data(): @@ -41,31 +44,119 @@ def test_subs_multiple(): assert subs_multiple(dsk, data) == {"a": (sum, [1, 2])} +@gen_cluster(client=True, nthreads=[("", 1)] * 10) +async def test_gather_from_workers_missing_replicas(c, s, *workers): + """When a key is replicated on multiple workers, but the who_has is slightly + obsolete, gather_from_workers, retries fetching from all known holders of a replica + until it finds the key + """ + a = random.choice(workers) + x = await c.scatter({"x": 1}, workers=a.address) + assert len(s.workers) == 10 + assert len(s.tasks["x"].who_has) == 1 + + rpc = await ConnectionPool() + data, missing, failed, bad_workers = await gather_from_workers( + {"x": [w.address for w in workers]}, rpc=rpc + ) + + assert data == {"x": 1} + assert missing == [] + assert failed == [] + assert bad_workers == [] + + @gen_cluster(client=True) async def test_gather_from_workers_permissive(c, s, a, b): + """gather_from_workers fetches multiple keys, of which some are missing. + Test that the available data is returned with a note for missing data. + """ rpc = await ConnectionPool() x = await c.scatter({"x": 1}, workers=a.address) - data, missing = await gather_from_workers(["x", "y"], s.get_who_has, rpc=rpc) + data, missing, failed, bad_workers = await gather_from_workers( + {"x": [a.address], "y": [b.address]}, rpc=rpc + ) assert data == {"x": 1} - assert missing == {"y"} + assert missing == ["y"] + assert failed == [] + assert bad_workers == [] class BrokenConnectionPool(ConnectionPool): - async def connect(self, *args, **kwargs): + async def connect(self, address, *args, **kwargs): return BrokenComm() @gen_cluster(client=True) async def test_gather_from_workers_permissive_flaky(c, s, a, b): + """gather_from_workers fails to connect to a worker""" x = await c.scatter({"x": 1}, workers=a.address) rpc = await BrokenConnectionPool() - data, missing = await gather_from_workers(["x"], s.get_who_has, rpc=rpc) + data, missing, failed, bad_workers = await gather_from_workers( + {"x": [a.address]}, rpc=rpc + ) assert data == {} - assert missing == {"x"} + assert missing == ["x"] + assert failed == [] + assert bad_workers == [a.address] + + +@gen_cluster( + client=True, + nthreads=[], + config={"distributed.worker.memory.pause": False}, +) +async def test_gather_from_workers_busy(c, s): + """gather_from_workers receives a 'busy' response from a worker""" + async with BarrierGetData(s.address, barrier_count=2) as w: + x = await c.scatter({"x": 1}, workers=[w.address]) + await wait(x) + # Throttle to 1 simultaneous connection + w.status = Status.paused + + rpc1 = await ConnectionPool() + rpc2 = await ConnectionPool() + out1, out2 = await asyncio.gather( + gather_from_workers({"x": [w.address]}, rpc=rpc1), + gather_from_workers({"x": [w.address]}, rpc=rpc2), + ) + assert w.barrier_count == -1 # w.get_data() has been hit 3 times + assert out1 == out2 == ({"x": 1}, [], [], []) + + +@pytest.mark.parametrize("when", ["pickle", "unpickle"]) +@gen_cluster(client=True) +async def test_gather_from_workers_serialization_error(c, s, a, b, when): + """A task fails to (de)serialize. Tasks from other workers are fetched + successfully. + """ + + class BadReduce: + def __reduce__(self): + if when == "pickle": + 1 / 0 + else: + return lambda: 1 / 0, () + + rpc = await ConnectionPool() + x = c.submit(BadReduce, key="x", workers=[a.address]) + y = c.submit(inc, 1, key="y", workers=[a.address]) + z = c.submit(inc, 2, key="z", workers=[b.address]) + await wait([x, y, z]) + data, missing, failed, bad_workers = await gather_from_workers( + {"x": [a.address], "y": [a.address], "z": [b.address]}, rpc=rpc + ) + + assert data == {"z": 3} + assert missing == [] + # x and y were serialized together with a single call to pickle; can't tell which + # raised + assert failed == ["x", "y"] + assert bad_workers == [] def test_retry_no_exception(cleanup): diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index f6f644d7000..1ba22e3dea8 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -108,7 +108,6 @@ async def gather_from_workers( for key in d[address]: failed_keys.append(key) del to_gather[key] - missing_workers.add(address) else: if r["status"] == "busy": busy_workers.add(address) From 1285cbf44712af3e68e1f9871f625e7b377497bd Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Aug 2023 16:28:06 +0100 Subject: [PATCH 08/10] Revert cosmetic refactor --- distributed/client.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 6c245c0a4bc..4e424a7386a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2331,23 +2331,23 @@ async def _gather_remote(self, direct: bool, local_worker: bool) -> dict[str, An self._gather_keys = None # clear state, these keys are being sent off self._gather_future = None - if not direct and not local_worker: - # ask scheduler to gather data for us - return await retry_operation(self.scheduler.gather, keys=keys) - - # gather directly from workers - who_has = await retry_operation(self.scheduler.who_has, keys=keys) - data, missing_keys, failed_keys, _ = await gather_from_workers( - who_has, rpc=self.rpc - ) - response: dict[str, Any] = {"status": "OK", "data": data} - if missing_keys or failed_keys: - response = await retry_operation( - self.scheduler.gather, keys=missing_keys + failed_keys + if direct or local_worker: # gather directly from workers + who_has = await retry_operation(self.scheduler.who_has, keys=keys) + data, missing_keys, failed_keys, _ = await gather_from_workers( + who_has, rpc=self.rpc ) - if response["status"] == "OK": - response["data"].update(data) - return response + response: dict[str, Any] = {"status": "OK", "data": data} + if missing_keys or failed_keys: + response = await retry_operation( + self.scheduler.gather, keys=missing_keys + failed_keys + ) + if response["status"] == "OK": + response["data"].update(data) + + else: # ask scheduler to gather data for us + response = await retry_operation(self.scheduler.gather, keys=keys) + + return response def gather(self, futures, errors="raise", direct=None, asynchronous=None): """Gather futures from distributed memory From 3a67188c70f3ff8207d8b530c0260d5a4eea3400 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 4 Aug 2023 11:35:14 +0100 Subject: [PATCH 09/10] stress flaky tests --- .github/workflows/tests.yaml | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 29c9d3f14c4..af89ecfe346 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -197,8 +197,16 @@ jobs: set -o pipefail mkdir reports - pytest distributed \ - -m "not avoid_ci and ${{ matrix.partition }}" --runslow \ + pytest \ + distributed/diagnostics/tests/test_task_stream.py::test_TaskStreamPlugin \ + distributed/tests/test_client.py::test_file_descriptors_dont_leak \ + distributed/tests/test_failed_workers.py \ + distributed/tests/test_scheduler.py::test_tell_workers_when_peers_have_left \ + distributed/tests/test_steal.py::test_dont_steal_fast_tasks_compute_time \ + distributed/tests/test_stress.py::test_chaos_rechunk \ + distributed/shuffle/tests/test_shuffle.py \ + --count=10 \ + --runslow \ --leaks=fds,processes,threads \ --junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \ --cov=distributed --cov-report=xml \ From 53d3319a8620156ee968de036654316c35d81479 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 4 Aug 2023 11:35:14 +0100 Subject: [PATCH 10/10] Revert "stress flaky tests" This reverts commit 3a67188c70f3ff8207d8b530c0260d5a4eea3400. --- .github/workflows/tests.yaml | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index af89ecfe346..29c9d3f14c4 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -197,16 +197,8 @@ jobs: set -o pipefail mkdir reports - pytest \ - distributed/diagnostics/tests/test_task_stream.py::test_TaskStreamPlugin \ - distributed/tests/test_client.py::test_file_descriptors_dont_leak \ - distributed/tests/test_failed_workers.py \ - distributed/tests/test_scheduler.py::test_tell_workers_when_peers_have_left \ - distributed/tests/test_steal.py::test_dont_steal_fast_tasks_compute_time \ - distributed/tests/test_stress.py::test_chaos_rechunk \ - distributed/shuffle/tests/test_shuffle.py \ - --count=10 \ - --runslow \ + pytest distributed \ + -m "not avoid_ci and ${{ matrix.partition }}" --runslow \ --leaks=fds,processes,threads \ --junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \ --cov=distributed --cov-report=xml \