diff --git a/distributed/client.py b/distributed/client.py index 86ca5f506e0..e49a14d90dd 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4294,6 +4294,35 @@ def collections_to_dsk(collections, *args, **kwargs): """Convert many collections into a single dask graph, after optimization""" return collections_to_dsk(collections, *args, **kwargs) + async def _story(self, keys=(), on_error="raise"): + assert on_error in ("raise", "ignore") + + try: + flat_stories = list(await self.scheduler.get_story(keys=keys)) + except Exception: + if on_error == "raise": + raise + elif on_error == "ignore": + flat_stories = [] + else: + raise ValueError(f"on_error not in {'raise', 'ignore'}") + + responses = await self.scheduler.broadcast( + msg={"op": "get_story", "keys": keys}, on_error=on_error + ) + + for stories in responses.values(): + if isinstance(stories, (tuple, list)): + flat_stories.extend(s for s in stories) + else: + flat_stories.append(stories) + + return flat_stories + + def story(self, *keys_or_stimulus_ids, on_error="raise"): + """Returns a cluster-wide story for the given keys or simtulus_id's""" + return self.sync(self._story, keys=keys_or_stimulus_ids, on_error=on_error) + def get_task_stream( self, start=None, diff --git a/distributed/http/templates/task.html b/distributed/http/templates/task.html index 0b5c10695e0..f10aaad5602 100644 --- a/distributed/http/templates/task.html +++ b/distributed/http/templates/task.html @@ -118,16 +118,18 @@

Transition Log

Key Start Finish + Stimulus ID Recommended Key Recommended Action - {% for key, start, finish, recommendations, transition_time in scheduler.story(Task) %} + {% for key, start, finish, recommendations, stimulus_id, transition_time in scheduler.story(Task) %} {{ fromtimestamp(transition_time) }} {{key}} {{ start }} {{ finish }} + {{ stimulus_id }} @@ -137,6 +139,7 @@

Transition Log

+ {{key2}} {{ rec }} diff --git a/distributed/nanny.py b/distributed/nanny.py index db2371523e8..91873bb3a8a 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -290,7 +290,10 @@ async def _unregister(self, timeout=10): allowed_errors = (TimeoutError, CommClosedError, EnvironmentError, RPCClosed) with suppress(allowed_errors): await asyncio.wait_for( - self.scheduler.unregister(address=self.worker_address), timeout + self.scheduler.unregister( + address=self.worker_address, stimulus_id=f"nanny-close-{time()}" + ), + timeout, ) @property diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 0e0ae003b5f..f81a8327747 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -116,6 +116,7 @@ def _encode_default(obj): return msgpack_encode_default(obj) frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True) + return frames except Exception: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0b22d2ff2ef..1035676bbcc 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -110,6 +110,7 @@ DEFAULT_DATA_SIZE = parse_bytes( dask.config.get("distributed.scheduler.default-data-size") ) +STIMULUS_ID_UNSET = "" DEFAULT_EXTENSIONS = { "locks": LockExtension, @@ -1617,7 +1618,7 @@ def new_task( # State Transitions # ##################### - def _transition(self, key, finish: str, *args, **kwargs): + def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs): """Transition a key from its current state to the finish state Examples @@ -1652,14 +1653,16 @@ def _transition(self, key, finish: str, *args, **kwargs): start_finish = (start, finish) func = self.transitions_table.get(start_finish) if func is not None: - recommendations, client_msgs, worker_msgs = func(key, *args, **kwargs) # type: ignore + recommendations, client_msgs, worker_msgs = func( + key, stimulus_id, *args, **kwargs + ) # type: ignore self.transition_counter += 1 elif "released" not in start_finish: assert not args and not kwargs, (args, kwargs, start_finish) a_recs: dict a_cmsgs: dict a_wmsgs: dict - a: tuple = self._transition(key, "released") + a: tuple = self._transition(key, "released", stimulus_id) a_recs, a_cmsgs, a_wmsgs = a v = a_recs.get(key, finish) @@ -1667,7 +1670,7 @@ def _transition(self, key, finish: str, *args, **kwargs): b_recs: dict b_cmsgs: dict b_wmsgs: dict - b: tuple = func(key) # type: ignore + b: tuple = func(key, stimulus_id) # type: ignore b_recs, b_cmsgs, b_wmsgs = b recommendations.update(a_recs) @@ -1702,13 +1705,20 @@ def _transition(self, key, finish: str, *args, **kwargs): else: raise RuntimeError("Impossible transition from %r to %r" % start_finish) + if not stimulus_id: + stimulus_id = STIMULUS_ID_UNSET + finish2 = ts._state # FIXME downcast antipattern scheduler = pep484_cast(Scheduler, self) scheduler.transition_log.append( - (key, start, finish2, recommendations, time()) + (key, start, finish2, recommendations, stimulus_id, time()) ) if self.validate: + if stimulus_id == STIMULUS_ID_UNSET: + raise RuntimeError( + "stimulus_id not set during Scheduler transition" + ) logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -1752,7 +1762,13 @@ def _transition(self, key, finish: str, *args, **kwargs): pdb.set_trace() raise - def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict): + def _transitions( + self, + recommendations: dict, + client_msgs: dict, + worker_msgs: dict, + stimulus_id: str, + ): """Process transitions until none are left This includes feedback from previous transitions and continues until we @@ -1770,7 +1786,7 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di key, finish = recommendations.popitem() keys.add(key) - new = self._transition(key, finish) + new = self._transition(key, finish, stimulus_id) new_recs, new_cmsgs, new_wmsgs = new recommendations.update(new_recs) @@ -1793,7 +1809,7 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di for key in keys: scheduler.validate_key(key) - def transition_released_waiting(self, key): + def transition_released_waiting(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -1848,7 +1864,7 @@ def transition_released_waiting(self, key): pdb.set_trace() raise - def transition_no_worker_waiting(self, key): + def transition_no_worker_waiting(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -1896,7 +1912,13 @@ def transition_no_worker_waiting(self, key): raise def transition_no_worker_memory( - self, key, nbytes=None, type=None, typename: str = None, worker=None + self, + key, + stimulus_id, + nbytes=None, + type=None, + typename: str = None, + worker=None, ): try: ws: WorkerState = self.workers[worker] @@ -2051,7 +2073,7 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> float: return total_duration - def transition_waiting_processing(self, key): + def transition_waiting_processing(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -2096,7 +2118,14 @@ def transition_waiting_processing(self, key): raise def transition_waiting_memory( - self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs + self, + key, + stimulus_id, + nbytes=None, + type=None, + typename: str = None, + worker=None, + **kwargs, ): try: ws: WorkerState = self.workers[worker] @@ -2138,6 +2167,7 @@ def transition_waiting_memory( def transition_processing_memory( self, key, + stimulus_id, nbytes=None, type=None, typename: str = None, @@ -2231,7 +2261,7 @@ def transition_processing_memory( pdb.set_trace() raise - def transition_memory_released(self, key, safe: bool = False): + def transition_memory_released(self, key, stimulus_id, safe: bool = False): ws: WorkerState try: ts: TaskState = self.tasks[key] @@ -2269,7 +2299,7 @@ def transition_memory_released(self, key, safe: bool = False): worker_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": f"memory-released-{time()}", + "stimulus_id": stimulus_id, } for ws in ts.who_has: worker_msgs[ws.address] = [worker_msg] @@ -2301,7 +2331,7 @@ def transition_memory_released(self, key, safe: bool = False): pdb.set_trace() raise - def transition_released_erred(self, key): + def transition_released_erred(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -2346,7 +2376,7 @@ def transition_released_erred(self, key): pdb.set_trace() raise - def transition_erred_released(self, key): + def transition_erred_released(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -2372,7 +2402,7 @@ def transition_erred_released(self, key): w_msg = { "op": "free-keys", "keys": [key], - "stimulus_id": f"erred-released-{time()}", + "stimulus_id": stimulus_id, } for ws_addr in ts.erred_on: worker_msgs[ws_addr] = [w_msg] @@ -2394,7 +2424,7 @@ def transition_erred_released(self, key): pdb.set_trace() raise - def transition_waiting_released(self, key): + def transition_waiting_released(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] recommendations: dict = {} @@ -2431,7 +2461,7 @@ def transition_waiting_released(self, key): pdb.set_trace() raise - def transition_processing_released(self, key): + def transition_processing_released(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -2451,7 +2481,7 @@ def transition_processing_released(self, key): { "op": "free-keys", "keys": [key], - "stimulus_id": f"processing-released-{time()}", + "stimulus_id": stimulus_id, } ] @@ -2485,6 +2515,7 @@ def transition_processing_released(self, key): def transition_processing_erred( self, key: str, + stimulus_id: str, cause: str = None, exception=None, traceback=None, @@ -2571,7 +2602,7 @@ def transition_processing_erred( pdb.set_trace() raise - def transition_no_worker_released(self, key): + def transition_no_worker_released(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] dts: TaskState @@ -2613,7 +2644,7 @@ def remove_key(self, key): ts.exception_blame = ts.exception = ts.traceback = None self.task_metadata.pop(key, None) - def transition_memory_forgotten(self, key): + def transition_memory_forgotten(self, key, stimulus_id): ws: WorkerState try: ts: TaskState = self.tasks[key] @@ -2641,7 +2672,7 @@ def transition_memory_forgotten(self, key): for ws in ts.who_has: ws.actors.discard(ts) - _propagate_forgotten(self, ts, recommendations, worker_msgs) + _propagate_forgotten(self, ts, recommendations, worker_msgs, stimulus_id) client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) @@ -2655,7 +2686,7 @@ def transition_memory_forgotten(self, key): pdb.set_trace() raise - def transition_released_forgotten(self, key): + def transition_released_forgotten(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] recommendations: dict = {} @@ -2679,7 +2710,7 @@ def transition_released_forgotten(self, key): else: assert 0, (ts,) - _propagate_forgotten(self, ts, recommendations, worker_msgs) + _propagate_forgotten(self, ts, recommendations, worker_msgs, stimulus_id) client_msgs = _task_to_client_msgs(self, ts) self.remove_key(key) @@ -3303,6 +3334,7 @@ def __init__( "get_cluster_state": self.get_cluster_state, "dump_cluster_state_to_url": self.dump_cluster_state_to_url, "benchmark_hardware": self.benchmark_hardware, + "get_story": self.get_story, } connection_limit = get_fileno_limit() / 2 @@ -3626,7 +3658,7 @@ async def close(self, fast=False, close_workers=False): setproctitle("dask-scheduler [closed]") disable_gc_diagnosis() - async def close_worker(self, worker: str, safe: bool = False): + async def close_worker(self, worker: str, stimulus_id: str, safe: bool = False): """Remove a worker from the cluster This both removes the worker from our local state and also sends a @@ -3638,7 +3670,7 @@ async def close_worker(self, worker: str, safe: bool = False): self.log_event(worker, {"action": "close-worker"}) # FIXME: This does not handle nannies self.worker_send(worker, {"op": "close", "report": False}) - await self.remove_worker(address=worker, safe=safe) + await self.remove_worker(address=worker, safe=safe, stimulus_id=stimulus_id) ########### # Stimuli # @@ -3775,8 +3807,10 @@ async def add_worker( versions=None, nanny=None, extra=None, + stimulus_id=None, ): """Add a new worker to the cluster""" + stimulus_id = stimulus_id or f"add-worker-{time()}" with log_errors(): address = self.coerce_address(address, resolve_address) address = normalize_address(address) @@ -3876,12 +3910,15 @@ async def add_worker( t: tuple = self._transition( key, "memory", + stimulus_id, worker=address, nbytes=nbytes[key], typename=types[key], ) recommendations, client_msgs, worker_msgs = t - self._transitions(recommendations, client_msgs, worker_msgs) + self._transitions( + recommendations, client_msgs, worker_msgs, stimulus_id + ) recommendations = {} else: already_released_keys.append(key) @@ -3900,7 +3937,9 @@ async def add_worker( recommendations.update(self.bulk_schedule_after_adding_worker(ws)) if recommendations: - self._transitions(recommendations, client_msgs, worker_msgs) + self._transitions( + recommendations, client_msgs, worker_msgs, stimulus_id + ) self.send_all(client_msgs, worker_msgs) @@ -3928,7 +3967,7 @@ async def add_worker( if comm: await comm.write(msg) - await self.handle_worker(comm=comm, worker=address) + await self.handle_worker(comm=comm, worker=address, stimulus_id=stimulus_id) async def add_nanny(self, comm): msg = { @@ -3992,6 +4031,7 @@ def update_graph_hlg( fifo_timeout, annotations, code=code, + stimulus_id=f"update-graph-hlg-{time()}", ) def update_graph( @@ -4011,12 +4051,14 @@ def update_graph( fifo_timeout=0, annotations=None, code=None, + stimulus_id=None, ): """ Add new computations to the internal dask graph This happens whenever the Client calls submit, map, get, or compute. """ + stimulus_id = stimulus_id or f"update-graph-{time()}" start = time() fifo_timeout = parse_timedelta(fifo_timeout) keys = set(keys) @@ -4268,7 +4310,7 @@ def update_graph( except Exception as e: logger.exception(e) - self.transitions(recommendations) + self.transitions(recommendations, stimulus_id) for ts in touched_tasks: if ts.state in ("memory", "erred"): @@ -4280,7 +4322,7 @@ def update_graph( # TODO: balance workers - def stimulus_task_finished(self, key=None, worker=None, **kwargs): + def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): """Mark that a task has finished execution on a particular worker""" logger.debug("Stimulus task finished %s, %s", key, worker) @@ -4303,14 +4345,16 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): { "op": "free-keys", "keys": [key], - "stimulus_id": f"already-released-or-forgotten-{time()}", + "stimulus_id": stimulus_id, } ] elif ts.state == "memory": self.add_keys(worker=worker, keys=[key]) else: ts.metadata.update(kwargs["metadata"]) - r: tuple = self._transition(key, "memory", worker=worker, **kwargs) + r: tuple = self._transition( + key, "memory", stimulus_id, worker=worker, **kwargs + ) recommendations, client_msgs, worker_msgs = r if ts.state == "memory": @@ -4318,7 +4362,13 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): return recommendations, client_msgs, worker_msgs def stimulus_task_erred( - self, key=None, worker=None, exception=None, traceback=None, **kwargs + self, + key=None, + worker=None, + exception=None, + stimulus_id=None, + traceback=None, + **kwargs, ): """Mark that a task has erred on a particular worker""" logger.debug("Stimulus task erred %s, %s", key, worker) @@ -4329,11 +4379,12 @@ def stimulus_task_erred( if ts.retries > 0: ts.retries -= 1 - return self._transition(key, "waiting") + return self._transition(key, "waiting", stimulus_id) else: return self._transition( key, "erred", + stimulus_id, cause=key, exception=exception, traceback=traceback, @@ -4362,7 +4413,7 @@ def stimulus_retry(self, keys, client=None): roots.append(key) recommendations: dict = {key: "waiting" for key in roots} - self.transitions(recommendations) + self.transitions(recommendations, f"stimulus-retry-{time()}") if self.validate: for key in seen: @@ -4370,7 +4421,7 @@ def stimulus_retry(self, keys, client=None): return tuple(seen) - async def remove_worker(self, address, safe=False, close=True): + async def remove_worker(self, address, safe=False, close=True, stimulus_id=None): """ Remove worker from cluster @@ -4378,6 +4429,7 @@ async def remove_worker(self, address, safe=False, close=True): appears to be unresponsive. This may send its tasks back to a released state. """ + stimulus_id = stimulus_id or f"remove-worker-{time()}" with log_errors(): if self.status == Status.closed: return @@ -4438,7 +4490,9 @@ async def remove_worker(self, address, safe=False, close=True): e = pickle.dumps( KilledWorker(task=k, last_worker=ws.clean()), protocol=4 ) - r = self.transition(k, "erred", exception=e, cause=k) + r = self.transition( + k, "erred", exception=e, cause=k, stimulus_id=stimulus_id + ) recommendations.update(r) logger.info( "Task %s marked as failed because %d workers died" @@ -4455,7 +4509,7 @@ async def remove_worker(self, address, safe=False, close=True): else: # pure data recommendations[ts.key] = "forgotten" - self.transitions(recommendations) + self.transitions(recommendations, stimulus_id=stimulus_id) for plugin in list(self.plugins.values()): try: @@ -4535,15 +4589,16 @@ def client_desires_keys(self, keys=None, client=None): if ts.state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) - def client_releases_keys(self, keys=None, client=None): + def client_releases_keys(self, keys=None, client=None, stimulus_id=None): """Remove keys from client desired list""" + stimulus_id = stimulus_id or f"client-releases-keys-{time()}" if not isinstance(keys, list): keys = list(keys) cs: ClientState = self.clients[client] recommendations: dict = {} _client_releases_keys(self, keys=keys, cs=cs, recommendations=recommendations) - self.transitions(recommendations) + self.transitions(recommendations, stimulus_id) def client_heartbeat(self, client=None): """Handle heartbeats from Client""" @@ -4740,6 +4795,7 @@ async def add_client(self, comm: Comm, client: str, versions: dict) -> None: We listen to all future messages from this Comm. """ + stimulus_id = f"add-client-{time()}" assert client is not None comm.name = "Scheduler->Client" logger.info("Receive client connection: %s", client) @@ -4768,7 +4824,7 @@ async def add_client(self, comm: Comm, client: str, versions: dict) -> None: try: await self.handle_stream(comm=comm, extra={"client": client}) finally: - self.remove_client(client=client) + self.remove_client(client=client, stimulus_id=stimulus_id) logger.debug("Finished handling client %s", client) finally: if not comm.closed(): @@ -4782,8 +4838,9 @@ async def add_client(self, comm: Comm, client: str, versions: dict) -> None: except TypeError: # comm becomes None during GC pass - def remove_client(self, client: str) -> None: + def remove_client(self, client: str, stimulus_id: str = None) -> None: """Remove client from network""" + stimulus_id = stimulus_id or f"remove-client-{time()}" if self.status == Status.running: logger.info("Remove client %s", client) self.log_event(["all", client], {"action": "remove-client", "client": client}) @@ -4794,7 +4851,9 @@ def remove_client(self, client: str) -> None: pass else: self.client_releases_keys( - keys=[ts.key for ts in cs.wants_what], client=cs.client_key + keys=[ts.key for ts in cs.wants_what], + client=cs.client_key, + stimulus_id=stimulus_id, ) del self.clients[client] @@ -4830,24 +4889,28 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: float = -1): def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) - def handle_task_finished(self, key=None, worker=None, **msg): + def handle_task_finished(self, key=None, worker=None, stimulus_id=None, **msg): if worker not in self.workers: return validate_key(key) - r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) + + r: tuple = self.stimulus_task_finished( + key=key, worker=worker, stimulus_id=stimulus_id, **msg + ) recommendations, client_msgs, worker_msgs = r - self._transitions(recommendations, client_msgs, worker_msgs) + self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) - def handle_task_erred(self, key=None, **msg): - r: tuple = self.stimulus_task_erred(key=key, **msg) + def handle_task_erred(self, key=None, stimulus_id=None, **msg): + r: tuple = self.stimulus_task_erred(key=key, stimulus_id=stimulus_id, **msg) recommendations, client_msgs, worker_msgs = r - self._transitions(recommendations, client_msgs, worker_msgs) - + self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) - def handle_missing_data(self, key=None, errant_worker=None, **kwargs): + def handle_missing_data( + self, key=None, errant_worker=None, stimulus_id=None, **kwargs + ): """Signal that `errant_worker` does not hold `key` This may either indicate that `errant_worker` is dead or that we may be @@ -4864,6 +4927,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): errant_worker : str, optional Address of the worker supposed to hold a replica, by default None """ + assert stimulus_id logger.debug("handle missing data key=%s worker=%s", key, errant_worker) self.log_event(errant_worker, {"action": "missing-data", "key": key}) ts: TaskState = self.tasks.get(key) @@ -4875,11 +4939,11 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): self.remove_replica(ts, ws) if ts.state == "memory" and not ts.who_has: if ts.run_spec: - self.transitions({key: "released"}) + self.transitions({key: "released"}, stimulus_id) else: - self.transitions({key: "forgotten"}) + self.transitions({key: "forgotten"}, stimulus_id) - def release_worker_data(self, key, worker): + def release_worker_data(self, key, worker, stimulus_id): ws: WorkerState = self.workers.get(worker) ts: TaskState = self.tasks.get(key) if not ws or not ts: @@ -4890,7 +4954,7 @@ def release_worker_data(self, key, worker): if not ts.who_has: recommendations[ts.key] = "released" if recommendations: - self.transitions(recommendations) + self.transitions(recommendations, stimulus_id) def handle_long_running(self, key=None, worker=None, compute_duration=None): """A task has seceded from the thread pool @@ -4932,7 +4996,9 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): ws.long_running.add(ts) self.check_idle_saturated(ws) - def handle_worker_status_change(self, status: str, worker: str) -> None: + def handle_worker_status_change( + self, status: str, worker: str, stimulus_id: str + ) -> None: ws: WorkerState = self.workers.get(worker) # type: ignore if not ws: return @@ -4956,13 +5022,13 @@ def handle_worker_status_change(self, status: str, worker: str) -> None: if recs: client_msgs: dict = {} worker_msgs: dict = {} - self._transitions(recs, client_msgs, worker_msgs) + self._transitions(recs, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) else: self.running.discard(ws) - async def handle_worker(self, comm=None, worker=None): + async def handle_worker(self, comm=None, worker=None, stimulus_id=None): """ Listen to responses from a single worker @@ -4972,6 +5038,7 @@ async def handle_worker(self, comm=None, worker=None): -------- Scheduler.handle_client: Equivalent coroutine for clients """ + assert stimulus_id comm.name = "Scheduler connection to worker" worker_comm = self.stream_comms[worker] worker_comm.start(comm) @@ -4981,7 +5048,7 @@ async def handle_worker(self, comm=None, worker=None): finally: if worker in self.stream_comms: worker_comm.abort() - await self.remove_worker(address=worker) + await self.remove_worker(address=worker, stimulus_id=stimulus_id) def add_plugin( self, @@ -5119,7 +5186,11 @@ def worker_send(self, worker, msg): try: stream_comms[worker].send(msg) except (CommClosedError, AttributeError): - self.loop.add_callback(self.remove_worker, address=worker) + self.loop.add_callback( + self.remove_worker, + address=worker, + stimulus_id=f"worker-send-comm-fail-{time()}", + ) def client_send(self, client, msg): """Send message to client""" @@ -5164,7 +5235,11 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): # worker already gone pass except (CommClosedError, AttributeError): - self.loop.add_callback(self.remove_worker, address=worker) + self.loop.add_callback( + self.remove_worker, + address=worker, + stimulus_id=f"send-all-comm-fail-{time()}", + ) ############################ # Less common interactions # @@ -5221,6 +5296,7 @@ async def scatter( async def gather(self, keys, serializers=None): """Collect data from workers to the scheduler""" + stimulus_id = f"gather-{time()}" keys = list(keys) who_has = {} for key in keys: @@ -5252,7 +5328,9 @@ async def gather(self, keys, serializers=None): # reconnect. await asyncio.gather( *( - self.remove_worker(address=worker, close=False) + self.remove_worker( + address=worker, close=False, stimulus_id=stimulus_id + ) for worker in missing_workers ) ) @@ -5290,6 +5368,7 @@ def clear_task_state(self): async def restart(self, client=None, timeout=30): """Restart all workers. Reset local state.""" + stimulus_id = f"restart-{time()}" with log_errors(): n_workers = len(self.workers) @@ -5306,7 +5385,9 @@ async def restart(self, client=None, timeout=30): try: # Ask the worker to close if it doesn't have a nanny, # otherwise the nanny will kill it anyway - await self.remove_worker(address=addr, close=addr not in nannies) + await self.remove_worker( + address=addr, close=addr not in nannies, stimulus_id=stimulus_id + ) except Exception: logger.info( "Exception while restarting. This is normal", exc_info=True @@ -5501,7 +5582,7 @@ async def gather_on_worker( return keys_failed async def delete_worker_data( - self, worker_address: str, keys: "Collection[str]" + self, worker_address: str, keys: "Collection[str]", stimulus_id: str ) -> None: """Delete data from a worker and update the corresponding worker/task states @@ -5538,7 +5619,7 @@ async def delete_worker_data( self.remove_replica(ts, ws) if not ts.who_has: # Last copy deleted - self.transitions({key: "released"}) + self.transitions({key: "released"}, stimulus_id) self.log_event(ws.address, {"action": "remove-worker-data", "keys": keys}) @@ -5547,6 +5628,7 @@ async def rebalance( comm=None, keys: "Iterable[Hashable]" = None, workers: "Iterable[str]" = None, + stimulus_id: str = None, ) -> dict: """Rebalance keys so that each worker ends up with roughly the same process memory (managed+unmanaged). @@ -5613,6 +5695,7 @@ async def rebalance( All other workers will be ignored. The mean cluster occupancy will be calculated only using the allowed workers. """ + stimulus_id = stimulus_id or f"rebalance-{time()}" with log_errors(): if workers is not None: wss = [self.workers[w] for w in workers] @@ -5637,7 +5720,7 @@ async def rebalance( return {"status": "OK"} async with self._lock: - result = await self._rebalance_move_data(msgs) + result = await self._rebalance_move_data(msgs, stimulus_id) if result["status"] == "partial-fail" and keys is None: # Only return failed keys if the client explicitly asked for them result = {"status": "OK"} @@ -5834,7 +5917,7 @@ def _rebalance_find_msgs( return msgs async def _rebalance_move_data( - self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]" + self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]", stimulus_id: str ) -> dict: """Perform the actual transfer of data across the network in rebalance(). Takes in input the output of _rebalance_find_msgs(), that is a list of tuples: @@ -5870,7 +5953,7 @@ async def _rebalance_move_data( # Note: this never raises exceptions await asyncio.gather( - *(self.delete_worker_data(r, v) for r, v in to_senders.items()) + *(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items()) ) for r, v in to_recipients.items(): @@ -5900,6 +5983,7 @@ async def replicate( branching_factor=2, delete=True, lock=True, + stimulus_id=None, ): """Replicate data throughout cluster @@ -5922,6 +6006,7 @@ async def replicate( -------- Scheduler.rebalance """ + stimulus_id = stimulus_id or f"replicate-{time()}" assert branching_factor > 0 async with self._lock if lock else empty_context: if workers is not None: @@ -5956,7 +6041,9 @@ async def replicate( # Note: this never raises exceptions await asyncio.gather( *[ - self.delete_worker_data(ws.address, [t.key for t in tasks]) + self.delete_worker_data( + ws.address, [t.key for t in tasks], stimulus_id + ) for ws, tasks in del_worker_tasks.items() ] ) @@ -6144,6 +6231,7 @@ async def retire_workers( names: "list | None" = None, close_workers: bool = False, remove: bool = True, + stimulus_id: str = None, **kwargs, ) -> dict: """Gracefully retire workers from cluster @@ -6177,6 +6265,7 @@ async def retire_workers( -------- Scheduler.workers_to_close """ + stimulus_id = stimulus_id or f"retire-workers-{time()}" with log_errors(): # This lock makes retire_workers, rebalance, and replicate mutually # exclusive and will no longer be necessary once rebalance and replicate are @@ -6231,7 +6320,11 @@ async def retire_workers( ws.status = Status.closing_gracefully self.running.discard(ws) self.stream_comms[ws.address].send( - {"op": "worker-status-change", "status": ws.status.name} + { + "op": "worker-status-change", + "status": ws.status.name, + "stimulus_id": stimulus_id, + } ) coros.append( @@ -6241,6 +6334,7 @@ async def retire_workers( prev_status=prev_status, close_workers=close_workers, remove=remove, + stimulus_id=stimulus_id, ) ) @@ -6267,6 +6361,7 @@ async def _track_retire_worker( prev_status: Status, close_workers: bool, remove: bool, + stimulus_id: str, ) -> tuple: # tuple[str | None, dict] while not policy.done(): if policy.no_recipients: @@ -6274,7 +6369,11 @@ async def _track_retire_worker( # conditions and we can wait for a scheduler->worker->scheduler # round-trip. self.stream_comms[ws.address].send( - {"op": "worker-status-change", "status": prev_status.name} + { + "op": "worker-status-change", + "status": prev_status.name, + "stimulus_id": stimulus_id, + } ) return None, {} @@ -6288,9 +6387,13 @@ async def _track_retire_worker( ) if close_workers and ws.address in self.workers: - await self.close_worker(worker=ws.address, safe=True) + await self.close_worker( + worker=ws.address, safe=True, stimulus_id=stimulus_id + ) if remove: - await self.remove_worker(address=ws.address, safe=True) + await self.remove_worker( + address=ws.address, safe=True, stimulus_id=stimulus_id + ) logger.info("Retired worker %s", ws.address) return ws.address, ws.identity() @@ -6727,7 +6830,7 @@ async def unregister_nanny_plugin(self, comm, name): ) return responses - def transition(self, key, finish: str, *args, **kwargs): + def transition(self, key, finish: str, *args, stimulus_id: str, **kwargs): """Transition a key from its current state to the finish state Examples @@ -6743,12 +6846,12 @@ def transition(self, key, finish: str, *args, **kwargs): -------- Scheduler.transitions: transitive version of this function """ - a: tuple = self._transition(key, finish, *args, **kwargs) + a: tuple = self._transition(key, finish, stimulus_id, *args, **kwargs) recommendations, client_msgs, worker_msgs = a self.send_all(client_msgs, worker_msgs) return recommendations - def transitions(self, recommendations: dict): + def transitions(self, recommendations: dict, stimulus_id: str): """Process transitions until none are left This includes feedback from previous transitions and continues until we @@ -6756,7 +6859,7 @@ def transitions(self, recommendations: dict): """ client_msgs: dict = {} worker_msgs: dict = {} - self._transitions(recommendations, client_msgs, worker_msgs) + self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) def story(self, *keys): @@ -6764,6 +6867,9 @@ def story(self, *keys): keys = {key.key if isinstance(key, TaskState) else key for key in keys} return scheduler_story(keys, self.transition_log) + async def get_story(self, keys=None): + return self.story(*keys) + transition_story = story def reschedule(self, key=None, worker=None): @@ -6784,7 +6890,7 @@ def reschedule(self, key=None, worker=None): return if worker and ts.processing_on.address != worker: return - self.transitions({key: "released"}) + self.transitions({key: "released"}, f"reschedule-{time()}") ##################### # Utility functions # @@ -7217,7 +7323,9 @@ async def check_worker_ttl(self): self.worker_ttl, ws, ) - await self.remove_worker(address=ws.address) + await self.remove_worker( + address=ws.address, stimulus_id=f"check-worker-ttl-{time()}" + ) def check_idle(self): if any([ws.processing for ws in self.workers.values()]) or self.unrunnable: @@ -7427,7 +7535,11 @@ def _add_to_memory( def _propagate_forgotten( - state: SchedulerState, ts: TaskState, recommendations: dict, worker_msgs: dict + state: SchedulerState, + ts: TaskState, + recommendations: dict, + worker_msgs: dict, + stimulus_id: str, ): ts.state = "forgotten" key: str = ts.key @@ -7458,7 +7570,7 @@ def _propagate_forgotten( { "op": "free-keys", "keys": [key], - "stimulus_id": f"propagate-forgotten-{time()}", + "stimulus_id": stimulus_id, } ] state.remove_all_replicas(ts) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index a18594dc6e4..94fdc3bc11f 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -6,7 +6,7 @@ from distributed.core import CommClosedError from distributed.utils_test import ( _LockedCommPool, - assert_worker_story, + assert_story, gen_cluster, inc, slowinc, @@ -82,7 +82,7 @@ def f(ev): while "f1" in a.tasks: await asyncio.sleep(0.01) - assert_worker_story( + assert_story( a.story("f1"), [ ("f1", "compute-task"), @@ -160,7 +160,7 @@ async def wait_and_raise(*args, **kwargs): await wait_for_state(fut1.key, "flight", b) # Close in scheduler to ensure we transition and reschedule task properly - await s.close_worker(worker=a.address) + await s.close_worker(worker=a.address, stimulus_id="test") await wait_for_state(fut1.key, "resumed", b) lock.release() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index abb9dcfd16e..0cbe97a50a2 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -77,6 +77,7 @@ from distributed.utils_test import ( TaskStateMetadataPlugin, _UnhashableCallable, + assert_story, async_wait_for, asyncinc, captured_logger, @@ -4563,7 +4564,7 @@ def test_auto_normalize_collection_sync(c): def assert_no_data_loss(scheduler): - for key, start, finish, recommendations, _ in scheduler.transition_log: + for key, start, finish, recommendations, _, _ in scheduler.transition_log: if start == "memory" and finish == "released": for k, v in recommendations.items(): assert not (k == key and v == "waiting") @@ -7494,3 +7495,58 @@ async def test_wait_for_workers_updates_info(c, s): async with Worker(s.address): await c.wait_for_workers(1) assert c.scheduler_info()["workers"] + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_client_story(c, s, *workers): + f = c.submit(inc, 1) + assert await f == 2 + story = await c.story(f.key) + + assert_story( + story, + [ + (f.key, "released", "waiting", {f.key: "processing"}), + (f.key, "waiting", "processing", {}), + (f.key, "processing", "memory", {}), + (f.key, "compute-task"), + (f.key, "released", "waiting", "waiting", {f.key: "ready"}), + (f.key, "waiting", "ready", "ready", {f.key: "executing"}), + (f.key, "ready", "executing", "executing", {}), + (f.key, "put-in-memory"), + (f.key, "executing", "memory", "memory", {}), + ], + ordered_timestamps=False, + ) + + stimulus_ids = {ev[-2] for ev in story} + assert len(stimulus_ids) == 3 + + +class WorkerBrokenStory(Worker): + async def get_story(self, *args, **kw): + raise CommClosedError + + +@gen_cluster(client=True, Worker=WorkerBrokenStory) +@pytest.mark.parametrize("on_error", ["ignore", "raise"]) +async def test_client_story_failed_worker(c, s, a, b, on_error): + f = c.submit(inc, 1) + coro = c.story(f.key, on_error=on_error) + await f + + if on_error == "raise": + with pytest.raises(CommClosedError) as e: + await coro + elif on_error == "ignore": + story = await coro + assert_story( + story, + [ + (f.key, "released", "waiting", {f.key: "processing"}), + (f.key, "waiting", "processing", {}), + (f.key, "processing", "memory", {}), + ], + ) + else: + raise ValueError(on_error) diff --git a/distributed/tests/test_cluster_dump.py b/distributed/tests/test_cluster_dump.py index c3912116c46..040bbc53eb7 100644 --- a/distributed/tests/test_cluster_dump.py +++ b/distributed/tests/test_cluster_dump.py @@ -8,7 +8,7 @@ import distributed from distributed.cluster_dump import DumpArtefact, _tuple_to_list, write_state -from distributed.utils_test import assert_worker_story, gen_cluster, gen_test, inc +from distributed.utils_test import assert_story, gen_cluster, gen_test, inc @pytest.mark.parametrize( @@ -140,7 +140,7 @@ async def test_cluster_dump_story(c, s, a, b, tmp_path): assert story.keys() == {"f1", "f2"} for k, task_story in story.items(): - assert_worker_story( + assert_story( task_story, [ (k, "compute-task"), diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 7ee7e2fe0d9..45b26dd1234 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -37,6 +37,7 @@ from distributed.utils import TimeoutError from distributed.utils_test import ( BrokenComm, + assert_story, captured_logger, cluster, dec, @@ -810,7 +811,7 @@ async def test_story(c, s, a, b): story = s.story(x.key) assert all(line in s.transition_log for line in story) assert len(story) < len(s.transition_log) - assert all(x.key == line[0] or x.key in line[-2] for line in story) + assert all(x.key == line[0] or x.key in line[3] for line in story) assert len(s.story(x.key, y.key)) > len(story) @@ -1233,7 +1234,7 @@ def f(dask_scheduler=None): async def test_close_worker(c, s, a, b): assert len(s.workers) == 2 - await s.close_worker(worker=a.address) + await s.close_worker(worker=a.address, stimulus_id="test") assert len(s.workers) == 1 assert a.address not in s.workers @@ -1251,7 +1252,7 @@ async def test_close_nanny(c, s, a, b): assert a.process.is_alive() a_worker_address = a.worker_address start = time() - await s.close_worker(worker=a_worker_address) + await s.close_worker(worker=a_worker_address, stimulus_id="test") assert len(s.workers) == 1 assert a_worker_address not in s.workers @@ -3098,7 +3099,9 @@ async def test_rebalance_dead_recipient(client, s, a, b, c): await c.close() assert s.workers.keys() == {a.address, b.address} - out = await s._rebalance_move_data([(a_ws, b_ws, x_ts), (a_ws, c_ws, y_ts)]) + out = await s._rebalance_move_data( + [(a_ws, b_ws, x_ts), (a_ws, c_ws, y_ts)], stimulus_id="test" + ) assert out == {"status": "partial-fail", "keys": [y.key]} assert a.data == {y.key: "y"} assert b.data == {x.key: "x"} @@ -3117,7 +3120,7 @@ async def test_delete_worker_data(c, s, a, b): assert b.data == {y.key: "y"} assert s.tasks.keys() == {x.key, y.key, z.key} - await s.delete_worker_data(a.address, [x.key, y.key]) + await s.delete_worker_data(a.address, [x.key, y.key], stimulus_id="test") assert a.data == {z.key: "z"} assert b.data == {y.key: "y"} assert s.tasks.keys() == {y.key, z.key} @@ -3131,8 +3134,8 @@ async def test_delete_worker_data_double_delete(c, s, a): """ x, y = await c.scatter(["x", "y"]) await asyncio.gather( - s.delete_worker_data(a.address, [x.key]), - s.delete_worker_data(a.address, [x.key]), + s.delete_worker_data(a.address, [x.key], stimulus_id="test"), + s.delete_worker_data(a.address, [x.key], stimulus_id="test"), ) assert a.data == {y.key: "y"} a_ws = s.workers[a.address] @@ -3147,7 +3150,7 @@ async def test_delete_worker_data_bad_worker(s, a, b): """ await a.close() assert s.workers.keys() == {b.address} - await s.delete_worker_data(a.address, ["x"]) + await s.delete_worker_data(a.address, ["x"], stimulus_id="test") @pytest.mark.parametrize("bad_first", [False, True]) @@ -3162,7 +3165,7 @@ async def test_delete_worker_data_bad_task(c, s, a, bad_first): assert s.tasks.keys() == {x.key, y.key} keys = ["notexist", x.key] if bad_first else [x.key, "notexist"] - await s.delete_worker_data(a.address, keys) + await s.delete_worker_data(a.address, keys, stimulus_id="test") assert a.data == {y.key: "y"} assert s.tasks.keys() == {y.key} assert s.workers[a.address].nbytes == s.tasks[y.key].nbytes @@ -3253,7 +3256,7 @@ async def test_worker_reconnect_task_memory(c, s, a): await res assert ("no-worker", "memory") in { - (start, finish) for (_, start, finish, _, _) in s.transition_log + (start, finish) for (_, start, finish, _, _, _) in s.transition_log } @@ -3277,7 +3280,7 @@ async def test_worker_reconnect_task_memory_with_resources(c, s, a): await res assert ("no-worker", "memory") in { - (start, finish) for (_, start, finish, _, _) in s.transition_log + (start, finish) for (_, start, finish, _, _, _) in s.transition_log } @@ -3538,3 +3541,84 @@ async def test_repr(s, a): repr(ws_b) == f"" ) + + +@gen_cluster(client=True) +async def test_stimulus_success(c, s, a, b): + f = c.submit(inc, 1) + key = f.key + + await f + await c.close() + + stories = s.story(key) + + assert_story( + stories, + [ + (key, "released", "waiting", {key: "processing"}), + (key, "waiting", "processing", {}), + (key, "processing", "memory", {}), + ( + key, + "memory", + "forgotten", + {}, + ), + ], + ) + + stimulus_ids = {s[-2] for s in stories} + assert len(stimulus_ids) == 3 + + +@gen_cluster(client=True) +async def test_stimulus_retry(c, s, a, b): + def task(): + assert dask.config.get("foo") + + with dask.config.set(foo=False): + f = c.submit(task) + with pytest.raises(AssertionError): + await f + + with dask.config.set(foo=True): + await f.retry() + await f + + story = s.story(f.key) + stimulus_ids = {s[-2] for s in story} + assert len(stimulus_ids) == 4 + + assert_story( + story, + [ + (f.key, "released", "waiting", {f.key: "processing"}), + (f.key, "waiting", "processing", {}), + ( + f.key, + "processing", + "erred", + {}, + ), + ( + f.key, + "erred", + "released", + {}, + ), + ( + f.key, + "released", + "waiting", + {f.key: "processing"}, + ), + ( + f.key, + "waiting", + "processing", + {}, + ), + (f.key, "processing", "memory", {}), + ], + ) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index c88e8b592ed..85a2b754616 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -21,7 +21,7 @@ from distributed.utils_test import ( _LockedCommPool, _UnhashableCallable, - assert_worker_story, + assert_story, check_process_leak, cluster, dump_cluster_state, @@ -406,7 +406,7 @@ async def inner_test(c, s, a, b): assert "workers" in state -def test_assert_worker_story(): +def test_assert_story(): now = time() story = [ ("foo", "id1", now - 600), @@ -414,38 +414,38 @@ def test_assert_worker_story(): ("baz", {1: 2}, "id2", now), ] # strict=False - assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})]) - assert_worker_story(story, []) - assert_worker_story(story, [("foo",)]) - assert_worker_story(story, [("foo",), ("bar",)]) - assert_worker_story(story, [("baz", lambda d: d[1] == 2)]) + assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})]) + assert_story(story, []) + assert_story(story, [("foo",)]) + assert_story(story, [("foo",), ("bar",)]) + assert_story(story, [("baz", lambda d: d[1] == 2)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo", "nomatch")]) + assert_story(story, [("foo", "nomatch")]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz",)]) + assert_story(story, [("baz",)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz", {1: 3})]) + assert_story(story, [("baz", {1: 3})]) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)]) + assert_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)]) with pytest.raises(AssertionError): - assert_worker_story(story, [("baz", lambda d: d[1] == 3)]) + assert_story(story, [("baz", lambda d: d[1] == 3)]) with pytest.raises(KeyError): # Faulty lambda - assert_worker_story(story, [("baz", lambda d: d[2] == 1)]) - assert_worker_story([], []) - assert_worker_story([("foo", "id1", now)], [("foo",)]) + assert_story(story, [("baz", lambda d: d[2] == 1)]) + assert_story([], []) + assert_story([("foo", "id1", now)], [("foo",)]) with pytest.raises(AssertionError): - assert_worker_story([], [("foo",)]) + assert_story([], [("foo",)]) # strict=True - assert_worker_story([], [], strict=True) - assert_worker_story([("foo", "id1", now)], [("foo",)]) - assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True) + assert_story([], [], strict=True) + assert_story([("foo", "id1", now)], [("foo",)]) + assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("bar",)], strict=True) + assert_story(story, [("foo",), ("bar",)], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [("foo",), ("baz", {1: 2})], strict=True) + assert_story(story, [("foo",), ("baz", {1: 2})], strict=True) with pytest.raises(AssertionError): - assert_worker_story(story, [], strict=True) + assert_story(story, [], strict=True) @pytest.mark.parametrize( @@ -466,11 +466,11 @@ def test_assert_worker_story(): ), ], ) -def test_assert_worker_story_malformed_story(story_factory): +def test_assert_story_malformed_story(story_factory): # defer the calls to time() to when the test runs rather than collection story = story_factory() with pytest.raises(AssertionError, match="Malformed story event"): - assert_worker_story(story, []) + assert_story(story, []) @gen_cluster() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index d56b768555c..ec44f5e1356 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -48,7 +48,7 @@ from distributed.utils_test import ( TaskStateMetadataPlugin, _LockedCommPool, - assert_worker_story, + assert_story, captured_logger, dec, div, @@ -1406,7 +1406,7 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None: """Test that an in-memory key was transferred from worker w_from to worker w_to by the Active Memory Manager and it was not recalculated on w_to """ - assert_worker_story( + assert_story( w_to.story(key), [ (key, "ensure-task-exists", "released"), @@ -1745,7 +1745,8 @@ async def test_story_with_deps(c, s, a, b): # Story now includes randomized stimulus_ids and timestamps. stimulus_ids = {ev[-2] for ev in story} - assert len(stimulus_ids) == 2, stimulus_ids + assert len(stimulus_ids) == 3 + # This is a simple transition log expected = [ ("res", "compute-task"), @@ -1755,7 +1756,7 @@ async def test_story_with_deps(c, s, a, b): ("res", "put-in-memory"), ("res", "executing", "memory", "memory", {}), ] - assert_worker_story(story, expected, strict=True) + assert_story(story, expected, strict=True) story = b.story("dep") stimulus_ids = {ev[-2] for ev in story} @@ -1770,7 +1771,7 @@ async def test_story_with_deps(c, s, a, b): ("dep", "put-in-memory"), ("dep", "flight", "memory", "memory", {"res": "ready"}), ] - assert_worker_story(story, expected, strict=True) + assert_story(story, expected, strict=True) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -2708,7 +2709,7 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b): while fut1.key not in b.tasks or b.tasks[fut1.key].state == "flight": await asyncio.sleep(0) - s.handle_missing_data(key="f1", errant_worker=a.address) + s.handle_missing_data(key="f1", errant_worker=a.address, stimulus_id="test") await fut2 @@ -2753,7 +2754,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b): # same communication channel for fut in (futA, futB): - assert_worker_story( + assert_story( b.story(fut.key), [ ("gather-dependencies", a.address, {fut.key}), @@ -2812,7 +2813,7 @@ def __getstate__(self): assert await y == 123 story = await c.run(lambda dask_worker: dask_worker.story("x")) - assert_worker_story( + assert_story( story[b], [ ("x", "ensure-task-exists", "released"), @@ -2989,7 +2990,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers): await f2 - assert_worker_story(a.story(f1.key), [(f1.key, "missing-dep")]) + assert_story(a.story(f1.key), [(f1.key, "missing-dep")]) assert a.tasks[f1.key].suspicious_count == 0 assert s.tasks[f1.key].suspicious == 0 @@ -3069,7 +3070,7 @@ async def test_missing_released_zombie_tasks_2(c, s, a, b): while b.tasks: await asyncio.sleep(0.01) - assert_worker_story( + assert_story( b.story(ts), [("f1", "missing", "released", "released", {"f1": "forgotten"})], ) @@ -3181,7 +3182,7 @@ async def test_task_flight_compute_oserror(c, s, a, b): ("f1", "put-in-memory"), ("f1", "executing", "memory", "memory", {}), ] - assert_worker_story(sum_story, expected_sum_story, strict=True) + assert_story(sum_story, expected_sum_story, strict=True) @gen_cluster(client=True) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index d8337ace8e5..feb99ee995a 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -115,15 +115,19 @@ def test_slots(cls): def test_sendmsg_to_dict(): # Arbitrary sample class - smsg = ReleaseWorkerDataMsg(key="x") - assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"} + smsg = ReleaseWorkerDataMsg(key="x", stimulus_id="test") + assert smsg.to_dict() == { + "op": "release-worker-data", + "key": "x", + "stimulus_id": "test", + } def test_merge_recs_instructions(): x = TaskState("x") y = TaskState("y") - instr1 = RescheduleMsg(key="foo", worker="a") - instr2 = RescheduleMsg(key="bar", worker="b") + instr1 = RescheduleMsg(key="foo", worker="a", stimulus_id="test") + instr2 = RescheduleMsg(key="bar", worker="b", stimulus_id="test") assert merge_recs_instructions( ({x: "memory"}, [instr1]), ({y: "released"}, [instr2]), diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 31e6cbe08fe..c85199a4bd8 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1875,8 +1875,12 @@ def xfail_ssl_issue5601(): raise -def assert_worker_story( - story: list[tuple], expect: list[tuple], *, strict: bool = False +def assert_story( + story: list[tuple], + expect: list[tuple], + *, + strict: bool = False, + ordered_timestamps: bool = True, ) -> None: """Test the output of ``Worker.story`` @@ -1906,6 +1910,10 @@ def assert_worker_story( If True, the story must contain exactly as many events as expect. If False (the default), the story may contain more events than expect; extra events are ignored. + ordered_timestamps: bool, optional + If False, timestamps are not required to be monotically increasing. + Useful for asserting stories composed from the scheduler and + multiple workers """ now = time() prev_ts = 0.0 @@ -1915,7 +1923,8 @@ def assert_worker_story( assert isinstance(ev, tuple) assert isinstance(ev[-2], str) and ev[-2] # stimulus_id assert isinstance(ev[-1], float) # timestamp - assert prev_ts <= ev[-1] # Timestamps are monotonic ascending + if ordered_timestamps: + assert prev_ts <= ev[-1] # Timestamps are monotonic ascending # Timestamps are within the last hour. It's been observed that a timestamp # generated in a Nanny process can be a few milliseconds in the future. assert now - 3600 < ev[-1] <= now + 1 @@ -1941,7 +1950,7 @@ def assert_worker_story( break except StopIteration: raise AssertionError( - f"assert_worker_story({strict=}) failed\n" + f"assert_story({strict=}) failed\n" f"story:\n{_format_story(story)}\n" f"expect:\n{_format_story(expect)}" ) from None diff --git a/distributed/worker.py b/distributed/worker.py index 4271abc35c1..708d409f54a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -754,6 +754,7 @@ def __init__( "benchmark_disk": self.benchmark_disk, "benchmark_memory": self.benchmark_memory, "benchmark_network": self.benchmark_network, + "get_story": self.get_story, } stream_handlers = { @@ -927,21 +928,25 @@ def status(self, value): """ prev_status = self.status ServerNode.status.__set__(self, value) - self._send_worker_status_change() + self._send_worker_status_change(f"worker-status-change-{time()}") if prev_status == Status.paused and value == Status.running: self.handle_stimulus(UnpauseEvent(stimulus_id=f"set-status-{time()}")) - def _send_worker_status_change(self) -> None: + def _send_worker_status_change(self, stimulus_id: str) -> None: if ( self.batched_stream and self.batched_stream.comm and not self.batched_stream.comm.closed() ): self.batched_stream.send( - {"op": "worker-status-change", "status": self._status.name} + { + "op": "worker-status-change", + "status": self._status.name, + "stimulus_id": stimulus_id, + }, ) elif self._status != Status.closed: - self.loop.call_later(0.05, self._send_worker_status_change) + self.loop.call_later(0.05, self._send_worker_status_change, stimulus_id) async def get_metrics(self) -> dict: try: @@ -1083,6 +1088,7 @@ async def _register_with_scheduler(self): versions=get_versions(), metrics=await self.get_metrics(), extra=await self.get_startup_information(), + stimulus_id=f"worker-connect-{time()}", ), serializers=["msgpack"], ) @@ -1476,7 +1482,9 @@ async def close( if report and self.contact_address is not None: await asyncio.wait_for( self.scheduler.unregister( - address=self.contact_address, safe=safe + address=self.contact_address, + safe=safe, + stimulus_id=f"worker-close-{time()}", ), timeout, ) @@ -1541,7 +1549,10 @@ async def close_gracefully(self, restart=None): # Scheduler.retire_workers will set the status to closing_gracefully and push it # back to this worker. await self.scheduler.retire_workers( - workers=[self.address], close_workers=False, remove=False + workers=[self.address], + close_workers=False, + remove=False, + stimulus_id=f"worker-close-gracefully-{time()}", ) await self.close(safe=True, nanny=not restart) @@ -1896,7 +1907,9 @@ def handle_compute_task( pass elif ts.state == "memory": recommendations[ts] = "memory" - instructions.append(self._get_task_finished_msg(ts)) + instructions.append( + self._get_task_finished_msg(ts, stimulus_id=stimulus_id) + ) elif ts.state in { "released", "fetch", @@ -2035,7 +2048,7 @@ def transition_memory_released( recs, instructions = self.transition_generic_released( ts, stimulus_id=stimulus_id ) - instructions.append(ReleaseWorkerDataMsg(ts.key)) + instructions.append(ReleaseWorkerDataMsg(key=ts.key, stimulus_id=stimulus_id)) return recs, instructions def transition_waiting_constrained( @@ -2058,7 +2071,7 @@ def transition_long_running_rescheduled( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, worker=self.address) + smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id) return recs, [smsg] def transition_executing_rescheduled( @@ -2069,7 +2082,14 @@ def transition_executing_rescheduled( self._executing.discard(ts) return merge_recs_instructions( - ({ts: "released"}, [RescheduleMsg(key=ts.key, worker=self.address)]), + ( + {ts: "released"}, + [ + RescheduleMsg( + key=ts.key, worker=self.address, stimulus_id=stimulus_id + ) + ], + ), self._ensure_computing(), ) @@ -2147,6 +2167,7 @@ def transition_generic_error( traceback_text=traceback_text, thread=self.threads.get(ts.key), startstops=ts.startstops, + stimulus_id=stimulus_id, ) return {}, [smsg] @@ -2336,7 +2357,9 @@ def transition_generic_memory( else: if self.validate: assert ts.key in self.data or ts.key in self.actors - instructions.append(self._get_task_finished_msg(ts)) + instructions.append( + self._get_task_finished_msg(ts, stimulus_id=stimulus_id) + ) return recs, instructions @@ -2455,7 +2478,10 @@ def transition_executing_long_running( self.long_running.add(ts.key) return merge_recs_instructions( - ({}, [LongRunningMsg(key=ts.key, compute_duration=compute_duration)]), + ( + {}, + [LongRunningMsg(key=ts.key, compute_duration=compute_duration)], + ), self._ensure_computing(), ) @@ -2686,6 +2712,9 @@ def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]: keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} return worker_story(keys, self.log) + async def get_story(self, keys=None): + return self.story(*keys) + def stimulus_story( self, *keys_or_tasks: str | TaskState ) -> list[StateMachineEvent]: @@ -2754,7 +2783,9 @@ def ensure_communicating(self) -> None: for el in skipped_worker_in_flight: self.data_needed.push(el) - def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: + def _get_task_finished_msg( + self, ts: TaskState, stimulus_id: str + ) -> TaskFinishedMsg: if ts.key not in self.data and ts.key not in self.actors: raise RuntimeError(f"Task {ts} not ready") typ = ts.type @@ -2780,6 +2811,7 @@ def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: metadata=ts.metadata, thread=self.threads.get(ts.key), startstops=ts.startstops, + stimulus_id=stimulus_id, ) def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: @@ -3087,7 +3119,12 @@ async def gather_dep( self.has_what[worker].discard(ts.key) self.log.append((d, "missing-dep", stimulus_id, time())) self.batched_stream.send( - {"op": "missing-data", "errant_worker": worker, "key": d} + { + "op": "missing-data", + "errant_worker": worker, + "key": d, + "stimulus_id": stimulus_id, + } ) recommendations[ts] = "fetch" if ts.who_has else "missing" del data, response @@ -3190,7 +3227,7 @@ def handle_steal_request(self, key: str, stimulus_id: str) -> None: # `transition_constrained_executing` self.transition(ts, "released", stimulus_id=stimulus_id) - def handle_worker_status_change(self, status: str) -> None: + def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: new_status = Status.lookup[status] # type: ignore if ( @@ -3201,7 +3238,7 @@ def handle_worker_status_change(self, status: str) -> None: "Invalid Worker.status transition: %s -> %s", self._status, new_status ) # Reiterate the current status to the scheduler to restore sync - self._send_worker_status_change() + self._send_worker_status_change(stimulus_id) else: # Update status and send confirmation to the Scheduler (see status.setter) self.status = new_status @@ -3555,7 +3592,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No stop=result["stop"], nbytes=result["nbytes"], type=result["type"], - stimulus_id=stimulus_id, + stimulus_id=f"task-finished-{time()}", ) if isinstance(result["actual-exception"], Reschedule): @@ -3582,7 +3619,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No traceback=result["traceback"], exception_text=result["exception_text"], traceback_text=result["traceback_text"], - stimulus_id=stimulus_id, + stimulus_id=f"task-erred-{time()}", ) except Exception as exc: @@ -3596,7 +3633,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No traceback=msg["traceback"], exception_text=msg["exception_text"], traceback_text=msg["traceback_text"], - stimulus_id=stimulus_id, + stimulus_id=f"task-erred-{time()}", ) @functools.singledispatchmethod diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 5e993fe5041..012b34bba9f 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -289,6 +289,7 @@ class TaskFinishedMsg(SendMessageToScheduler): metadata: dict thread: int | None startstops: list[StartStop] + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_dict(self) -> dict[str, Any]: @@ -308,6 +309,7 @@ class TaskErredMsg(SendMessageToScheduler): traceback_text: str thread: int | None startstops: list[StartStop] + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_dict(self) -> dict[str, Any]: @@ -320,8 +322,9 @@ def to_dict(self) -> dict[str, Any]: class ReleaseWorkerDataMsg(SendMessageToScheduler): op = "release-worker-data" - __slots__ = ("key",) + __slots__ = ("key", "stimulus_id") key: str + stimulus_id: str # Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception @@ -329,9 +332,10 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler): class RescheduleMsg(SendMessageToScheduler): op = "reschedule" - __slots__ = ("key", "worker") + __slots__ = ("key", "worker", "stimulus_id") key: str worker: str + stimulus_id: str @dataclass @@ -424,6 +428,7 @@ class ExecuteSuccessEvent(StateMachineEvent): stop: float nbytes: int type: type | None + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_loggable(self, *, handled: float) -> StateMachineEvent: @@ -446,6 +451,7 @@ class ExecuteFailureEvent(StateMachineEvent): traceback: Serialize | None exception_text: str traceback_text: str + stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def _after_from_dict(self) -> None: