Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 57 additions & 50 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
inc,
slowinc,
wait_for_state,
wait_for_stimulus,
)
from distributed.worker_state_machine import (
ComputeTaskEvent,
Execute,
FreeKeysEvent,
GatherDep,
GatherDepNetworkFailureEvent,
)


Expand Down Expand Up @@ -375,60 +383,59 @@ def block_execution(event, lock):
assert await fut2 == 2


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_cancelled_resumed_after_flight_with_dependencies(c, s, w2, w3):
# See https://github.com/dask/distributed/pull/6327#discussion_r872231090
block_get_data_1 = asyncio.Lock()
enter_get_data_1 = asyncio.Event()
await block_get_data_1.acquire()
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_cancelled_resumed_after_flight_with_dependencies(c, s, a):
"""A task is in flight from b to a.
While a is waiting, b dies. The scheduler notices before a and reschedules the
task on a itself (as the only surviving replica was just lost).
Test that the worker eventually computes the task.

See https://github.com/dask/distributed/pull/6327#discussion_r872231090
See test_cancelled_resumed_after_flight_with_dependencies_workerstate below.
"""
async with await BlockedGetData(s.address) as b:
x = c.submit(inc, 1, key="x", workers=[b.address], allow_other_workers=True)
y = c.submit(inc, x, key="y", workers=[a.address])
await b.in_get_data.wait()

class BlockGetDataWorker(Worker):
def __init__(self, *args, get_data_event, get_data_lock, **kwargs):
self._get_data_event = get_data_event
self._get_data_lock = get_data_lock
super().__init__(*args, **kwargs)
# Make b dead to s, but not to a
await s.remove_worker(b.address, stimulus_id="stim-id")

async def get_data(self, comm, *args, **kwargs):
self._get_data_event.set()
async with self._get_data_lock:
return await super().get_data(comm, *args, **kwargs)

async with await BlockGetDataWorker(
s.address,
get_data_event=enter_get_data_1,
get_data_lock=block_get_data_1,
name="w1",
) as w1:

f1 = c.submit(inc, 1, key="f1", workers=[w1.address])
f2 = c.submit(inc, 2, key="f2", workers=[w1.address])
f3 = c.submit(sum, [f1, f2], key="f3", workers=[w1.address])

await wait(f3)
f4 = c.submit(inc, f3, key="f4", workers=[w2.address])

await enter_get_data_1.wait()
s.set_restrictions(
{
f1.key: {w3.address},
f2.key: {w3.address},
f3.key: {w2.address},
}
)
await s.remove_worker(w1.address, stimulus_id="stim-id")
# Wait for the scheduler to reschedule x on a.
# We want the comms from the scheduler to reach a before b closes the RPC
# channel, causing a.gather_dep() to raise OSError.
await wait_for_stimulus(ComputeTaskEvent, a, key="x")

# b closed; a.gather_dep() fails. Note that, in the current implementation, x won't
# be recomputed on a until this happens.
assert await y == 3

await wait_for_state(f3.key, "resumed", w2)
assert_story(
w2.state.log,
[
(f3.key, "flight", "released", "cancelled", {}),
# ...
(f3.key, "cancelled", "waiting", "resumed", {}),
],
)
# w1 closed

assert await f4 == 6
def test_cancelled_resumed_after_flight_with_dependencies_workerstate(ws):
"""Same as test_cancelled_resumed_after_flight_with_dependencies, but testing the
WorkerState in isolation
"""
ws2 = "127.0.0.1:2"
instructions = ws.handle_stimulus(
# Create task x and put it in flight from ws2
ComputeTaskEvent.dummy(key="y", who_has={"x": [ws2]}, stimulus_id="s1"),
# The scheduler realises that ws2 is unresponsive, although ws doesn't know yet.
# Having lost the last surviving replica of x, the scheduler cancels all of its
# dependents. This also cancels x.
FreeKeysEvent(keys=["y"], stimulus_id="s2"),
# The scheduler reschedules x on another worker, which just happens to be one
# that was previously fetching it. This does not generate an Execute
# instruction, because the GatherDep instruction isn't complete yet.
ComputeTaskEvent.dummy(key="x", stimulus_id="s3"),
# After ~30s, the TCP socket with ws2 finally times out and collapses.
# This triggers the Execute instruction.
GatherDepNetworkFailureEvent(worker=ws2, total_nbytes=1, stimulus_id="s4"),
)
assert instructions == [
GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s1"),
Execute(key="x", stimulus_id="s4"), # Note the stimulus_id!
]
assert ws.tasks["x"].state == "executing"


@pytest.mark.parametrize("wait_for_processing", [True, False])
Expand Down
23 changes: 22 additions & 1 deletion distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@
raises_with_cause,
tls_only_security,
wait_for_state,
wait_for_stimulus,
)
from distributed.worker import fail_hard
from distributed.worker_state_machine import (
ComputeTaskEvent,
InvalidTaskState,
InvalidTransition,
PauseEvent,
Expand Down Expand Up @@ -920,7 +922,7 @@ async def test_freeze_batched_send():
assert e.count == 3


@gen_cluster(client=True, nthreads=[("", 1)], timeout=2)
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_wait_for_state(c, s, a, capsys):
ev = Event()
x = c.submit(lambda ev: ev.wait(), ev, key="x")
Expand All @@ -947,3 +949,22 @@ async def test_wait_for_state(c, s, a, capsys):
f"tasks[x].state='memory' on {s.address}; expected state='bad_state'\n"
f"tasks[y] not found on {s.address}\n"
)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_wait_for_stimulus(c, s, a):
t1 = asyncio.create_task(wait_for_stimulus(ComputeTaskEvent, a))
t2 = asyncio.create_task(wait_for_stimulus(ComputeTaskEvent, a, key="y"))
await asyncio.sleep(0.05)
assert not t1.done()
assert not t2.done()

x = c.submit(inc, 1, key="x")
ev = await t1
assert isinstance(ev, ComputeTaskEvent)
await wait_for_stimulus(ComputeTaskEvent, a, key="x")
await c.run(wait_for_stimulus, ComputeTaskEvent, key="x")
assert not t2.done()

y = c.submit(inc, 1, key="y")
await t2
34 changes: 24 additions & 10 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,29 @@ def test_computetask_to_dict():
assert ev3.priority == (0,) # List is automatically converted back to tuple


def test_computetask_dummy():
ev = ComputeTaskEvent.dummy(key="x", stimulus_id="s")
assert ev == ComputeTaskEvent(
key="x",
who_has={},
nbytes={},
priority=(0,),
duration=1.0,
run_spec=None,
resource_restrictions={},
actor=False,
annotations={},
stimulus_id="s",
function=None,
args=None,
kwargs=None,
)

# nbytes is generated from who_has if omitted
ev2 = ComputeTaskEvent.dummy(key="x", who_has={"y": "127.0.0.1:2"}, stimulus_id="s")
assert ev2.nbytes == {"y": 1}


def test_updatedata_to_dict():
"""The potentially very large UpdateDataEvent.data is not stored in the log"""
ev = UpdateDataEvent(
Expand Down Expand Up @@ -933,19 +956,10 @@ def test_gather_priority(ws):
stimulus_id="compute1",
),
# A higher-priority task, even if scheduled later, is fetched first
ComputeTaskEvent(
ComputeTaskEvent.dummy(
key="z",
who_has={"y": ["127.0.0.7:1"]},
nbytes={"y": 1},
priority=(0,),
duration=1.0,
run_spec=None,
function=None,
args=None,
kwargs=None,
resource_restrictions={},
actor=False,
annotations={},
stimulus_id="compute2",
),
UnpauseEvent(stimulus_id="unpause"),
Expand Down
26 changes: 25 additions & 1 deletion distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
sync,
)
from distributed.worker import WORKER_ANY_RUNNING, Worker
from distributed.worker_state_machine import InvalidTransition
from distributed.worker_state_machine import InvalidTransition, StateMachineEvent
from distributed.worker_state_machine import TaskState as WorkerTaskState
from distributed.worker_state_machine import WorkerState

Expand Down Expand Up @@ -2400,6 +2400,9 @@ def freeze_batched_send(bcomm: BatchedSend) -> Iterator[LockedComm]:
async def wait_for_state(
key: str, state: str, dask_worker: Worker | Scheduler, *, interval: float = 0.01
) -> None:
"""Wait for a task to appear on a Worker or on the Scheduler and to be in a specific
state.
"""
if isinstance(dask_worker, Worker):
tasks = dask_worker.state.tasks
elif isinstance(dask_worker, Scheduler):
Expand All @@ -2424,6 +2427,27 @@ async def wait_for_state(
raise


async def wait_for_stimulus(
type_: type[StateMachineEvent] | tuple[type[StateMachineEvent], ...],
dask_worker: Worker,
*,
interval: float = 0.01,
**matches: Any,
) -> StateMachineEvent:
"""Wait for a specific stimulus to appear in the log of the WorkerState."""
log = dask_worker.state.stimulus_log
last_ev = None
while True:
if log and log[-1] is not last_ev:
last_ev = log[-1]
for ev in log:
if not isinstance(ev, type_):
continue
if all(getattr(ev, k) == v for k, v in matches.items()):
return ev
await asyncio.sleep(interval)


@pytest.fixture
def ws():
state = WorkerState(address="127.0.0.1:1", transition_counter_max=50_000)
Expand Down
32 changes: 32 additions & 0 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,38 @@ def to_loggable(self, *, handled: float) -> StateMachineEvent:
def _after_from_dict(self) -> None:
self.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None)

@staticmethod
def dummy(
*,
key: str,
who_has: dict[str, Collection[str]] | None = None,
nbytes: dict[str, int] | None = None,
priority: tuple[int, ...] = (0,),
duration: float = 1.0,
resource_restrictions: dict[str, float] | None = None,
actor: bool = False,
annotations: dict | None = None,
stimulus_id: str,
) -> ComputeTaskEvent:
"""Build a dummy event, with most attributes set to a reasonable default.
This is a convenience method to be used in unit testing only.
"""
return ComputeTaskEvent(
key=key,
who_has=who_has or {},
nbytes=nbytes or {k: 1 for k in who_has or ()},
priority=priority,
duration=duration,
run_spec=None,
function=None,
args=None,
kwargs=None,
resource_restrictions=resource_restrictions or {},
actor=actor,
annotations=annotations or {},
stimulus_id=stimulus_id,
)


@dataclass
class ExecuteSuccessEvent(StateMachineEvent):
Expand Down