From bf430816842f2528ab463479da179f2e29deba0d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sun, 29 Oct 2023 19:49:51 +0100 Subject: [PATCH 1/2] Don't share host_array between objects --- distributed/comm/tcp.py | 114 +++++++++++++----- distributed/comm/tests/test_comms.py | 95 ++++++++++++++- distributed/comm/tests/test_ucx.py | 28 +++++ .../protocol/tests/test_protocol_utils.py | 42 ++++++- distributed/protocol/tests/test_utils_test.py | 26 ++++ distributed/protocol/utils.py | 51 +++++++- distributed/protocol/utils_test.py | 30 +++++ 7 files changed, 352 insertions(+), 34 deletions(-) create mode 100644 distributed/protocol/tests/test_utils_test.py create mode 100644 distributed/protocol/utils_test.py diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index e3e8029e7a1..8a2271f2bb2 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -220,18 +220,16 @@ async def read(self, deserializers=None): fmt_size = struct.calcsize(fmt) try: - frames_nbytes = await stream.read_bytes(fmt_size) - (frames_nbytes,) = struct.unpack(fmt, frames_nbytes) - - frames = host_array(frames_nbytes) - for i, j in sliding_window( - 2, - range(0, frames_nbytes + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE), - ): - chunk = frames[i:j] - chunk_nbytes = chunk.nbytes - n = await stream.read_into(chunk) - assert n == chunk_nbytes, (n, chunk_nbytes) + # Don't store multiple numpy or parquet buffers into the same buffer, or + # none will be released until all are released. + frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size) + (frames_nosplit_nbytes,) = struct.unpack(fmt, frames_nosplit_nbytes_bin) + frames_nosplit = await read_bytes_rw(stream, frames_nosplit_nbytes) + frames, buffers_nbytes = unpack_frames(frames_nosplit, partial=True) + for buffer_nbytes in buffers_nbytes: + buffer = await read_bytes_rw(stream, buffer_nbytes) + frames.append(buffer) + except StreamClosedError as e: self.stream = None self._closed = True @@ -247,8 +245,6 @@ async def read(self, deserializers=None): raise else: try: - frames = unpack_frames(frames) - msg = await from_frames( frames, deserialize=self.deserialize, @@ -278,23 +274,10 @@ async def write(self, msg, serializers=None, on_error="message"): }, frame_split_size=self.max_shard_size, ) - frames_nbytes = [nbytes(f) for f in frames] - frames_nbytes_total = sum(frames_nbytes) - - header = pack_frames_prelude(frames) - header = struct.pack("Q", nbytes(header) + frames_nbytes_total) + header - - frames = [header, *frames] - frames_nbytes = [nbytes(header), *frames_nbytes] - frames_nbytes_total += frames_nbytes[0] - - if frames_nbytes_total < 2**17: # 128kiB - # small enough, send in one go - frames = [b"".join(frames)] - frames_nbytes = [frames_nbytes_total] + frames, frames_nbytes, frames_nbytes_total = _add_frames_header(frames) try: - # trick to enque all frames for writing beforehand + # trick to enqueue all frames for writing beforehand for each_frame_nbytes, each_frame in zip(frames_nbytes, frames): if each_frame_nbytes: # Make sure that `len(data) == data.nbytes` @@ -371,6 +354,79 @@ def extra_info(self): return self._extra +async def read_bytes_rw(stream: IOStream, n: int) -> memoryview: + """Read n bytes from stream. Unlike stream.read_bytes, allow for + very large messages and return a writeable buffer. + """ + buf = host_array(n) + + for i, j in sliding_window( + 2, + range(0, n + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE), + ): + chunk = buf[i:j] + chunk_nbytes = chunk.nbytes + n = await stream.read_into(chunk) # type: ignore[arg-type] + assert n == chunk_nbytes, (n, chunk_nbytes) + + return buf + + +def _add_frames_header( + frames: list[bytes | memoryview], +) -> tuple[list[bytes | memoryview], list[int], int]: + """ """ + frames_nbytes = [nbytes(f) for f in frames] + frames_nbytes_total = sum(frames_nbytes) + + # Calculate the number of bytes that are inclusive of: + # - prelude + # - msgpack header + # - simple pickle bytes + # - compressed buffers + # - first uncompressed buffer (possibly sharded), IFF the pickle bytes are + # negligible in size + # + # All these can be fetched by read() into a single buffer with a single call to + # Tornado, because they will be dereferenced soon after they are deserialized. + # Read uncompressed numpy/parquet buffers, which will survive indefinitely past + # the end of read(), into their own host arrays so that their memory can be + # released independently. + frames_nbytes_nosplit = 0 + first_uncompressed_buffer: object = None + for frame, nb in zip(frames, frames_nbytes): + buffer = frame.obj if isinstance(frame, memoryview) else frame + if not isinstance(buffer, bytes): + # Uncompressed buffer; it will be referenced by the unpickled object + if first_uncompressed_buffer is None: + if frames_nbytes_nosplit > max(2048, nb * 0.05): + # Don't extend the lifespan of non-trivial amounts of pickled bytes + # to that of the buffers + break + first_uncompressed_buffer = buffer + elif first_uncompressed_buffer is not buffer: # don't split sharded frame + # Always store 2+ separate numpy/parquet objects onto separate + # buffers + break + + frames_nbytes_nosplit += nb + + header = pack_frames_prelude(frames) + header = struct.pack("Q", nbytes(header) + frames_nbytes_nosplit) + header + header_nbytes = nbytes(header) + + frames = [header, *frames] + frames_nbytes = [header_nbytes, *frames_nbytes] + frames_nbytes_total += header_nbytes + + if frames_nbytes_total < 2**17: # 128kiB + # small enough, send in one go + frames = [b"".join(frames)] + frames_nbytes = [frames_nbytes_total] + + return frames, frames_nbytes, frames_nbytes_total + + class TLS(TCP): """ A TLS-specific version of TCP. diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 1dac2935e8e..ef06d4c8505 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -26,10 +26,11 @@ ) from distributed.comm.registry import backends, get_backend from distributed.comm.tcp import get_stream_address -from distributed.compatibility import asyncio_run +from distributed.compatibility import WINDOWS, asyncio_run from distributed.config import get_loop_factory from distributed.metrics import time from distributed.protocol import Serialized, deserialize, serialize, to_serialize +from distributed.protocol.utils_test import get_host_array from distributed.utils import get_ip, get_ipv6, get_mp_context, wait_for from distributed.utils_test import ( gen_test, @@ -1379,3 +1380,95 @@ def test_register_backend_entrypoint(tmp_path): with get_mp_context().Pool(1) as pool: assert pool.apply(_get_backend_on_path, args=(tmp_path,)) == 1 pool.join() + + +class OpaqueList(list): + """Don't let the serialization layer travese this object""" + + pass + + +@pytest.mark.parametrize( + "list_cls", + [ + list, # Use protocol.numpy.serialize_numpy_array / deserialize_numpy_array + OpaqueList, # Use generic pickle.dumps / pickle.loads + ], +) +@gen_test() +async def test_do_not_share_buffers(tcp, list_cls): + """Test that two objects with buffer interface in the same message do not share + their buffer upon deserialization + + See Also + -------- + test_share_buffer_with_header + test_ucx.py::test_do_not_share_buffers + """ + np = pytest.importorskip("numpy") + + async def handle_comm(comm): + msg = await comm.read() + msg["data"] = to_serialize(list_cls([np.array([1, 2]), np.array([3, 4])])) + await comm.write(msg) + await comm.close() + + listener = await tcp.TCPListener("127.0.0.1", handle_comm) + comm = await connect(listener.contact_address) + + await comm.write({"op": "ping"}) + msg = await comm.read() + await comm.close() + + a, b = msg["data"] + assert get_host_array(a) is not get_host_array(b) + + +@pytest.mark.parametrize( + "nbytes_np,nbytes_other,expect_separate_buffer", + [ + (1, 0, False), # <2 kiB (including prologue and msgpack header) + (1, 1800, False), # <2 kiB + (1, 2100, True), # >2 kiB + (200_000, 9500, False), # <5% of numpy array + (200_000, 10500, True), # >5% of numpy array + (350_000, 0, False), # sharded buffer + ], +) +@gen_test() +async def test_share_buffer_with_header( + tcp, nbytes_np, nbytes_other, expect_separate_buffer +): + """Test that a numpy or parquet object shares the buffer with its serialized header + to improve performance, but only as long as the header is trivial in size. + + See Also + -------- + test_do_not_share_buffers + """ + np = pytest.importorskip("numpy") + if tcp is asyncio_tcp and WINDOWS: + pytest.xfail("asyncio_tcp is faulty on windows") + + async def handle_comm(comm): + comm.max_shard_size = 250_000 + msg = await comm.read() + msg["np"] = to_serialize(np.random.randint(0, 256, nbytes_np, dtype="u1")) + msg["other"] = np.random.bytes(nbytes_other) + await comm.write(msg) + await comm.close() + + listener = await tcp.TCPListener("127.0.0.1", handle_comm) + comm = await connect(listener.contact_address) + + await comm.write({"op": "ping"}) + msg = await comm.read() + await comm.close() + + a = msg["np"] + ha = get_host_array(a) + if tcp is asyncio_tcp: + # TODO unimplemented optimization. Buffers are always split. + assert ha.nbytes == a.nbytes + else: + assert (ha.nbytes == a.nbytes) == expect_separate_buffer diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index aa5a2824eeb..a8435810afd 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -26,6 +26,7 @@ has_cuda_context, ) from distributed.protocol import to_serialize +from distributed.protocol.utils_test import get_host_array from distributed.utils_test import gen_test, inc try: @@ -432,3 +433,30 @@ async def test_embedded_cupy_array( x = da.from_array(a, chunks=(10000,)) b = await client.compute(x) cupy.testing.assert_array_equal(a, b) + + +@gen_test() +async def test_do_not_share_buffers(ucx_loop): + """Test that two objects with buffer interface in the same message do not share + their buffer upon deserialization. + + See Also + -------- + test_comms.py::test_do_not_share_buffers + """ + np = pytest.importorskip("numpy") + + com, serv_com = await get_comm_pair() + msg = {"data": to_serialize([np.array([1, 2]), np.array([3, 4])])} + + await com.write(msg) + result = await serv_com.read() + await com.close() + await serv_com.close() + + a, b = result["data"] + ha = get_host_array(a) + hb = get_host_array(b) + assert ha is not hb + assert ha.nbytes == a.nbytes + assert hb.nbytes == a.nbytes diff --git a/distributed/protocol/tests/test_protocol_utils.py b/distributed/protocol/tests/test_protocol_utils.py index 3d9b1b51df8..50380f86e78 100644 --- a/distributed/protocol/tests/test_protocol_utils.py +++ b/distributed/protocol/tests/test_protocol_utils.py @@ -2,7 +2,12 @@ import pytest -from distributed.protocol.utils import merge_memoryviews, pack_frames, unpack_frames +from distributed.protocol.utils import ( + merge_memoryviews, + pack_frames, + pack_frames_prelude, + unpack_frames, +) def test_pack_frames(): @@ -10,8 +15,41 @@ def test_pack_frames(): b = pack_frames(frames) assert isinstance(b, bytes) frames2 = unpack_frames(b) + assert frames2 == frames - assert frames == frames2 + +@pytest.mark.parametrize("extra", [b"456", b""]) +def test_unpack_frames_remainder(extra): + frames = [b"123", b"asdf"] + b = pack_frames(frames) + assert isinstance(b, bytes) + + frames2 = unpack_frames(b + extra) + assert frames2 == frames + + frames2 = unpack_frames(b + extra, remainder=True) + assert isinstance(frames2[-1], memoryview) + assert frames2 == frames + [extra] + + +def test_unpack_frames_partial(): + frames = [b"123", b"asdf"] + frames.insert(0, pack_frames_prelude(frames)) + + frames2, missing_lenghts = unpack_frames(b"".join(frames), partial=True) + assert frames2 == frames[1:] + assert missing_lenghts == [] + + frames2, missing_lenghts = unpack_frames(b"".join(frames[:-1]), partial=True) + assert frames2 == frames[1:-1] + assert missing_lenghts == [4] + + frames2, missing_lenghts = unpack_frames(frames[0], partial=True) + assert frames2 == [] + assert missing_lenghts == [3, 4] + + with pytest.raises(AssertionError): + unpack_frames(b"".join(frames[:-1])) class TestMergeMemroyviews: diff --git a/distributed/protocol/tests/test_utils_test.py b/distributed/protocol/tests/test_utils_test.py new file mode 100644 index 00000000000..c3861eea74c --- /dev/null +++ b/distributed/protocol/tests/test_utils_test.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import pytest + +from distributed.protocol.utils import host_array +from distributed.protocol.utils_test import get_host_array + + +def test_get_host_array(): + np = pytest.importorskip("numpy") + + a = np.array([1, 2, 3]) + assert get_host_array(a) is a + assert get_host_array(a[1:]) is a + assert get_host_array(a[1:][1:]) is a + + buf = host_array(3) + a = np.frombuffer(buf, dtype="u1") + assert get_host_array(a) is buf.obj + assert get_host_array(a[1:]) is buf.obj + a = np.frombuffer(buf[1:], dtype="u1") + assert get_host_array(a) is buf.obj + + a = np.frombuffer(bytearray(3), dtype="u1") + with pytest.raises(TypeError): + get_host_array(a) diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 1f42fe387e6..f7fa7a9984a 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -3,6 +3,7 @@ import ctypes import struct from collections.abc import Collection, Iterable, Sequence +from typing import Literal, overload import dask @@ -90,12 +91,44 @@ def pack_frames(frames: Collection[bytes | bytearray | memoryview]) -> bytes: return b"".join([pack_frames_prelude(frames), *frames]) -def unpack_frames(b): +@overload +def unpack_frames( + b: bytes | bytearray | memoryview, + *, + remainder: bool = False, + partial: Literal[False] = False, +) -> list[memoryview]: + ... + + +@overload +def unpack_frames( + b: bytes | bytearray | memoryview, + *, + remainder: bool = False, + partial: Literal[True], +) -> tuple[list[memoryview], list[int]]: + ... + + +def unpack_frames(b, *, remainder=False, partial=False): """Unpack bytes into a sequence of frames This assumes that length information is at the front of the bytestring, as performed by pack_frames + Parameters + ---------- + b: + packed frames, as returned by :func:`pack_frames` + remainder: + If True, return one extra frame at the end which is the continuation of a + stream created by concatenating multiple calls to :func:`pack_frames`. + This last frame will be empty at the end of the stream. + partial: + If True, allow for b to contain less frames than what the preamble indicates; + return a tuple of ([frames so far], [lengths of missing frames]) + See Also -------- pack_frames @@ -110,12 +143,26 @@ def unpack_frames(b): frames = [] start = fmt_size * (1 + n_frames) + nb = b.nbytes + end = 0 + missing_lengths = [] for length in lengths: + if partial and start == nb: + missing_lengths.extend(lengths[len(frames) :]) + break + end = start + length frames.append(b[start:end]) start = end + assert end <= nb - return frames + if remainder: + frames.append(b[start:]) + + if partial: + return frames, missing_lengths + else: + return frames def merge_memoryviews(mvs: Sequence[memoryview]) -> memoryview: diff --git a/distributed/protocol/utils_test.py b/distributed/protocol/utils_test.py new file mode 100644 index 00000000000..d72c98df372 --- /dev/null +++ b/distributed/protocol/utils_test.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy + + +def get_host_array(a: numpy.ndarray) -> numpy.ndarray: + """Given a numpy array, find the underlying memory allocated by either + distributed.protocol.utils.host_array or internally by numpy + """ + import numpy + + assert isinstance(a, numpy.ndarray) + o: object = a + while True: + if isinstance(o, memoryview): + o = o.obj + elif isinstance(o, numpy.ndarray): + if o.base is not None: + o = o.base + else: + return o + else: + # distributed.comm.utils.host_array() uses numpy.empty() + raise TypeError( + "Array uses a buffer allocated neither internally nor by host_array: " + f"{type(o)}" + ) From 0136a5b8aa39d9e4467dbd4a4e02d395bad6d101 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 3 Nov 2023 21:51:39 +0100 Subject: [PATCH 2/2] Don't shadow 'n' variable --- distributed/comm/tcp.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 8a2271f2bb2..79e8046743d 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -365,9 +365,8 @@ async def read_bytes_rw(stream: IOStream, n: int) -> memoryview: range(0, n + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE), ): chunk = buf[i:j] - chunk_nbytes = chunk.nbytes - n = await stream.read_into(chunk) # type: ignore[arg-type] - assert n == chunk_nbytes, (n, chunk_nbytes) + actual = await stream.read_into(chunk) # type: ignore[arg-type] + assert actual == chunk.nbytes return buf