diff --git a/distributed/core.py b/distributed/core.py index df6d3cb8e99..94e88d25fdb 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1542,11 +1542,6 @@ async def _connect(self, addr: str, timeout: float | None = None) -> Comm: raise finally: self._connecting_count -= 1 - except asyncio.CancelledError: - current_task = asyncio.current_task() - assert current_task - reason = self._reasons.pop(current_task, "ConnectionPool closing.") - raise CommClosedError(reason) finally: self._pending_count -= 1 @@ -1599,12 +1594,13 @@ def callback(task: asyncio.Task[Comm]) -> None: except asyncio.CancelledError: # This is an outside cancel attempt connect_attempt.cancel() - try: - await connect_attempt - except CommClosedError: - pass + await connect_attempt raise - return await connect_attempt + try: + return connect_attempt.result() + except asyncio.CancelledError: + reason = self._reasons.pop(connect_attempt, "ConnectionPool closing.") + raise CommClosedError(reason) def reuse(self, addr: str, comm: Comm) -> None: """ diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 5eb05dcb1fb..4754073e458 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -886,6 +886,30 @@ async def remove_address(): assert connect_finished.cancelled() +@gen_test() +async def test_remove_cancels_connect_before_task_running(): + loop = asyncio.get_running_loop() + connect_finished = loop.create_future() + + async def connect(*args, **kwargs): + await connect_finished + + async def connect_to_server(): + with pytest.raises(CommClosedError, match="Address removed."): + await rpc.connect("tcp://0.0.0.0") + return True + + rpc = await ConnectionPool(limit=1) + with mock.patch("distributed.core.connect", connect): + t1 = asyncio.create_task(connect_to_server()) + # Cancel the actual connect task before it can even run + while not rpc._connecting: + await asyncio.sleep(0) + rpc.remove("tcp://0.0.0.0") + + assert await t1 + + @gen_test() async def test_connection_pool_respects_limit(): limit = 5 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index bb52753205b..0fc2f2cc3cd 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3326,44 +3326,6 @@ async def test_gather_dep_no_longer_in_flight_tasks(c, s, a): assert not any("missing-dep" in msg for msg in f2_story) -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_gather_dep_cancelled_error(c, s, a): - """Something somewhere in the networking stack raises CancelledError while - gather_dep is running - - See Also - -------- - test_get_data_cancelled_error - https://github.com/dask/distributed/issues/8006 - """ - async with BlockedGetData(s.address) as b: - x = c.submit(inc, 1, key="x", workers=[b.address]) - y = c.submit(inc, x, key="y", workers=[a.address]) - await b.in_get_data.wait() - tasks = { - task for task in asyncio.all_tasks() if "gather_dep" in task.get_name() - } - assert tasks - # There should be only one task but cope with finding more just in case a - # previous test didn't properly clean up - for task in tasks: - task.cancel() - - b.block_get_data.set() - assert await y == 3 - - assert_story( - a.state.story("x"), - [ - ("x", "fetch", "flight", "flight", {}), - ("x", "flight", "missing", "missing", {}), - ("x", "missing", "fetch", "fetch", {}), - ("x", "fetch", "flight", "flight", {}), - ("x", "flight", "memory", "memory", {"y": "ready"}), - ], - ) - - @gen_cluster(client=True, nthreads=[("", 1)], timeout=5) async def test_get_data_cancelled_error(c, s, a): """Something somewhere in the networking stack raises CancelledError while @@ -3371,7 +3333,6 @@ async def test_get_data_cancelled_error(c, s, a): See Also -------- - test_gather_dep_cancelled_error https://github.com/dask/distributed/issues/8006 """ diff --git a/distributed/worker.py b/distributed/worker.py index 683ad1cbeaf..d2eaef8877a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2112,11 +2112,7 @@ async def gather_dep( data=response["data"], stimulus_id=f"gather-dep-success-{time()}", ) - - # Note: CancelledError and asyncio.TimeoutError are rare conditions - # that can be raised by the network stack. - # See https://github.com/dask/distributed/issues/8006 - except (OSError, asyncio.CancelledError, asyncio.TimeoutError): + except OSError: logger.exception("Worker stream died during communication: %s", worker) self.state.log.append( ("gather-dep-failed", worker, to_gather, stimulus_id, time())