diff --git a/distributed/client.py b/distributed/client.py
index 86ca5f506e0..8da3f3b91df 100644
--- a/distributed/client.py
+++ b/distributed/client.py
@@ -209,6 +209,7 @@ def __init__(self, key, client=None, inform=True, state=None):
"op": "client-desires-keys",
"keys": [stringify(key)],
"client": self.client.id,
+ "stimulus_id": f"client-desires-keys-{time()}",
}
)
@@ -472,6 +473,7 @@ def __setstate__(self, state):
"tasks": {},
"keys": [stringify(self.key)],
"client": c.id,
+ "stimulus_id": f"stimulus-id-{time()}",
}
)
@@ -1265,6 +1267,7 @@ async def _ensure_connected(self, timeout=None):
"client": self.id,
"reply": False,
"versions": version_module.get_versions(),
+ "stimulus_id": f"client-ensure-connected-{time()}",
}
)
except Exception:
@@ -1371,9 +1374,9 @@ def _dec_ref(self, key):
self.refcount[key] -= 1
if self.refcount[key] == 0:
del self.refcount[key]
- self._release_key(key)
+ self._release_key(key, f"client-release-key-{time()}")
- def _release_key(self, key):
+ def _release_key(self, key, stimulus_id: str):
"""Release key from distributed memory"""
logger.debug("Release key %s", key)
st = self.futures.pop(key, None)
@@ -1381,7 +1384,12 @@ def _release_key(self, key):
st.cancel()
if self.status != "closed":
self._send_to_scheduler(
- {"op": "client-releases-keys", "keys": [key], "client": self.id}
+ {
+ "op": "client-releases-keys",
+ "keys": [key],
+ "client": self.id,
+ "stimulus_id": stimulus_id,
+ }
)
async def _handle_report(self):
@@ -1506,7 +1514,9 @@ async def _close(self, fast=False):
and self.scheduler_comm.comm
and not self.scheduler_comm.comm.closed()
):
- self._send_to_scheduler({"op": "close-client"})
+ self._send_to_scheduler(
+ {"op": "close-client", "stimulus_id": f"client-close-{time()}"}
+ )
self._send_to_scheduler({"op": "close-stream"})
current_task = asyncio.current_task()
@@ -1527,8 +1537,10 @@ async def _close(self, fast=False):
):
await self.scheduler_comm.close()
+ stimulus_id = f"client-close-{time()}"
+
for key in list(self.futures):
- self._release_key(key=key)
+ self._release_key(key=key, stimulus_id=stimulus_id)
if self._start_arg is None:
with suppress(AttributeError):
@@ -2110,12 +2122,20 @@ async def _gather_remote(self, direct, local_worker):
response = {"status": "OK", "data": data2}
if missing_keys:
keys2 = [key for key in keys if key not in data2]
- response = await retry_operation(self.scheduler.gather, keys=keys2)
+ response = await retry_operation(
+ self.scheduler.gather,
+ keys=keys2,
+ stimulus_id=f"client-gather-remote-{time()}",
+ )
if response["status"] == "OK":
response["data"].update(data2)
else: # ask scheduler to gather data for us
- response = await retry_operation(self.scheduler.gather, keys=keys)
+ response = await retry_operation(
+ self.scheduler.gather,
+ keys=keys,
+ stimulus_id=f"client-gather-remote-{time()}",
+ )
return response
@@ -2201,6 +2221,8 @@ async def _scatter(
d = await self._scatter(keymap(stringify, data), workers, broadcast)
return {k: d[stringify(k)] for k in data}
+ stimulus_id = f"client-scatter-{time()}"
+
if isinstance(data, type(range(0))):
data = list(data)
input_type = type(data)
@@ -2242,6 +2264,7 @@ async def _scatter(
who_has={key: [local_worker.address] for key in data},
nbytes=valmap(sizeof, data),
client=self.id,
+ stimulus_id=stimulus_id,
)
else:
@@ -2264,7 +2287,10 @@ async def _scatter(
)
await self.scheduler.update_data(
- who_has=who_has, nbytes=nbytes, client=self.id
+ who_has=who_has,
+ nbytes=nbytes,
+ client=self.id,
+ stimulus_id=stimulus_id,
)
else:
await self.scheduler.scatter(
@@ -2273,6 +2299,7 @@ async def _scatter(
client=self.id,
broadcast=broadcast,
timeout=timeout,
+ stimulus_id=stimulus_id,
)
out = {k: Future(k, self, inform=False) for k in data}
@@ -2396,7 +2423,12 @@ def scatter(
async def _cancel(self, futures, force=False):
keys = list({stringify(f.key) for f in futures_of(futures)})
- await self.scheduler.cancel(keys=keys, client=self.id, force=force)
+ await self.scheduler.cancel(
+ keys=keys,
+ client=self.id,
+ force=force,
+ stimulus_id=f"client-cancel-{time()}",
+ )
for k in keys:
st = self.futures.pop(k, None)
if st is not None:
@@ -2423,7 +2455,9 @@ def cancel(self, futures, asynchronous=None, force=False):
async def _retry(self, futures):
keys = list({stringify(f.key) for f in futures_of(futures)})
- response = await self.scheduler.retry(keys=keys, client=self.id)
+ response = await self.scheduler.retry(
+ keys=keys, client=self.id, stimulus_id=f"client-retry-{time()}"
+ )
for key in response:
st = self.futures[key]
st.retry()
@@ -2922,6 +2956,7 @@ def _graph_to_futures(
"fifo_timeout": fifo_timeout,
"actors": actors,
"code": self._get_computation_code(),
+ "stimulus_id": f"client-update-graph-hlg-{time()}",
}
)
return futures
@@ -3347,7 +3382,13 @@ async def _restart(self, timeout=no_default):
if timeout is not None:
timeout = parse_timedelta(timeout, "s")
- self._send_to_scheduler({"op": "restart", "timeout": timeout})
+ self._send_to_scheduler(
+ {
+ "op": "restart",
+ "timeout": timeout,
+ "stimulus_id": f"client-restart-{time()}",
+ }
+ )
self._restart_event = asyncio.Event()
try:
await asyncio.wait_for(self._restart_event.wait(), timeout)
@@ -3424,7 +3465,9 @@ async def _rebalance(self, futures=None, workers=None):
keys = list({stringify(f.key) for f in self.futures_of(futures)})
else:
keys = None
- result = await self.scheduler.rebalance(keys=keys, workers=workers)
+ result = await self.scheduler.rebalance(
+ keys=keys, workers=workers, stimulus_id=f"client-rebalance-{time()}"
+ )
if result["status"] == "partial-fail":
raise KeyError(f"Could not rebalance keys: {result['keys']}")
assert result["status"] == "OK", result
@@ -3459,7 +3502,11 @@ async def _replicate(self, futures, n=None, workers=None, branching_factor=2):
await _wait(futures)
keys = {stringify(f.key) for f in futures}
await self.scheduler.replicate(
- keys=list(keys), n=n, workers=workers, branching_factor=branching_factor
+ keys=list(keys),
+ n=n,
+ workers=workers,
+ branching_factor=branching_factor,
+ stimulus_id=f"client-replicate-{time()}",
)
def replicate(self, futures, n=None, workers=None, branching_factor=2, **kwargs):
@@ -4177,6 +4224,7 @@ def retire_workers(
self.scheduler.retire_workers,
workers=workers,
close_workers=close_workers,
+ stimulus_id=f"client-retire-workers-{time()}",
**kwargs,
)
@@ -5138,6 +5186,7 @@ def fire_and_forget(obj):
"op": "client-desires-keys",
"keys": [stringify(future.key)],
"client": "fire-and-forget",
+ "stimulus_id": f"client-fire-and-forget-{time()}",
}
)
diff --git a/distributed/core.py b/distributed/core.py
index 4b37721b8ba..ff8d48859c3 100644
--- a/distributed/core.py
+++ b/distributed/core.py
@@ -13,6 +13,7 @@
from collections import defaultdict
from collections.abc import Container
from contextlib import suppress
+from contextvars import copy_context
from enum import Enum
from functools import partial
from typing import Callable, ClassVar, TypedDict, TypeVar
@@ -619,7 +620,9 @@ async def handle_stream(self, comm, extra=None, every_cycle=()):
break
handler = self.stream_handlers[op]
if is_coroutine_function(handler):
- self.loop.add_callback(handler, **merge(extra, msg))
+ self.loop.add_callback(
+ copy_context().run, handler, **merge(extra, msg)
+ )
await gen.sleep(0)
else:
handler(**merge(extra, msg))
@@ -629,7 +632,7 @@ async def handle_stream(self, comm, extra=None, every_cycle=()):
for func in every_cycle:
if is_coroutine_function(func):
- self.loop.add_callback(func)
+ self.loop.add_callback(copy_context().run, func)
else:
func()
diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py
index d17f82fc893..e966b0f9a40 100644
--- a/distributed/deploy/adaptive.py
+++ b/distributed/deploy/adaptive.py
@@ -7,6 +7,7 @@
from dask.utils import parse_timedelta
from distributed.deploy.adaptive_core import AdaptiveCore
+from distributed.metrics import time
from distributed.protocol import pickle
from distributed.utils import log_errors
@@ -193,6 +194,7 @@ async def scale_down(self, workers):
names=workers,
remove=True,
close_workers=True,
+ stimulus_id=f"scale-down-{time()}",
)
# close workers more forcefully
diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py
index 3a9fa16a3e7..bc2a811a858 100644
--- a/distributed/deploy/spec.py
+++ b/distributed/deploy/spec.py
@@ -19,6 +19,7 @@
from distributed.core import CommClosedError, Status, rpc
from distributed.deploy.adaptive import Adaptive
from distributed.deploy.cluster import Cluster
+from distributed.metrics import time
from distributed.scheduler import Scheduler
from distributed.security import Security
from distributed.utils import NoOpAwaitable, TimeoutError, import_term, silence_logging
@@ -318,7 +319,10 @@ async def _correct_state_internal(self):
to_close = set(self.workers) - set(self.worker_spec)
if to_close:
if self.scheduler.status == Status.running:
- await self.scheduler_comm.retire_workers(workers=list(to_close))
+ await self.scheduler_comm.retire_workers(
+ workers=list(to_close),
+ stimulus_id=f"spec-cluster-correct-internal-state-{time()}",
+ )
tasks = [
asyncio.create_task(self.workers[w].close())
for w in to_close
diff --git a/distributed/diagnostics/tests/test_progress.py b/distributed/diagnostics/tests/test_progress.py
index 872b1e4c7f6..591e7b29dc0 100644
--- a/distributed/diagnostics/tests/test_progress.py
+++ b/distributed/diagnostics/tests/test_progress.py
@@ -209,7 +209,7 @@ async def test_group_timing(c, s, a, b):
]
)
- await s.restart()
+ await s.handle_restart(stimulus_id="test")
assert len(p.time) == 2
assert len(p.nthreads) == 2
assert len(p.compute) == 0
diff --git a/distributed/diagnostics/tests/test_widgets.py b/distributed/diagnostics/tests/test_widgets.py
index e47c1bd5bc9..844c9b5fba4 100644
--- a/distributed/diagnostics/tests/test_widgets.py
+++ b/distributed/diagnostics/tests/test_widgets.py
@@ -145,7 +145,7 @@ async def test_multi_progressbar_widget(c, s, a, b):
@gen_cluster()
async def test_multi_progressbar_widget_after_close(s, a, b):
- s.update_graph(
+ s.handle_update_graph(
tasks=valmap(
dumps_task,
{
@@ -166,6 +166,7 @@ async def test_multi_progressbar_widget_after_close(s, a, b):
"y-2": {"y-1"},
"e": {"y-2"},
},
+ stimulus_id="test",
)
p = MultiProgressWidget(["x-1", "x-2", "x-3"], scheduler=s.address)
@@ -231,7 +232,7 @@ def test_progressbar_cancel(client):
@gen_cluster()
async def test_multibar_complete(s, a, b):
- s.update_graph(
+ s.handle_update_graph(
tasks=valmap(
dumps_task,
{
@@ -252,6 +253,7 @@ async def test_multibar_complete(s, a, b):
"y-2": {"y-1"},
"e": {"y-2"},
},
+ stimulus_id="test",
)
p = MultiProgressWidget(["e"], scheduler=s.address, complete=True)
diff --git a/distributed/http/templates/task.html b/distributed/http/templates/task.html
index 0b5c10695e0..f10aaad5602 100644
--- a/distributed/http/templates/task.html
+++ b/distributed/http/templates/task.html
@@ -118,16 +118,18 @@
Transition Log
Key |
Start |
Finish |
+ Stimulus ID |
Recommended Key |
Recommended Action |
- {% for key, start, finish, recommendations, transition_time in scheduler.story(Task) %}
+ {% for key, start, finish, recommendations, stimulus_id, transition_time in scheduler.story(Task) %}
| {{ fromtimestamp(transition_time) }} |
{{key}} |
{{ start }} |
{{ finish }} |
+ {{ stimulus_id }} |
|
|
@@ -137,6 +139,7 @@ Transition Log
|
|
|
+ |
{{key2}} |
{{ rec }} |
diff --git a/distributed/nanny.py b/distributed/nanny.py
index db2371523e8..325329e7524 100644
--- a/distributed/nanny.py
+++ b/distributed/nanny.py
@@ -290,7 +290,10 @@ async def _unregister(self, timeout=10):
allowed_errors = (TimeoutError, CommClosedError, EnvironmentError, RPCClosed)
with suppress(allowed_errors):
await asyncio.wait_for(
- self.scheduler.unregister(address=self.worker_address), timeout
+ self.scheduler.unregister(
+ address=self.worker_address, stimulus_id=f"close-nanny-{time()}"
+ ),
+ timeout,
)
@property
diff --git a/distributed/scheduler.py b/distributed/scheduler.py
index fc96d5d02ec..42ed631cfba 100644
--- a/distributed/scheduler.py
+++ b/distributed/scheduler.py
@@ -25,6 +25,7 @@
Set,
)
from contextlib import suppress
+from contextvars import copy_context
from datetime import timedelta
from functools import partial
from numbers import Number
@@ -85,9 +86,11 @@
from distributed.stealing import WorkStealing
from distributed.stories import scheduler_story
from distributed.utils import (
+ STIMULUS_ID,
All,
TimeoutError,
empty_context,
+ expect_stimulus,
get_fileno_limit,
key_split,
key_split_group,
@@ -2373,16 +2376,21 @@ def _transition(self, key, finish: str, *args, **kwargs):
finish2 = ts._state
# FIXME downcast antipattern
scheduler = pep484_cast(Scheduler, self)
+ stimulus_id = STIMULUS_ID.get("")
scheduler.transition_log.append(
- (key, start, finish2, recommendations, time())
+ (key, start, finish2, recommendations, stimulus_id, time())
)
if parent._validate:
+ if stimulus_id == "":
+ raise LookupError(STIMULUS_ID.name)
+
logger.debug(
- "Transitioned %r %s->%s (actual: %s). Consequence: %s",
+ "Transitioned %r %s->%s (actual: %s) from %s. Consequence: %s",
key,
start,
finish2,
ts._state,
+ stimulus_id,
dict(recommendations),
)
if self.plugins:
@@ -2857,7 +2865,7 @@ def transition_processing_memory(
{
"op": "cancel-compute",
"key": key,
- "stimulus_id": f"processing-memory-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
]
@@ -2945,7 +2953,7 @@ def transition_memory_released(self, key, safe: bint = False):
worker_msg = {
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"memory-released-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
for ws in ts._who_has:
worker_msgs[ws._address] = [worker_msg]
@@ -3048,7 +3056,7 @@ def transition_erred_released(self, key):
w_msg = {
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"erred-released-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
for ws_addr in ts._erred_on:
worker_msgs[ws_addr] = [w_msg]
@@ -3127,7 +3135,7 @@ def transition_processing_released(self, key):
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"processing-released-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
]
@@ -3901,37 +3909,37 @@ def __init__(
worker_handlers = {
"task-finished": self.handle_task_finished,
"task-erred": self.handle_task_erred,
- "release-worker-data": self.release_worker_data,
- "add-keys": self.add_keys,
+ "release-worker-data": self.handle_release_worker_data,
+ "add-keys": self.handle_add_keys,
"missing-data": self.handle_missing_data,
"long-running": self.handle_long_running,
- "reschedule": self.reschedule,
+ "reschedule": self.handle_reschedule,
"keep-alive": lambda *args, **kwargs: None,
"log-event": self.log_worker_event,
"worker-status-change": self.handle_worker_status_change,
}
client_handlers = {
- "update-graph": self.update_graph,
- "update-graph-hlg": self.update_graph_hlg,
- "client-desires-keys": self.client_desires_keys,
- "update-data": self.update_data,
+ "update-graph": self.handle_update_graph,
+ "update-graph-hlg": self.handle_update_graph_hlg,
+ "client-desires-keys": self.handle_client_desires_keys,
+ "update-data": self.handle_update_data,
"report-key": self.report_on_key,
- "client-releases-keys": self.client_releases_keys,
+ "client-releases-keys": self.handle_client_releases_keys,
"heartbeat-client": self.client_heartbeat,
- "close-client": self.remove_client,
- "restart": self.restart,
+ "close-client": self.handle_remove_client,
+ "restart": self.handle_restart,
"subscribe-topic": self.subscribe_topic,
"unsubscribe-topic": self.unsubscribe_topic,
}
self.handlers = {
- "register-client": self.add_client,
- "scatter": self.scatter,
- "register-worker": self.add_worker,
+ "register-client": self.handle_add_client,
+ "scatter": self.handle_scatter,
+ "register-worker": self.handle_add_worker,
"register_nanny": self.add_nanny,
- "unregister": self.remove_worker,
- "gather": self.gather,
+ "unregister": self.handle_remove_worker,
+ "gather": self.handle_gather,
"cancel": self.stimulus_cancel,
"retry": self.stimulus_retry,
"feed": self.feed,
@@ -3954,12 +3962,12 @@ def __init__(
"nbytes": self.get_nbytes,
"versions": self.versions,
"add_keys": self.add_keys,
- "rebalance": self.rebalance,
- "replicate": self.replicate,
+ "rebalance": self.handle_rebalance,
+ "replicate": self.handle_replicate,
"run_function": self.run_function,
- "update_data": self.update_data,
+ "update_data": self.handle_update_data,
"set_resources": self.add_resources,
- "retire_workers": self.retire_workers,
+ "retire_workers": self.handle_retire_workers,
"get_metadata": self.get_metadata,
"set_metadata": self.set_metadata,
"set_restrictions": self.set_restrictions,
@@ -4548,6 +4556,7 @@ async def add_worker(
recommendations: dict = {}
client_msgs: dict = {}
worker_msgs: dict = {}
+
if nbytes:
assert isinstance(nbytes, dict)
already_released_keys = []
@@ -4578,7 +4587,7 @@ async def add_worker(
{
"op": "remove-replicas",
"keys": already_released_keys,
- "stimulus_id": f"reconnect-already-released-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
)
@@ -5003,7 +5012,7 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs):
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"already-released-or-forgotten-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
]
elif ts._state == "memory":
@@ -5042,6 +5051,7 @@ def stimulus_task_erred(
**kwargs,
)
+ @expect_stimulus(sync=True)
def stimulus_retry(self, keys, client=None):
parent: SchedulerState = cast(SchedulerState, self)
logger.info("Client %s requests to retry %d keys", client, len(keys))
@@ -5188,6 +5198,7 @@ def remove_worker_from_events():
return "OK"
+ @expect_stimulus(sync=True)
def stimulus_cancel(self, comm, keys=None, client=None, force=False):
"""Stop execution on a list of keys"""
logger.info("Client %s requests to cancel %d keys", client, len(keys))
@@ -5562,6 +5573,7 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1):
def handle_uncaught_error(self, **msg):
logger.exception(clean_exception(**msg)[1])
+ @expect_stimulus(sync=True)
def handle_task_finished(self, key=None, worker=None, **msg):
parent: SchedulerState = cast(SchedulerState, self)
if worker not in parent._workers_dv:
@@ -5578,6 +5590,7 @@ def handle_task_finished(self, key=None, worker=None, **msg):
self.send_all(client_msgs, worker_msgs)
+ @expect_stimulus(sync=True)
def handle_task_erred(self, key=None, **msg):
parent: SchedulerState = cast(SchedulerState, self)
recommendations: dict
@@ -5589,6 +5602,7 @@ def handle_task_erred(self, key=None, **msg):
self.send_all(client_msgs, worker_msgs)
+ @expect_stimulus(sync=True)
def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
"""Signal that `errant_worker` does not hold `key`
@@ -5636,6 +5650,7 @@ def release_worker_data(self, key, worker):
if recommendations:
self.transitions(recommendations)
+ @expect_stimulus(sync=True)
def handle_long_running(self, key=None, worker=None, compute_duration=None):
"""A task has seceded from the thread pool
@@ -5678,6 +5693,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None):
ws._long_running.add(ts)
self.check_idle_saturated(ws)
+ @expect_stimulus(sync=True)
def handle_worker_status_change(self, status: str, worker: str) -> None:
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState = parent._workers_dv.get(worker) # type: ignore
@@ -5872,7 +5888,9 @@ def worker_send(self, worker, msg):
try:
stream_comms[worker].send(msg)
except (CommClosedError, AttributeError):
- self.loop.add_callback(self.remove_worker, address=worker)
+ self.loop.add_callback(
+ copy_context().run, self.remove_worker, address=worker
+ )
def client_send(self, client, msg):
"""Send message to client"""
@@ -5917,7 +5935,9 @@ def send_all(self, client_msgs: dict, worker_msgs: dict):
# worker already gone
pass
except (CommClosedError, AttributeError):
- self.loop.add_callback(self.remove_worker, address=worker)
+ self.loop.add_callback(
+ copy_context().run, self.remove_worker, address=worker
+ )
############################
# Less common interactions #
@@ -7026,7 +7046,11 @@ async def retire_workers(
ws.status = Status.closing_gracefully
self.running.discard(ws)
self.stream_comms[ws.address].send(
- {"op": "worker-status-change", "status": ws.status.name}
+ {
+ "op": "worker-status-change",
+ "status": ws.status.name,
+ "stimulus_id": STIMULUS_ID.get(),
+ }
)
coros.append(
@@ -7071,7 +7095,11 @@ async def _track_retire_worker(
# conditions and we can wait for a scheduler->worker->scheduler
# round-trip.
self.stream_comms[ws.address].send(
- {"op": "worker-status-change", "status": prev_status.name}
+ {
+ "op": "worker-status-change",
+ "status": prev_status.name,
+ "stimulus_id": STIMULUS_ID.get(),
+ }
)
return None, {}
@@ -7092,7 +7120,7 @@ async def _track_retire_worker(
logger.info("Retired worker %s", ws._address)
return ws._address, ws.identity()
- def add_keys(self, worker=None, keys=(), stimulus_id=None):
+ def add_keys(self, worker=None, keys=()):
"""
Learn that a worker has certain keys
@@ -7113,14 +7141,12 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None):
redundant_replicas.append(key)
if redundant_replicas:
- if not stimulus_id:
- stimulus_id = f"redundant-replicas-{time()}"
self.worker_send(
worker,
{
"op": "remove-replicas",
"keys": redundant_replicas,
- "stimulus_id": stimulus_id,
+ "stimulus_id": STIMULUS_ID.get(),
},
)
@@ -8073,7 +8099,9 @@ async def check_worker_ttl(self):
self.worker_ttl,
ws,
)
- await self.remove_worker(address=ws._address)
+ await self.handle_remove_worker(
+ address=ws._address, stimulus_id=f"check-worker-ttl-{time()}"
+ )
def check_idle(self):
parent: SchedulerState = cast(SchedulerState, self)
@@ -8204,6 +8232,27 @@ def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str):
}
)
+ handle_update_graph_hlg = expect_stimulus(sync=True)(update_graph_hlg)
+ handle_update_graph = expect_stimulus(sync=True)(update_graph)
+ handle_add_worker = expect_stimulus(sync=False)(add_worker)
+ handle_remove_worker = expect_stimulus(sync=False)(remove_worker)
+ handle_cancel_key = expect_stimulus(sync=True)(cancel_key)
+ handle_client_desires_keys = expect_stimulus(sync=True)(client_desires_keys)
+ handle_client_releases_keys = expect_stimulus(sync=True)(client_releases_keys)
+ handle_add_client = expect_stimulus(sync=False)(add_client)
+ handle_remove_client = expect_stimulus(sync=True)(remove_client)
+ handle_release_worker_data = expect_stimulus(sync=True)(release_worker_data)
+ handle_scatter = expect_stimulus(sync=False)(scatter)
+ handle_gather = expect_stimulus(sync=False)(gather)
+ handle_restart = expect_stimulus(sync=False)(restart)
+ handle_delete_worker_data = expect_stimulus(sync=False)(delete_worker_data)
+ handle_rebalance = expect_stimulus(sync=False)(rebalance)
+ handle_replicate = expect_stimulus(sync=False)(replicate)
+ handle_retire_workers = expect_stimulus(sync=False)(retire_workers)
+ handle_add_keys = expect_stimulus(sync=True)(add_keys)
+ handle_update_data = expect_stimulus(sync=True)(update_data)
+ handle_reschedule = expect_stimulus(sync=True)(reschedule)
+
@cfunc
@exceptval(check=False)
@@ -8337,7 +8386,7 @@ def _propagate_forgotten(
{
"op": "free-keys",
"keys": [key],
- "stimulus_id": f"propagate-forgotten-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
}
]
state.remove_all_replicas(ts)
@@ -8380,7 +8429,7 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
"key": ts._key,
"priority": ts._priority,
"duration": duration,
- "stimulus_id": f"compute-task-{time()}",
+ "stimulus_id": STIMULUS_ID.get(),
"who_has": {},
}
if ts._resource_restrictions:
diff --git a/distributed/stealing.py b/distributed/stealing.py
index 54ef0098c63..be7d757c5f7 100644
--- a/distributed/stealing.py
+++ b/distributed/stealing.py
@@ -17,7 +17,12 @@
from distributed.comm.addressing import get_address_host
from distributed.core import CommClosedError, Status
from distributed.diagnostics.plugin import SchedulerPlugin
-from distributed.utils import log_errors, recursive_to_dict
+from distributed.utils import (
+ STIMULUS_ID,
+ expect_stimulus,
+ log_errors,
+ recursive_to_dict,
+)
# Stealing requires multiple network bounces and if successful also task
# submission which may include code serialization. Therefore, be very
@@ -79,7 +84,7 @@ def __init__(self, scheduler):
self.in_flight_occupancy = defaultdict(lambda: 0)
self._in_flight_event = asyncio.Event()
- self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm
+ self.scheduler.stream_handlers["steal-response"] = self.handle_move_task_confirm
async def start(self, scheduler=None):
"""Start the background coroutine to balance the tasks on the cluster.
@@ -265,7 +270,7 @@ def move_task_request(self, ts, victim, thief) -> str:
pdb.set_trace()
raise
- async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
+ async def move_task_confirm(self, *, key, state, worker=None):
try:
ts = self.scheduler.tasks[key]
except KeyError:
@@ -273,12 +278,12 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
return
try:
d = self.in_flight.pop(ts)
- if d["stimulus_id"] != stimulus_id:
- self.log(("stale-response", key, state, worker, stimulus_id))
+ if d["stimulus_id"] != STIMULUS_ID.get():
+ self.log(("stale-response", key, state, worker, STIMULUS_ID.get()))
self.in_flight[ts] = d
return
except KeyError:
- self.log(("already-aborted", key, state, worker, stimulus_id))
+ self.log(("already-aborted", key, state, worker, STIMULUS_ID.get()))
return
thief = d["thief"]
@@ -297,7 +302,7 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
assert ts.processing_on == victim
try:
- _log_msg = [key, state, victim.address, thief.address, stimulus_id]
+ _log_msg = [key, state, victim.address, thief.address, STIMULUS_ID.get()]
if ts.state != "processing":
self.scheduler._reevaluate_occupancy_worker(thief)
@@ -348,6 +353,8 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
self.scheduler.check_idle_saturated(thief)
self.scheduler.check_idle_saturated(victim)
+ handle_move_task_confirm = expect_stimulus(sync=False)(move_task_confirm)
+
def balance(self):
s = self.scheduler
diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py
index 1a0df078376..8ca0dbded04 100644
--- a/distributed/tests/test_cancelled_state.py
+++ b/distributed/tests/test_cancelled_state.py
@@ -4,9 +4,10 @@
import distributed
from distributed import Event
from distributed.core import CommClosedError
+from distributed.utils import STIMULUS_ID
from distributed.utils_test import (
_LockedCommPool,
- assert_worker_story,
+ assert_story,
gen_cluster,
inc,
slowinc,
@@ -82,7 +83,7 @@ def f(ev):
while "f1" in a.tasks:
await asyncio.sleep(0.01)
- assert_worker_story(
+ assert_story(
a.story("f1"),
[
("f1", "compute-task"),
@@ -160,7 +161,12 @@ async def wait_and_raise(*args, **kwargs):
await wait_for_state(fut1.key, "flight", b)
# Close in scheduler to ensure we transition and reschedule task properly
- await s.close_worker(worker=a.address)
+ try:
+ token = STIMULUS_ID.set("test")
+ await s.close_worker(worker=a.address)
+ finally:
+ STIMULUS_ID.reset(token)
+
await wait_for_state(fut1.key, "resumed", b)
lock.release()
diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py
index 5044c487bba..9671c7b492b 100644
--- a/distributed/tests/test_client.py
+++ b/distributed/tests/test_client.py
@@ -2060,15 +2060,15 @@ async def test_forget_simple(c, s, a, b):
assert set(s.tasks) == {x.key, y.key, z.key}
- s.client_releases_keys(keys=[x.key], client=c.id)
+ s.handle_client_releases_keys(keys=[x.key], client=c.id, stimulus_id="test")
assert x.key in s.tasks
- s.client_releases_keys(keys=[z.key], client=c.id)
+ s.handle_client_releases_keys(keys=[z.key], client=c.id, stimulus_id="test")
assert x.key not in s.tasks
assert z.key not in s.tasks
assert not s.tasks[y.key].dependents
- s.client_releases_keys(keys=[y.key], client=c.id)
+ s.handle_client_releases_keys(keys=[y.key], client=c.id, stimulus_id="test")
assert not s.tasks
@@ -2084,20 +2084,20 @@ async def test_forget_complex(e, s, A, B):
assert set(s.tasks) == {f.key for f in [ab, ac, cd, acab, a, b, c, d]}
- s.client_releases_keys(keys=[ab.key], client=e.id)
+ s.handle_client_releases_keys(keys=[ab.key], client=e.id, stimulus_id="test")
assert set(s.tasks) == {f.key for f in [ab, ac, cd, acab, a, b, c, d]}
- s.client_releases_keys(keys=[b.key], client=e.id)
+ s.handle_client_releases_keys(keys=[b.key], client=e.id, stimulus_id="test")
assert set(s.tasks) == {f.key for f in [ac, cd, acab, a, c, d]}
- s.client_releases_keys(keys=[acab.key], client=e.id)
+ s.handle_client_releases_keys(keys=[acab.key], client=e.id, stimulus_id="test")
assert set(s.tasks) == {f.key for f in [ac, cd, a, c, d]}
assert b.key not in s.tasks
while b.key in A.data or b.key in B.data:
await asyncio.sleep(0.01)
- s.client_releases_keys(keys=[ac.key], client=e.id)
+ s.handle_client_releases_keys(keys=[ac.key], client=e.id, stimulus_id="test")
assert set(s.tasks) == {f.key for f in [cd, a, c, d]}
@@ -2117,7 +2117,7 @@ async def test_forget_in_flight(e, s, A, B):
await asyncio.sleep(0.01)
s.validate_state()
- s.client_releases_keys(keys=[y.key], client=e.id)
+ s.handle_client_releases_keys(keys=[y.key], client=e.id, stimulus_id="test")
s.validate_state()
for k in [acab.key, ab.key, b.key]:
@@ -2136,21 +2136,21 @@ async def test_forget_errors(c, s, a, b):
assert y.key in s.exceptions_blame
assert z.key in s.exceptions_blame
- s.client_releases_keys(keys=[z.key], client=c.id)
+ s.handle_client_releases_keys(keys=[z.key], client=c.id, stimulus_id="test")
assert x.key in s.exceptions
assert x.key in s.exceptions_blame
assert y.key in s.exceptions_blame
assert z.key not in s.exceptions_blame
- s.client_releases_keys(keys=[x.key], client=c.id)
+ s.handle_client_releases_keys(keys=[x.key], client=c.id, stimulus_id="test")
assert x.key in s.exceptions
assert x.key in s.exceptions_blame
assert y.key in s.exceptions_blame
assert z.key not in s.exceptions_blame
- s.client_releases_keys(keys=[y.key], client=c.id)
+ s.handle_client_releases_keys(keys=[y.key], client=c.id, stimulus_id="test")
assert x.key not in s.exceptions
assert x.key not in s.exceptions_blame
@@ -4355,7 +4355,7 @@ async def test_scatter_type(c, s, a, b):
async def test_retire_workers_2(c, s, a, b):
[x] = await c.scatter([1], workers=a.address)
- await s.retire_workers(workers=[a.address])
+ await s.handle_retire_workers(workers=[a.address], stimulus_id="test")
assert b.data == {x.key: 1}
assert s.who_has == {x.key: {b.address}}
assert s.has_what == {b.address: {x.key}}
@@ -4367,7 +4367,9 @@ async def test_retire_workers_2(c, s, a, b):
async def test_retire_many_workers(c, s, *workers):
futures = await c.scatter(list(range(100)))
- await s.retire_workers(workers=[w.address for w in workers[:7]])
+ await s.handle_retire_workers(
+ workers=[w.address for w in workers[:7]], stimulus_id="test"
+ )
results = await c.gather(futures)
assert results == list(range(100))
@@ -4565,7 +4567,7 @@ def test_auto_normalize_collection_sync(c):
def assert_no_data_loss(scheduler):
- for key, start, finish, recommendations, _ in scheduler.transition_log:
+ for key, start, finish, recommendations, _, _ in scheduler.transition_log:
if start == "memory" and finish == "released":
for k, v in recommendations.items():
assert not (k == key and v == "waiting")
diff --git a/distributed/tests/test_cluster_dump.py b/distributed/tests/test_cluster_dump.py
index b01cf2611ca..1762929d378 100644
--- a/distributed/tests/test_cluster_dump.py
+++ b/distributed/tests/test_cluster_dump.py
@@ -8,7 +8,7 @@
import distributed
from distributed.cluster_dump import DumpArtefact, _tuple_to_list, write_state
-from distributed.utils_test import assert_worker_story, gen_cluster, gen_test, inc
+from distributed.utils_test import assert_story, gen_cluster, gen_test, inc
@pytest.mark.parametrize(
@@ -140,7 +140,7 @@ async def test_cluster_dump_story(c, s, a, b, tmp_path):
assert story.keys() == {"f1", "f2"}
for k, task_story in story.items():
- assert_worker_story(
+ assert_story(
task_story,
[
(k, "compute-task"),
diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py
index 2750868dc01..c4fc650472d 100644
--- a/distributed/tests/test_scheduler.py
+++ b/distributed/tests/test_scheduler.py
@@ -37,6 +37,7 @@
from distributed.utils import TimeoutError
from distributed.utils_test import (
BrokenComm,
+ assert_story,
captured_logger,
cluster,
dec,
@@ -292,22 +293,23 @@ async def test_no_workers(client, s):
@gen_cluster(nthreads=[])
async def test_retire_workers_empty(s):
- await s.retire_workers(workers=[])
+ await s.handle_retire_workers(workers=[], stimulus_id="test")
@gen_cluster()
async def test_remove_client(s, a, b):
- s.update_graph(
+ s.handle_update_graph(
tasks={"x": dumps_task((inc, 1)), "y": dumps_task((inc, "x"))},
dependencies={"x": [], "y": ["x"]},
keys=["y"],
client="ident",
+ stimulus_id="test",
)
assert s.tasks
assert s.dependencies
- s.remove_client(client="ident")
+ s.handle_remove_client(client="ident", stimulus_id="test")
assert not s.tasks
assert not s.dependencies
@@ -324,14 +326,15 @@ async def test_server_listens_to_other_ops(s, a, b):
@gen_cluster()
async def test_remove_worker_from_scheduler(s, a, b):
dsk = {("x-%d" % i): (inc, i) for i in range(20)}
- s.update_graph(
+ s.handle_update_graph(
tasks=valmap(dumps_task, dsk),
keys=list(dsk),
dependencies={k: set() for k in dsk},
+ stimulus_id="test",
)
assert a.address in s.stream_comms
- await s.remove_worker(address=a.address)
+ await s.handle_remove_worker(address=a.address, stimulus_id="test")
assert a.address not in s.nthreads
assert len(s.workers[b.address].processing) == len(dsk) # b owns everything
@@ -390,11 +393,12 @@ async def test_add_worker(s, a, b):
w.data["y"] = 1
dsk = {("x-%d" % i): (inc, i) for i in range(10)}
- s.update_graph(
+ s.handle_update_graph(
tasks=valmap(dumps_task, dsk),
keys=list(dsk),
client="client",
dependencies={k: set() for k in dsk},
+ stimulus_id="test",
)
s.validate_state()
await w
@@ -539,8 +543,12 @@ async def test_delete(c, s, a):
async def test_filtered_communication(s, a, b):
c = await connect(s.address)
f = await connect(s.address)
- await c.write({"op": "register-client", "client": "c", "versions": {}})
- await f.write({"op": "register-client", "client": "f", "versions": {}})
+ await c.write(
+ {"op": "register-client", "client": "c", "versions": {}, "stimulus_id": "test"}
+ )
+ await f.write(
+ {"op": "register-client", "client": "f", "versions": {}, "stimulus_id": "test"}
+ )
await c.read()
await f.read()
@@ -553,6 +561,7 @@ async def test_filtered_communication(s, a, b):
"dependencies": {"x": [], "y": ["x"]},
"client": "c",
"keys": ["y"],
+ "stimulus_id": "test",
}
)
@@ -566,6 +575,7 @@ async def test_filtered_communication(s, a, b):
"dependencies": {"x": [], "z": ["x"]},
"client": "f",
"keys": ["z"],
+ "stimulus_id": "test",
}
)
(msg,) = await c.read()
@@ -605,16 +615,17 @@ def test_dumps_task():
@gen_cluster()
async def test_ready_remove_worker(s, a, b):
- s.update_graph(
+ s.handle_update_graph(
tasks={"x-%d" % i: dumps_task((inc, i)) for i in range(20)},
keys=["x-%d" % i for i in range(20)],
client="client",
dependencies={"x-%d" % i: [] for i in range(20)},
+ stimulus_id="test",
)
assert all(len(w.processing) > w.nthreads for w in s.workers.values())
- await s.remove_worker(address=a.address)
+ await s.handle_remove_worker(address=a.address, stimulus_id="test")
assert set(s.workers) == {b.address}
assert all(len(w.processing) > w.nthreads for w in s.workers.values())
@@ -625,7 +636,7 @@ async def test_restart(c, s, a, b):
futures = c.map(inc, range(20))
await wait(futures)
- await s.restart()
+ await s.handle_restart(stimulus_id="test")
assert len(s.workers) == 2
@@ -779,7 +790,7 @@ async def test_file_descriptors_dont_leak(s):
@gen_cluster()
async def test_update_graph_culls(s, a, b):
- s.update_graph(
+ s.handle_update_graph(
tasks={
"x": dumps_task((inc, 1)),
"y": dumps_task((inc, "x")),
@@ -788,6 +799,7 @@ async def test_update_graph_culls(s, a, b):
keys=["y"],
dependencies={"y": "x", "x": [], "z": []},
client="client",
+ stimulus_id="test",
)
assert "z" not in s.tasks
assert "z" not in s.dependencies
@@ -810,7 +822,7 @@ async def test_story(c, s, a, b):
story = s.story(x.key)
assert all(line in s.transition_log for line in story)
assert len(story) < len(s.transition_log)
- assert all(x.key == line[0] or x.key in line[-2] for line in story)
+ assert all(x.key == line[0] or x.key in line[3] for line in story)
assert len(s.story(x.key, y.key)) > len(story)
@@ -821,7 +833,9 @@ async def test_story(c, s, a, b):
@gen_cluster(client=True, nthreads=[])
async def test_scatter_no_workers(c, s, direct):
with pytest.raises(TimeoutError):
- await s.scatter(data={"x": 1}, client="alice", timeout=0.1)
+ await s.handle_scatter(
+ data={"x": 1}, client="alice", timeout=0.1, stimulus_id="test"
+ )
start = time()
with pytest.raises(TimeoutError):
@@ -856,7 +870,7 @@ async def test_retire_workers(c, s, a, b):
assert s.workers_to_close() == [a.address]
- workers = await s.retire_workers()
+ workers = await s.handle_retire_workers(stimulus_id="test")
assert list(workers) == [a.address]
assert workers[a.address]["nthreads"] == a.nthreads
assert list(s.nthreads) == [b.address]
@@ -865,22 +879,22 @@ async def test_retire_workers(c, s, a, b):
assert s.workers[b.address].has_what == {s.tasks[x.key], s.tasks[y.key]}
- workers = await s.retire_workers()
+ workers = await s.handle_retire_workers(stimulus_id="test")
assert not workers
@gen_cluster(client=True)
async def test_retire_workers_n(c, s, a, b):
- await s.retire_workers(n=1, close_workers=True)
+ await s.handle_retire_workers(n=1, close_workers=True, stimulus_id="test")
assert len(s.workers) == 1
- await s.retire_workers(n=0, close_workers=True)
+ await s.handle_retire_workers(n=0, close_workers=True, stimulus_id="test")
assert len(s.workers) == 1
- await s.retire_workers(n=1, close_workers=True)
+ await s.handle_retire_workers(n=1, close_workers=True, stimulus_id="test")
assert len(s.workers) == 0
- await s.retire_workers(n=0, close_workers=True)
+ await s.handle_retire_workers(n=0, close_workers=True, stimulus_id="test")
assert len(s.workers) == 0
while not (
@@ -944,7 +958,7 @@ async def test_retire_workers_no_suspicious_tasks(c, s, a, b):
slowinc, 100, delay=0.5, workers=a.address, allow_other_workers=True
)
await asyncio.sleep(0.2)
- await s.retire_workers(workers=[a.address])
+ await s.handle_retire_workers(workers=[a.address], stimulus_id="test")
assert all(ts.suspicious == 0 for ts in s.tasks.values())
assert all(tp.suspicious == 0 for tp in s.task_prefixes.values())
@@ -1277,7 +1291,7 @@ async def test_close_nanny(c, s, a, b):
@gen_cluster(client=True)
async def test_retire_workers_close(c, s, a, b):
- await s.retire_workers(close_workers=True)
+ await s.handle_retire_workers(close_workers=True, stimulus_id="test")
assert not s.workers
while a.status != Status.closed and b.status != Status.closed:
await asyncio.sleep(0.01)
@@ -1286,7 +1300,7 @@ async def test_retire_workers_close(c, s, a, b):
@gen_cluster(client=True, Worker=Nanny)
async def test_retire_nannies_close(c, s, a, b):
nannies = [a, b]
- await s.retire_workers(close_workers=True, remove=True)
+ await s.handle_retire_workers(close_workers=True, remove=True, stimulus_id="test")
assert not s.workers
start = time()
@@ -1327,12 +1341,13 @@ async def test_scheduler_file():
@gen_cluster(client=True, nthreads=[])
async def test_non_existent_worker(c, s):
with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}):
- await s.add_worker(
+ await s.handle_add_worker(
address="127.0.0.1:5738",
status="running",
nthreads=2,
nbytes={},
host_info={},
+ stimulus_id="test",
)
futures = c.map(inc, range(10))
await asyncio.sleep(0.300)
@@ -1481,7 +1496,7 @@ async def test_reschedule(c, s, a, b):
await asyncio.sleep(0.001)
for future in x:
- s.reschedule(key=future.key)
+ s.handle_reschedule(key=future.key, stimulus_id="test")
# Worker b gets more of the original tasks
await wait(x)
@@ -1492,7 +1507,7 @@ async def test_reschedule(c, s, a, b):
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_reschedule_warns(c, s, a, b):
with captured_logger(logging.getLogger("distributed.scheduler")) as sched:
- s.reschedule(key="__this-key-does-not-exist__")
+ s.handle_reschedule(key="__this-key-does-not-exist__", stimulus_id="test")
assert "not found on the scheduler" in sched.getvalue()
assert "Aborting reschedule" in sched.getvalue()
@@ -1794,7 +1809,7 @@ async def f(dask_worker):
assert s.bandwidth_workers
- await s.restart()
+ await s.handle_restart(stimulus_id="test")
assert not s.bandwidth_workers
@@ -1914,7 +1929,7 @@ async def test_retire_names_str(c, s):
futures = c.map(inc, range(10))
await wait(futures)
assert a.data and b.data
- await s.retire_workers(names=[0])
+ await s.handle_retire_workers(names=[0], stimulus_id="test")
assert all(f.done() for f in futures)
assert len(b.data) == 10
@@ -2150,7 +2165,7 @@ async def test_gather_failing_cnn_error(c, s, a, b):
x = await c.scatter({"x": 1}, workers=a.address)
s.rpc = await FlakyConnectionPool(failing_connections=10)
- res = await s.gather(keys=["x"])
+ res = await s.handle_gather(keys=["x"], stimulus_id="test")
assert res["status"] == "error"
assert list(res["keys"]) == ["x"]
@@ -3117,7 +3132,7 @@ async def test_delete_worker_data(c, s, a, b):
assert b.data == {y.key: "y"}
assert s.tasks.keys() == {x.key, y.key, z.key}
- await s.delete_worker_data(a.address, [x.key, y.key])
+ await s.handle_delete_worker_data(a.address, [x.key, y.key], stimulus_id="test")
assert a.data == {z.key: "z"}
assert b.data == {y.key: "y"}
assert s.tasks.keys() == {y.key, z.key}
@@ -3131,8 +3146,8 @@ async def test_delete_worker_data_double_delete(c, s, a):
"""
x, y = await c.scatter(["x", "y"])
await asyncio.gather(
- s.delete_worker_data(a.address, [x.key]),
- s.delete_worker_data(a.address, [x.key]),
+ s.handle_delete_worker_data(a.address, [x.key], stimulus_id="test"),
+ s.handle_delete_worker_data(a.address, [x.key], stimulus_id="test"),
)
assert a.data == {y.key: "y"}
a_ws = s.workers[a.address]
@@ -3147,7 +3162,7 @@ async def test_delete_worker_data_bad_worker(s, a, b):
"""
await a.close()
assert s.workers.keys() == {b.address}
- await s.delete_worker_data(a.address, ["x"])
+ await s.handle_delete_worker_data(a.address, ["x"], stimulus_id="test")
@pytest.mark.parametrize("bad_first", [False, True])
@@ -3162,7 +3177,7 @@ async def test_delete_worker_data_bad_task(c, s, a, bad_first):
assert s.tasks.keys() == {x.key, y.key}
keys = ["notexist", x.key] if bad_first else [x.key, "notexist"]
- await s.delete_worker_data(a.address, keys)
+ await s.handle_delete_worker_data(a.address, keys, stimulus_id="test")
assert a.data == {y.key: "y"}
assert s.tasks.keys() == {y.key}
assert s.workers[a.address].nbytes == s.tasks[y.key].nbytes
@@ -3247,13 +3262,13 @@ async def test_worker_reconnect_task_memory(c, s, a):
while not a.executing_count and not a.data:
await asyncio.sleep(0.001)
- await s.remove_worker(address=a.address, close=False)
+ await s.handle_remove_worker(address=a.address, close=False, stimulus_id="test")
while not res.done():
await a.heartbeat()
await res
assert ("no-worker", "memory") in {
- (start, finish) for (_, start, finish, _, _) in s.transition_log
+ (start, finish) for (_, start, finish, _, _, _) in s.transition_log
}
@@ -3271,13 +3286,13 @@ async def test_worker_reconnect_task_memory_with_resources(c, s, a):
while not b.executing_count and not b.data:
await asyncio.sleep(0.001)
- await s.remove_worker(address=b.address, close=False)
+ await s.handle_remove_worker(address=b.address, close=False, stimulus_id="test")
while not res.done():
await b.heartbeat()
await res
assert ("no-worker", "memory") in {
- (start, finish) for (_, start, finish, _, _) in s.transition_log
+ (start, finish) for (_, start, finish, _, _, _) in s.transition_log
}
@@ -3538,3 +3553,40 @@ async def test_repr(s, a):
repr(ws_b)
== f""
)
+
+
+@gen_cluster(client=True)
+async def test_stimuli(c, s, a, b):
+ f = c.submit(inc, 1)
+ key = f.key
+
+ await f
+ await c.close()
+
+ assert_story(
+ s.story(key),
+ [
+ (key, "released", "waiting", {key: "processing"}),
+ (key, "waiting", "processing", {}),
+ (key, "processing", "memory", {}),
+ (
+ key,
+ "memory",
+ "forgotten",
+ {},
+ ),
+ ],
+ )
+
+ stimuli = [
+ "client-update-graph-hlg",
+ "client-update-graph-hlg",
+ "task-finished",
+ "client-close",
+ ]
+
+ stories = s.story(key)
+ assert len(stories) == len(stimuli)
+
+ for stimulus_id, story in zip(stimuli, stories):
+ assert story[-2].startswith(stimulus_id), (story[-2], stimulus_id)
diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py
index 55f1bbf043a..74e012341db 100644
--- a/distributed/tests/test_steal.py
+++ b/distributed/tests/test_steal.py
@@ -1107,7 +1107,7 @@ async def test_steal_reschedule_reset_in_flight_occupancy(c, s, *workers):
steal.move_task_request(victim_ts, wsA, wsB)
- s.reschedule(victim_key)
+ s.handle_reschedule(victim_key, stimulus_id="test")
await c.gather(futs1)
del futs1
@@ -1173,7 +1173,7 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers):
steal.move_task_request(victim_ts, wsA, wsB)
s.set_restrictions(worker={victim_key: [wsB.address]})
- s.reschedule(victim_key)
+ s.handle_reschedule(victim_key, stimulus_id="test")
assert wsB == victim_ts.processing_on
# move_task_request is not responsible for respecting worker restrictions
steal.move_task_request(victim_ts, wsB, wsC)
diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py
index c920366c4c2..6b4a583b9ac 100644
--- a/distributed/tests/test_utils.py
+++ b/distributed/tests/test_utils.py
@@ -19,6 +19,7 @@
from distributed.metrics import time
from distributed.utils import (
LRU,
+ STIMULUS_ID,
All,
Log,
Logs,
@@ -27,6 +28,7 @@
_maybe_complex,
ensure_bytes,
ensure_ip,
+ expect_stimulus,
format_dashboard_link,
get_ip_interface,
get_traceback,
@@ -781,3 +783,16 @@ def __repr__(self):
],
}
assert recursive_to_dict(info) == expect
+
+
+def test_expect_stimulus():
+ @expect_stimulus(sync=True)
+ def fn(x):
+ return STIMULUS_ID.get()
+
+ assert fn(1, stimulus_id="test") == "test"
+
+ with pytest.raises(ValueError):
+ assert fn(1) == 1
+
+ assert fn(**{"x": 1, "stimulus_id": "test"}) == "test"
diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py
index dc481aae17f..c0d377acc01 100755
--- a/distributed/tests/test_utils_test.py
+++ b/distributed/tests/test_utils_test.py
@@ -21,7 +21,7 @@
from distributed.utils_test import (
_LockedCommPool,
_UnhashableCallable,
- assert_worker_story,
+ assert_story,
check_process_leak,
cluster,
dump_cluster_state,
@@ -406,7 +406,7 @@ async def inner_test(c, s, a, b):
assert "workers" in state
-def test_assert_worker_story():
+def test_assert_story():
now = time()
story = [
("foo", "id1", now - 600),
@@ -414,38 +414,38 @@ def test_assert_worker_story():
("baz", {1: 2}, "id2", now),
]
# strict=False
- assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})])
- assert_worker_story(story, [])
- assert_worker_story(story, [("foo",)])
- assert_worker_story(story, [("foo",), ("bar",)])
- assert_worker_story(story, [("baz", lambda d: d[1] == 2)])
+ assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})])
+ assert_story(story, [])
+ assert_story(story, [("foo",)])
+ assert_story(story, [("foo",), ("bar",)])
+ assert_story(story, [("baz", lambda d: d[1] == 2)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo", "nomatch")])
+ assert_story(story, [("foo", "nomatch")])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz",)])
+ assert_story(story, [("baz",)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz", {1: 3})])
+ assert_story(story, [("baz", {1: 3})])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)])
+ assert_story(story, [("foo",), ("bar",), ("baz", "extra"), ("+1",)])
with pytest.raises(AssertionError):
- assert_worker_story(story, [("baz", lambda d: d[1] == 3)])
+ assert_story(story, [("baz", lambda d: d[1] == 3)])
with pytest.raises(KeyError): # Faulty lambda
- assert_worker_story(story, [("baz", lambda d: d[2] == 1)])
- assert_worker_story([], [])
- assert_worker_story([("foo", "id1", now)], [("foo",)])
+ assert_story(story, [("baz", lambda d: d[2] == 1)])
+ assert_story([], [])
+ assert_story([("foo", "id1", now)], [("foo",)])
with pytest.raises(AssertionError):
- assert_worker_story([], [("foo",)])
+ assert_story([], [("foo",)])
# strict=True
- assert_worker_story([], [], strict=True)
- assert_worker_story([("foo", "id1", now)], [("foo",)])
- assert_worker_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True)
+ assert_story([], [], strict=True)
+ assert_story([("foo", "id1", now)], [("foo",)])
+ assert_story(story, [("foo",), ("bar",), ("baz", {1: 2})], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("bar",)], strict=True)
+ assert_story(story, [("foo",), ("bar",)], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [("foo",), ("baz", {1: 2})], strict=True)
+ assert_story(story, [("foo",), ("baz", {1: 2})], strict=True)
with pytest.raises(AssertionError):
- assert_worker_story(story, [], strict=True)
+ assert_story(story, [], strict=True)
@pytest.mark.parametrize(
@@ -466,11 +466,11 @@ def test_assert_worker_story():
),
],
)
-def test_assert_worker_story_malformed_story(story_factory):
+def test_assert_story_malformed_story(story_factory):
# defer the calls to time() to when the test runs rather than collection
story = story_factory()
with pytest.raises(AssertionError, match="Malformed story event"):
- assert_worker_story(story, [])
+ assert_story(story, [])
@gen_cluster()
diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py
index aa8549b07ba..14521c72c6f 100644
--- a/distributed/tests/test_worker.py
+++ b/distributed/tests/test_worker.py
@@ -4,6 +4,7 @@
import importlib
import logging
import os
+import re
import sys
import threading
import traceback
@@ -41,11 +42,11 @@
from distributed.metrics import time
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
-from distributed.utils import TimeoutError
+from distributed.utils import STIMULUS_ID, TimeoutError
from distributed.utils_test import (
TaskStateMetadataPlugin,
_LockedCommPool,
- assert_worker_story,
+ assert_story,
captured_logger,
dec,
div,
@@ -1403,7 +1404,7 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None:
"""Test that an in-memory key was transferred from worker w_from to worker w_to by
the Active Memory Manager and it was not recalculated on w_to
"""
- assert_worker_story(
+ assert_story(
w_to.story(key),
[
(key, "ensure-task-exists", "released"),
@@ -1687,7 +1688,12 @@ async def test_story_with_deps(c, s, a, b):
# Story now includes randomized stimulus_ids and timestamps.
stimulus_ids = {ev[-2] for ev in story}
- assert len(stimulus_ids) == 3, stimulus_ids
+ assert {sid[: re.search(r"\d", sid).start()] for sid in stimulus_ids} == {
+ "ensure-computing-",
+ "ensure-communicating-",
+ "task-finished-",
+ }
+
# This is a simple transition log
expected = [
("res", "compute-task"),
@@ -1697,7 +1703,7 @@ async def test_story_with_deps(c, s, a, b):
("res", "put-in-memory"),
("res", "executing", "memory", "memory", {}),
]
- assert_worker_story(story, expected, strict=True)
+ assert_story(story, expected, strict=True)
story = b.story("dep")
stimulus_ids = {ev[-2] for ev in story}
@@ -1712,7 +1718,7 @@ async def test_story_with_deps(c, s, a, b):
("dep", "put-in-memory"),
("dep", "flight", "memory", "memory", {"res": "ready"}),
]
- assert_worker_story(story, expected, strict=True)
+ assert_story(story, expected, strict=True)
@gen_cluster(client=True)
@@ -2541,7 +2547,12 @@ def __call__(self, *args, **kwargs):
ts = s.tasks[fut.key]
a.handle_steal_request(fut.key, stimulus_id="test")
- stealing_ext.scheduler.send_task_to_worker(b.address, ts)
+
+ try:
+ token = STIMULUS_ID.set("test")
+ stealing_ext.scheduler.send_task_to_worker(b.address, ts)
+ finally:
+ STIMULUS_ID.reset(token)
fut2 = c.submit(inc, fut, workers=[a.address])
fut3 = c.submit(inc, fut2, workers=[a.address])
@@ -2621,7 +2632,7 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b):
while fut1.key not in b.tasks or b.tasks[fut1.key].state == "flight":
await asyncio.sleep(0)
- s.handle_missing_data(key="f1", errant_worker=a.address)
+ s.handle_missing_data(key="f1", errant_worker=a.address, stimulus_id="test")
await fut2
@@ -2666,7 +2677,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b):
# same communication channel
for fut in (futA, futB):
- assert_worker_story(
+ assert_story(
b.story(fut.key),
[
("gather-dependencies", a.address, {fut.key}),
@@ -2725,7 +2736,7 @@ def __getstate__(self):
assert await y == 123
story = await c.run(lambda dask_worker: dask_worker.story("x"))
- assert_worker_story(
+ assert_story(
story[b],
[
("x", "ensure-task-exists", "released"),
@@ -2902,7 +2913,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers):
await f2
- assert_worker_story(a.story(f1.key), [(f1.key, "missing-dep")])
+ assert_story(a.story(f1.key), [(f1.key, "missing-dep")])
assert a.tasks[f1.key].suspicious_count == 0
assert s.tasks[f1.key].suspicious == 0
@@ -2982,7 +2993,7 @@ async def test_missing_released_zombie_tasks_2(c, s, a, b):
while b.tasks:
await asyncio.sleep(0.01)
- assert_worker_story(
+ assert_story(
b.story(ts),
[("f1", "missing", "released", "released", {"f1": "forgotten"})],
)
@@ -2999,7 +3010,7 @@ async def test_worker_status_sync(s, a):
while ws.status != Status.running:
await asyncio.sleep(0.01)
- await s.retire_workers()
+ await s.handle_retire_workers(stimulus_id="test")
while ws.status != Status.closed:
await asyncio.sleep(0.01)
@@ -3094,7 +3105,7 @@ async def test_task_flight_compute_oserror(c, s, a, b):
("f1", "put-in-memory"),
("f1", "executing", "memory", "memory", {}),
]
- assert_worker_story(sum_story, expected_sum_story, strict=True)
+ assert_story(sum_story, expected_sum_story, strict=True)
@gen_cluster(client=True)
@@ -3274,7 +3285,9 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
while not mocked_gather.call_args:
await asyncio.sleep(0)
- await s.remove_worker(address=x.address, safe=True, close=close_worker)
+ await s.handle_remove_worker(
+ address=x.address, safe=True, close=close_worker, stimulus_id="test"
+ )
await _wait_for_state(fut2_key, b, intermediate_state)
diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py
index 78597a37e67..f03966ba656 100644
--- a/distributed/tests/test_worker_state_machine.py
+++ b/distributed/tests/test_worker_state_machine.py
@@ -107,5 +107,9 @@ def test_slots(cls):
def test_sendmsg_to_dict():
# Arbitrary sample class
- smsg = ReleaseWorkerDataMsg(key="x")
- assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"}
+ smsg = ReleaseWorkerDataMsg(key="x", stimulus_id="test")
+ assert smsg.to_dict() == {
+ "op": "release-worker-data",
+ "key": "x",
+ "stimulus_id": "test",
+ }
diff --git a/distributed/utils.py b/distributed/utils.py
index cd0d93d08dd..0eeb63eddf4 100644
--- a/distributed/utils.py
+++ b/distributed/utils.py
@@ -1437,6 +1437,45 @@ def __getattr__(name):
raise AttributeError(f"module {__name__} has no attribute {name}")
+STIMULUS_ID: ContextVar[str] = ContextVar("stimulus_id")
+
+
+def expect_stimulus(sync: bool = True):
+ def decorator(fn):
+ if sync:
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ token = STIMULUS_ID.set(kwargs.pop("stimulus_id"))
+ except KeyError:
+ raise ValueError(f"{fn} missing stimulus_id")
+
+ try:
+ return fn(*args, **kwargs)
+ finally:
+ STIMULUS_ID.reset(token)
+
+ else:
+
+ @functools.wraps(fn)
+ # async def wrapper(*args, stimulus_id: str, **kwargs):
+ async def wrapper(*args, **kwargs):
+ try:
+ token = STIMULUS_ID.set(kwargs.pop("stimulus_id"))
+ except KeyError:
+ raise ValueError(f"{fn} missing stimulus_id")
+
+ try:
+ return await fn(*args, **kwargs)
+ finally:
+ STIMULUS_ID.reset(token)
+
+ return wrapper
+
+ return decorator
+
+
# Used internally by recursive_to_dict to stop infinite recursion. If an object has
# already been encountered, a string representation will be returned instead. This is
# necessary since we have multiple cyclic referencing data structures.
diff --git a/distributed/utils_test.py b/distributed/utils_test.py
index 31e6cbe08fe..53d4f1403ac 100644
--- a/distributed/utils_test.py
+++ b/distributed/utils_test.py
@@ -1875,7 +1875,7 @@ def xfail_ssl_issue5601():
raise
-def assert_worker_story(
+def assert_story(
story: list[tuple], expect: list[tuple], *, strict: bool = False
) -> None:
"""Test the output of ``Worker.story``
@@ -1941,7 +1941,7 @@ def assert_worker_story(
break
except StopIteration:
raise AssertionError(
- f"assert_worker_story({strict=}) failed\n"
+ f"assert_story({strict=}) failed\n"
f"story:\n{_format_story(story)}\n"
f"expect:\n{_format_story(expect)}"
) from None
diff --git a/distributed/worker.py b/distributed/worker.py
index 7a0628763dd..38e1e341577 100644
--- a/distributed/worker.py
+++ b/distributed/worker.py
@@ -919,22 +919,26 @@ def status(self, value):
"""
prev_status = self.status
ServerNode.status.__set__(self, value)
- self._send_worker_status_change()
+ self._send_worker_status_change(f"worker-status-change-{time()}")
if prev_status == Status.paused and value == Status.running:
self.ensure_computing()
self.ensure_communicating()
- def _send_worker_status_change(self) -> None:
+ def _send_worker_status_change(self, stimulus_id: str) -> None:
if (
self.batched_stream
and self.batched_stream.comm
and not self.batched_stream.comm.closed()
):
self.batched_stream.send(
- {"op": "worker-status-change", "status": self._status.name}
+ {
+ "op": "worker-status-change",
+ "status": self._status.name,
+ "stimulus_id": stimulus_id,
+ },
)
elif self._status != Status.closed:
- self.loop.call_later(0.05, self._send_worker_status_change)
+ self.loop.call_later(0.05, self._send_worker_status_change, stimulus_id)
async def get_metrics(self) -> dict:
try:
@@ -1075,6 +1079,7 @@ async def _register_with_scheduler(self):
versions=get_versions(),
metrics=await self.get_metrics(),
extra=await self.get_startup_information(),
+ stimulus_id=f"worker-connect-{time()}",
),
serializers=["msgpack"],
)
@@ -1457,7 +1462,9 @@ async def close(
if report and self.contact_address is not None:
await asyncio.wait_for(
self.scheduler.unregister(
- address=self.contact_address, safe=safe
+ address=self.contact_address,
+ safe=safe,
+ stimulus_id=f"worker-close-{time()}",
),
timeout,
)
@@ -1522,7 +1529,10 @@ async def close_gracefully(self, restart=None):
# Scheduler.retire_workers will set the status to closing_gracefully and push it
# back to this worker.
await self.scheduler.retire_workers(
- workers=[self.address], close_workers=False, remove=False
+ workers=[self.address],
+ close_workers=False,
+ remove=False,
+ stimulus_id=f"worker-close-gracefully-{time()}",
)
await self.close(safe=True, nanny=not restart)
@@ -1877,7 +1887,9 @@ def handle_compute_task(
pass
elif ts.state == "memory":
recommendations[ts] = "memory"
- instructions.append(self._get_task_finished_msg(ts))
+ instructions.append(
+ self._get_task_finished_msg(ts, stimulus_id=stimulus_id)
+ )
elif ts.state in {
"released",
"fetch",
@@ -2005,7 +2017,7 @@ def transition_memory_released(
recs, instructions = self.transition_generic_released(
ts, stimulus_id=stimulus_id
)
- instructions.append(ReleaseWorkerDataMsg(ts.key))
+ instructions.append(ReleaseWorkerDataMsg(key=ts.key, stimulus_id=stimulus_id))
return recs, instructions
def transition_waiting_constrained(
@@ -2028,7 +2040,7 @@ def transition_long_running_rescheduled(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
recs: Recs = {ts: "released"}
- smsg = RescheduleMsg(key=ts.key, worker=self.address)
+ smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id)
return recs, [smsg]
def transition_executing_rescheduled(
@@ -2039,7 +2051,7 @@ def transition_executing_rescheduled(
self._executing.discard(ts)
recs: Recs = {ts: "released"}
- smsg = RescheduleMsg(key=ts.key, worker=self.address)
+ smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id)
return recs, [smsg]
def transition_waiting_ready(
@@ -2116,6 +2128,7 @@ def transition_generic_error(
traceback_text=traceback_text,
thread=self.threads.get(ts.key),
startstops=ts.startstops,
+ stimulus_id=stimulus_id,
)
return {}, [smsg]
@@ -2300,7 +2313,7 @@ def transition_generic_memory(
return recs, []
if self.validate:
assert ts.key in self.data or ts.key in self.actors
- smsg = self._get_task_finished_msg(ts)
+ smsg = self._get_task_finished_msg(ts, stimulus_id=stimulus_id)
return recs, [smsg]
def transition_executing_memory(
@@ -2417,7 +2430,9 @@ def transition_executing_long_running(
ts.state = "long-running"
self._executing.discard(ts)
self.long_running.add(ts.key)
- smsg = LongRunningMsg(key=ts.key, compute_duration=compute_duration)
+ smsg = LongRunningMsg(
+ key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id
+ )
self.io_loop.add_callback(self.ensure_computing)
return {}, [smsg]
@@ -2701,7 +2716,9 @@ def ensure_communicating(self) -> None:
for el in skipped_worker_in_flight:
self.data_needed.push(el)
- def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg:
+ def _get_task_finished_msg(
+ self, ts: TaskState, stimulus_id: str
+ ) -> TaskFinishedMsg:
if ts.key not in self.data and ts.key not in self.actors:
raise RuntimeError(f"Task {ts} not ready")
typ = ts.type
@@ -2727,6 +2744,7 @@ def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg:
metadata=ts.metadata,
thread=self.threads.get(ts.key),
startstops=ts.startstops,
+ stimulus_id=stimulus_id,
)
def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs:
@@ -3030,7 +3048,12 @@ async def gather_dep(
self.has_what[worker].discard(ts.key)
self.log.append((d, "missing-dep", stimulus_id, time()))
self.batched_stream.send(
- {"op": "missing-data", "errant_worker": worker, "key": d}
+ {
+ "op": "missing-data",
+ "errant_worker": worker,
+ "key": d,
+ "stimulus_id": stimulus_id,
+ }
)
recommendations[ts] = "fetch" if ts.who_has else "missing"
del data, response
@@ -3135,7 +3158,7 @@ def handle_steal_request(self, key: str, stimulus_id: str) -> None:
# `transition_constrained_executing`
self.transition(ts, "released", stimulus_id=stimulus_id)
- def handle_worker_status_change(self, status: str) -> None:
+ def handle_worker_status_change(self, status: str, stimulus_id: str) -> None:
new_status = Status.lookup[status] # type: ignore
if (
@@ -3146,7 +3169,7 @@ def handle_worker_status_change(self, status: str) -> None:
"Invalid Worker.status transition: %s -> %s", self._status, new_status
)
# Reiterate the current status to the scheduler to restore sync
- self._send_worker_status_change()
+ self._send_worker_status_change(stimulus_id)
else:
# Update status and send confirmation to the Scheduler (see status.setter)
self.status = new_status
@@ -3483,6 +3506,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
if result["op"] == "task-finished":
if self.digests is not None:
self.digests["task-duration"].add(result["stop"] - result["start"])
+ new_stimulus_id = f"{result['op']}-{time()}"
return ExecuteSuccessEvent(
key=key,
value=result["result"],
@@ -3490,7 +3514,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
stop=result["stop"],
nbytes=result["nbytes"],
type=result["type"],
- stimulus_id=stimulus_id,
+ stimulus_id=new_stimulus_id,
)
if isinstance(result["actual-exception"], Reschedule):
@@ -3517,7 +3541,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
traceback=result["traceback"],
exception_text=result["exception_text"],
traceback_text=result["traceback_text"],
- stimulus_id=stimulus_id,
+ stimulus_id=f"task-erred-{time()}",
)
except Exception as exc:
@@ -3531,7 +3555,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
traceback=msg["traceback"],
exception_text=msg["exception_text"],
traceback_text=msg["traceback_text"],
- stimulus_id=stimulus_id,
+ stimulus_id=f"task-erred-{time()}",
)
@functools.singledispatchmethod
@@ -3596,7 +3620,7 @@ def _(self, ev: ExecuteFailureEvent) -> RecsInstrs:
ev.traceback,
ev.exception_text,
ev.traceback_text,
- )
+ ),
}, []
@handle_event.register
diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py
index 8ae454417c9..966be0a3527 100644
--- a/distributed/worker_state_machine.py
+++ b/distributed/worker_state_machine.py
@@ -288,6 +288,7 @@ class TaskFinishedMsg(SendMessageToScheduler):
metadata: dict
thread: int | None
startstops: list[StartStop]
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore
def to_dict(self) -> dict[str, Any]:
@@ -307,6 +308,7 @@ class TaskErredMsg(SendMessageToScheduler):
traceback_text: str
thread: int | None
startstops: list[StartStop]
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore
def to_dict(self) -> dict[str, Any]:
@@ -319,8 +321,9 @@ def to_dict(self) -> dict[str, Any]:
class ReleaseWorkerDataMsg(SendMessageToScheduler):
op = "release-worker-data"
- __slots__ = ("key",)
+ __slots__ = ("key", "stimulus_id")
key: str
+ stimulus_id: str
# Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception
@@ -328,18 +331,21 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler):
class RescheduleMsg(SendMessageToScheduler):
op = "reschedule"
- __slots__ = ("key", "worker")
+ # Not to be confused with the distributed.Reschedule Exception
+ __slots__ = ("key", "worker", "stimulus_id")
key: str
worker: str
+ stimulus_id: str
@dataclass
class LongRunningMsg(SendMessageToScheduler):
op = "long-running"
- __slots__ = ("key", "compute_duration")
+ __slots__ = ("key", "compute_duration", "stimulus_id")
key: str
compute_duration: float
+ stimulus_id: str
@dataclass
@@ -365,6 +371,7 @@ class ExecuteSuccessEvent(StateMachineEvent):
stop: float
nbytes: int
type: type | None
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore
@@ -377,6 +384,7 @@ class ExecuteFailureEvent(StateMachineEvent):
traceback: Serialize | None
exception_text: str
traceback_text: str
+ stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore