From 24d8001947b2c2d7e2f1f5373907a1b48e0d0520 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 29 Mar 2022 10:53:52 +0200 Subject: [PATCH] Ensure multiple clients can cancel their key without interference --- distributed/client.py | 2 ++ distributed/scheduler.py | 39 +++++++++++++--------- distributed/tests/test_as_completed.py | 8 ++--- distributed/tests/test_client.py | 45 +++++++++++++++++--------- distributed/tests/test_worker.py | 4 +-- 5 files changed, 61 insertions(+), 37 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 82474fb7620..caf42dc19ad 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4881,6 +4881,8 @@ def _get_and_raise(self): if self.raise_errors and future.status == "error": typ, exc, tb = result raise exc.with_traceback(tb) + elif future.status == "cancelled": + res = (res[0], CancelledError(future.key)) return res def __next__(self): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c9ddc2ec67c..e6250762dea 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4298,36 +4298,45 @@ def remove_worker_from_events(): return "OK" - def stimulus_cancel(self, comm, keys=None, client=None, force=False): + async def stimulus_cancel(self, keys, client, force=False): """Stop execution on a list of keys""" logger.info("Client %s requests to cancel %d keys", client, len(keys)) if client: self.log_event( client, {"action": "cancel", "count": len(keys), "force": force} ) - for key in keys: - self.cancel_key(key, client, force=force) - def cancel_key(self, key, client, retries=5, force=False): + await asyncio.gather( + *[self._cancel_key(key, client, force=force) for key in keys] + ) + + async def _cancel_key(self, key, client, force=False): """Cancel a particular key and all dependents""" # TODO: this should be converted to use the transition mechanism - ts: TaskState = self.tasks.get(key) + ts: TaskState | None = self.tasks.get(key) dts: TaskState try: cs: ClientState = self.clients[client] except KeyError: return - if ts is None or not ts.who_wants: # no key yet, lets try again in a moment - if retries: - self.loop.call_later( - 0.2, lambda: self.cancel_key(key, client, retries - 1) - ) - return + + # no key yet, lets try again in a moment + start = time() + while ts is None or not ts.who_wants: + await asyncio.sleep(0.1) + ts = self.tasks.get(key) + if time() - start >= 1: + return + if force or ts.who_wants == {cs}: # no one else wants this key - for dts in list(ts.dependents): - self.cancel_key(dts.key, client, force=force) - logger.info("Scheduler cancels key %s. Force=%s", key, force) - self.report({"op": "cancelled-key", "key": key}) + await asyncio.gather( + *[ + self._cancel_key(dts.key, client, force=force) + for dts in ts.dependents + ] + ) + logger.info("Scheduler cancels key %s. Force=%s", key, force) + self.report({"op": "cancelled-key", "key": key}) clients = list(ts.who_wants) if force else [cs] for cs in clients: self.client_releases_keys( diff --git a/distributed/tests/test_as_completed.py b/distributed/tests/test_as_completed.py index 0e0f4d254fc..7b0e376a92b 100644 --- a/distributed/tests/test_as_completed.py +++ b/distributed/tests/test_as_completed.py @@ -240,12 +240,12 @@ async def test_str(c, s, a, b): @gen_cluster(client=True) async def test_as_completed_with_results_no_raise_async(c, s, a, b): - x = c.submit(throws, 1) - y = c.submit(inc, 5) - z = c.submit(inc, 1) + x = c.submit(throws, 1, key="x") + y = c.submit(inc, 5, key="y") + z = c.submit(inc, 1, key="z") ac = as_completed([x, y, z], with_results=True, raise_errors=False) - c.loop.add_callback(y.cancel) + await y.cancel() res = [el async for el in ac] dd = {r[0]: r[1:] for r in res} diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index a0ee538f6fa..5a2dbc79f6b 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2305,30 +2305,43 @@ async def test_cancel_tuple_key(c, s, a, b): @gen_cluster() async def test_cancel_multi_client(s, a, b): - c = await Client(s.address, asynchronous=True) - f = await Client(s.address, asynchronous=True) + async with Client(s.address, asynchronous=True, name="c") as c: + async with Client(s.address, asynchronous=True, name="f") as f: - x = c.submit(slowinc, 1) - y = f.submit(slowinc, 1) + x = c.submit(slowinc, 1) + y = f.submit(slowinc, 1) - assert x.key == y.key + assert x.key == y.key - await c.cancel([x]) + # Ensure both clients are known to the scheduler. + await y + await x - assert x.cancelled() - assert not y.cancelled() + await c.cancel([x]) - while y.key not in s.tasks: - await asyncio.sleep(0.01) + # Give the scheduler time to pass messages + await asyncio.sleep(0.1) - out = await y - assert out == 2 + assert x.cancelled() + assert not y.cancelled() - with pytest.raises(CancelledError): - await x + out = await y + assert out == 2 - await c.close() - await f.close() + with pytest.raises(CancelledError): + await x + + +@gen_cluster(nthreads=[("", 1)], client=True) +async def test_cancel_before_known_to_scheduler(c, s, a): + with captured_logger("distributed.scheduler") as slogs: + f = c.submit(inc, 1) + await c.cancel([f]) + + with pytest.raises(CancelledError): + await f + + assert "Scheduler cancels key" in slogs.getvalue() @gen_cluster(client=True) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index bb99438e9f4..7421fed1756 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -687,9 +687,9 @@ async def test_restrictions(c, s, a, b): await x ts = a.tasks[x.key] assert ts.resource_restrictions == {"A": 1} - await c._cancel(x) + await c.cancel([x]) - while ts.state != "memory": + while ts.state == "executing": # Resource should be unavailable while task isn't finished assert a.available_resources == {"A": 0} await asyncio.sleep(0.01)