Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 44 additions & 28 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,27 +254,28 @@ async def write(
) -> int:
if self.closed():
raise CommClosedError("Endpoint is closed -- unable to send message")
try:
if serializers is None:
serializers = ("cuda", "dask", "pickle", "error")
# msg can also be a list of dicts when sending batched messages
frames = await to_frames(
msg,
serializers=serializers,
on_error=on_error,
allow_offload=self.allow_offload,
)
nframes = len(frames)
cuda_frames = tuple(hasattr(f, "__cuda_array_interface__") for f in frames)
sizes = tuple(nbytes(f) for f in frames)
cuda_send_frames, send_frames = zip(
*(
(is_cuda, each_frame)
for is_cuda, each_frame in zip(cuda_frames, frames)
if nbytes(each_frame) > 0
)

if serializers is None:
serializers = ("cuda", "dask", "pickle", "error")
# msg can also be a list of dicts when sending batched messages
frames = await to_frames(
msg,
serializers=serializers,
on_error=on_error,
allow_offload=self.allow_offload,
)
nframes = len(frames)
cuda_frames = tuple(hasattr(f, "__cuda_array_interface__") for f in frames)
sizes = tuple(nbytes(f) for f in frames)
cuda_send_frames, send_frames = zip(
*(
(is_cuda, each_frame)
for is_cuda, each_frame in zip(cuda_frames, frames)
if nbytes(each_frame) > 0
)
)

try:
# Send meta data

# Send close flag and number of frames (_Bool, int64)
Expand Down Expand Up @@ -326,7 +327,7 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
header = struct.unpack(header_fmt, header)
cuda_frames, sizes = header[:nframes], header[nframes:]
except BaseException as e:
# In addition to UCX exceptions, may be CancelledError or a another
# In addition to UCX exceptions, may be CancelledError or another
# "low-level" exception. The only safe thing to do is to abort.
# (See also https://github.com/dask/distributed/pull/6574).
self.abort()
Expand All @@ -352,14 +353,29 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
if any(cuda_recv_frames):
synchronize_stream(0)

for each_frame in recv_frames:
await self.ep.recv(each_frame)
msg = await from_frames(
frames,
deserialize=self.deserialize,
deserializers=deserializers,
allow_offload=self.allow_offload,
)
try:
for each_frame in recv_frames:
await self.ep.recv(each_frame)
except BaseException as e:
# In addition to UCX exceptions, may be CancelledError or another
# "low-level" exception. The only safe thing to do is to abort.
# (See also https://github.com/dask/distributed/pull/6574).
self.abort()
raise CommClosedError(
f"Connection closed by writer.\nInner exception: {e!r}"
)

try:
msg = await from_frames(
frames,
deserialize=self.deserialize,
deserializers=deserializers,
allow_offload=self.allow_offload,
)
except EOFError:
# Frames possibly garbled or truncated by communication error
self.abort()
raise CommClosedError("Aborted stream on truncated data")
return msg

async def close(self):
Expand Down