Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions distributed/comm/tests/test_ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from distributed import Client, Scheduler, wait
from distributed.comm import connect, listen, parse_address, ucx
from distributed.comm.core import CommClosedError
from distributed.comm.registry import backends, get_backend
from distributed.deploy.local import LocalCluster
from distributed.diagnostics.nvml import has_cuda_context
Expand Down Expand Up @@ -367,3 +368,15 @@ async def test_ucx_unreachable(
):
with pytest.raises(OSError, match="Timed out trying to connect to"):
await Client("ucx://255.255.255.255:12345", timeout=1, asynchronous=True)


@gen_test()
async def test_comm_closed_on_read_error():
reader, writer = await get_comm_pair()

# Depending on the UCP protocol selected, it may raise either
# `asyncio.TimeoutError` or `CommClosedError`, so validate either one.
with pytest.raises((asyncio.TimeoutError, CommClosedError)):
await asyncio.wait_for(reader.read(), 0.01)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so here we're waiting for a read that will never be matched by a write, and so eventually we'll fail.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right.


assert reader.closed()
12 changes: 7 additions & 5 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,14 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
await self.ep.recv(header)
header = struct.unpack(header_fmt, header)
cuda_frames, sizes = header[:nframes], header[nframes:]
except (
ucp.exceptions.UCXCloseError,
ucp.exceptions.UCXCanceled,
) + (getattr(ucp.exceptions, "UCXConnectionReset", ()),):
except BaseException as e:
# In addition to UCX exceptions, may be CancelledError or a another
# "low-level" exception. The only safe thing to do is to abort.
# (See also https://github.com/dask/distributed/pull/6574).
self.abort()
raise CommClosedError("Connection closed by writer")
raise CommClosedError(
f"Connection closed by writer.\nInner exception: {e!r}"
)
else:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to catch connection issues on line 354 as well.

So perhaps lines 353 and 354 should be replaced by:

try:
    for frame in recv_frames:
        await self.ep.recv(frame)
except BaseException as e:
    raise CommClosedError("Connection closed by writer.\nInner exception: {e!r}")

I had thought that one might be able to reduce synchronisation a little bit by using:

await asyncio.gather(*(map(self.ep.recv, recv_frames))

With a matching change in write of await asyncio.gather(*map(self.ep.send, send_frames)).

But I am unsure of the semantics of UCX wrt message overtaking. I think this could potentially result in the second (say) sent frame ending up in the first receive slot, which would be bad.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to catch connection issues on line 354 as well.

So perhaps lines 353 and 354 should be replaced by:

try:
    for frame in recv_frames:
        await self.ep.recv(frame)
except BaseException as e:
    raise CommClosedError("Connection closed by writer.\nInner exception: {e!r}")

I'm not entirely sure we want that, maybe it never occurred in practice or just raising the original exception may be fine. I'm mostly concerned with unforeseen side-effects this may cause and would prefer not to mess with it now given it's not been a problem so far. WDYT?

I had thought that one might be able to reduce synchronisation a little bit by using:

await asyncio.gather(*(map(self.ep.recv, recv_frames))

With a matching change in write of await asyncio.gather(*map(self.ep.send, send_frames)).

But I am unsure of the semantics of UCX wrt message overtaking. I think this could potentially result in the second (say) sent frame ending up in the first receive slot, which would be bad.

I would expect that as well and had done it once and had to revert #5505 because that caused various issues, unfortunately. In any case, with the C++ UCX introduction of "multi-transfers" this will anyway be reduced to a single future, so I will not try to improve this code in its current form.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure we want that, maybe it never occurred in practice or just raising the original exception may be fine. I'm mostly concerned with unforeseen side-effects this may cause and would prefer not to mess with it now given it's not been a problem so far. WDYT?

What could happen (although my guess is that it would be low likelihood) is that we're receiving a bunch of frames, each await yields to the event loop, and in between awaits the remote endpoint is closed for some other reason.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I fear that by raising a different exception now we may end up in some different control path that we didn't expect. I'm hoping that this patch can end up in the next Distributed release and it could be included in RAPIDS 22.10. I would be fine trying that out afterwards, but I'm a bit nervous of breaking something close to release time.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks, makes sense.

# Recv frames
frames = [
Expand Down