diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index d88582bc67c..03f23116018 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -240,11 +240,12 @@ async def read(self, deserializers=None): self._closed = True if not sys.is_finalizing(): convert_stream_closed_error(self, e) - except Exception: - # Some OSError or a another "low-level" exception. We do not really know what - # was already read from the underlying socket, so it is not even safe to retry - # here using the same stream. The only safe thing to do is to abort. - # (See also GitHub #4133). + except BaseException: + # Some OSError, CancelledError or a another "low-level" exception. + # We do not really know what was already read from the underlying + # socket, so it is not even safe to retry here using the same stream. + # The only safe thing to do is to abort. + # (See also GitHub #4133, #6548). self.abort() raise else: @@ -317,11 +318,14 @@ async def write(self, msg, serializers=None, on_error="message"): self._closed = True if not sys.is_finalizing(): convert_stream_closed_error(self, e) - except Exception: + except BaseException: # Some OSError or a another "low-level" exception. We do not really know # what was already written to the underlying socket, so it is not even safe # to retry here using the same stream. The only safe thing to do is to # abort. (See also GitHub #4133). + # In case of, for instance, KeyboardInterrupts or other + # BaseExceptions that could be handled further upstream, we equally + # want to discard this comm self.abort() raise diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index fae1e5b6d23..cff48105dea 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -857,8 +857,15 @@ async def handle_comm(comm): await comm.close() +class CustomBase(BaseException): + # We don't want to interfere with KeyboardInterrupts or CancelledErrors for + # this test + ... + + +@pytest.mark.parametrize("exc_type", [BufferError, CustomBase]) @gen_test() -async def test_comm_closed_on_buffer_error(tcp): +async def test_comm_closed_on_write_error(tcp, exc_type): # Internal errors from comm.stream.write, such as # BufferError should lead to the stream being closed # and not re-used. See GitHub #4133 @@ -868,12 +875,29 @@ async def test_comm_closed_on_buffer_error(tcp): reader, writer = await get_tcp_comm_pair() def _write(data): - raise BufferError + raise exc_type() writer.stream.write = _write - with pytest.raises(BufferError): + with pytest.raises(exc_type): await writer.write("x") - assert writer.stream is None + + assert writer.closed() + + await reader.close() + await writer.close() + + +@gen_test() +async def test_comm_closed_on_read_error(tcp): + if tcp is asyncio_tcp: + pytest.skip("Not applicable for asyncio") + + reader, writer = await get_tcp_comm_pair() + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(reader.read(), 0.01) + + assert reader.closed() await reader.close() await writer.close()