From 9631b88b44b6b0f38f7fb97c27ffe10738678142 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 15 Aug 2023 17:20:54 +0200 Subject: [PATCH 1/4] Handle CancelledError in ConnectionPool --- distributed/core.py | 18 ++++++++---------- distributed/worker.py | 6 +----- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index df6d3cb8e99..5d94b531bb4 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1542,11 +1542,6 @@ async def _connect(self, addr: str, timeout: float | None = None) -> Comm: raise finally: self._connecting_count -= 1 - except asyncio.CancelledError: - current_task = asyncio.current_task() - assert current_task - reason = self._reasons.pop(current_task, "ConnectionPool closing.") - raise CommClosedError(reason) finally: self._pending_count -= 1 @@ -1599,12 +1594,15 @@ def callback(task: asyncio.Task[Comm]) -> None: except asyncio.CancelledError: # This is an outside cancel attempt connect_attempt.cancel() - try: - await connect_attempt - except CommClosedError: - pass + await connect_attempt raise - return await connect_attempt + try: + return connect_attempt.result() + except asyncio.CancelledError: + current_task = asyncio.current_task() + assert current_task + reason = self._reasons.pop(current_task, "ConnectionPool closing.") + raise CommClosedError(reason) def reuse(self, addr: str, comm: Comm) -> None: """ diff --git a/distributed/worker.py b/distributed/worker.py index 683ad1cbeaf..d2eaef8877a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2112,11 +2112,7 @@ async def gather_dep( data=response["data"], stimulus_id=f"gather-dep-success-{time()}", ) - - # Note: CancelledError and asyncio.TimeoutError are rare conditions - # that can be raised by the network stack. - # See https://github.com/dask/distributed/issues/8006 - except (OSError, asyncio.CancelledError, asyncio.TimeoutError): + except OSError: logger.exception("Worker stream died during communication: %s", worker) self.state.log.append( ("gather-dep-failed", worker, to_gather, stimulus_id, time()) From b5b259cf012e5f5f7f0c6b0d20f3fdecce211877 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 15 Aug 2023 18:09:57 +0200 Subject: [PATCH 2/4] fixes --- distributed/core.py | 4 +--- distributed/tests/test_worker.py | 38 -------------------------------- 2 files changed, 1 insertion(+), 41 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 5d94b531bb4..94e88d25fdb 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1599,9 +1599,7 @@ def callback(task: asyncio.Task[Comm]) -> None: try: return connect_attempt.result() except asyncio.CancelledError: - current_task = asyncio.current_task() - assert current_task - reason = self._reasons.pop(current_task, "ConnectionPool closing.") + reason = self._reasons.pop(connect_attempt, "ConnectionPool closing.") raise CommClosedError(reason) def reuse(self, addr: str, comm: Comm) -> None: diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index bb52753205b..225c68b1d50 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3326,44 +3326,6 @@ async def test_gather_dep_no_longer_in_flight_tasks(c, s, a): assert not any("missing-dep" in msg for msg in f2_story) -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_gather_dep_cancelled_error(c, s, a): - """Something somewhere in the networking stack raises CancelledError while - gather_dep is running - - See Also - -------- - test_get_data_cancelled_error - https://github.com/dask/distributed/issues/8006 - """ - async with BlockedGetData(s.address) as b: - x = c.submit(inc, 1, key="x", workers=[b.address]) - y = c.submit(inc, x, key="y", workers=[a.address]) - await b.in_get_data.wait() - tasks = { - task for task in asyncio.all_tasks() if "gather_dep" in task.get_name() - } - assert tasks - # There should be only one task but cope with finding more just in case a - # previous test didn't properly clean up - for task in tasks: - task.cancel() - - b.block_get_data.set() - assert await y == 3 - - assert_story( - a.state.story("x"), - [ - ("x", "fetch", "flight", "flight", {}), - ("x", "flight", "missing", "missing", {}), - ("x", "missing", "fetch", "fetch", {}), - ("x", "fetch", "flight", "flight", {}), - ("x", "flight", "memory", "memory", {"y": "ready"}), - ], - ) - - @gen_cluster(client=True, nthreads=[("", 1)], timeout=5) async def test_get_data_cancelled_error(c, s, a): """Something somewhere in the networking stack raises CancelledError while From a29829dec47d08e0844209056b324937bd8d3288 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 15 Aug 2023 18:25:49 +0200 Subject: [PATCH 3/4] Remove dead reference --- distributed/tests/test_worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 225c68b1d50..0fc2f2cc3cd 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3333,7 +3333,6 @@ async def test_get_data_cancelled_error(c, s, a): See Also -------- - test_gather_dep_cancelled_error https://github.com/dask/distributed/issues/8006 """ From 6a2a8389739bde690e0d5ebe793ef418f826703c Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 15 Aug 2023 18:42:21 +0200 Subject: [PATCH 4/4] Add test --- distributed/tests/test_core.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 5eb05dcb1fb..4754073e458 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -886,6 +886,30 @@ async def remove_address(): assert connect_finished.cancelled() +@gen_test() +async def test_remove_cancels_connect_before_task_running(): + loop = asyncio.get_running_loop() + connect_finished = loop.create_future() + + async def connect(*args, **kwargs): + await connect_finished + + async def connect_to_server(): + with pytest.raises(CommClosedError, match="Address removed."): + await rpc.connect("tcp://0.0.0.0") + return True + + rpc = await ConnectionPool(limit=1) + with mock.patch("distributed.core.connect", connect): + t1 = asyncio.create_task(connect_to_server()) + # Cancel the actual connect task before it can even run + while not rpc._connecting: + await asyncio.sleep(0) + rpc.remove("tcp://0.0.0.0") + + assert await t1 + + @gen_test() async def test_connection_pool_respects_limit(): limit = 5