diff --git a/distributed/client.py b/distributed/client.py index 86ca5f506e0..8da3f3b91df 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -209,6 +209,7 @@ def __init__(self, key, client=None, inform=True, state=None): "op": "client-desires-keys", "keys": [stringify(key)], "client": self.client.id, + "stimulus_id": f"client-desires-keys-{time()}", } ) @@ -472,6 +473,7 @@ def __setstate__(self, state): "tasks": {}, "keys": [stringify(self.key)], "client": c.id, + "stimulus_id": f"stimulus-id-{time()}", } ) @@ -1265,6 +1267,7 @@ async def _ensure_connected(self, timeout=None): "client": self.id, "reply": False, "versions": version_module.get_versions(), + "stimulus_id": f"client-ensure-connected-{time()}", } ) except Exception: @@ -1371,9 +1374,9 @@ def _dec_ref(self, key): self.refcount[key] -= 1 if self.refcount[key] == 0: del self.refcount[key] - self._release_key(key) + self._release_key(key, f"client-release-key-{time()}") - def _release_key(self, key): + def _release_key(self, key, stimulus_id: str): """Release key from distributed memory""" logger.debug("Release key %s", key) st = self.futures.pop(key, None) @@ -1381,7 +1384,12 @@ def _release_key(self, key): st.cancel() if self.status != "closed": self._send_to_scheduler( - {"op": "client-releases-keys", "keys": [key], "client": self.id} + { + "op": "client-releases-keys", + "keys": [key], + "client": self.id, + "stimulus_id": stimulus_id, + } ) async def _handle_report(self): @@ -1506,7 +1514,9 @@ async def _close(self, fast=False): and self.scheduler_comm.comm and not self.scheduler_comm.comm.closed() ): - self._send_to_scheduler({"op": "close-client"}) + self._send_to_scheduler( + {"op": "close-client", "stimulus_id": f"client-close-{time()}"} + ) self._send_to_scheduler({"op": "close-stream"}) current_task = asyncio.current_task() @@ -1527,8 +1537,10 @@ async def _close(self, fast=False): ): await self.scheduler_comm.close() + stimulus_id = f"client-close-{time()}" + for key in list(self.futures): - self._release_key(key=key) + self._release_key(key=key, stimulus_id=stimulus_id) if self._start_arg is None: with suppress(AttributeError): @@ -2110,12 +2122,20 @@ async def _gather_remote(self, direct, local_worker): response = {"status": "OK", "data": data2} if missing_keys: keys2 = [key for key in keys if key not in data2] - response = await retry_operation(self.scheduler.gather, keys=keys2) + response = await retry_operation( + self.scheduler.gather, + keys=keys2, + stimulus_id=f"client-gather-remote-{time()}", + ) if response["status"] == "OK": response["data"].update(data2) else: # ask scheduler to gather data for us - response = await retry_operation(self.scheduler.gather, keys=keys) + response = await retry_operation( + self.scheduler.gather, + keys=keys, + stimulus_id=f"client-gather-remote-{time()}", + ) return response @@ -2201,6 +2221,8 @@ async def _scatter( d = await self._scatter(keymap(stringify, data), workers, broadcast) return {k: d[stringify(k)] for k in data} + stimulus_id = f"client-scatter-{time()}" + if isinstance(data, type(range(0))): data = list(data) input_type = type(data) @@ -2242,6 +2264,7 @@ async def _scatter( who_has={key: [local_worker.address] for key in data}, nbytes=valmap(sizeof, data), client=self.id, + stimulus_id=stimulus_id, ) else: @@ -2264,7 +2287,10 @@ async def _scatter( ) await self.scheduler.update_data( - who_has=who_has, nbytes=nbytes, client=self.id + who_has=who_has, + nbytes=nbytes, + client=self.id, + stimulus_id=stimulus_id, ) else: await self.scheduler.scatter( @@ -2273,6 +2299,7 @@ async def _scatter( client=self.id, broadcast=broadcast, timeout=timeout, + stimulus_id=stimulus_id, ) out = {k: Future(k, self, inform=False) for k in data} @@ -2396,7 +2423,12 @@ def scatter( async def _cancel(self, futures, force=False): keys = list({stringify(f.key) for f in futures_of(futures)}) - await self.scheduler.cancel(keys=keys, client=self.id, force=force) + await self.scheduler.cancel( + keys=keys, + client=self.id, + force=force, + stimulus_id=f"client-cancel-{time()}", + ) for k in keys: st = self.futures.pop(k, None) if st is not None: @@ -2423,7 +2455,9 @@ def cancel(self, futures, asynchronous=None, force=False): async def _retry(self, futures): keys = list({stringify(f.key) for f in futures_of(futures)}) - response = await self.scheduler.retry(keys=keys, client=self.id) + response = await self.scheduler.retry( + keys=keys, client=self.id, stimulus_id=f"client-retry-{time()}" + ) for key in response: st = self.futures[key] st.retry() @@ -2922,6 +2956,7 @@ def _graph_to_futures( "fifo_timeout": fifo_timeout, "actors": actors, "code": self._get_computation_code(), + "stimulus_id": f"client-update-graph-hlg-{time()}", } ) return futures @@ -3347,7 +3382,13 @@ async def _restart(self, timeout=no_default): if timeout is not None: timeout = parse_timedelta(timeout, "s") - self._send_to_scheduler({"op": "restart", "timeout": timeout}) + self._send_to_scheduler( + { + "op": "restart", + "timeout": timeout, + "stimulus_id": f"client-restart-{time()}", + } + ) self._restart_event = asyncio.Event() try: await asyncio.wait_for(self._restart_event.wait(), timeout) @@ -3424,7 +3465,9 @@ async def _rebalance(self, futures=None, workers=None): keys = list({stringify(f.key) for f in self.futures_of(futures)}) else: keys = None - result = await self.scheduler.rebalance(keys=keys, workers=workers) + result = await self.scheduler.rebalance( + keys=keys, workers=workers, stimulus_id=f"client-rebalance-{time()}" + ) if result["status"] == "partial-fail": raise KeyError(f"Could not rebalance keys: {result['keys']}") assert result["status"] == "OK", result @@ -3459,7 +3502,11 @@ async def _replicate(self, futures, n=None, workers=None, branching_factor=2): await _wait(futures) keys = {stringify(f.key) for f in futures} await self.scheduler.replicate( - keys=list(keys), n=n, workers=workers, branching_factor=branching_factor + keys=list(keys), + n=n, + workers=workers, + branching_factor=branching_factor, + stimulus_id=f"client-replicate-{time()}", ) def replicate(self, futures, n=None, workers=None, branching_factor=2, **kwargs): @@ -4177,6 +4224,7 @@ def retire_workers( self.scheduler.retire_workers, workers=workers, close_workers=close_workers, + stimulus_id=f"client-retire-workers-{time()}", **kwargs, ) @@ -5138,6 +5186,7 @@ def fire_and_forget(obj): "op": "client-desires-keys", "keys": [stringify(future.key)], "client": "fire-and-forget", + "stimulus_id": f"client-fire-and-forget-{time()}", } ) diff --git a/distributed/core.py b/distributed/core.py index 4b37721b8ba..ff8d48859c3 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -13,6 +13,7 @@ from collections import defaultdict from collections.abc import Container from contextlib import suppress +from contextvars import copy_context from enum import Enum from functools import partial from typing import Callable, ClassVar, TypedDict, TypeVar @@ -619,7 +620,9 @@ async def handle_stream(self, comm, extra=None, every_cycle=()): break handler = self.stream_handlers[op] if is_coroutine_function(handler): - self.loop.add_callback(handler, **merge(extra, msg)) + self.loop.add_callback( + copy_context().run, handler, **merge(extra, msg) + ) await gen.sleep(0) else: handler(**merge(extra, msg)) @@ -629,7 +632,7 @@ async def handle_stream(self, comm, extra=None, every_cycle=()): for func in every_cycle: if is_coroutine_function(func): - self.loop.add_callback(func) + self.loop.add_callback(copy_context().run, func) else: func() diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index d17f82fc893..e966b0f9a40 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -7,6 +7,7 @@ from dask.utils import parse_timedelta from distributed.deploy.adaptive_core import AdaptiveCore +from distributed.metrics import time from distributed.protocol import pickle from distributed.utils import log_errors @@ -193,6 +194,7 @@ async def scale_down(self, workers): names=workers, remove=True, close_workers=True, + stimulus_id=f"scale-down-{time()}", ) # close workers more forcefully diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 3a9fa16a3e7..bc2a811a858 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -19,6 +19,7 @@ from distributed.core import CommClosedError, Status, rpc from distributed.deploy.adaptive import Adaptive from distributed.deploy.cluster import Cluster +from distributed.metrics import time from distributed.scheduler import Scheduler from distributed.security import Security from distributed.utils import NoOpAwaitable, TimeoutError, import_term, silence_logging @@ -318,7 +319,10 @@ async def _correct_state_internal(self): to_close = set(self.workers) - set(self.worker_spec) if to_close: if self.scheduler.status == Status.running: - await self.scheduler_comm.retire_workers(workers=list(to_close)) + await self.scheduler_comm.retire_workers( + workers=list(to_close), + stimulus_id=f"spec-cluster-correct-internal-state-{time()}", + ) tasks = [ asyncio.create_task(self.workers[w].close()) for w in to_close diff --git a/distributed/diagnostics/tests/test_progress.py b/distributed/diagnostics/tests/test_progress.py index 872b1e4c7f6..591e7b29dc0 100644 --- a/distributed/diagnostics/tests/test_progress.py +++ b/distributed/diagnostics/tests/test_progress.py @@ -209,7 +209,7 @@ async def test_group_timing(c, s, a, b): ] ) - await s.restart() + await s.handle_restart(stimulus_id="test") assert len(p.time) == 2 assert len(p.nthreads) == 2 assert len(p.compute) == 0 diff --git a/distributed/diagnostics/tests/test_widgets.py b/distributed/diagnostics/tests/test_widgets.py index e47c1bd5bc9..844c9b5fba4 100644 --- a/distributed/diagnostics/tests/test_widgets.py +++ b/distributed/diagnostics/tests/test_widgets.py @@ -145,7 +145,7 @@ async def test_multi_progressbar_widget(c, s, a, b): @gen_cluster() async def test_multi_progressbar_widget_after_close(s, a, b): - s.update_graph( + s.handle_update_graph( tasks=valmap( dumps_task, { @@ -166,6 +166,7 @@ async def test_multi_progressbar_widget_after_close(s, a, b): "y-2": {"y-1"}, "e": {"y-2"}, }, + stimulus_id="test", ) p = MultiProgressWidget(["x-1", "x-2", "x-3"], scheduler=s.address) @@ -231,7 +232,7 @@ def test_progressbar_cancel(client): @gen_cluster() async def test_multibar_complete(s, a, b): - s.update_graph( + s.handle_update_graph( tasks=valmap( dumps_task, { @@ -252,6 +253,7 @@ async def test_multibar_complete(s, a, b): "y-2": {"y-1"}, "e": {"y-2"}, }, + stimulus_id="test", ) p = MultiProgressWidget(["e"], scheduler=s.address, complete=True) diff --git a/distributed/http/templates/task.html b/distributed/http/templates/task.html index 0b5c10695e0..f10aaad5602 100644 --- a/distributed/http/templates/task.html +++ b/distributed/http/templates/task.html @@ -118,16 +118,18 @@

Transition Log

Key Start Finish + Stimulus ID Recommended Key Recommended Action - {% for key, start, finish, recommendations, transition_time in scheduler.story(Task) %} + {% for key, start, finish, recommendations, stimulus_id, transition_time in scheduler.story(Task) %} {{ fromtimestamp(transition_time) }} {{key}} {{ start }} {{ finish }} + {{ stimulus_id }} @@ -137,6 +139,7 @@

Transition Log

+ {{key2}} {{ rec }} diff --git a/distributed/nanny.py b/distributed/nanny.py index db2371523e8..325329e7524 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -290,7 +290,10 @@ async def _unregister(self, timeout=10): allowed_errors = (TimeoutError, CommClosedError, EnvironmentError, RPCClosed) with suppress(allowed_errors): await asyncio.wait_for( - self.scheduler.unregister(address=self.worker_address), timeout + self.scheduler.unregister( + address=self.worker_address, stimulus_id=f"close-nanny-{time()}" + ), + timeout, ) @property diff --git a/distributed/scheduler.py b/distributed/scheduler.py index fc96d5d02ec..42ed631cfba 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -25,6 +25,7 @@ Set, ) from contextlib import suppress +from contextvars import copy_context from datetime import timedelta from functools import partial from numbers import Number @@ -85,9 +86,11 @@ from distributed.stealing import WorkStealing from distributed.stories import scheduler_story from distributed.utils import ( + STIMULUS_ID, All, TimeoutError, empty_context, + expect_stimulus, get_fileno_limit, key_split, key_split_group, @@ -2373,16 +2376,21 @@ def _transition(self, key, finish: str, *args, **kwargs): finish2 = ts._state # FIXME downcast antipattern scheduler = pep484_cast(Scheduler, self) + stimulus_id = STIMULUS_ID.get("") scheduler.transition_log.append( - (key, start, finish2, recommendations, time()) + (key, start, finish2, recommendations, stimulus_id, time()) ) if parent._validate: + if stimulus_id == "": + raise LookupError(STIMULUS_ID.name) + logger.debug( - "Transitioned %r %s->%s (actual: %s). Consequence: %s", + "Transitioned %r %s->%s (actual: %s) from %s. Consequence: %s", key, start, finish2, ts._state, + stimulus_id, dict(recommendations), ) if self.plugins: @@ -2857,7 +2865,7 @@ def transition_processing_memory( { "op": "cancel-compute", "key": key, - "stimulus_id": f"processing-memory-{time()}", + "stimulus_id": STIMULUS_ID.get(), } ] @@ -2945,7 +2953,7 @@ def transition_memory_released(self, key, safe: bint = False): worker_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": f"memory-released-{time()}", + "stimulus_id": STIMULUS_ID.get(), } for ws in ts._who_has: worker_msgs[ws._address] = [worker_msg] @@ -3048,7 +3056,7 @@ def transition_erred_released(self, key): w_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": f"erred-released-{time()}", + "stimulus_id": STIMULUS_ID.get(), } for ws_addr in ts._erred_on: worker_msgs[ws_addr] = [w_msg] @@ -3127,7 +3135,7 @@ def transition_processing_released(self, key): { "op": "free-keys", "keys": [key], - "stimulus_id": f"processing-released-{time()}", + "stimulus_id": STIMULUS_ID.get(), } ] @@ -3901,37 +3909,37 @@ def __init__( worker_handlers = { "task-finished": self.handle_task_finished, "task-erred": self.handle_task_erred, - "release-worker-data": self.release_worker_data, - "add-keys": self.add_keys, + "release-worker-data": self.handle_release_worker_data, + "add-keys": self.handle_add_keys, "missing-data": self.handle_missing_data, "long-running": self.handle_long_running, - "reschedule": self.reschedule, + "reschedule": self.handle_reschedule, "keep-alive": lambda *args, **kwargs: None, "log-event": self.log_worker_event, "worker-status-change": self.handle_worker_status_change, } client_handlers = { - "update-graph": self.update_graph, - "update-graph-hlg": self.update_graph_hlg, - "client-desires-keys": self.client_desires_keys, - "update-data": self.update_data, + "update-graph": self.handle_update_graph, + "update-graph-hlg": self.handle_update_graph_hlg, + "client-desires-keys": self.handle_client_desires_keys, + "update-data": self.handle_update_data, "report-key": self.report_on_key, - "client-releases-keys": self.client_releases_keys, + "client-releases-keys": self.handle_client_releases_keys, "heartbeat-client": self.client_heartbeat, - "close-client": self.remove_client, - "restart": self.restart, + "close-client": self.handle_remove_client, + "restart": self.handle_restart, "subscribe-topic": self.subscribe_topic, "unsubscribe-topic": self.unsubscribe_topic, } self.handlers = { - "register-client": self.add_client, - "scatter": self.scatter, - "register-worker": self.add_worker, + "register-client": self.handle_add_client, + "scatter": self.handle_scatter, + "register-worker": self.handle_add_worker, "register_nanny": self.add_nanny, - "unregister": self.remove_worker, - "gather": self.gather, + "unregister": self.handle_remove_worker, + "gather": self.handle_gather, "cancel": self.stimulus_cancel, "retry": self.stimulus_retry, "feed": self.feed, @@ -3954,12 +3962,12 @@ def __init__( "nbytes": self.get_nbytes, "versions": self.versions, "add_keys": self.add_keys, - "rebalance": self.rebalance, - "replicate": self.replicate, + "rebalance": self.handle_rebalance, + "replicate": self.handle_replicate, "run_function": self.run_function, - "update_data": self.update_data, + "update_data": self.handle_update_data, "set_resources": self.add_resources, - "retire_workers": self.retire_workers, + "retire_workers": self.handle_retire_workers, "get_metadata": self.get_metadata, "set_metadata": self.set_metadata, "set_restrictions": self.set_restrictions, @@ -4548,6 +4556,7 @@ async def add_worker( recommendations: dict = {} client_msgs: dict = {} worker_msgs: dict = {} + if nbytes: assert isinstance(nbytes, dict) already_released_keys = [] @@ -4578,7 +4587,7 @@ async def add_worker( { "op": "remove-replicas", "keys": already_released_keys, - "stimulus_id": f"reconnect-already-released-{time()}", + "stimulus_id": STIMULUS_ID.get(), } ) @@ -5003,7 +5012,7 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): { "op": "free-keys", "keys": [key], - "stimulus_id": f"already-released-or-forgotten-{time()}", + "stimulus_id": STIMULUS_ID.get(), } ] elif ts._state == "memory": @@ -5042,6 +5051,7 @@ def stimulus_task_erred( **kwargs, ) + @expect_stimulus(sync=True) def stimulus_retry(self, keys, client=None): parent: SchedulerState = cast(SchedulerState, self) logger.info("Client %s requests to retry %d keys", client, len(keys)) @@ -5188,6 +5198,7 @@ def remove_worker_from_events(): return "OK" + @expect_stimulus(sync=True) def stimulus_cancel(self, comm, keys=None, client=None, force=False): """Stop execution on a list of keys""" logger.info("Client %s requests to cancel %d keys", client, len(keys)) @@ -5562,6 +5573,7 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) + @expect_stimulus(sync=True) def handle_task_finished(self, key=None, worker=None, **msg): parent: SchedulerState = cast(SchedulerState, self) if worker not in parent._workers_dv: @@ -5578,6 +5590,7 @@ def handle_task_finished(self, key=None, worker=None, **msg): self.send_all(client_msgs, worker_msgs) + @expect_stimulus(sync=True) def handle_task_erred(self, key=None, **msg): parent: SchedulerState = cast(SchedulerState, self) recommendations: dict @@ -5589,6 +5602,7 @@ def handle_task_erred(self, key=None, **msg): self.send_all(client_msgs, worker_msgs) + @expect_stimulus(sync=True) def handle_missing_data(self, key=None, errant_worker=None, **kwargs): """Signal that `errant_worker` does not hold `key` @@ -5636,6 +5650,7 @@ def release_worker_data(self, key, worker): if recommendations: self.transitions(recommendations) + @expect_stimulus(sync=True) def handle_long_running(self, key=None, worker=None, compute_duration=None): """A task has seceded from the thread pool @@ -5678,6 +5693,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): ws._long_running.add(ts) self.check_idle_saturated(ws) + @expect_stimulus(sync=True) def handle_worker_status_change(self, status: str, worker: str) -> None: parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState = parent._workers_dv.get(worker) # type: ignore @@ -5872,7 +5888,9 @@ def worker_send(self, worker, msg): try: stream_comms[worker].send(msg) except (CommClosedError, AttributeError): - self.loop.add_callback(self.remove_worker, address=worker) + self.loop.add_callback( + copy_context().run, self.remove_worker, address=worker + ) def client_send(self, client, msg): """Send message to client""" @@ -5917,7 +5935,9 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): # worker already gone pass except (CommClosedError, AttributeError): - self.loop.add_callback(self.remove_worker, address=worker) + self.loop.add_callback( + copy_context().run, self.remove_worker, address=worker + ) ############################ # Less common interactions # @@ -7026,7 +7046,11 @@ async def retire_workers( ws.status = Status.closing_gracefully self.running.discard(ws) self.stream_comms[ws.address].send( - {"op": "worker-status-change", "status": ws.status.name} + { + "op": "worker-status-change", + "status": ws.status.name, + "stimulus_id": STIMULUS_ID.get(), + } ) coros.append( @@ -7071,7 +7095,11 @@ async def _track_retire_worker( # conditions and we can wait for a scheduler->worker->scheduler # round-trip. self.stream_comms[ws.address].send( - {"op": "worker-status-change", "status": prev_status.name} + { + "op": "worker-status-change", + "status": prev_status.name, + "stimulus_id": STIMULUS_ID.get(), + } ) return None, {} @@ -7092,7 +7120,7 @@ async def _track_retire_worker( logger.info("Retired worker %s", ws._address) return ws._address, ws.identity() - def add_keys(self, worker=None, keys=(), stimulus_id=None): + def add_keys(self, worker=None, keys=()): """ Learn that a worker has certain keys @@ -7113,14 +7141,12 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None): redundant_replicas.append(key) if redundant_replicas: - if not stimulus_id: - stimulus_id = f"redundant-replicas-{time()}" self.worker_send( worker, { "op": "remove-replicas", "keys": redundant_replicas, - "stimulus_id": stimulus_id, + "stimulus_id": STIMULUS_ID.get(), }, ) @@ -8073,7 +8099,9 @@ async def check_worker_ttl(self): self.worker_ttl, ws, ) - await self.remove_worker(address=ws._address) + await self.handle_remove_worker( + address=ws._address, stimulus_id=f"check-worker-ttl-{time()}" + ) def check_idle(self): parent: SchedulerState = cast(SchedulerState, self) @@ -8204,6 +8232,27 @@ def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str): } ) + handle_update_graph_hlg = expect_stimulus(sync=True)(update_graph_hlg) + handle_update_graph = expect_stimulus(sync=True)(update_graph) + handle_add_worker = expect_stimulus(sync=False)(add_worker) + handle_remove_worker = expect_stimulus(sync=False)(remove_worker) + handle_cancel_key = expect_stimulus(sync=True)(cancel_key) + handle_client_desires_keys = expect_stimulus(sync=True)(client_desires_keys) + handle_client_releases_keys = expect_stimulus(sync=True)(client_releases_keys) + handle_add_client = expect_stimulus(sync=False)(add_client) + handle_remove_client = expect_stimulus(sync=True)(remove_client) + handle_release_worker_data = expect_stimulus(sync=True)(release_worker_data) + handle_scatter = expect_stimulus(sync=False)(scatter) + handle_gather = expect_stimulus(sync=False)(gather) + handle_restart = expect_stimulus(sync=False)(restart) + handle_delete_worker_data = expect_stimulus(sync=False)(delete_worker_data) + handle_rebalance = expect_stimulus(sync=False)(rebalance) + handle_replicate = expect_stimulus(sync=False)(replicate) + handle_retire_workers = expect_stimulus(sync=False)(retire_workers) + handle_add_keys = expect_stimulus(sync=True)(add_keys) + handle_update_data = expect_stimulus(sync=True)(update_data) + handle_reschedule = expect_stimulus(sync=True)(reschedule) + @cfunc @exceptval(check=False) @@ -8337,7 +8386,7 @@ def _propagate_forgotten( { "op": "free-keys", "keys": [key], - "stimulus_id": f"propagate-forgotten-{time()}", + "stimulus_id": STIMULUS_ID.get(), } ] state.remove_all_replicas(ts) @@ -8380,7 +8429,7 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> "key": ts._key, "priority": ts._priority, "duration": duration, - "stimulus_id": f"compute-task-{time()}", + "stimulus_id": STIMULUS_ID.get(), "who_has": {}, } if ts._resource_restrictions: diff --git a/distributed/stealing.py b/distributed/stealing.py index 54ef0098c63..be7d757c5f7 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -17,7 +17,12 @@ from distributed.comm.addressing import get_address_host from distributed.core import CommClosedError, Status from distributed.diagnostics.plugin import SchedulerPlugin -from distributed.utils import log_errors, recursive_to_dict +from distributed.utils import ( + STIMULUS_ID, + expect_stimulus, + log_errors, + recursive_to_dict, +) # Stealing requires multiple network bounces and if successful also task # submission which may include code serialization. Therefore, be very @@ -79,7 +84,7 @@ def __init__(self, scheduler): self.in_flight_occupancy = defaultdict(lambda: 0) self._in_flight_event = asyncio.Event() - self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm + self.scheduler.stream_handlers["steal-response"] = self.handle_move_task_confirm async def start(self, scheduler=None): """Start the background coroutine to balance the tasks on the cluster. @@ -265,7 +270,7 @@ def move_task_request(self, ts, victim, thief) -> str: pdb.set_trace() raise - async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): + async def move_task_confirm(self, *, key, state, worker=None): try: ts = self.scheduler.tasks[key] except KeyError: @@ -273,12 +278,12 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): return try: d = self.in_flight.pop(ts) - if d["stimulus_id"] != stimulus_id: - self.log(("stale-response", key, state, worker, stimulus_id)) + if d["stimulus_id"] != STIMULUS_ID.get(): + self.log(("stale-response", key, state, worker, STIMULUS_ID.get())) self.in_flight[ts] = d return except KeyError: - self.log(("already-aborted", key, state, worker, stimulus_id)) + self.log(("already-aborted", key, state, worker, STIMULUS_ID.get())) return thief = d["thief"] @@ -297,7 +302,7 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): assert ts.processing_on == victim try: - _log_msg = [key, state, victim.address, thief.address, stimulus_id] + _log_msg = [key, state, victim.address, thief.address, STIMULUS_ID.get()] if ts.state != "processing": self.scheduler._reevaluate_occupancy_worker(thief) @@ -348,6 +353,8 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): self.scheduler.check_idle_saturated(thief) self.scheduler.check_idle_saturated(victim) + handle_move_task_confirm = expect_stimulus(sync=False)(move_task_confirm) + def balance(self): s = self.scheduler diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 1a0df078376..8ca0dbded04 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -4,9 +4,10 @@ import distributed from distributed import Event from distributed.core import CommClosedError +from distributed.utils import STIMULUS_ID from distributed.utils_test import ( _LockedCommPool, - assert_worker_story, + assert_story, gen_cluster, inc, slowinc, @@ -82,7 +83,7 @@ def f(ev): while "f1" in a.tasks: await asyncio.sleep(0.01) - assert_worker_story( + assert_story( a.story("f1"), [ ("f1", "compute-task"), @@ -160,7 +161,12 @@ async def wait_and_raise(*args, **kwargs): await wait_for_state(fut1.key, "flight", b) # Close in scheduler to ensure we transition and reschedule task properly - await s.close_worker(worker=a.address) + try: + token = STIMULUS_ID.set("test") + await s.close_worker(worker=a.address) + finally: + STIMULUS_ID.reset(token) + await wait_for_state(fut1.key, "resumed", b) lock.release() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 5044c487bba..9671c7b492b 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2060,15 +2060,15 @@ async def test_forget_simple(c, s, a, b): assert set(s.tasks) == {x.key, y.key, z.key} - s.client_releases_keys(keys=[x.key], client=c.id) + s.handle_client_releases_keys(keys=[x.key], client=c.id, stimulus_id="test") assert x.key in s.tasks - s.client_releases_keys(keys=[z.key], client=c.id) + s.handle_client_releases_keys(keys=[z.key], client=c.id, stimulus_id="test") assert x.key not in s.tasks assert z.key not in s.tasks assert not s.tasks[y.key].dependents - s.client_releases_keys(keys=[y.key], client=c.id) + s.handle_client_releases_keys(keys=[y.key], client=c.id, stimulus_id="test") assert not s.tasks @@ -2084,20 +2084,20 @@ async def test_forget_complex(e, s, A, B): assert set(s.tasks) == {f.key for f in [ab, ac, cd, acab, a, b, c, d]} - s.client_releases_keys(keys=[ab.key], client=e.id) + s.handle_client_releases_keys(keys=[ab.key], client=e.id, stimulus_id="test") assert set(s.tasks) == {f.key for f in [ab, ac, cd, acab, a, b, c, d]} - s.client_releases_keys(keys=[b.key], client=e.id) + s.handle_client_releases_keys(keys=[b.key], client=e.id, stimulus_id="test") assert set(s.tasks) == {f.key for f in [ac, cd, acab, a, c, d]} - s.client_releases_keys(keys=[acab.key], client=e.id) + s.handle_client_releases_keys(keys=[acab.key], client=e.id, stimulus_id="test") assert set(s.tasks) == {f.key for f in [ac, cd, a, c, d]} assert b.key not in s.tasks while b.key in A.data or b.key in B.data: await asyncio.sleep(0.01) - s.client_releases_keys(keys=[ac.key], client=e.id) + s.handle_client_releases_keys(keys=[ac.key], client=e.id, stimulus_id="test") assert set(s.tasks) == {f.key for f in [cd, a, c, d]} @@ -2117,7 +2117,7 @@ async def test_forget_in_flight(e, s, A, B): await asyncio.sleep(0.01) s.validate_state() - s.client_releases_keys(keys=[y.key], client=e.id) + s.handle_client_releases_keys(keys=[y.key], client=e.id, stimulus_id="test") s.validate_state() for k in [acab.key, ab.key, b.key]: @@ -2136,21 +2136,21 @@ async def test_forget_errors(c, s, a, b): assert y.key in s.exceptions_blame assert z.key in s.exceptions_blame - s.client_releases_keys(keys=[z.key], client=c.id) + s.handle_client_releases_keys(keys=[z.key], client=c.id, stimulus_id="test") assert x.key in s.exceptions assert x.key in s.exceptions_blame assert y.key in s.exceptions_blame assert z.key not in s.exceptions_blame - s.client_releases_keys(keys=[x.key], client=c.id) + s.handle_client_releases_keys(keys=[x.key], client=c.id, stimulus_id="test") assert x.key in s.exceptions assert x.key in s.exceptions_blame assert y.key in s.exceptions_blame assert z.key not in s.exceptions_blame - s.client_releases_keys(keys=[y.key], client=c.id) + s.handle_client_releases_keys(keys=[y.key], client=c.id, stimulus_id="test") assert x.key not in s.exceptions assert x.key not in s.exceptions_blame @@ -4355,7 +4355,7 @@ async def test_scatter_type(c, s, a, b): async def test_retire_workers_2(c, s, a, b): [x] = await c.scatter([1], workers=a.address) - await s.retire_workers(workers=[a.address]) + await s.handle_retire_workers(workers=[a.address], stimulus_id="test") assert b.data == {x.key: 1} assert s.who_has == {x.key: {b.address}} assert s.has_what == {b.address: {x.key}} @@ -4367,7 +4367,9 @@ async def test_retire_workers_2(c, s, a, b): async def test_retire_many_workers(c, s, *workers): futures = await c.scatter(list(range(100))) - await s.retire_workers(workers=[w.address for w in workers[:7]]) + await s.handle_retire_workers( + workers=[w.address for w in workers[:7]], stimulus_id="test" + ) results = await c.gather(futures) assert results == list(range(100)) @@ -4565,7 +4567,7 @@ def test_auto_normalize_collection_sync(c): def assert_no_data_loss(scheduler): - for key, start, finish, recommendations, _ in scheduler.transition_log: + for key, start, finish, recommendations, _, _ in scheduler.transition_log: if start == "memory" and finish == "released": for k, v in recommendations.items(): assert not (k == key and v == "waiting") diff --git a/distributed/tests/test_cluster_dump.py b/distributed/tests/test_cluster_dump.py index b01cf2611ca..1762929d378 100644 --- a/distributed/tests/test_cluster_dump.py +++ b/distributed/tests/test_cluster_dump.py @@ -8,7 +8,7 @@ import distributed from distributed.cluster_dump import DumpArtefact, _tuple_to_list, write_state -from distributed.utils_test import assert_worker_story, gen_cluster, gen_test, inc +from distributed.utils_test import assert_story, gen_cluster, gen_test, inc @pytest.mark.parametrize( @@ -140,7 +140,7 @@ async def test_cluster_dump_story(c, s, a, b, tmp_path): assert story.keys() == {"f1", "f2"} for k, task_story in story.items(): - assert_worker_story( + assert_story( task_story, [ (k, "compute-task"), diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 2750868dc01..c4fc650472d 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -37,6 +37,7 @@ from distributed.utils import TimeoutError from distributed.utils_test import ( BrokenComm, + assert_story, captured_logger, cluster, dec, @@ -292,22 +293,23 @@ async def test_no_workers(client, s): @gen_cluster(nthreads=[]) async def test_retire_workers_empty(s): - await s.retire_workers(workers=[]) + await s.handle_retire_workers(workers=[], stimulus_id="test") @gen_cluster() async def test_remove_client(s, a, b): - s.update_graph( + s.handle_update_graph( tasks={"x": dumps_task((inc, 1)), "y": dumps_task((inc, "x"))}, dependencies={"x": [], "y": ["x"]}, keys=["y"], client="ident", + stimulus_id="test", ) assert s.tasks assert s.dependencies - s.remove_client(client="ident") + s.handle_remove_client(client="ident", stimulus_id="test") assert not s.tasks assert not s.dependencies @@ -324,14 +326,15 @@ async def test_server_listens_to_other_ops(s, a, b): @gen_cluster() async def test_remove_worker_from_scheduler(s, a, b): dsk = {("x-%d" % i): (inc, i) for i in range(20)} - s.update_graph( + s.handle_update_graph( tasks=valmap(dumps_task, dsk), keys=list(dsk), dependencies={k: set() for k in dsk}, + stimulus_id="test", ) assert a.address in s.stream_comms - await s.remove_worker(address=a.address) + await s.handle_remove_worker(address=a.address, stimulus_id="test") assert a.address not in s.nthreads assert len(s.workers[b.address].processing) == len(dsk) # b owns everything @@ -390,11 +393,12 @@ async def test_add_worker(s, a, b): w.data["y"] = 1 dsk = {("x-%d" % i): (inc, i) for i in range(10)} - s.update_graph( + s.handle_update_graph( tasks=valmap(dumps_task, dsk), keys=list(dsk), client="client", dependencies={k: set() for k in dsk}, + stimulus_id="test", ) s.validate_state() await w @@ -539,8 +543,12 @@ async def test_delete(c, s, a): async def test_filtered_communication(s, a, b): c = await connect(s.address) f = await connect(s.address) - await c.write({"op": "register-client", "client": "c", "versions": {}}) - await f.write({"op": "register-client", "client": "f", "versions": {}}) + await c.write( + {"op": "register-client", "client": "c", "versions": {}, "stimulus_id": "test"} + ) + await f.write( + {"op": "register-client", "client": "f", "versions": {}, "stimulus_id": "test"} + ) await c.read() await f.read() @@ -553,6 +561,7 @@ async def test_filtered_communication(s, a, b): "dependencies": {"x": [], "y": ["x"]}, "client": "c", "keys": ["y"], + "stimulus_id": "test", } ) @@ -566,6 +575,7 @@ async def test_filtered_communication(s, a, b): "dependencies": {"x": [], "z": ["x"]}, "client": "f", "keys": ["z"], + "stimulus_id": "test", } ) (msg,) = await c.read() @@ -605,16 +615,17 @@ def test_dumps_task(): @gen_cluster() async def test_ready_remove_worker(s, a, b): - s.update_graph( + s.handle_update_graph( tasks={"x-%d" % i: dumps_task((inc, i)) for i in range(20)}, keys=["x-%d" % i for i in range(20)], client="client", dependencies={"x-%d" % i: [] for i in range(20)}, + stimulus_id="test", ) assert all(len(w.processing) > w.nthreads for w in s.workers.values()) - await s.remove_worker(address=a.address) + await s.handle_remove_worker(address=a.address, stimulus_id="test") assert set(s.workers) == {b.address} assert all(len(w.processing) > w.nthreads for w in s.workers.values()) @@ -625,7 +636,7 @@ async def test_restart(c, s, a, b): futures = c.map(inc, range(20)) await wait(futures) - await s.restart() + await s.handle_restart(stimulus_id="test") assert len(s.workers) == 2 @@ -779,7 +790,7 @@ async def test_file_descriptors_dont_leak(s): @gen_cluster() async def test_update_graph_culls(s, a, b): - s.update_graph( + s.handle_update_graph( tasks={ "x": dumps_task((inc, 1)), "y": dumps_task((inc, "x")), @@ -788,6 +799,7 @@ async def test_update_graph_culls(s, a, b): keys=["y"], dependencies={"y": "x", "x": [], "z": []}, client="client", + stimulus_id="test", ) assert "z" not in s.tasks assert "z" not in s.dependencies @@ -810,7 +822,7 @@ async def test_story(c, s, a, b): story = s.story(x.key) assert all(line in s.transition_log for line in story) assert len(story) < len(s.transition_log) - assert all(x.key == line[0] or x.key in line[-2] for line in story) + assert all(x.key == line[0] or x.key in line[3] for line in story) assert len(s.story(x.key, y.key)) > len(story) @@ -821,7 +833,9 @@ async def test_story(c, s, a, b): @gen_cluster(client=True, nthreads=[]) async def test_scatter_no_workers(c, s, direct): with pytest.raises(TimeoutError): - await s.scatter(data={"x": 1}, client="alice", timeout=0.1) + await s.handle_scatter( + data={"x": 1}, client="alice", timeout=0.1, stimulus_id="test" + ) start = time() with pytest.raises(TimeoutError): @@ -856,7 +870,7 @@ async def test_retire_workers(c, s, a, b): assert s.workers_to_close() == [a.address] - workers = await s.retire_workers() + workers = await s.handle_retire_workers(stimulus_id="test") assert list(workers) == [a.address] assert workers[a.address]["nthreads"] == a.nthreads assert list(s.nthreads) == [b.address] @@ -865,22 +879,22 @@ async def test_retire_workers(c, s, a, b): assert s.workers[b.address].has_what == {s.tasks[x.key], s.tasks[y.key]} - workers = await s.retire_workers() + workers = await s.handle_retire_workers(stimulus_id="test") assert not workers @gen_cluster(client=True) async def test_retire_workers_n(c, s, a, b): - await s.retire_workers(n=1, close_workers=True) + await s.handle_retire_workers(n=1, close_workers=True, stimulus_id="test") assert len(s.workers) == 1 - await s.retire_workers(n=0, close_workers=True) + await s.handle_retire_workers(n=0, close_workers=True, stimulus_id="test") assert len(s.workers) == 1 - await s.retire_workers(n=1, close_workers=True) + await s.handle_retire_workers(n=1, close_workers=True, stimulus_id="test") assert len(s.workers) == 0 - await s.retire_workers(n=0, close_workers=True) + await s.handle_retire_workers(n=0, close_workers=True, stimulus_id="test") assert len(s.workers) == 0 while not ( @@ -944,7 +958,7 @@ async def test_retire_workers_no_suspicious_tasks(c, s, a, b): slowinc, 100, delay=0.5, workers=a.address, allow_other_workers=True ) await asyncio.sleep(0.2) - await s.retire_workers(workers=[a.address]) + await s.handle_retire_workers(workers=[a.address], stimulus_id="test") assert all(ts.suspicious == 0 for ts in s.tasks.values()) assert all(tp.suspicious == 0 for tp in s.task_prefixes.values()) @@ -1277,7 +1291,7 @@ async def test_close_nanny(c, s, a, b): @gen_cluster(client=True) async def test_retire_workers_close(c, s, a, b): - await s.retire_workers(close_workers=True) + await s.handle_retire_workers(close_workers=True, stimulus_id="test") assert not s.workers while a.status != Status.closed and b.status != Status.closed: await asyncio.sleep(0.01) @@ -1286,7 +1300,7 @@ async def test_retire_workers_close(c, s, a, b): @gen_cluster(client=True, Worker=Nanny) async def test_retire_nannies_close(c, s, a, b): nannies = [a, b] - await s.retire_workers(close_workers=True, remove=True) + await s.handle_retire_workers(close_workers=True, remove=True, stimulus_id="test") assert not s.workers start = time() @@ -1327,12 +1341,13 @@ async def test_scheduler_file(): @gen_cluster(client=True, nthreads=[]) async def test_non_existent_worker(c, s): with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}): - await s.add_worker( + await s.handle_add_worker( address="127.0.0.1:5738", status="running", nthreads=2, nbytes={}, host_info={}, + stimulus_id="test", ) futures = c.map(inc, range(10)) await asyncio.sleep(0.300) @@ -1481,7 +1496,7 @@ async def test_reschedule(c, s, a, b): await asyncio.sleep(0.001) for future in x: - s.reschedule(key=future.key) + s.handle_reschedule(key=future.key, stimulus_id="test") # Worker b gets more of the original tasks await wait(x) @@ -1492,7 +1507,7 @@ async def test_reschedule(c, s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) async def test_reschedule_warns(c, s, a, b): with captured_logger(logging.getLogger("distributed.scheduler")) as sched: - s.reschedule(key="__this-key-does-not-exist__") + s.handle_reschedule(key="__this-key-does-not-exist__", stimulus_id="test") assert "not found on the scheduler" in sched.getvalue() assert "Aborting reschedule" in sched.getvalue() @@ -1794,7 +1809,7 @@ async def f(dask_worker): assert s.bandwidth_workers - await s.restart() + await s.handle_restart(stimulus_id="test") assert not s.bandwidth_workers @@ -1914,7 +1929,7 @@ async def test_retire_names_str(c, s): futures = c.map(inc, range(10)) await wait(futures) assert a.data and b.data - await s.retire_workers(names=[0]) + await s.handle_retire_workers(names=[0], stimulus_id="test") assert all(f.done() for f in futures) assert len(b.data) == 10 @@ -2150,7 +2165,7 @@ async def test_gather_failing_cnn_error(c, s, a, b): x = await c.scatter({"x": 1}, workers=a.address) s.rpc = await FlakyConnectionPool(failing_connections=10) - res = await s.gather(keys=["x"]) + res = await s.handle_gather(keys=["x"], stimulus_id="test") assert res["status"] == "error" assert list(res["keys"]) == ["x"] @@ -3117,7 +3132,7 @@ async def test_delete_worker_data(c, s, a, b): assert b.data == {y.key: "y"} assert s.tasks.keys() == {x.key, y.key, z.key} - await s.delete_worker_data(a.address, [x.key, y.key]) + await s.handle_delete_worker_data(a.address, [x.key, y.key], stimulus_id="test") assert a.data == {z.key: "z"} assert b.data == {y.key: "y"} assert s.tasks.keys() == {y.key, z.key} @@ -3131,8 +3146,8 @@ async def test_delete_worker_data_double_delete(c, s, a): """ x, y = await c.scatter(["x", "y"]) await asyncio.gather( - s.delete_worker_data(a.address, [x.key]), - s.delete_worker_data(a.address, [x.key]), + s.handle_delete_worker_data(a.address, [x.key], stimulus_id="test"), + s.handle_delete_worker_data(a.address, [x.key], stimulus_id="test"), ) assert a.data == {y.key: "y"} a_ws = s.workers[a.address] @@ -3147,7 +3162,7 @@ async def test_delete_worker_data_bad_worker(s, a, b): """ await a.close() assert s.workers.keys() == {b.address} - await s.delete_worker_data(a.address, ["x"]) + await s.handle_delete_worker_data(a.address, ["x"], stimulus_id="test") @pytest.mark.parametrize("bad_first", [False, True]) @@ -3162,7 +3177,7 @@ async def test_delete_worker_data_bad_task(c, s, a, bad_first): assert s.tasks.keys() == {x.key, y.key} keys = ["notexist", x.key] if bad_first else [x.key, "notexist"] - await s.delete_worker_data(a.address, keys) + await s.handle_delete_worker_data(a.address, keys, stimulus_id="test") assert a.data == {y.key: "y"} assert s.tasks.keys() == {y.key} assert s.workers[a.address].nbytes == s.tasks[y.key].nbytes @@ -3247,13 +3262,13 @@ async def test_worker_reconnect_task_memory(c, s, a): while not a.executing_count and not a.data: await asyncio.sleep(0.001) - await s.remove_worker(address=a.address, close=False) + await s.handle_remove_worker(address=a.address, close=False, stimulus_id="test") while not res.done(): await a.heartbeat() await res assert ("no-worker", "memory") in { - (start, finish) for (_, start, finish, _, _) in s.transition_log + (start, finish) for (_, start, finish, _, _, _) in s.transition_log } @@ -3271,13 +3286,13 @@ async def test_worker_reconnect_task_memory_with_resources(c, s, a): while not b.executing_count and not b.data: await asyncio.sleep(0.001) - await s.remove_worker(address=b.address, close=False) + await s.handle_remove_worker(address=b.address, close=False, stimulus_id="test") while not res.done(): await b.heartbeat() await res assert ("no-worker", "memory") in { - (start, finish) for (_, start, finish, _, _) in s.transition_log + (start, finish) for (_, start, finish, _, _, _) in s.transition_log } @@ -3538,3 +3553,40 @@ async def test_repr(s, a): repr(ws_b) == f"" ) + + +@gen_cluster(client=True) +async def test_stimuli(c, s, a, b): + f = c.submit(inc, 1) + key = f.key + + await f + await c.close() + + assert_story( + s.story(key), + [ + (key, "released", "waiting", {key: "processing"}), + (key, "waiting", "processing", {}), + (key, "processing", "memory", {}), + ( + key, + "memory", + "forgotten", + {}, + ), + ], + ) + + stimuli = [ + "client-update-graph-hlg", + "client-update-graph-hlg", + "task-finished", + "client-close", + ] + + stories = s.story(key) + assert len(stories) == len(stimuli) + + for stimulus_id, story in zip(stimuli, stories): + assert story[-2].startswith(stimulus_id), (story[-2], stimulus_id) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 55f1bbf043a..74e012341db 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1107,7 +1107,7 @@ async def test_steal_reschedule_reset_in_flight_occupancy(c, s, *workers): steal.move_task_request(victim_ts, wsA, wsB) - s.reschedule(victim_key) + s.handle_reschedule(victim_key, stimulus_id="test") await c.gather(futs1) del futs1 @@ -1173,7 +1173,7 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers): steal.move_task_request(victim_ts, wsA, wsB) s.set_restrictions(worker={victim_key: [wsB.address]}) - s.reschedule(victim_key) + s.handle_reschedule(victim_key, stimulus_id="test") assert wsB == victim_ts.processing_on # move_task_request is not responsible for respecting worker restrictions steal.move_task_request(victim_ts, wsB, wsC) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index c920366c4c2..6b4a583b9ac 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -19,6 +19,7 @@ from distributed.metrics import time from distributed.utils import ( LRU, + STIMULUS_ID, All, Log, Logs, @@ -27,6 +28,7 @@ _maybe_complex, ensure_bytes, ensure_ip, + expect_stimulus, format_dashboard_link, get_ip_interface, get_traceback, @@ -781,3 +783,16 @@ def __repr__(self): ], } assert recursive_to_dict(info) == expect + + +def test_expect_stimulus(): + @expect_stimulus(sync=True) + def fn(x): + return STIMULUS_ID.get() + + assert fn(1, stimulus_id="test") == "test" + + with pytest.raises(ValueError): + assert fn(1) == 1 + + assert fn(**{"x": 1, "stimulus_id": "test"}) == "test" diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index dc481aae17f..c0d377acc01 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -21,7 +21,7 @@ from distributed.utils_test import ( _LockedCommPool, _UnhashableCallable, - assert_worker_story, + assert_story, check_process_leak, cluster, dump_cluster_state, @@ -406,7 +406,7 @@ async def inner_test(c, s, a, b): assert "workers" in state -def test_assert_worker_story(): +def test_assert_story(): now = time() story = [ ("foo", "id1", now - 600), @@ -414,38 +414,38 @@ def test_assert_worker_story(): ("baz", {1: 2}, "id2", now), ] # strict=False - assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})]) - assert_worker_story(story, []) - assert_worker_story(story, [("foo",)]) - assert_worker_story(story, [("foo",), ("bar",)]) - assert_worker_story(story, [("baz", lambda d: d[1] == 2)]) + assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})]) + assert_story(story, []) + assert_story(story, [("foo",)]) + assert_story(story, [("foo",), ("bar",)]) + assert_story(story, [("baz", lambda d: d[1] == 2)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo", "nomatch")]) + assert_story(story, [("foo", "nomatch")]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz",)]) + assert_story(story, [("baz",)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz", {1: 3})]) + assert_story(story, [("baz", {1: 3})]) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)]) + assert_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz", lambda d: d[1] == 3)]) + assert_story(story, [("baz", lambda d: d[1] == 3)]) with pytest.raises(KeyError): # Faulty lambda - assert_worker_story(story, [("baz", lambda d: d[2] == 1)]) - assert_worker_story([], []) - assert_worker_story([("foo", "id1", now)], [("foo",)]) + assert_story(story, [("baz", lambda d: d[2] == 1)]) + assert_story([], []) + assert_story([("foo", "id1", now)], [("foo",)]) with pytest.raises(AssertionError): - assert_worker_story([], [("foo",)]) + assert_story([], [("foo",)]) # strict=True - assert_worker_story([], [], strict=True) - assert_worker_story([("foo", "id1", now)], [("foo",)]) - assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True) + assert_story([], [], strict=True) + assert_story([("foo", "id1", now)], [("foo",)]) + assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("bar",)], strict=True) + assert_story(story, [("foo",), ("bar",)], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("baz", {1: 2})], strict=True) + assert_story(story, [("foo",), ("baz", {1: 2})], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [], strict=True) + assert_story(story, [], strict=True) @pytest.mark.parametrize( @@ -466,11 +466,11 @@ def test_assert_worker_story(): ), ], ) -def test_assert_worker_story_malformed_story(story_factory): +def test_assert_story_malformed_story(story_factory): # defer the calls to time() to when the test runs rather than collection story = story_factory() with pytest.raises(AssertionError, match="Malformed story event"): - assert_worker_story(story, []) + assert_story(story, []) @gen_cluster() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index aa8549b07ba..14521c72c6f 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -4,6 +4,7 @@ import importlib import logging import os +import re import sys import threading import traceback @@ -41,11 +42,11 @@ from distributed.metrics import time from distributed.protocol import pickle from distributed.scheduler import Scheduler -from distributed.utils import TimeoutError +from distributed.utils import STIMULUS_ID, TimeoutError from distributed.utils_test import ( TaskStateMetadataPlugin, _LockedCommPool, - assert_worker_story, + assert_story, captured_logger, dec, div, @@ -1403,7 +1404,7 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None: """Test that an in-memory key was transferred from worker w_from to worker w_to by the Active Memory Manager and it was not recalculated on w_to """ - assert_worker_story( + assert_story( w_to.story(key), [ (key, "ensure-task-exists", "released"), @@ -1687,7 +1688,12 @@ async def test_story_with_deps(c, s, a, b): # Story now includes randomized stimulus_ids and timestamps. stimulus_ids = {ev[-2] for ev in story} - assert len(stimulus_ids) == 3, stimulus_ids + assert {sid[: re.search(r"\d", sid).start()] for sid in stimulus_ids} == { + "ensure-computing-", + "ensure-communicating-", + "task-finished-", + } + # This is a simple transition log expected = [ ("res", "compute-task"), @@ -1697,7 +1703,7 @@ async def test_story_with_deps(c, s, a, b): ("res", "put-in-memory"), ("res", "executing", "memory", "memory", {}), ] - assert_worker_story(story, expected, strict=True) + assert_story(story, expected, strict=True) story = b.story("dep") stimulus_ids = {ev[-2] for ev in story} @@ -1712,7 +1718,7 @@ async def test_story_with_deps(c, s, a, b): ("dep", "put-in-memory"), ("dep", "flight", "memory", "memory", {"res": "ready"}), ] - assert_worker_story(story, expected, strict=True) + assert_story(story, expected, strict=True) @gen_cluster(client=True) @@ -2541,7 +2547,12 @@ def __call__(self, *args, **kwargs): ts = s.tasks[fut.key] a.handle_steal_request(fut.key, stimulus_id="test") - stealing_ext.scheduler.send_task_to_worker(b.address, ts) + + try: + token = STIMULUS_ID.set("test") + stealing_ext.scheduler.send_task_to_worker(b.address, ts) + finally: + STIMULUS_ID.reset(token) fut2 = c.submit(inc, fut, workers=[a.address]) fut3 = c.submit(inc, fut2, workers=[a.address]) @@ -2621,7 +2632,7 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b): while fut1.key not in b.tasks or b.tasks[fut1.key].state == "flight": await asyncio.sleep(0) - s.handle_missing_data(key="f1", errant_worker=a.address) + s.handle_missing_data(key="f1", errant_worker=a.address, stimulus_id="test") await fut2 @@ -2666,7 +2677,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b): # same communication channel for fut in (futA, futB): - assert_worker_story( + assert_story( b.story(fut.key), [ ("gather-dependencies", a.address, {fut.key}), @@ -2725,7 +2736,7 @@ def __getstate__(self): assert await y == 123 story = await c.run(lambda dask_worker: dask_worker.story("x")) - assert_worker_story( + assert_story( story[b], [ ("x", "ensure-task-exists", "released"), @@ -2902,7 +2913,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers): await f2 - assert_worker_story(a.story(f1.key), [(f1.key, "missing-dep")]) + assert_story(a.story(f1.key), [(f1.key, "missing-dep")]) assert a.tasks[f1.key].suspicious_count == 0 assert s.tasks[f1.key].suspicious == 0 @@ -2982,7 +2993,7 @@ async def test_missing_released_zombie_tasks_2(c, s, a, b): while b.tasks: await asyncio.sleep(0.01) - assert_worker_story( + assert_story( b.story(ts), [("f1", "missing", "released", "released", {"f1": "forgotten"})], ) @@ -2999,7 +3010,7 @@ async def test_worker_status_sync(s, a): while ws.status != Status.running: await asyncio.sleep(0.01) - await s.retire_workers() + await s.handle_retire_workers(stimulus_id="test") while ws.status != Status.closed: await asyncio.sleep(0.01) @@ -3094,7 +3105,7 @@ async def test_task_flight_compute_oserror(c, s, a, b): ("f1", "put-in-memory"), ("f1", "executing", "memory", "memory", {}), ] - assert_worker_story(sum_story, expected_sum_story, strict=True) + assert_story(sum_story, expected_sum_story, strict=True) @gen_cluster(client=True) @@ -3274,7 +3285,9 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker( while not mocked_gather.call_args: await asyncio.sleep(0) - await s.remove_worker(address=x.address, safe=True, close=close_worker) + await s.handle_remove_worker( + address=x.address, safe=True, close=close_worker, stimulus_id="test" + ) await _wait_for_state(fut2_key, b, intermediate_state) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 78597a37e67..f03966ba656 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -107,5 +107,9 @@ def test_slots(cls): def test_sendmsg_to_dict(): # Arbitrary sample class - smsg = ReleaseWorkerDataMsg(key="x") - assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"} + smsg = ReleaseWorkerDataMsg(key="x", stimulus_id="test") + assert smsg.to_dict() == { + "op": "release-worker-data", + "key": "x", + "stimulus_id": "test", + } diff --git a/distributed/utils.py b/distributed/utils.py index cd0d93d08dd..0eeb63eddf4 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1437,6 +1437,45 @@ def __getattr__(name): raise AttributeError(f"module {__name__} has no attribute {name}") +STIMULUS_ID: ContextVar[str] = ContextVar("stimulus_id") + + +def expect_stimulus(sync: bool = True): + def decorator(fn): + if sync: + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + try: + token = STIMULUS_ID.set(kwargs.pop("stimulus_id")) + except KeyError: + raise ValueError(f"{fn} missing stimulus_id") + + try: + return fn(*args, **kwargs) + finally: + STIMULUS_ID.reset(token) + + else: + + @functools.wraps(fn) + # async def wrapper(*args, stimulus_id: str, **kwargs): + async def wrapper(*args, **kwargs): + try: + token = STIMULUS_ID.set(kwargs.pop("stimulus_id")) + except KeyError: + raise ValueError(f"{fn} missing stimulus_id") + + try: + return await fn(*args, **kwargs) + finally: + STIMULUS_ID.reset(token) + + return wrapper + + return decorator + + # Used internally by recursive_to_dict to stop infinite recursion. If an object has # already been encountered, a string representation will be returned instead. This is # necessary since we have multiple cyclic referencing data structures. diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 31e6cbe08fe..53d4f1403ac 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1875,7 +1875,7 @@ def xfail_ssl_issue5601(): raise -def assert_worker_story( +def assert_story( story: list[tuple], expect: list[tuple], *, strict: bool = False ) -> None: """Test the output of ``Worker.story`` @@ -1941,7 +1941,7 @@ def assert_worker_story( break except StopIteration: raise AssertionError( - f"assert_worker_story({strict=}) failed\n" + f"assert_story({strict=}) failed\n" f"story:\n{_format_story(story)}\n" f"expect:\n{_format_story(expect)}" ) from None diff --git a/distributed/worker.py b/distributed/worker.py index 7a0628763dd..38e1e341577 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -919,22 +919,26 @@ def status(self, value): """ prev_status = self.status ServerNode.status.__set__(self, value) - self._send_worker_status_change() + self._send_worker_status_change(f"worker-status-change-{time()}") if prev_status == Status.paused and value == Status.running: self.ensure_computing() self.ensure_communicating() - def _send_worker_status_change(self) -> None: + def _send_worker_status_change(self, stimulus_id: str) -> None: if ( self.batched_stream and self.batched_stream.comm and not self.batched_stream.comm.closed() ): self.batched_stream.send( - {"op": "worker-status-change", "status": self._status.name} + { + "op": "worker-status-change", + "status": self._status.name, + "stimulus_id": stimulus_id, + }, ) elif self._status != Status.closed: - self.loop.call_later(0.05, self._send_worker_status_change) + self.loop.call_later(0.05, self._send_worker_status_change, stimulus_id) async def get_metrics(self) -> dict: try: @@ -1075,6 +1079,7 @@ async def _register_with_scheduler(self): versions=get_versions(), metrics=await self.get_metrics(), extra=await self.get_startup_information(), + stimulus_id=f"worker-connect-{time()}", ), serializers=["msgpack"], ) @@ -1457,7 +1462,9 @@ async def close( if report and self.contact_address is not None: await asyncio.wait_for( self.scheduler.unregister( - address=self.contact_address, safe=safe + address=self.contact_address, + safe=safe, + stimulus_id=f"worker-close-{time()}", ), timeout, ) @@ -1522,7 +1529,10 @@ async def close_gracefully(self, restart=None): # Scheduler.retire_workers will set the status to closing_gracefully and push it # back to this worker. await self.scheduler.retire_workers( - workers=[self.address], close_workers=False, remove=False + workers=[self.address], + close_workers=False, + remove=False, + stimulus_id=f"worker-close-gracefully-{time()}", ) await self.close(safe=True, nanny=not restart) @@ -1877,7 +1887,9 @@ def handle_compute_task( pass elif ts.state == "memory": recommendations[ts] = "memory" - instructions.append(self._get_task_finished_msg(ts)) + instructions.append( + self._get_task_finished_msg(ts, stimulus_id=stimulus_id) + ) elif ts.state in { "released", "fetch", @@ -2005,7 +2017,7 @@ def transition_memory_released( recs, instructions = self.transition_generic_released( ts, stimulus_id=stimulus_id ) - instructions.append(ReleaseWorkerDataMsg(ts.key)) + instructions.append(ReleaseWorkerDataMsg(key=ts.key, stimulus_id=stimulus_id)) return recs, instructions def transition_waiting_constrained( @@ -2028,7 +2040,7 @@ def transition_long_running_rescheduled( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, worker=self.address) + smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id) return recs, [smsg] def transition_executing_rescheduled( @@ -2039,7 +2051,7 @@ def transition_executing_rescheduled( self._executing.discard(ts) recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, worker=self.address) + smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id) return recs, [smsg] def transition_waiting_ready( @@ -2116,6 +2128,7 @@ def transition_generic_error( traceback_text=traceback_text, thread=self.threads.get(ts.key), startstops=ts.startstops, + stimulus_id=stimulus_id, ) return {}, [smsg] @@ -2300,7 +2313,7 @@ def transition_generic_memory( return recs, [] if self.validate: assert ts.key in self.data or ts.key in self.actors - smsg = self._get_task_finished_msg(ts) + smsg = self._get_task_finished_msg(ts, stimulus_id=stimulus_id) return recs, [smsg] def transition_executing_memory( @@ -2417,7 +2430,9 @@ def transition_executing_long_running( ts.state = "long-running" self._executing.discard(ts) self.long_running.add(ts.key) - smsg = LongRunningMsg(key=ts.key, compute_duration=compute_duration) + smsg = LongRunningMsg( + key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id + ) self.io_loop.add_callback(self.ensure_computing) return {}, [smsg] @@ -2701,7 +2716,9 @@ def ensure_communicating(self) -> None: for el in skipped_worker_in_flight: self.data_needed.push(el) - def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: + def _get_task_finished_msg( + self, ts: TaskState, stimulus_id: str + ) -> TaskFinishedMsg: if ts.key not in self.data and ts.key not in self.actors: raise RuntimeError(f"Task {ts} not ready") typ = ts.type @@ -2727,6 +2744,7 @@ def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: metadata=ts.metadata, thread=self.threads.get(ts.key), startstops=ts.startstops, + stimulus_id=stimulus_id, ) def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: @@ -3030,7 +3048,12 @@ async def gather_dep( self.has_what[worker].discard(ts.key) self.log.append((d, "missing-dep", stimulus_id, time())) self.batched_stream.send( - {"op": "missing-data", "errant_worker": worker, "key": d} + { + "op": "missing-data", + "errant_worker": worker, + "key": d, + "stimulus_id": stimulus_id, + } ) recommendations[ts] = "fetch" if ts.who_has else "missing" del data, response @@ -3135,7 +3158,7 @@ def handle_steal_request(self, key: str, stimulus_id: str) -> None: # `transition_constrained_executing` self.transition(ts, "released", stimulus_id=stimulus_id) - def handle_worker_status_change(self, status: str) -> None: + def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: new_status = Status.lookup[status] # type: ignore if ( @@ -3146,7 +3169,7 @@ def handle_worker_status_change(self, status: str) -> None: "Invalid Worker.status transition: %s -> %s", self._status, new_status ) # Reiterate the current status to the scheduler to restore sync - self._send_worker_status_change() + self._send_worker_status_change(stimulus_id) else: # Update status and send confirmation to the Scheduler (see status.setter) self.status = new_status @@ -3483,6 +3506,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No if result["op"] == "task-finished": if self.digests is not None: self.digests["task-duration"].add(result["stop"] - result["start"]) + new_stimulus_id = f"{result['op']}-{time()}" return ExecuteSuccessEvent( key=key, value=result["result"], @@ -3490,7 +3514,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No stop=result["stop"], nbytes=result["nbytes"], type=result["type"], - stimulus_id=stimulus_id, + stimulus_id=new_stimulus_id, ) if isinstance(result["actual-exception"], Reschedule): @@ -3517,7 +3541,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No traceback=result["traceback"], exception_text=result["exception_text"], traceback_text=result["traceback_text"], - stimulus_id=stimulus_id, + stimulus_id=f"task-erred-{time()}", ) except Exception as exc: @@ -3531,7 +3555,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No traceback=msg["traceback"], exception_text=msg["exception_text"], traceback_text=msg["traceback_text"], - stimulus_id=stimulus_id, + stimulus_id=f"task-erred-{time()}", ) @functools.singledispatchmethod @@ -3596,7 +3620,7 @@ def _(self, ev: ExecuteFailureEvent) -> RecsInstrs: ev.traceback, ev.exception_text, ev.traceback_text, - ) + ), }, [] @handle_event.register diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 8ae454417c9..966be0a3527 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -288,6 +288,7 @@ class TaskFinishedMsg(SendMessageToScheduler): metadata: dict thread: int | None startstops: list[StartStop] + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_dict(self) -> dict[str, Any]: @@ -307,6 +308,7 @@ class TaskErredMsg(SendMessageToScheduler): traceback_text: str thread: int | None startstops: list[StartStop] + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_dict(self) -> dict[str, Any]: @@ -319,8 +321,9 @@ def to_dict(self) -> dict[str, Any]: class ReleaseWorkerDataMsg(SendMessageToScheduler): op = "release-worker-data" - __slots__ = ("key",) + __slots__ = ("key", "stimulus_id") key: str + stimulus_id: str # Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception @@ -328,18 +331,21 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler): class RescheduleMsg(SendMessageToScheduler): op = "reschedule" - __slots__ = ("key", "worker") + # Not to be confused with the distributed.Reschedule Exception + __slots__ = ("key", "worker", "stimulus_id") key: str worker: str + stimulus_id: str @dataclass class LongRunningMsg(SendMessageToScheduler): op = "long-running" - __slots__ = ("key", "compute_duration") + __slots__ = ("key", "compute_duration", "stimulus_id") key: str compute_duration: float + stimulus_id: str @dataclass @@ -365,6 +371,7 @@ class ExecuteSuccessEvent(StateMachineEvent): stop: float nbytes: int type: type | None + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore @@ -377,6 +384,7 @@ class ExecuteFailureEvent(StateMachineEvent): traceback: Serialize | None exception_text: str traceback_text: str + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore