diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 9c73a538156..fc07d0489be 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -254,27 +254,28 @@ async def write( ) -> int: if self.closed(): raise CommClosedError("Endpoint is closed -- unable to send message") - try: - if serializers is None: - serializers = ("cuda", "dask", "pickle", "error") - # msg can also be a list of dicts when sending batched messages - frames = await to_frames( - msg, - serializers=serializers, - on_error=on_error, - allow_offload=self.allow_offload, - ) - nframes = len(frames) - cuda_frames = tuple(hasattr(f, "__cuda_array_interface__") for f in frames) - sizes = tuple(nbytes(f) for f in frames) - cuda_send_frames, send_frames = zip( - *( - (is_cuda, each_frame) - for is_cuda, each_frame in zip(cuda_frames, frames) - if nbytes(each_frame) > 0 - ) + + if serializers is None: + serializers = ("cuda", "dask", "pickle", "error") + # msg can also be a list of dicts when sending batched messages + frames = await to_frames( + msg, + serializers=serializers, + on_error=on_error, + allow_offload=self.allow_offload, + ) + nframes = len(frames) + cuda_frames = tuple(hasattr(f, "__cuda_array_interface__") for f in frames) + sizes = tuple(nbytes(f) for f in frames) + cuda_send_frames, send_frames = zip( + *( + (is_cuda, each_frame) + for is_cuda, each_frame in zip(cuda_frames, frames) + if nbytes(each_frame) > 0 ) + ) + try: # Send meta data # Send close flag and number of frames (_Bool, int64) @@ -326,7 +327,7 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): header = struct.unpack(header_fmt, header) cuda_frames, sizes = header[:nframes], header[nframes:] except BaseException as e: - # In addition to UCX exceptions, may be CancelledError or a another + # In addition to UCX exceptions, may be CancelledError or another # "low-level" exception. The only safe thing to do is to abort. # (See also https://github.com/dask/distributed/pull/6574). self.abort() @@ -352,14 +353,29 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): if any(cuda_recv_frames): synchronize_stream(0) - for each_frame in recv_frames: - await self.ep.recv(each_frame) - msg = await from_frames( - frames, - deserialize=self.deserialize, - deserializers=deserializers, - allow_offload=self.allow_offload, - ) + try: + for each_frame in recv_frames: + await self.ep.recv(each_frame) + except BaseException as e: + # In addition to UCX exceptions, may be CancelledError or another + # "low-level" exception. The only safe thing to do is to abort. + # (See also https://github.com/dask/distributed/pull/6574). + self.abort() + raise CommClosedError( + f"Connection closed by writer.\nInner exception: {e!r}" + ) + + try: + msg = await from_frames( + frames, + deserialize=self.deserialize, + deserializers=deserializers, + allow_offload=self.allow_offload, + ) + except EOFError: + # Frames possibly garbled or truncated by communication error + self.abort() + raise CommClosedError("Aborted stream on truncated data") return msg async def close(self):