From b3916208c2f9e34a719f26f4eea6209869eaceab Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 1 Aug 2021 10:10:45 -0500 Subject: [PATCH 1/2] Split large header in comms Today we try to split up large messages in comms. This is useful in a few situations: 1. Websockets, which often pass frames through middleware that requires small messages 2. TLS, which fails on some OpenSSL versions with frames above the size of an int We correctly cut up data frames into smaller pieces to address these issues. However we don't apply this same logic to the header frame, which may still contain very large bytestrings. This commit adds a workaround in protocol dumps/loads which watches for this event and splits the header frame up if necessary. It works, but it's not very smooth. I would prefer that in the future we think about what a proper header should look like and ensure that it contains no user data. In the meantime this should help. --- distributed/comm/tests/test_ws.py | 11 +++++++++++ distributed/comm/ws.py | 8 ++++++++ distributed/protocol/core.py | 16 +++++++++++++++- 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/distributed/comm/tests/test_ws.py b/distributed/comm/tests/test_ws.py index 08d1cbce0ee..5bb826be133 100644 --- a/distributed/comm/tests/test_ws.py +++ b/distributed/comm/tests/test_ws.py @@ -209,3 +209,14 @@ async def test_wss_roundtrip(c, s, a, b): future = await c.scatter(x) y = await future assert (x == y).all() + + +@gen_cluster(client=True, scheduler_kwargs={"protocol": "ws://"}) +async def test_ws_roundtrip_large(c, s, a, b): + import numpy as np + + x = np.random.random(25000000) + + future = c.submit(lambda x: x, x) + y = await future + assert (x == y).all() diff --git a/distributed/comm/ws.py b/distributed/comm/ws.py index 46ce55b86e5..2ab059496bf 100644 --- a/distributed/comm/ws.py +++ b/distributed/comm/ws.py @@ -120,6 +120,10 @@ async def write(self, msg, serializers=None, on_error=None): }, frame_split_size=BIG_BYTES_SHARD_SIZE, ) + assert all(len(frame) <= BIG_BYTES_SHARD_SIZE for frame in frames), list( + map(len, frames) + ) + n = struct.pack("Q", len(frames)) try: await self.handler.write_message(n, binary=True) @@ -218,6 +222,10 @@ async def write(self, msg, serializers=None, on_error=None): }, frame_split_size=BIG_BYTES_SHARD_SIZE, ) + assert all(len(frame) <= BIG_BYTES_SHARD_SIZE for frame in frames), list( + map(len, frames) + ) + n = struct.pack("Q", len(frames)) try: await self.sock.write_message(n, binary=True) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 1be2d761e35..bcaf8a057a6 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -74,6 +74,14 @@ def _encode_default(obj): return msgpack_encode_default(obj) frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True) + + if len(frames[0]) > frame_split_size: + from distributed.protocol.utils import frame_split_size as split + + msg_frames = split(frames[0], n=frame_split_size) + header = msgpack.dumps({"large-header": True, "count": len(msg_frames)}) + frames = [header] + msg_frames + frames[1:] + return frames except Exception: @@ -108,9 +116,15 @@ def _decode_default(obj): else: return msgpack_decode_default(obj) - return msgpack.loads( + result = msgpack.loads( frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts ) + if isinstance(result, dict) and "large-header" in result: + frame = b"".join(frames[1 : result["count"] + 1]) + frames = [frame] + frames[result["count"] + 1 :] + return loads(frames, deserialize=deserialize, deserializers=deserializers) + else: + return result except Exception: logger.critical("Failed to deserialize", exc_info=True) From 76d9f03632e2a3587ddb894d4034027c856f50d0 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 1 Aug 2021 13:59:36 -0500 Subject: [PATCH 2/2] allow frame_split_size to be zero --- distributed/protocol/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index bcaf8a057a6..f4173ccb4c0 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -75,7 +75,7 @@ def _encode_default(obj): frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True) - if len(frames[0]) > frame_split_size: + if frame_split_size and len(frames[0]) > frame_split_size: from distributed.protocol.utils import frame_split_size as split msg_frames = split(frames[0], n=frame_split_size)