diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9a7162a486e..cfe39f9c72b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6673,7 +6673,7 @@ async def retire_workers( return worker_keys - def add_keys(self, comm=None, worker=None, keys=()): + def add_keys(self, comm=None, worker=None, keys=(), stimulus_id=None): """ Learn that a worker has certain keys @@ -6694,12 +6694,14 @@ def add_keys(self, comm=None, worker=None, keys=()): redundant_replicas.append(key) if redundant_replicas: + if not stimulus_id: + stimulus_id = f"redundant-replicas-{time()}" self.worker_send( worker, { "op": "remove-replicas", "keys": redundant_replicas, - "stimulus_id": f"redundant-replicas-{time()}", + "stimulus_id": stimulus_id, }, ) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index c28a521ea72..8c9413dec56 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2760,8 +2760,12 @@ def _acquire_replicas(scheduler, worker, *futures): def _remove_replicas(scheduler, worker, *futures): keys = [f.key for f in futures] - - scheduler.stream_comms[worker.address].send( + ws = scheduler.workers[worker.address] + for k in keys: + ts = scheduler.tasks[k] + if ws in ts.who_has: + scheduler.remove_replica(ts, ws) + scheduler.stream_comms[ws.address].send( { "op": "remove-replicas", "keys": keys, @@ -2852,21 +2856,26 @@ async def test_remove_replica_simple(c, s, a, b): _remove_replicas(s, b, *futs) + assert all(len(s.tasks[f.key].who_has) == 1 for f in futs) + while b.tasks: await asyncio.sleep(0.01) - # might take a moment for the reply to reach the scheduler - while not all(len(s.tasks[f.key].who_has) == 1 for f in futs): - await asyncio.sleep(0.01) + # Ensure there is no delayed reply to re-register the key + await asyncio.sleep(0.01) + assert all(s.tasks[f.key].who_has == {s.workers[a.address]} for f in futs) -@gen_cluster(client=True) +@gen_cluster( + client=True, + config={"distributed.comm.recent-messages-log-length": 1_000}, +) async def test_remove_replica_while_computing(c, s, *workers): futs = c.map(inc, range(10), workers=[workers[0].address]) # All interesting things will happen on that worker w = workers[1] - intermediate = c.map(slowinc, futs, delay=0.1, workers=[w.address]) + intermediate = c.map(slowinc, futs, delay=0.05, workers=[w.address]) def reduce(*args, **kwargs): import time @@ -2875,24 +2884,52 @@ def reduce(*args, **kwargs): return final = c.submit(reduce, intermediate, workers=[w.address], key="final") - while final.key not in w.tasks: + + while not any(f.key in w.tasks for f in intermediate): await asyncio.sleep(0.001) + # The scheduler removes keys from who_has/has_what immediately + # Make sure the worker responds to the rejection and the scheduler corrects + # the state + ws = s.workers[w.address] + while not any(s.tasks[fut.key] in ws.has_what for fut in futs): + await asyncio.sleep(0.001) + + _remove_replicas(s, w, *futs) + # Scheduler removed keys immediately... + assert not any(s.tasks[fut.key] in ws.has_what for fut in futs) + # ... but the state is properly restored + while not any(s.tasks[fut.key] in ws.has_what for fut in futs): + await asyncio.sleep(0.01) + + # The worker should reject all of these since they are required while not all(fut.done() for fut in intermediate): - # The worker should reject all of these since they are required _remove_replicas(s, w, *futs) - _remove_replicas(s, w, *intermediate) - await asyncio.sleep(0.001) + await asyncio.sleep(0.01) await wait(intermediate) + # If a request is rejected, the worker responds with an add-keys message to + # reenlist the key in the schedulers state system to avoid race conditions, + # see also https://github.com/dask/distributed/issues/5265 + rejections = set() + for msg in w.log: + if msg[0] == "remove-replica-rejected": + rejections.update(msg[1]) + for rejected_key in rejections: + + def answer_sent(key): + for batch in w.batched_stream.recent_message_log: + for msg in batch: + if "op" in msg and msg["op"] == "add-keys" and key in msg["keys"]: + return True + return False + + assert answer_sent(rejected_key) + # Since intermediate is done, futs replicas may be removed. # They might be already gone due to the above remove replica calls _remove_replicas(s, w, *futs) - # the intermediate tasks should not be touched because they are still needed - # (the scheduler should not have made the above call but we should be safe - # regarless) - assert all(w.tasks[f.key].state == "memory" for f in intermediate) while any(w.tasks[f.key].state != "released" for f in futs if f.key in w.tasks): await asyncio.sleep(0.001) diff --git a/distributed/worker.py b/distributed/worker.py index 7d1a2710b8c..ac343d0b10a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1586,7 +1586,9 @@ def update_data( self.log.append((key, "receive-from-scatter")) if report: - scheduler_messages.append({"op": "add-keys", "keys": list(data)}) + scheduler_messages.append( + {"op": "add-keys", "keys": list(data), "stimulus_id": stimulus_id} + ) self.transitions(recommendations, stimulus_id=stimulus_id) for msg in scheduler_messages: @@ -1634,10 +1636,23 @@ def handle_remove_replicas(self, keys, stimulus_id): """ self.log.append(("remove-replicas", keys, stimulus_id)) recommendations = {} + + rejected = [] for key in keys: ts = self.tasks.get(key) - if ts and not ts.is_protected(): + if ts is None or ts.state != "memory": + continue + if not ts.is_protected(): + self.log.append(("remove-replica-confirmed", ts.key, stimulus_id)) recommendations[ts] = "released" if ts.dependents else "forgotten" + else: + rejected.append(key) + + if rejected: + self.log.append(("remove-replica-rejected", rejected, stimulus_id)) + self.batched_stream.send( + {"op": "add-keys", "keys": rejected, "stimulus_id": stimulus_id} + ) self.transitions(recommendations=recommendations, stimulus_id=stimulus_id) @@ -2111,14 +2126,14 @@ def transition_executing_long_running(self, ts, compute_duration, *, stimulus_id def transition_released_memory(self, ts, value, *, stimulus_id): recs, smsgs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) - smsgs.append({"op": "add-keys", "keys": [ts.key]}) + smsgs.append({"op": "add-keys", "keys": [ts.key], "stimulus_id": stimulus_id}) return recs, smsgs def transition_flight_memory(self, ts, value, *, stimulus_id): self._in_flight_tasks.discard(ts) ts.coming_from = None recs, smsgs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) - smsgs.append({"op": "add-keys", "keys": [ts.key]}) + smsgs.append({"op": "add-keys", "keys": [ts.key], "stimulus_id": stimulus_id}) return recs, smsgs def transition_released_forgotten(self, ts, *, stimulus_id):