-
-
Notifications
You must be signed in to change notification settings - Fork 762
Don't share host_array when receiving from network #8308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -220,18 +220,16 @@ async def read(self, deserializers=None): | |
| fmt_size = struct.calcsize(fmt) | ||
|
|
||
| try: | ||
| frames_nbytes = await stream.read_bytes(fmt_size) | ||
| (frames_nbytes,) = struct.unpack(fmt, frames_nbytes) | ||
|
|
||
| frames = host_array(frames_nbytes) | ||
| for i, j in sliding_window( | ||
| 2, | ||
| range(0, frames_nbytes + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE), | ||
| ): | ||
| chunk = frames[i:j] | ||
| chunk_nbytes = chunk.nbytes | ||
| n = await stream.read_into(chunk) | ||
| assert n == chunk_nbytes, (n, chunk_nbytes) | ||
| # Don't store multiple numpy or parquet buffers into the same buffer, or | ||
| # none will be released until all are released. | ||
| frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size) | ||
| (frames_nosplit_nbytes,) = struct.unpack(fmt, frames_nosplit_nbytes_bin) | ||
| frames_nosplit = await read_bytes_rw(stream, frames_nosplit_nbytes) | ||
| frames, buffers_nbytes = unpack_frames(frames_nosplit, partial=True) | ||
| for buffer_nbytes in buffers_nbytes: | ||
| buffer = await read_bytes_rw(stream, buffer_nbytes) | ||
| frames.append(buffer) | ||
|
|
||
| except StreamClosedError as e: | ||
| self.stream = None | ||
| self._closed = True | ||
|
|
@@ -247,8 +245,6 @@ async def read(self, deserializers=None): | |
| raise | ||
| else: | ||
| try: | ||
| frames = unpack_frames(frames) | ||
|
|
||
| msg = await from_frames( | ||
| frames, | ||
| deserialize=self.deserialize, | ||
|
|
@@ -278,23 +274,10 @@ async def write(self, msg, serializers=None, on_error="message"): | |
| }, | ||
| frame_split_size=self.max_shard_size, | ||
| ) | ||
| frames_nbytes = [nbytes(f) for f in frames] | ||
| frames_nbytes_total = sum(frames_nbytes) | ||
|
|
||
| header = pack_frames_prelude(frames) | ||
| header = struct.pack("Q", nbytes(header) + frames_nbytes_total) + header | ||
|
|
||
| frames = [header, *frames] | ||
| frames_nbytes = [nbytes(header), *frames_nbytes] | ||
| frames_nbytes_total += frames_nbytes[0] | ||
|
|
||
| if frames_nbytes_total < 2**17: # 128kiB | ||
| # small enough, send in one go | ||
| frames = [b"".join(frames)] | ||
| frames_nbytes = [frames_nbytes_total] | ||
| frames, frames_nbytes, frames_nbytes_total = _add_frames_header(frames) | ||
|
|
||
| try: | ||
| # trick to enque all frames for writing beforehand | ||
| # trick to enqueue all frames for writing beforehand | ||
| for each_frame_nbytes, each_frame in zip(frames_nbytes, frames): | ||
| if each_frame_nbytes: | ||
| # Make sure that `len(data) == data.nbytes` | ||
|
|
@@ -371,6 +354,78 @@ def extra_info(self): | |
| return self._extra | ||
|
|
||
|
|
||
| async def read_bytes_rw(stream: IOStream, n: int) -> memoryview: | ||
| """Read n bytes from stream. Unlike stream.read_bytes, allow for | ||
| very large messages and return a writeable buffer. | ||
| """ | ||
| buf = host_array(n) | ||
|
|
||
| for i, j in sliding_window( | ||
| 2, | ||
| range(0, n + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE), | ||
| ): | ||
| chunk = buf[i:j] | ||
| actual = await stream.read_into(chunk) # type: ignore[arg-type] | ||
| assert actual == chunk.nbytes | ||
|
|
||
| return buf | ||
|
|
||
|
|
||
| def _add_frames_header( | ||
| frames: list[bytes | memoryview], | ||
| ) -> tuple[list[bytes | memoryview], list[int], int]: | ||
| """ """ | ||
| frames_nbytes = [nbytes(f) for f in frames] | ||
| frames_nbytes_total = sum(frames_nbytes) | ||
|
|
||
| # Calculate the number of bytes that are inclusive of: | ||
| # - prelude | ||
| # - msgpack header | ||
| # - simple pickle bytes | ||
| # - compressed buffers | ||
| # - first uncompressed buffer (possibly sharded), IFF the pickle bytes are | ||
| # negligible in size | ||
| # | ||
| # All these can be fetched by read() into a single buffer with a single call to | ||
| # Tornado, because they will be dereferenced soon after they are deserialized. | ||
| # Read uncompressed numpy/parquet buffers, which will survive indefinitely past | ||
| # the end of read(), into their own host arrays so that their memory can be | ||
| # released independently. | ||
| frames_nbytes_nosplit = 0 | ||
| first_uncompressed_buffer: object = None | ||
| for frame, nb in zip(frames, frames_nbytes): | ||
| buffer = frame.obj if isinstance(frame, memoryview) else frame | ||
| if not isinstance(buffer, bytes): | ||
| # Uncompressed buffer; it will be referenced by the unpickled object | ||
| if first_uncompressed_buffer is None: | ||
| if frames_nbytes_nosplit > max(2048, nb * 0.05): | ||
| # Don't extend the lifespan of non-trivial amounts of pickled bytes | ||
| # to that of the buffers | ||
| break | ||
| first_uncompressed_buffer = buffer | ||
| elif first_uncompressed_buffer is not buffer: # don't split sharded frame | ||
| # Always store 2+ separate numpy/parquet objects onto separate | ||
| # buffers | ||
| break | ||
|
|
||
| frames_nbytes_nosplit += nb | ||
|
|
||
| header = pack_frames_prelude(frames) | ||
| header = struct.pack("Q", nbytes(header) + frames_nbytes_nosplit) + header | ||
| header_nbytes = nbytes(header) | ||
|
|
||
| frames = [header, *frames] | ||
| frames_nbytes = [header_nbytes, *frames_nbytes] | ||
| frames_nbytes_total += header_nbytes | ||
|
|
||
| if frames_nbytes_total < 2**17: # 128kiB | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the testing you did recently. Do you think this number still makes sense? Something to look into or not worth the effort?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks about right given my recent testing. |
||
| # small enough, send in one go | ||
| frames = [b"".join(frames)] | ||
| frames_nbytes = [frames_nbytes_total] | ||
|
|
||
| return frames, frames_nbytes, frames_nbytes_total | ||
|
|
||
|
|
||
| class TLS(TCP): | ||
| """ | ||
| A TLS-specific version of TCP. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,10 +26,11 @@ | |
| ) | ||
| from distributed.comm.registry import backends, get_backend | ||
| from distributed.comm.tcp import get_stream_address | ||
| from distributed.compatibility import asyncio_run | ||
| from distributed.compatibility import WINDOWS, asyncio_run | ||
| from distributed.config import get_loop_factory | ||
| from distributed.metrics import time | ||
| from distributed.protocol import Serialized, deserialize, serialize, to_serialize | ||
| from distributed.protocol.utils_test import get_host_array | ||
| from distributed.utils import get_ip, get_ipv6, get_mp_context, wait_for | ||
| from distributed.utils_test import ( | ||
| gen_test, | ||
|
|
@@ -1379,3 +1380,95 @@ def test_register_backend_entrypoint(tmp_path): | |
| with get_mp_context().Pool(1) as pool: | ||
| assert pool.apply(_get_backend_on_path, args=(tmp_path,)) == 1 | ||
| pool.join() | ||
|
|
||
|
|
||
| class OpaqueList(list): | ||
| """Don't let the serialization layer travese this object""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "list_cls", | ||
| [ | ||
| list, # Use protocol.numpy.serialize_numpy_array / deserialize_numpy_array | ||
| OpaqueList, # Use generic pickle.dumps / pickle.loads | ||
| ], | ||
| ) | ||
| @gen_test() | ||
| async def test_do_not_share_buffers(tcp, list_cls): | ||
| """Test that two objects with buffer interface in the same message do not share | ||
| their buffer upon deserialization | ||
|
|
||
| See Also | ||
| -------- | ||
| test_share_buffer_with_header | ||
| test_ucx.py::test_do_not_share_buffers | ||
| """ | ||
| np = pytest.importorskip("numpy") | ||
|
|
||
| async def handle_comm(comm): | ||
| msg = await comm.read() | ||
| msg["data"] = to_serialize(list_cls([np.array([1, 2]), np.array([3, 4])])) | ||
| await comm.write(msg) | ||
| await comm.close() | ||
|
|
||
| listener = await tcp.TCPListener("127.0.0.1", handle_comm) | ||
| comm = await connect(listener.contact_address) | ||
|
|
||
| await comm.write({"op": "ping"}) | ||
| msg = await comm.read() | ||
| await comm.close() | ||
|
|
||
| a, b = msg["data"] | ||
| assert get_host_array(a) is not get_host_array(b) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "nbytes_np,nbytes_other,expect_separate_buffer", | ||
| [ | ||
| (1, 0, False), # <2 kiB (including prologue and msgpack header) | ||
| (1, 1800, False), # <2 kiB | ||
| (1, 2100, True), # >2 kiB | ||
| (200_000, 9500, False), # <5% of numpy array | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC the 9500 extra bytes here will be written to the same memory buffer the numpy array is using, i.e. those 9500 bytes will only be released once the numpy array is released. Similarly, the numpy array will only be released once the other thing has been released. However, that other thing is guaranteed to be a bytes object or some header information or some other garbage that is guaranteed to be released after the message is deserialized. So, in other words, we're accepting a memory overhead of up to 5% for numpy arrays/arrow tables/etc. (and previously this could've been a multiple, depending on how large a single fetch was)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is correct. This is really only material for pandas objects with substantial pure-python index / columns / other metadata; numpy objects tend to be <100 bytes worth of metadata.
It was worse than a multiple.
|
||
| (200_000, 10500, True), # >5% of numpy array | ||
| (350_000, 0, False), # sharded buffer | ||
| ], | ||
| ) | ||
| @gen_test() | ||
| async def test_share_buffer_with_header( | ||
| tcp, nbytes_np, nbytes_other, expect_separate_buffer | ||
| ): | ||
| """Test that a numpy or parquet object shares the buffer with its serialized header | ||
| to improve performance, but only as long as the header is trivial in size. | ||
|
|
||
| See Also | ||
| -------- | ||
| test_do_not_share_buffers | ||
| """ | ||
| np = pytest.importorskip("numpy") | ||
| if tcp is asyncio_tcp and WINDOWS: | ||
| pytest.xfail("asyncio_tcp is faulty on windows") | ||
|
|
||
| async def handle_comm(comm): | ||
| comm.max_shard_size = 250_000 | ||
| msg = await comm.read() | ||
| msg["np"] = to_serialize(np.random.randint(0, 256, nbytes_np, dtype="u1")) | ||
| msg["other"] = np.random.bytes(nbytes_other) | ||
| await comm.write(msg) | ||
| await comm.close() | ||
|
|
||
| listener = await tcp.TCPListener("127.0.0.1", handle_comm) | ||
| comm = await connect(listener.contact_address) | ||
|
|
||
| await comm.write({"op": "ping"}) | ||
| msg = await comm.read() | ||
| await comm.close() | ||
|
|
||
| a = msg["np"] | ||
| ha = get_host_array(a) | ||
| if tcp is asyncio_tcp: | ||
| # TODO unimplemented optimization. Buffers are always split. | ||
| assert ha.nbytes == a.nbytes | ||
| else: | ||
| assert (ha.nbytes == a.nbytes) == expect_separate_buffer | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import pytest | ||
|
|
||
| from distributed.protocol.utils import host_array | ||
| from distributed.protocol.utils_test import get_host_array | ||
|
|
||
|
|
||
| def test_get_host_array(): | ||
| np = pytest.importorskip("numpy") | ||
|
|
||
| a = np.array([1, 2, 3]) | ||
| assert get_host_array(a) is a | ||
| assert get_host_array(a[1:]) is a | ||
| assert get_host_array(a[1:][1:]) is a | ||
|
|
||
| buf = host_array(3) | ||
| a = np.frombuffer(buf, dtype="u1") | ||
| assert get_host_array(a) is buf.obj | ||
| assert get_host_array(a[1:]) is buf.obj | ||
| a = np.frombuffer(buf[1:], dtype="u1") | ||
| assert get_host_array(a) is buf.obj | ||
|
|
||
| a = np.frombuffer(bytearray(3), dtype="u1") | ||
| with pytest.raises(TypeError): | ||
| get_host_array(a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: To my knowledge, we're never sending parquet over the network unless of course a user decides to do this themselves.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you are referring to pyarrow Table objects or anything that can be directly instantiated from a buffer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrong location for the comment?
parquet was a typo; i meant arrow.