diff --git a/distributed/client.py b/distributed/client.py
index 86ca5f506e0..e49a14d90dd 100644
--- a/distributed/client.py
+++ b/distributed/client.py
@@ -4294,6 +4294,35 @@ def collections_to_dsk(collections, *args, **kwargs):
"""Convert many collections into a single dask graph, after optimization"""
return collections_to_dsk(collections, *args, **kwargs)
+ async def _story(self, keys=(), on_error="raise"):
+ assert on_error in ("raise", "ignore")
+
+ try:
+ flat_stories = list(await self.scheduler.get_story(keys=keys))
+ except Exception:
+ if on_error == "raise":
+ raise
+ elif on_error == "ignore":
+ flat_stories = []
+ else:
+ raise ValueError(f"on_error not in {'raise', 'ignore'}")
+
+ responses = await self.scheduler.broadcast(
+ msg={"op": "get_story", "keys": keys}, on_error=on_error
+ )
+
+ for stories in responses.values():
+ if isinstance(stories, (tuple, list)):
+ flat_stories.extend(s for s in stories)
+ else:
+ flat_stories.append(stories)
+
+ return flat_stories
+
+ def story(self, *keys_or_stimulus_ids, on_error="raise"):
+ """Returns a cluster-wide story for the given keys or simtulus_id's"""
+ return self.sync(self._story, keys=keys_or_stimulus_ids, on_error=on_error)
+
def get_task_stream(
self,
start=None,
diff --git a/distributed/http/templates/task.html b/distributed/http/templates/task.html
index 0b5c10695e0..f10aaad5602 100644
--- a/distributed/http/templates/task.html
+++ b/distributed/http/templates/task.html
@@ -118,16 +118,18 @@
Transition Log
Key |
Start |
Finish |
+ Stimulus ID |
Recommended Key |
Recommended Action |
- {% for key, start, finish, recommendations, transition_time in scheduler.story(Task) %}
+ {% for key, start, finish, recommendations, stimulus_id, transition_time in scheduler.story(Task) %}
| {{ fromtimestamp(transition_time) }} |
{{key}} |
{{ start }} |
{{ finish }} |
+ {{ stimulus_id }} |
|
|
@@ -137,6 +139,7 @@ Transition Log
|
|
|
+ |
{{key2}} |
{{ rec }} |
diff --git a/distributed/nanny.py b/distributed/nanny.py
index db2371523e8..91873bb3a8a 100644
--- a/distributed/nanny.py
+++ b/distributed/nanny.py
@@ -290,7 +290,10 @@ async def _unregister(self, timeout=10):
allowed_errors = (TimeoutError, CommClosedError, EnvironmentError, RPCClosed)
with suppress(allowed_errors):
await asyncio.wait_for(
- self.scheduler.unregister(address=self.worker_address), timeout
+ self.scheduler.unregister(
+ address=self.worker_address, stimulus_id=f"nanny-close-{time()}"
+ ),
+ timeout,
)
@property
diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py
index 0e0ae003b5f..f81a8327747 100644
--- a/distributed/protocol/core.py
+++ b/distributed/protocol/core.py
@@ -116,6 +116,7 @@ def _encode_default(obj):
return msgpack_encode_default(obj)
frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True)
+
return frames
except Exception:
diff --git a/distributed/scheduler.py b/distributed/scheduler.py
index 0b22d2ff2ef..1035676bbcc 100644
--- a/distributed/scheduler.py
+++ b/distributed/scheduler.py
@@ -110,6 +110,7 @@
DEFAULT_DATA_SIZE = parse_bytes(
dask.config.get("distributed.scheduler.default-data-size")
)
+STIMULUS_ID_UNSET = ""
DEFAULT_EXTENSIONS = {
"locks": LockExtension,
@@ -1617,7 +1618,7 @@ def new_task(
# State Transitions #
#####################
- def _transition(self, key, finish: str, *args, **kwargs):
+ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
"""Transition a key from its current state to the finish state
Examples
@@ -1652,14 +1653,16 @@ def _transition(self, key, finish: str, *args, **kwargs):
start_finish = (start, finish)
func = self.transitions_table.get(start_finish)
if func is not None:
- recommendations, client_msgs, worker_msgs = func(key, *args, **kwargs) # type: ignore
+ recommendations, client_msgs, worker_msgs = func(
+ key, stimulus_id, *args, **kwargs
+ ) # type: ignore
self.transition_counter += 1
elif "released" not in start_finish:
assert not args and not kwargs, (args, kwargs, start_finish)
a_recs: dict
a_cmsgs: dict
a_wmsgs: dict
- a: tuple = self._transition(key, "released")
+ a: tuple = self._transition(key, "released", stimulus_id)
a_recs, a_cmsgs, a_wmsgs = a
v = a_recs.get(key, finish)
@@ -1667,7 +1670,7 @@ def _transition(self, key, finish: str, *args, **kwargs):
b_recs: dict
b_cmsgs: dict
b_wmsgs: dict
- b: tuple = func(key) # type: ignore
+ b: tuple = func(key, stimulus_id) # type: ignore
b_recs, b_cmsgs, b_wmsgs = b
recommendations.update(a_recs)
@@ -1702,13 +1705,20 @@ def _transition(self, key, finish: str, *args, **kwargs):
else:
raise RuntimeError("Impossible transition from %r to %r" % start_finish)
+ if not stimulus_id:
+ stimulus_id = STIMULUS_ID_UNSET
+
finish2 = ts._state
# FIXME downcast antipattern
scheduler = pep484_cast(Scheduler, self)
scheduler.transition_log.append(
- (key, start, finish2, recommendations, time())
+ (key, start, finish2, recommendations, stimulus_id, time())
)
if self.validate:
+ if stimulus_id == STIMULUS_ID_UNSET:
+ raise RuntimeError(
+ "stimulus_id not set during Scheduler transition"
+ )
logger.debug(
"Transitioned %r %s->%s (actual: %s). Consequence: %s",
key,
@@ -1752,7 +1762,13 @@ def _transition(self, key, finish: str, *args, **kwargs):
pdb.set_trace()
raise
- def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict):
+ def _transitions(
+ self,
+ recommendations: dict,
+ client_msgs: dict,
+ worker_msgs: dict,
+ stimulus_id: str,
+ ):
"""Process transitions until none are left
This includes feedback from previous transitions and continues until we
@@ -1770,7 +1786,7 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di
key, finish = recommendations.popitem()
keys.add(key)
- new = self._transition(key, finish)
+ new = self._transition(key, finish, stimulus_id)
new_recs, new_cmsgs, new_wmsgs = new
recommendations.update(new_recs)
@@ -1793,7 +1809,7 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di
for key in keys:
scheduler.validate_key(key)
- def transition_released_waiting(self, key):
+ def transition_released_waiting(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -1848,7 +1864,7 @@ def transition_released_waiting(self, key):
pdb.set_trace()
raise
- def transition_no_worker_waiting(self, key):
+ def transition_no_worker_waiting(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -1896,7 +1912,13 @@ def transition_no_worker_waiting(self, key):
raise
def transition_no_worker_memory(
- self, key, nbytes=None, type=None, typename: str = None, worker=None
+ self,
+ key,
+ stimulus_id,
+ nbytes=None,
+ type=None,
+ typename: str = None,
+ worker=None,
):
try:
ws: WorkerState = self.workers[worker]
@@ -2051,7 +2073,7 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> float:
return total_duration
- def transition_waiting_processing(self, key):
+ def transition_waiting_processing(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -2096,7 +2118,14 @@ def transition_waiting_processing(self, key):
raise
def transition_waiting_memory(
- self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs
+ self,
+ key,
+ stimulus_id,
+ nbytes=None,
+ type=None,
+ typename: str = None,
+ worker=None,
+ **kwargs,
):
try:
ws: WorkerState = self.workers[worker]
@@ -2138,6 +2167,7 @@ def transition_waiting_memory(
def transition_processing_memory(
self,
key,
+ stimulus_id,
nbytes=None,
type=None,
typename: str = None,
@@ -2231,7 +2261,7 @@ def transition_processing_memory(
pdb.set_trace()
raise
- def transition_memory_released(self, key, safe: bool = False):
+ def transition_memory_released(self, key, stimulus_id, safe: bool = False):
ws: WorkerState
try:
ts: TaskState = self.tasks[key]
@@ -2269,7 +2299,7 @@ def transition_memory_released(self, key, safe: bool = False):
worker_msg = {
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"memory-released-{time()}",
+ "stimulus_id": stimulus_id,
}
for ws in ts.who_has:
worker_msgs[ws.address] = [worker_msg]
@@ -2301,7 +2331,7 @@ def transition_memory_released(self, key, safe: bool = False):
pdb.set_trace()
raise
- def transition_released_erred(self, key):
+ def transition_released_erred(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -2346,7 +2376,7 @@ def transition_released_erred(self, key):
pdb.set_trace()
raise
- def transition_erred_released(self, key):
+ def transition_erred_released(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -2372,7 +2402,7 @@ def transition_erred_released(self, key):
w_msg = {
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"erred-released-{time()}",
+ "stimulus_id": stimulus_id,
}
for ws_addr in ts.erred_on:
worker_msgs[ws_addr] = [w_msg]
@@ -2394,7 +2424,7 @@ def transition_erred_released(self, key):
pdb.set_trace()
raise
- def transition_waiting_released(self, key):
+ def transition_waiting_released(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
recommendations: dict = {}
@@ -2431,7 +2461,7 @@ def transition_waiting_released(self, key):
pdb.set_trace()
raise
- def transition_processing_released(self, key):
+ def transition_processing_released(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -2451,7 +2481,7 @@ def transition_processing_released(self, key):
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"processing-released-{time()}",
+ "stimulus_id": stimulus_id,
}
]
@@ -2485,6 +2515,7 @@ def transition_processing_released(self, key):
def transition_processing_erred(
self,
key: str,
+ stimulus_id: str,
cause: str = None,
exception=None,
traceback=None,
@@ -2571,7 +2602,7 @@ def transition_processing_erred(
pdb.set_trace()
raise
- def transition_no_worker_released(self, key):
+ def transition_no_worker_released(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
@@ -2613,7 +2644,7 @@ def remove_key(self, key):
ts.exception_blame = ts.exception = ts.traceback = None
self.task_metadata.pop(key, None)
- def transition_memory_forgotten(self, key):
+ def transition_memory_forgotten(self, key, stimulus_id):
ws: WorkerState
try:
ts: TaskState = self.tasks[key]
@@ -2641,7 +2672,7 @@ def transition_memory_forgotten(self, key):
for ws in ts.who_has:
ws.actors.discard(ts)
- _propagate_forgotten(self, ts, recommendations, worker_msgs)
+ _propagate_forgotten(self, ts, recommendations, worker_msgs, stimulus_id)
client_msgs = _task_to_client_msgs(self, ts)
self.remove_key(key)
@@ -2655,7 +2686,7 @@ def transition_memory_forgotten(self, key):
pdb.set_trace()
raise
- def transition_released_forgotten(self, key):
+ def transition_released_forgotten(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
recommendations: dict = {}
@@ -2679,7 +2710,7 @@ def transition_released_forgotten(self, key):
else:
assert 0, (ts,)
- _propagate_forgotten(self, ts, recommendations, worker_msgs)
+ _propagate_forgotten(self, ts, recommendations, worker_msgs, stimulus_id)
client_msgs = _task_to_client_msgs(self, ts)
self.remove_key(key)
@@ -3303,6 +3334,7 @@ def __init__(
"get_cluster_state": self.get_cluster_state,
"dump_cluster_state_to_url": self.dump_cluster_state_to_url,
"benchmark_hardware": self.benchmark_hardware,
+ "get_story": self.get_story,
}
connection_limit = get_fileno_limit() / 2
@@ -3626,7 +3658,7 @@ async def close(self, fast=False, close_workers=False):
setproctitle("dask-scheduler [closed]")
disable_gc_diagnosis()
- async def close_worker(self, worker: str, safe: bool = False):
+ async def close_worker(self, worker: str, stimulus_id: str, safe: bool = False):
"""Remove a worker from the cluster
This both removes the worker from our local state and also sends a
@@ -3638,7 +3670,7 @@ async def close_worker(self, worker: str, safe: bool = False):
self.log_event(worker, {"action": "close-worker"})
# FIXME: This does not handle nannies
self.worker_send(worker, {"op": "close", "report": False})
- await self.remove_worker(address=worker, safe=safe)
+ await self.remove_worker(address=worker, safe=safe, stimulus_id=stimulus_id)
###########
# Stimuli #
@@ -3775,8 +3807,10 @@ async def add_worker(
versions=None,
nanny=None,
extra=None,
+ stimulus_id=None,
):
"""Add a new worker to the cluster"""
+ stimulus_id = stimulus_id or f"add-worker-{time()}"
with log_errors():
address = self.coerce_address(address, resolve_address)
address = normalize_address(address)
@@ -3876,12 +3910,15 @@ async def add_worker(
t: tuple = self._transition(
key,
"memory",
+ stimulus_id,
worker=address,
nbytes=nbytes[key],
typename=types[key],
)
recommendations, client_msgs, worker_msgs = t
- self._transitions(recommendations, client_msgs, worker_msgs)
+ self._transitions(
+ recommendations, client_msgs, worker_msgs, stimulus_id
+ )
recommendations = {}
else:
already_released_keys.append(key)
@@ -3900,7 +3937,9 @@ async def add_worker(
recommendations.update(self.bulk_schedule_after_adding_worker(ws))
if recommendations:
- self._transitions(recommendations, client_msgs, worker_msgs)
+ self._transitions(
+ recommendations, client_msgs, worker_msgs, stimulus_id
+ )
self.send_all(client_msgs, worker_msgs)
@@ -3928,7 +3967,7 @@ async def add_worker(
if comm:
await comm.write(msg)
- await self.handle_worker(comm=comm, worker=address)
+ await self.handle_worker(comm=comm, worker=address, stimulus_id=stimulus_id)
async def add_nanny(self, comm):
msg = {
@@ -3992,6 +4031,7 @@ def update_graph_hlg(
fifo_timeout,
annotations,
code=code,
+ stimulus_id=f"update-graph-hlg-{time()}",
)
def update_graph(
@@ -4011,12 +4051,14 @@ def update_graph(
fifo_timeout=0,
annotations=None,
code=None,
+ stimulus_id=None,
):
"""
Add new computations to the internal dask graph
This happens whenever the Client calls submit, map, get, or compute.
"""
+ stimulus_id = stimulus_id or f"update-graph-{time()}"
start = time()
fifo_timeout = parse_timedelta(fifo_timeout)
keys = set(keys)
@@ -4268,7 +4310,7 @@ def update_graph(
except Exception as e:
logger.exception(e)
- self.transitions(recommendations)
+ self.transitions(recommendations, stimulus_id)
for ts in touched_tasks:
if ts.state in ("memory", "erred"):
@@ -4280,7 +4322,7 @@ def update_graph(
# TODO: balance workers
- def stimulus_task_finished(self, key=None, worker=None, **kwargs):
+ def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs):
"""Mark that a task has finished execution on a particular worker"""
logger.debug("Stimulus task finished %s, %s", key, worker)
@@ -4303,14 +4345,16 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"already-released-or-forgotten-{time()}",
+ "stimulus_id": stimulus_id,
}
]
elif ts.state == "memory":
self.add_keys(worker=worker, keys=[key])
else:
ts.metadata.update(kwargs["metadata"])
- r: tuple = self._transition(key, "memory", worker=worker, **kwargs)
+ r: tuple = self._transition(
+ key, "memory", stimulus_id, worker=worker, **kwargs
+ )
recommendations, client_msgs, worker_msgs = r
if ts.state == "memory":
@@ -4318,7 +4362,13 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
return recommendations, client_msgs, worker_msgs
def stimulus_task_erred(
- self, key=None, worker=None, exception=None, traceback=None, **kwargs
+ self,
+ key=None,
+ worker=None,
+ exception=None,
+ stimulus_id=None,
+ traceback=None,
+ **kwargs,
):
"""Mark that a task has erred on a particular worker"""
logger.debug("Stimulus task erred %s, %s", key, worker)
@@ -4329,11 +4379,12 @@ def stimulus_task_erred(
if ts.retries > 0:
ts.retries -= 1
- return self._transition(key, "waiting")
+ return self._transition(key, "waiting", stimulus_id)
else:
return self._transition(
key,
"erred",
+ stimulus_id,
cause=key,
exception=exception,
traceback=traceback,
@@ -4362,7 +4413,7 @@ def stimulus_retry(self, keys, client=None):
roots.append(key)
recommendations: dict = {key: "waiting" for key in roots}
- self.transitions(recommendations)
+ self.transitions(recommendations, f"stimulus-retry-{time()}")
if self.validate:
for key in seen:
@@ -4370,7 +4421,7 @@ def stimulus_retry(self, keys, client=None):
return tuple(seen)
- async def remove_worker(self, address, safe=False, close=True):
+ async def remove_worker(self, address, safe=False, close=True, stimulus_id=None):
"""
Remove worker from cluster
@@ -4378,6 +4429,7 @@ async def remove_worker(self, address, safe=False, close=True):
appears to be unresponsive. This may send its tasks back to a released
state.
"""
+ stimulus_id = stimulus_id or f"remove-worker-{time()}"
with log_errors():
if self.status == Status.closed:
return
@@ -4438,7 +4490,9 @@ async def remove_worker(self, address, safe=False, close=True):
e = pickle.dumps(
KilledWorker(task=k, last_worker=ws.clean()), protocol=4
)
- r = self.transition(k, "erred", exception=e, cause=k)
+ r = self.transition(
+ k, "erred", exception=e, cause=k, stimulus_id=stimulus_id
+ )
recommendations.update(r)
logger.info(
"Task %s marked as failed because %d workers died"
@@ -4455,7 +4509,7 @@ async def remove_worker(self, address, safe=False, close=True):
else: # pure data
recommendations[ts.key] = "forgotten"
- self.transitions(recommendations)
+ self.transitions(recommendations, stimulus_id=stimulus_id)
for plugin in list(self.plugins.values()):
try:
@@ -4535,15 +4589,16 @@ def client_desires_keys(self, keys=None, client=None):
if ts.state in ("memory", "erred"):
self.report_on_key(ts=ts, client=client)
- def client_releases_keys(self, keys=None, client=None):
+ def client_releases_keys(self, keys=None, client=None, stimulus_id=None):
"""Remove keys from client desired list"""
+ stimulus_id = stimulus_id or f"client-releases-keys-{time()}"
if not isinstance(keys, list):
keys = list(keys)
cs: ClientState = self.clients[client]
recommendations: dict = {}
_client_releases_keys(self, keys=keys, cs=cs, recommendations=recommendations)
- self.transitions(recommendations)
+ self.transitions(recommendations, stimulus_id)
def client_heartbeat(self, client=None):
"""Handle heartbeats from Client"""
@@ -4740,6 +4795,7 @@ async def add_client(self, comm: Comm, client: str, versions: dict) -> None:
We listen to all future messages from this Comm.
"""
+ stimulus_id = f"add-client-{time()}"
assert client is not None
comm.name = "Scheduler->Client"
logger.info("Receive client connection: %s", client)
@@ -4768,7 +4824,7 @@ async def add_client(self, comm: Comm, client: str, versions: dict) -> None:
try:
await self.handle_stream(comm=comm, extra={"client": client})
finally:
- self.remove_client(client=client)
+ self.remove_client(client=client, stimulus_id=stimulus_id)
logger.debug("Finished handling client %s", client)
finally:
if not comm.closed():
@@ -4782,8 +4838,9 @@ async def add_client(self, comm: Comm, client: str, versions: dict) -> None:
except TypeError: # comm becomes None during GC
pass
- def remove_client(self, client: str) -> None:
+ def remove_client(self, client: str, stimulus_id: str = None) -> None:
"""Remove client from network"""
+ stimulus_id = stimulus_id or f"remove-client-{time()}"
if self.status == Status.running:
logger.info("Remove client %s", client)
self.log_event(["all", client], {"action": "remove-client", "client": client})
@@ -4794,7 +4851,9 @@ def remove_client(self, client: str) -> None:
pass
else:
self.client_releases_keys(
- keys=[ts.key for ts in cs.wants_what], client=cs.client_key
+ keys=[ts.key for ts in cs.wants_what],
+ client=cs.client_key,
+ stimulus_id=stimulus_id,
)
del self.clients[client]
@@ -4830,24 +4889,28 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: float = -1):
def handle_uncaught_error(self, **msg):
logger.exception(clean_exception(**msg)[1])
- def handle_task_finished(self, key=None, worker=None, **msg):
+ def handle_task_finished(self, key=None, worker=None, stimulus_id=None, **msg):
if worker not in self.workers:
return
validate_key(key)
- r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg)
+
+ r: tuple = self.stimulus_task_finished(
+ key=key, worker=worker, stimulus_id=stimulus_id, **msg
+ )
recommendations, client_msgs, worker_msgs = r
- self._transitions(recommendations, client_msgs, worker_msgs)
+ self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)
- def handle_task_erred(self, key=None, **msg):
- r: tuple = self.stimulus_task_erred(key=key, **msg)
+ def handle_task_erred(self, key=None, stimulus_id=None, **msg):
+ r: tuple = self.stimulus_task_erred(key=key, stimulus_id=stimulus_id, **msg)
recommendations, client_msgs, worker_msgs = r
- self._transitions(recommendations, client_msgs, worker_msgs)
-
+ self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)
- def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
+ def handle_missing_data(
+ self, key=None, errant_worker=None, stimulus_id=None, **kwargs
+ ):
"""Signal that `errant_worker` does not hold `key`
This may either indicate that `errant_worker` is dead or that we may be
@@ -4864,6 +4927,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
errant_worker : str, optional
Address of the worker supposed to hold a replica, by default None
"""
+ assert stimulus_id
logger.debug("handle missing data key=%s worker=%s", key, errant_worker)
self.log_event(errant_worker, {"action": "missing-data", "key": key})
ts: TaskState = self.tasks.get(key)
@@ -4875,11 +4939,11 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
self.remove_replica(ts, ws)
if ts.state == "memory" and not ts.who_has:
if ts.run_spec:
- self.transitions({key: "released"})
+ self.transitions({key: "released"}, stimulus_id)
else:
- self.transitions({key: "forgotten"})
+ self.transitions({key: "forgotten"}, stimulus_id)
- def release_worker_data(self, key, worker):
+ def release_worker_data(self, key, worker, stimulus_id):
ws: WorkerState = self.workers.get(worker)
ts: TaskState = self.tasks.get(key)
if not ws or not ts:
@@ -4890,7 +4954,7 @@ def release_worker_data(self, key, worker):
if not ts.who_has:
recommendations[ts.key] = "released"
if recommendations:
- self.transitions(recommendations)
+ self.transitions(recommendations, stimulus_id)
def handle_long_running(self, key=None, worker=None, compute_duration=None):
"""A task has seceded from the thread pool
@@ -4932,7 +4996,9 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None):
ws.long_running.add(ts)
self.check_idle_saturated(ws)
- def handle_worker_status_change(self, status: str, worker: str) -> None:
+ def handle_worker_status_change(
+ self, status: str, worker: str, stimulus_id: str
+ ) -> None:
ws: WorkerState = self.workers.get(worker) # type: ignore
if not ws:
return
@@ -4956,13 +5022,13 @@ def handle_worker_status_change(self, status: str, worker: str) -> None:
if recs:
client_msgs: dict = {}
worker_msgs: dict = {}
- self._transitions(recs, client_msgs, worker_msgs)
+ self._transitions(recs, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)
else:
self.running.discard(ws)
- async def handle_worker(self, comm=None, worker=None):
+ async def handle_worker(self, comm=None, worker=None, stimulus_id=None):
"""
Listen to responses from a single worker
@@ -4972,6 +5038,7 @@ async def handle_worker(self, comm=None, worker=None):
--------
Scheduler.handle_client: Equivalent coroutine for clients
"""
+ assert stimulus_id
comm.name = "Scheduler connection to worker"
worker_comm = self.stream_comms[worker]
worker_comm.start(comm)
@@ -4981,7 +5048,7 @@ async def handle_worker(self, comm=None, worker=None):
finally:
if worker in self.stream_comms:
worker_comm.abort()
- await self.remove_worker(address=worker)
+ await self.remove_worker(address=worker, stimulus_id=stimulus_id)
def add_plugin(
self,
@@ -5119,7 +5186,11 @@ def worker_send(self, worker, msg):
try:
stream_comms[worker].send(msg)
except (CommClosedError, AttributeError):
- self.loop.add_callback(self.remove_worker, address=worker)
+ self.loop.add_callback(
+ self.remove_worker,
+ address=worker,
+ stimulus_id=f"worker-send-comm-fail-{time()}",
+ )
def client_send(self, client, msg):
"""Send message to client"""
@@ -5164,7 +5235,11 @@ def send_all(self, client_msgs: dict, worker_msgs: dict):
# worker already gone
pass
except (CommClosedError, AttributeError):
- self.loop.add_callback(self.remove_worker, address=worker)
+ self.loop.add_callback(
+ self.remove_worker,
+ address=worker,
+ stimulus_id=f"send-all-comm-fail-{time()}",
+ )
############################
# Less common interactions #
@@ -5221,6 +5296,7 @@ async def scatter(
async def gather(self, keys, serializers=None):
"""Collect data from workers to the scheduler"""
+ stimulus_id = f"gather-{time()}"
keys = list(keys)
who_has = {}
for key in keys:
@@ -5252,7 +5328,9 @@ async def gather(self, keys, serializers=None):
# reconnect.
await asyncio.gather(
*(
- self.remove_worker(address=worker, close=False)
+ self.remove_worker(
+ address=worker, close=False, stimulus_id=stimulus_id
+ )
for worker in missing_workers
)
)
@@ -5290,6 +5368,7 @@ def clear_task_state(self):
async def restart(self, client=None, timeout=30):
"""Restart all workers. Reset local state."""
+ stimulus_id = f"restart-{time()}"
with log_errors():
n_workers = len(self.workers)
@@ -5306,7 +5385,9 @@ async def restart(self, client=None, timeout=30):
try:
# Ask the worker to close if it doesn't have a nanny,
# otherwise the nanny will kill it anyway
- await self.remove_worker(address=addr, close=addr not in nannies)
+ await self.remove_worker(
+ address=addr, close=addr not in nannies, stimulus_id=stimulus_id
+ )
except Exception:
logger.info(
"Exception while restarting. This is normal", exc_info=True
@@ -5501,7 +5582,7 @@ async def gather_on_worker(
return keys_failed
async def delete_worker_data(
- self, worker_address: str, keys: "Collection[str]"
+ self, worker_address: str, keys: "Collection[str]", stimulus_id: str
) -> None:
"""Delete data from a worker and update the corresponding worker/task states
@@ -5538,7 +5619,7 @@ async def delete_worker_data(
self.remove_replica(ts, ws)
if not ts.who_has:
# Last copy deleted
- self.transitions({key: "released"})
+ self.transitions({key: "released"}, stimulus_id)
self.log_event(ws.address, {"action": "remove-worker-data", "keys": keys})
@@ -5547,6 +5628,7 @@ async def rebalance(
comm=None,
keys: "Iterable[Hashable]" = None,
workers: "Iterable[str]" = None,
+ stimulus_id: str = None,
) -> dict:
"""Rebalance keys so that each worker ends up with roughly the same process
memory (managed+unmanaged).
@@ -5613,6 +5695,7 @@ async def rebalance(
All other workers will be ignored. The mean cluster occupancy will be
calculated only using the allowed workers.
"""
+ stimulus_id = stimulus_id or f"rebalance-{time()}"
with log_errors():
if workers is not None:
wss = [self.workers[w] for w in workers]
@@ -5637,7 +5720,7 @@ async def rebalance(
return {"status": "OK"}
async with self._lock:
- result = await self._rebalance_move_data(msgs)
+ result = await self._rebalance_move_data(msgs, stimulus_id)
if result["status"] == "partial-fail" and keys is None:
# Only return failed keys if the client explicitly asked for them
result = {"status": "OK"}
@@ -5834,7 +5917,7 @@ def _rebalance_find_msgs(
return msgs
async def _rebalance_move_data(
- self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]"
+ self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]", stimulus_id: str
) -> dict:
"""Perform the actual transfer of data across the network in rebalance().
Takes in input the output of _rebalance_find_msgs(), that is a list of tuples:
@@ -5870,7 +5953,7 @@ async def _rebalance_move_data(
# Note: this never raises exceptions
await asyncio.gather(
- *(self.delete_worker_data(r, v) for r, v in to_senders.items())
+ *(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items())
)
for r, v in to_recipients.items():
@@ -5900,6 +5983,7 @@ async def replicate(
branching_factor=2,
delete=True,
lock=True,
+ stimulus_id=None,
):
"""Replicate data throughout cluster
@@ -5922,6 +6006,7 @@ async def replicate(
--------
Scheduler.rebalance
"""
+ stimulus_id = stimulus_id or f"replicate-{time()}"
assert branching_factor > 0
async with self._lock if lock else empty_context:
if workers is not None:
@@ -5956,7 +6041,9 @@ async def replicate(
# Note: this never raises exceptions
await asyncio.gather(
*[
- self.delete_worker_data(ws.address, [t.key for t in tasks])
+ self.delete_worker_data(
+ ws.address, [t.key for t in tasks], stimulus_id
+ )
for ws, tasks in del_worker_tasks.items()
]
)
@@ -6144,6 +6231,7 @@ async def retire_workers(
names: "list | None" = None,
close_workers: bool = False,
remove: bool = True,
+ stimulus_id: str = None,
**kwargs,
) -> dict:
"""Gracefully retire workers from cluster
@@ -6177,6 +6265,7 @@ async def retire_workers(
--------
Scheduler.workers_to_close
"""
+ stimulus_id = stimulus_id or f"retire-workers-{time()}"
with log_errors():
# This lock makes retire_workers, rebalance, and replicate mutually
# exclusive and will no longer be necessary once rebalance and replicate are
@@ -6231,7 +6320,11 @@ async def retire_workers(
ws.status = Status.closing_gracefully
self.running.discard(ws)
self.stream_comms[ws.address].send(
- {"op": "worker-status-change", "status": ws.status.name}
+ {
+ "op": "worker-status-change",
+ "status": ws.status.name,
+ "stimulus_id": stimulus_id,
+ }
)
coros.append(
@@ -6241,6 +6334,7 @@ async def retire_workers(
prev_status=prev_status,
close_workers=close_workers,
remove=remove,
+ stimulus_id=stimulus_id,
)
)
@@ -6267,6 +6361,7 @@ async def _track_retire_worker(
prev_status: Status,
close_workers: bool,
remove: bool,
+ stimulus_id: str,
) -> tuple: # tuple[str | None, dict]
while not policy.done():
if policy.no_recipients:
@@ -6274,7 +6369,11 @@ async def _track_retire_worker(
# conditions and we can wait for a scheduler->worker->scheduler
# round-trip.
self.stream_comms[ws.address].send(
- {"op": "worker-status-change", "status": prev_status.name}
+ {
+ "op": "worker-status-change",
+ "status": prev_status.name,
+ "stimulus_id": stimulus_id,
+ }
)
return None, {}
@@ -6288,9 +6387,13 @@ async def _track_retire_worker(
)
if close_workers and ws.address in self.workers:
- await self.close_worker(worker=ws.address, safe=True)
+ await self.close_worker(
+ worker=ws.address, safe=True, stimulus_id=stimulus_id
+ )
if remove:
- await self.remove_worker(address=ws.address, safe=True)
+ await self.remove_worker(
+ address=ws.address, safe=True, stimulus_id=stimulus_id
+ )
logger.info("Retired worker %s", ws.address)
return ws.address, ws.identity()
@@ -6727,7 +6830,7 @@ async def unregister_nanny_plugin(self, comm, name):
)
return responses
- def transition(self, key, finish: str, *args, **kwargs):
+ def transition(self, key, finish: str, *args, stimulus_id: str, **kwargs):
"""Transition a key from its current state to the finish state
Examples
@@ -6743,12 +6846,12 @@ def transition(self, key, finish: str, *args, **kwargs):
--------
Scheduler.transitions: transitive version of this function
"""
- a: tuple = self._transition(key, finish, *args, **kwargs)
+ a: tuple = self._transition(key, finish, stimulus_id, *args, **kwargs)
recommendations, client_msgs, worker_msgs = a
self.send_all(client_msgs, worker_msgs)
return recommendations
- def transitions(self, recommendations: dict):
+ def transitions(self, recommendations: dict, stimulus_id: str):
"""Process transitions until none are left
This includes feedback from previous transitions and continues until we
@@ -6756,7 +6859,7 @@ def transitions(self, recommendations: dict):
"""
client_msgs: dict = {}
worker_msgs: dict = {}
- self._transitions(recommendations, client_msgs, worker_msgs)
+ self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)
def story(self, *keys):
@@ -6764,6 +6867,9 @@ def story(self, *keys):
keys = {key.key if isinstance(key, TaskState) else key for key in keys}
return scheduler_story(keys, self.transition_log)
+ async def get_story(self, keys=None):
+ return self.story(*keys)
+
transition_story = story
def reschedule(self, key=None, worker=None):
@@ -6784,7 +6890,7 @@ def reschedule(self, key=None, worker=None):
return
if worker and ts.processing_on.address != worker:
return
- self.transitions({key: "released"})
+ self.transitions({key: "released"}, f"reschedule-{time()}")
#####################
# Utility functions #
@@ -7217,7 +7323,9 @@ async def check_worker_ttl(self):
self.worker_ttl,
ws,
)
- await self.remove_worker(address=ws.address)
+ await self.remove_worker(
+ address=ws.address, stimulus_id=f"check-worker-ttl-{time()}"
+ )
def check_idle(self):
if any([ws.processing for ws in self.workers.values()]) or self.unrunnable:
@@ -7427,7 +7535,11 @@ def _add_to_memory(
def _propagate_forgotten(
- state: SchedulerState, ts: TaskState, recommendations: dict, worker_msgs: dict
+ state: SchedulerState,
+ ts: TaskState,
+ recommendations: dict,
+ worker_msgs: dict,
+ stimulus_id: str,
):
ts.state = "forgotten"
key: str = ts.key
@@ -7458,7 +7570,7 @@ def _propagate_forgotten(
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"propagate-forgotten-{time()}",
+ "stimulus_id": stimulus_id,
}
]
state.remove_all_replicas(ts)
diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py
index a18594dc6e4..94fdc3bc11f 100644
--- a/distributed/tests/test_cancelled_state.py
+++ b/distributed/tests/test_cancelled_state.py
@@ -6,7 +6,7 @@
from distributed.core import CommClosedError
from distributed.utils_test import (
_LockedCommPool,
- assert_worker_story,
+ assert_story,
gen_cluster,
inc,
slowinc,
@@ -82,7 +82,7 @@ def f(ev):
while "f1" in a.tasks:
await asyncio.sleep(0.01)
- assert_worker_story(
+ assert_story(
a.story("f1"),
[
("f1", "compute-task"),
@@ -160,7 +160,7 @@ async def wait_and_raise(*args, **kwargs):
await wait_for_state(fut1.key, "flight", b)
# Close in scheduler to ensure we transition and reschedule task properly
- await s.close_worker(worker=a.address)
+ await s.close_worker(worker=a.address, stimulus_id="test")
await wait_for_state(fut1.key, "resumed", b)
lock.release()
diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py
index abb9dcfd16e..0cbe97a50a2 100644
--- a/distributed/tests/test_client.py
+++ b/distributed/tests/test_client.py
@@ -77,6 +77,7 @@
from distributed.utils_test import (
TaskStateMetadataPlugin,
_UnhashableCallable,
+ assert_story,
async_wait_for,
asyncinc,
captured_logger,
@@ -4563,7 +4564,7 @@ def test_auto_normalize_collection_sync(c):
def assert_no_data_loss(scheduler):
- for key, start, finish, recommendations, _ in scheduler.transition_log:
+ for key, start, finish, recommendations, _, _ in scheduler.transition_log:
if start == "memory" and finish == "released":
for k, v in recommendations.items():
assert not (k == key and v == "waiting")
@@ -7494,3 +7495,58 @@ async def test_wait_for_workers_updates_info(c, s):
async with Worker(s.address):
await c.wait_for_workers(1)
assert c.scheduler_info()["workers"]
+
+
+@gen_cluster(client=True, nthreads=[("", 1)])
+async def test_client_story(c, s, *workers):
+ f = c.submit(inc, 1)
+ assert await f == 2
+ story = await c.story(f.key)
+
+ assert_story(
+ story,
+ [
+ (f.key, "released", "waiting", {f.key: "processing"}),
+ (f.key, "waiting", "processing", {}),
+ (f.key, "processing", "memory", {}),
+ (f.key, "compute-task"),
+ (f.key, "released", "waiting", "waiting", {f.key: "ready"}),
+ (f.key, "waiting", "ready", "ready", {f.key: "executing"}),
+ (f.key, "ready", "executing", "executing", {}),
+ (f.key, "put-in-memory"),
+ (f.key, "executing", "memory", "memory", {}),
+ ],
+ ordered_timestamps=False,
+ )
+
+ stimulus_ids = {ev[-2] for ev in story}
+ assert len(stimulus_ids) == 3
+
+
+class WorkerBrokenStory(Worker):
+ async def get_story(self, *args, **kw):
+ raise CommClosedError
+
+
+@gen_cluster(client=True, Worker=WorkerBrokenStory)
+@pytest.mark.parametrize("on_error", ["ignore", "raise"])
+async def test_client_story_failed_worker(c, s, a, b, on_error):
+ f = c.submit(inc, 1)
+ coro = c.story(f.key, on_error=on_error)
+ await f
+
+ if on_error == "raise":
+ with pytest.raises(CommClosedError) as e:
+ await coro
+ elif on_error == "ignore":
+ story = await coro
+ assert_story(
+ story,
+ [
+ (f.key, "released", "waiting", {f.key: "processing"}),
+ (f.key, "waiting", "processing", {}),
+ (f.key, "processing", "memory", {}),
+ ],
+ )
+ else:
+ raise ValueError(on_error)
diff --git a/distributed/tests/test_cluster_dump.py b/distributed/tests/test_cluster_dump.py
index c3912116c46..040bbc53eb7 100644
--- a/distributed/tests/test_cluster_dump.py
+++ b/distributed/tests/test_cluster_dump.py
@@ -8,7 +8,7 @@
import distributed
from distributed.cluster_dump import DumpArtefact, _tuple_to_list, write_state
-from distributed.utils_test import assert_worker_story, gen_cluster, gen_test, inc
+from distributed.utils_test import assert_story, gen_cluster, gen_test, inc
@pytest.mark.parametrize(
@@ -140,7 +140,7 @@ async def test_cluster_dump_story(c, s, a, b, tmp_path):
assert story.keys() == {"f1", "f2"}
for k, task_story in story.items():
- assert_worker_story(
+ assert_story(
task_story,
[
(k, "compute-task"),
diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py
index 7ee7e2fe0d9..45b26dd1234 100644
--- a/distributed/tests/test_scheduler.py
+++ b/distributed/tests/test_scheduler.py
@@ -37,6 +37,7 @@
from distributed.utils import TimeoutError
from distributed.utils_test import (
BrokenComm,
+ assert_story,
captured_logger,
cluster,
dec,
@@ -810,7 +811,7 @@ async def test_story(c, s, a, b):
story = s.story(x.key)
assert all(line in s.transition_log for line in story)
assert len(story) < len(s.transition_log)
- assert all(x.key == line[0] or x.key in line[-2] for line in story)
+ assert all(x.key == line[0] or x.key in line[3] for line in story)
assert len(s.story(x.key, y.key)) > len(story)
@@ -1233,7 +1234,7 @@ def f(dask_scheduler=None):
async def test_close_worker(c, s, a, b):
assert len(s.workers) == 2
- await s.close_worker(worker=a.address)
+ await s.close_worker(worker=a.address, stimulus_id="test")
assert len(s.workers) == 1
assert a.address not in s.workers
@@ -1251,7 +1252,7 @@ async def test_close_nanny(c, s, a, b):
assert a.process.is_alive()
a_worker_address = a.worker_address
start = time()
- await s.close_worker(worker=a_worker_address)
+ await s.close_worker(worker=a_worker_address, stimulus_id="test")
assert len(s.workers) == 1
assert a_worker_address not in s.workers
@@ -3098,7 +3099,9 @@ async def test_rebalance_dead_recipient(client, s, a, b, c):
await c.close()
assert s.workers.keys() == {a.address, b.address}
- out = await s._rebalance_move_data([(a_ws, b_ws, x_ts), (a_ws, c_ws, y_ts)])
+ out = await s._rebalance_move_data(
+ [(a_ws, b_ws, x_ts), (a_ws, c_ws, y_ts)], stimulus_id="test"
+ )
assert out == {"status": "partial-fail", "keys": [y.key]}
assert a.data == {y.key: "y"}
assert b.data == {x.key: "x"}
@@ -3117,7 +3120,7 @@ async def test_delete_worker_data(c, s, a, b):
assert b.data == {y.key: "y"}
assert s.tasks.keys() == {x.key, y.key, z.key}
- await s.delete_worker_data(a.address, [x.key, y.key])
+ await s.delete_worker_data(a.address, [x.key, y.key], stimulus_id="test")
assert a.data == {z.key: "z"}
assert b.data == {y.key: "y"}
assert s.tasks.keys() == {y.key, z.key}
@@ -3131,8 +3134,8 @@ async def test_delete_worker_data_double_delete(c, s, a):
"""
x, y = await c.scatter(["x", "y"])
await asyncio.gather(
- s.delete_worker_data(a.address, [x.key]),
- s.delete_worker_data(a.address, [x.key]),
+ s.delete_worker_data(a.address, [x.key], stimulus_id="test"),
+ s.delete_worker_data(a.address, [x.key], stimulus_id="test"),
)
assert a.data == {y.key: "y"}
a_ws = s.workers[a.address]
@@ -3147,7 +3150,7 @@ async def test_delete_worker_data_bad_worker(s, a, b):
"""
await a.close()
assert s.workers.keys() == {b.address}
- await s.delete_worker_data(a.address, ["x"])
+ await s.delete_worker_data(a.address, ["x"], stimulus_id="test")
@pytest.mark.parametrize("bad_first", [False, True])
@@ -3162,7 +3165,7 @@ async def test_delete_worker_data_bad_task(c, s, a, bad_first):
assert s.tasks.keys() == {x.key, y.key}
keys = ["notexist", x.key] if bad_first else [x.key, "notexist"]
- await s.delete_worker_data(a.address, keys)
+ await s.delete_worker_data(a.address, keys, stimulus_id="test")
assert a.data == {y.key: "y"}
assert s.tasks.keys() == {y.key}
assert s.workers[a.address].nbytes == s.tasks[y.key].nbytes
@@ -3253,7 +3256,7 @@ async def test_worker_reconnect_task_memory(c, s, a):
await res
assert ("no-worker", "memory") in {
- (start, finish) for (_, start, finish, _, _) in s.transition_log
+ (start, finish) for (_, start, finish, _, _, _) in s.transition_log
}
@@ -3277,7 +3280,7 @@ async def test_worker_reconnect_task_memory_with_resources(c, s, a):
await res
assert ("no-worker", "memory") in {
- (start, finish) for (_, start, finish, _, _) in s.transition_log
+ (start, finish) for (_, start, finish, _, _, _) in s.transition_log
}
@@ -3538,3 +3541,84 @@ async def test_repr(s, a):
repr(ws_b)
== f""
)
+
+
+@gen_cluster(client=True)
+async def test_stimulus_success(c, s, a, b):
+ f = c.submit(inc, 1)
+ key = f.key
+
+ await f
+ await c.close()
+
+ stories = s.story(key)
+
+ assert_story(
+ stories,
+ [
+ (key, "released", "waiting", {key: "processing"}),
+ (key, "waiting", "processing", {}),
+ (key, "processing", "memory", {}),
+ (
+ key,
+ "memory",
+ "forgotten",
+ {},
+ ),
+ ],
+ )
+
+ stimulus_ids = {s[-2] for s in stories}
+ assert len(stimulus_ids) == 3
+
+
+@gen_cluster(client=True)
+async def test_stimulus_retry(c, s, a, b):
+ def task():
+ assert dask.config.get("foo")
+
+ with dask.config.set(foo=False):
+ f = c.submit(task)
+ with pytest.raises(AssertionError):
+ await f
+
+ with dask.config.set(foo=True):
+ await f.retry()
+ await f
+
+ story = s.story(f.key)
+ stimulus_ids = {s[-2] for s in story}
+ assert len(stimulus_ids) == 4
+
+ assert_story(
+ story,
+ [
+ (f.key, "released", "waiting", {f.key: "processing"}),
+ (f.key, "waiting", "processing", {}),
+ (
+ f.key,
+ "processing",
+ "erred",
+ {},
+ ),
+ (
+ f.key,
+ "erred",
+ "released",
+ {},
+ ),
+ (
+ f.key,
+ "released",
+ "waiting",
+ {f.key: "processing"},
+ ),
+ (
+ f.key,
+ "waiting",
+ "processing",
+ {},
+ ),
+ (f.key, "processing", "memory", {}),
+ ],
+ )
diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py
index c88e8b592ed..85a2b754616 100755
--- a/distributed/tests/test_utils_test.py
+++ b/distributed/tests/test_utils_test.py
@@ -21,7 +21,7 @@
from distributed.utils_test import (
_LockedCommPool,
_UnhashableCallable,
- assert_worker_story,
+ assert_story,
check_process_leak,
cluster,
dump_cluster_state,
@@ -406,7 +406,7 @@ async def inner_test(c, s, a, b):
assert "workers" in state
-def test_assert_worker_story():
+def test_assert_story():
now = time()
story = [
("foo", "id1", now - 600),
@@ -414,38 +414,38 @@ def test_assert_worker_story():
("baz", {1: 2}, "id2", now),
]
# strict=False
- assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})])
- assert_worker_story(story, [])
- assert_worker_story(story, [("foo",)])
- assert_worker_story(story, [("foo",), ("bar",)])
- assert_worker_story(story, [("baz", lambda d: d[1] == 2)])
+ assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})])
+ assert_story(story, [])
+ assert_story(story, [("foo",)])
+ assert_story(story, [("foo",), ("bar",)])
+ assert_story(story, [("baz", lambda d: d[1] == 2)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo", "nomatch")])
+ assert_story(story, [("foo", "nomatch")])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz",)])
+ assert_story(story, [("baz",)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz", {1: 3})])
+ assert_story(story, [("baz", {1: 3})])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)])
+ assert_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz", lambda d: d[1] == 3)])
+ assert_story(story, [("baz", lambda d: d[1] == 3)])
with pytest.raises(KeyError): # Faulty lambda
- assert_worker_story(story, [("baz", lambda d: d[2] == 1)])
- assert_worker_story([], [])
- assert_worker_story([("foo", "id1", now)], [("foo",)])
+ assert_story(story, [("baz", lambda d: d[2] == 1)])
+ assert_story([], [])
+ assert_story([("foo", "id1", now)], [("foo",)])
with pytest.raises(AssertionError):
- assert_worker_story([], [("foo",)])
+ assert_story([], [("foo",)])
# strict=True
- assert_worker_story([], [], strict=True)
- assert_worker_story([("foo", "id1", now)], [("foo",)])
- assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True)
+ assert_story([], [], strict=True)
+ assert_story([("foo", "id1", now)], [("foo",)])
+ assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("bar",)], strict=True)
+ assert_story(story, [("foo",), ("bar",)], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("baz", {1: 2})], strict=True)
+ assert_story(story, [("foo",), ("baz", {1: 2})], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [], strict=True)
+ assert_story(story, [], strict=True)
@pytest.mark.parametrize(
@@ -466,11 +466,11 @@ def test_assert_worker_story():
),
],
)
-def test_assert_worker_story_malformed_story(story_factory):
+def test_assert_story_malformed_story(story_factory):
# defer the calls to time() to when the test runs rather than collection
story = story_factory()
with pytest.raises(AssertionError, match="Malformed story event"):
- assert_worker_story(story, [])
+ assert_story(story, [])
@gen_cluster()
diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py
index d56b768555c..ec44f5e1356 100644
--- a/distributed/tests/test_worker.py
+++ b/distributed/tests/test_worker.py
@@ -48,7 +48,7 @@
from distributed.utils_test import (
TaskStateMetadataPlugin,
_LockedCommPool,
- assert_worker_story,
+ assert_story,
captured_logger,
dec,
div,
@@ -1406,7 +1406,7 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None:
"""Test that an in-memory key was transferred from worker w_from to worker w_to by
the Active Memory Manager and it was not recalculated on w_to
"""
- assert_worker_story(
+ assert_story(
w_to.story(key),
[
(key, "ensure-task-exists", "released"),
@@ -1745,7 +1745,8 @@ async def test_story_with_deps(c, s, a, b):
# Story now includes randomized stimulus_ids and timestamps.
stimulus_ids = {ev[-2] for ev in story}
- assert len(stimulus_ids) == 2, stimulus_ids
+ assert len(stimulus_ids) == 3
+
# This is a simple transition log
expected = [
("res", "compute-task"),
@@ -1755,7 +1756,7 @@ async def test_story_with_deps(c, s, a, b):
("res", "put-in-memory"),
("res", "executing", "memory", "memory", {}),
]
- assert_worker_story(story, expected, strict=True)
+ assert_story(story, expected, strict=True)
story = b.story("dep")
stimulus_ids = {ev[-2] for ev in story}
@@ -1770,7 +1771,7 @@ async def test_story_with_deps(c, s, a, b):
("dep", "put-in-memory"),
("dep", "flight", "memory", "memory", {"res": "ready"}),
]
- assert_worker_story(story, expected, strict=True)
+ assert_story(story, expected, strict=True)
@gen_cluster(client=True, nthreads=[("", 1)])
@@ -2708,7 +2709,7 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b):
while fut1.key not in b.tasks or b.tasks[fut1.key].state == "flight":
await asyncio.sleep(0)
- s.handle_missing_data(key="f1", errant_worker=a.address)
+ s.handle_missing_data(key="f1", errant_worker=a.address, stimulus_id="test")
await fut2
@@ -2753,7 +2754,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b):
# same communication channel
for fut in (futA, futB):
- assert_worker_story(
+ assert_story(
b.story(fut.key),
[
("gather-dependencies", a.address, {fut.key}),
@@ -2812,7 +2813,7 @@ def __getstate__(self):
assert await y == 123
story = await c.run(lambda dask_worker: dask_worker.story("x"))
- assert_worker_story(
+ assert_story(
story[b],
[
("x", "ensure-task-exists", "released"),
@@ -2989,7 +2990,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers):
await f2
- assert_worker_story(a.story(f1.key), [(f1.key, "missing-dep")])
+ assert_story(a.story(f1.key), [(f1.key, "missing-dep")])
assert a.tasks[f1.key].suspicious_count == 0
assert s.tasks[f1.key].suspicious == 0
@@ -3069,7 +3070,7 @@ async def test_missing_released_zombie_tasks_2(c, s, a, b):
while b.tasks:
await asyncio.sleep(0.01)
- assert_worker_story(
+ assert_story(
b.story(ts),
[("f1", "missing", "released", "released", {"f1": "forgotten"})],
)
@@ -3181,7 +3182,7 @@ async def test_task_flight_compute_oserror(c, s, a, b):
("f1", "put-in-memory"),
("f1", "executing", "memory", "memory", {}),
]
- assert_worker_story(sum_story, expected_sum_story, strict=True)
+ assert_story(sum_story, expected_sum_story, strict=True)
@gen_cluster(client=True)
diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py
index d8337ace8e5..feb99ee995a 100644
--- a/distributed/tests/test_worker_state_machine.py
+++ b/distributed/tests/test_worker_state_machine.py
@@ -115,15 +115,19 @@ def test_slots(cls):
def test_sendmsg_to_dict():
# Arbitrary sample class
- smsg = ReleaseWorkerDataMsg(key="x")
- assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"}
+ smsg = ReleaseWorkerDataMsg(key="x", stimulus_id="test")
+ assert smsg.to_dict() == {
+ "op": "release-worker-data",
+ "key": "x",
+ "stimulus_id": "test",
+ }
def test_merge_recs_instructions():
x = TaskState("x")
y = TaskState("y")
- instr1 = RescheduleMsg(key="foo", worker="a")
- instr2 = RescheduleMsg(key="bar", worker="b")
+ instr1 = RescheduleMsg(key="foo", worker="a", stimulus_id="test")
+ instr2 = RescheduleMsg(key="bar", worker="b", stimulus_id="test")
assert merge_recs_instructions(
({x: "memory"}, [instr1]),
({y: "released"}, [instr2]),
diff --git a/distributed/utils_test.py b/distributed/utils_test.py
index 31e6cbe08fe..c85199a4bd8 100644
--- a/distributed/utils_test.py
+++ b/distributed/utils_test.py
@@ -1875,8 +1875,12 @@ def xfail_ssl_issue5601():
raise
-def assert_worker_story(
- story: list[tuple], expect: list[tuple], *, strict: bool = False
+def assert_story(
+ story: list[tuple],
+ expect: list[tuple],
+ *,
+ strict: bool = False,
+ ordered_timestamps: bool = True,
) -> None:
"""Test the output of ``Worker.story``
@@ -1906,6 +1910,10 @@ def assert_worker_story(
If True, the story must contain exactly as many events as expect.
If False (the default), the story may contain more events than expect; extra
events are ignored.
+ ordered_timestamps: bool, optional
+ If False, timestamps are not required to be monotically increasing.
+ Useful for asserting stories composed from the scheduler and
+ multiple workers
"""
now = time()
prev_ts = 0.0
@@ -1915,7 +1923,8 @@ def assert_worker_story(
assert isinstance(ev, tuple)
assert isinstance(ev[-2], str) and ev[-2] # stimulus_id
assert isinstance(ev[-1], float) # timestamp
- assert prev_ts <= ev[-1] # Timestamps are monotonic ascending
+ if ordered_timestamps:
+ assert prev_ts <= ev[-1] # Timestamps are monotonic ascending
# Timestamps are within the last hour. It's been observed that a timestamp
# generated in a Nanny process can be a few milliseconds in the future.
assert now - 3600 < ev[-1] <= now + 1
@@ -1941,7 +1950,7 @@ def assert_worker_story(
break
except StopIteration:
raise AssertionError(
- f"assert_worker_story({strict=}) failed\n"
+ f"assert_story({strict=}) failed\n"
f"story:\n{_format_story(story)}\n"
f"expect:\n{_format_story(expect)}"
) from None
diff --git a/distributed/worker.py b/distributed/worker.py
index 4271abc35c1..708d409f54a 100644
--- a/distributed/worker.py
+++ b/distributed/worker.py
@@ -754,6 +754,7 @@ def __init__(
"benchmark_disk": self.benchmark_disk,
"benchmark_memory": self.benchmark_memory,
"benchmark_network": self.benchmark_network,
+ "get_story": self.get_story,
}
stream_handlers = {
@@ -927,21 +928,25 @@ def status(self, value):
"""
prev_status = self.status
ServerNode.status.__set__(self, value)
- self._send_worker_status_change()
+ self._send_worker_status_change(f"worker-status-change-{time()}")
if prev_status == Status.paused and value == Status.running:
self.handle_stimulus(UnpauseEvent(stimulus_id=f"set-status-{time()}"))
- def _send_worker_status_change(self) -> None:
+ def _send_worker_status_change(self, stimulus_id: str) -> None:
if (
self.batched_stream
and self.batched_stream.comm
and not self.batched_stream.comm.closed()
):
self.batched_stream.send(
- {"op": "worker-status-change", "status": self._status.name}
+ {
+ "op": "worker-status-change",
+ "status": self._status.name,
+ "stimulus_id": stimulus_id,
+ },
)
elif self._status != Status.closed:
- self.loop.call_later(0.05, self._send_worker_status_change)
+ self.loop.call_later(0.05, self._send_worker_status_change, stimulus_id)
async def get_metrics(self) -> dict:
try:
@@ -1083,6 +1088,7 @@ async def _register_with_scheduler(self):
versions=get_versions(),
metrics=await self.get_metrics(),
extra=await self.get_startup_information(),
+ stimulus_id=f"worker-connect-{time()}",
),
serializers=["msgpack"],
)
@@ -1476,7 +1482,9 @@ async def close(
if report and self.contact_address is not None:
await asyncio.wait_for(
self.scheduler.unregister(
- address=self.contact_address, safe=safe
+ address=self.contact_address,
+ safe=safe,
+ stimulus_id=f"worker-close-{time()}",
),
timeout,
)
@@ -1541,7 +1549,10 @@ async def close_gracefully(self, restart=None):
# Scheduler.retire_workers will set the status to closing_gracefully and push it
# back to this worker.
await self.scheduler.retire_workers(
- workers=[self.address], close_workers=False, remove=False
+ workers=[self.address],
+ close_workers=False,
+ remove=False,
+ stimulus_id=f"worker-close-gracefully-{time()}",
)
await self.close(safe=True, nanny=not restart)
@@ -1896,7 +1907,9 @@ def handle_compute_task(
pass
elif ts.state == "memory":
recommendations[ts] = "memory"
- instructions.append(self._get_task_finished_msg(ts))
+ instructions.append(
+ self._get_task_finished_msg(ts, stimulus_id=stimulus_id)
+ )
elif ts.state in {
"released",
"fetch",
@@ -2035,7 +2048,7 @@ def transition_memory_released(
recs, instructions = self.transition_generic_released(
ts, stimulus_id=stimulus_id
)
- instructions.append(ReleaseWorkerDataMsg(ts.key))
+ instructions.append(ReleaseWorkerDataMsg(key=ts.key, stimulus_id=stimulus_id))
return recs, instructions
def transition_waiting_constrained(
@@ -2058,7 +2071,7 @@ def transition_long_running_rescheduled(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
recs: Recs = {ts: "released"}
- smsg = RescheduleMsg(key=ts.key, worker=self.address)
+ smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id)
return recs, [smsg]
def transition_executing_rescheduled(
@@ -2069,7 +2082,14 @@ def transition_executing_rescheduled(
self._executing.discard(ts)
return merge_recs_instructions(
- ({ts: "released"}, [RescheduleMsg(key=ts.key, worker=self.address)]),
+ (
+ {ts: "released"},
+ [
+ RescheduleMsg(
+ key=ts.key, worker=self.address, stimulus_id=stimulus_id
+ )
+ ],
+ ),
self._ensure_computing(),
)
@@ -2147,6 +2167,7 @@ def transition_generic_error(
traceback_text=traceback_text,
thread=self.threads.get(ts.key),
startstops=ts.startstops,
+ stimulus_id=stimulus_id,
)
return {}, [smsg]
@@ -2336,7 +2357,9 @@ def transition_generic_memory(
else:
if self.validate:
assert ts.key in self.data or ts.key in self.actors
- instructions.append(self._get_task_finished_msg(ts))
+ instructions.append(
+ self._get_task_finished_msg(ts, stimulus_id=stimulus_id)
+ )
return recs, instructions
@@ -2455,7 +2478,10 @@ def transition_executing_long_running(
self.long_running.add(ts.key)
return merge_recs_instructions(
- ({}, [LongRunningMsg(key=ts.key, compute_duration=compute_duration)]),
+ (
+ {},
+ [LongRunningMsg(key=ts.key, compute_duration=compute_duration)],
+ ),
self._ensure_computing(),
)
@@ -2686,6 +2712,9 @@ def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]:
keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks}
return worker_story(keys, self.log)
+ async def get_story(self, keys=None):
+ return self.story(*keys)
+
def stimulus_story(
self, *keys_or_tasks: str | TaskState
) -> list[StateMachineEvent]:
@@ -2754,7 +2783,9 @@ def ensure_communicating(self) -> None:
for el in skipped_worker_in_flight:
self.data_needed.push(el)
- def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg:
+ def _get_task_finished_msg(
+ self, ts: TaskState, stimulus_id: str
+ ) -> TaskFinishedMsg:
if ts.key not in self.data and ts.key not in self.actors:
raise RuntimeError(f"Task {ts} not ready")
typ = ts.type
@@ -2780,6 +2811,7 @@ def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg:
metadata=ts.metadata,
thread=self.threads.get(ts.key),
startstops=ts.startstops,
+ stimulus_id=stimulus_id,
)
def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs:
@@ -3087,7 +3119,12 @@ async def gather_dep(
self.has_what[worker].discard(ts.key)
self.log.append((d, "missing-dep", stimulus_id, time()))
self.batched_stream.send(
- {"op": "missing-data", "errant_worker": worker, "key": d}
+ {
+ "op": "missing-data",
+ "errant_worker": worker,
+ "key": d,
+ "stimulus_id": stimulus_id,
+ }
)
recommendations[ts] = "fetch" if ts.who_has else "missing"
del data, response
@@ -3190,7 +3227,7 @@ def handle_steal_request(self, key: str, stimulus_id: str) -> None:
# `transition_constrained_executing`
self.transition(ts, "released", stimulus_id=stimulus_id)
- def handle_worker_status_change(self, status: str) -> None:
+ def handle_worker_status_change(self, status: str, stimulus_id: str) -> None:
new_status = Status.lookup[status] # type: ignore
if (
@@ -3201,7 +3238,7 @@ def handle_worker_status_change(self, status: str) -> None:
"Invalid Worker.status transition: %s -> %s", self._status, new_status
)
# Reiterate the current status to the scheduler to restore sync
- self._send_worker_status_change()
+ self._send_worker_status_change(stimulus_id)
else:
# Update status and send confirmation to the Scheduler (see status.setter)
self.status = new_status
@@ -3555,7 +3592,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
stop=result["stop"],
nbytes=result["nbytes"],
type=result["type"],
- stimulus_id=stimulus_id,
+ stimulus_id=f"task-finished-{time()}",
)
if isinstance(result["actual-exception"], Reschedule):
@@ -3582,7 +3619,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
traceback=result["traceback"],
exception_text=result["exception_text"],
traceback_text=result["traceback_text"],
- stimulus_id=stimulus_id,
+ stimulus_id=f"task-erred-{time()}",
)
except Exception as exc:
@@ -3596,7 +3633,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
traceback=msg["traceback"],
exception_text=msg["exception_text"],
traceback_text=msg["traceback_text"],
- stimulus_id=stimulus_id,
+ stimulus_id=f"task-erred-{time()}",
)
@functools.singledispatchmethod
diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py
index 5e993fe5041..012b34bba9f 100644
--- a/distributed/worker_state_machine.py
+++ b/distributed/worker_state_machine.py
@@ -289,6 +289,7 @@ class TaskFinishedMsg(SendMessageToScheduler):
metadata: dict
thread: int | None
startstops: list[StartStop]
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore
def to_dict(self) -> dict[str, Any]:
@@ -308,6 +309,7 @@ class TaskErredMsg(SendMessageToScheduler):
traceback_text: str
thread: int | None
startstops: list[StartStop]
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore
def to_dict(self) -> dict[str, Any]:
@@ -320,8 +322,9 @@ def to_dict(self) -> dict[str, Any]:
class ReleaseWorkerDataMsg(SendMessageToScheduler):
op = "release-worker-data"
- __slots__ = ("key",)
+ __slots__ = ("key", "stimulus_id")
key: str
+ stimulus_id: str
# Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception
@@ -329,9 +332,10 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler):
class RescheduleMsg(SendMessageToScheduler):
op = "reschedule"
- __slots__ = ("key", "worker")
+ __slots__ = ("key", "worker", "stimulus_id")
key: str
worker: str
+ stimulus_id: str
@dataclass
@@ -424,6 +428,7 @@ class ExecuteSuccessEvent(StateMachineEvent):
stop: float
nbytes: int
type: type | None
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore
def to_loggable(self, *, handled: float) -> StateMachineEvent:
@@ -446,6 +451,7 @@ class ExecuteFailureEvent(StateMachineEvent):
traceback: Serialize | None
exception_text: str
traceback_text: str
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore
def _after_from_dict(self) -> None: