Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 84 additions & 29 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,16 @@ async def read(self, deserializers=None):
fmt_size = struct.calcsize(fmt)

try:
frames_nbytes = await stream.read_bytes(fmt_size)
(frames_nbytes,) = struct.unpack(fmt, frames_nbytes)

frames = host_array(frames_nbytes)
for i, j in sliding_window(
2,
range(0, frames_nbytes + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE),
):
chunk = frames[i:j]
chunk_nbytes = chunk.nbytes
n = await stream.read_into(chunk)
assert n == chunk_nbytes, (n, chunk_nbytes)
# Don't store multiple numpy or parquet buffers into the same buffer, or

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: To my knowledge, we're never sending parquet over the network unless of course a user decides to do this themselves.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you are referring to pyarrow Table objects or anything that can be directly instantiated from a buffer

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong location for the comment?
parquet was a typo; i meant arrow.

# none will be released until all are released.
frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
(frames_nosplit_nbytes,) = struct.unpack(fmt, frames_nosplit_nbytes_bin)
frames_nosplit = await read_bytes_rw(stream, frames_nosplit_nbytes)
frames, buffers_nbytes = unpack_frames(frames_nosplit, partial=True)
for buffer_nbytes in buffers_nbytes:
buffer = await read_bytes_rw(stream, buffer_nbytes)
frames.append(buffer)

except StreamClosedError as e:
self.stream = None
self._closed = True
Expand All @@ -247,8 +245,6 @@ async def read(self, deserializers=None):
raise
else:
try:
frames = unpack_frames(frames)

msg = await from_frames(
frames,
deserialize=self.deserialize,
Expand Down Expand Up @@ -278,23 +274,10 @@ async def write(self, msg, serializers=None, on_error="message"):
},
frame_split_size=self.max_shard_size,
)
frames_nbytes = [nbytes(f) for f in frames]
frames_nbytes_total = sum(frames_nbytes)

header = pack_frames_prelude(frames)
header = struct.pack("Q", nbytes(header) + frames_nbytes_total) + header

frames = [header, *frames]
frames_nbytes = [nbytes(header), *frames_nbytes]
frames_nbytes_total += frames_nbytes[0]

if frames_nbytes_total < 2**17: # 128kiB
# small enough, send in one go
frames = [b"".join(frames)]
frames_nbytes = [frames_nbytes_total]
frames, frames_nbytes, frames_nbytes_total = _add_frames_header(frames)

try:
# trick to enque all frames for writing beforehand
# trick to enqueue all frames for writing beforehand
for each_frame_nbytes, each_frame in zip(frames_nbytes, frames):
if each_frame_nbytes:
# Make sure that `len(data) == data.nbytes`
Expand Down Expand Up @@ -371,6 +354,78 @@ def extra_info(self):
return self._extra


async def read_bytes_rw(stream: IOStream, n: int) -> memoryview:
"""Read n bytes from stream. Unlike stream.read_bytes, allow for
very large messages and return a writeable buffer.
"""
buf = host_array(n)

for i, j in sliding_window(
2,
range(0, n + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE),
):
chunk = buf[i:j]
actual = await stream.read_into(chunk) # type: ignore[arg-type]
assert actual == chunk.nbytes

return buf


def _add_frames_header(
frames: list[bytes | memoryview],
) -> tuple[list[bytes | memoryview], list[int], int]:
""" """
frames_nbytes = [nbytes(f) for f in frames]
frames_nbytes_total = sum(frames_nbytes)

# Calculate the number of bytes that are inclusive of:
# - prelude
# - msgpack header
# - simple pickle bytes
# - compressed buffers
# - first uncompressed buffer (possibly sharded), IFF the pickle bytes are
# negligible in size
#
# All these can be fetched by read() into a single buffer with a single call to
# Tornado, because they will be dereferenced soon after they are deserialized.
# Read uncompressed numpy/parquet buffers, which will survive indefinitely past
# the end of read(), into their own host arrays so that their memory can be
# released independently.
frames_nbytes_nosplit = 0
first_uncompressed_buffer: object = None
for frame, nb in zip(frames, frames_nbytes):
buffer = frame.obj if isinstance(frame, memoryview) else frame
if not isinstance(buffer, bytes):
# Uncompressed buffer; it will be referenced by the unpickled object
if first_uncompressed_buffer is None:
if frames_nbytes_nosplit > max(2048, nb * 0.05):
# Don't extend the lifespan of non-trivial amounts of pickled bytes
# to that of the buffers
break
first_uncompressed_buffer = buffer
elif first_uncompressed_buffer is not buffer: # don't split sharded frame
# Always store 2+ separate numpy/parquet objects onto separate
# buffers
break

frames_nbytes_nosplit += nb

header = pack_frames_prelude(frames)
header = struct.pack("Q", nbytes(header) + frames_nbytes_nosplit) + header
header_nbytes = nbytes(header)

frames = [header, *frames]
frames_nbytes = [header_nbytes, *frames_nbytes]
frames_nbytes_total += header_nbytes

if frames_nbytes_total < 2**17: # 128kiB

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the testing you did recently. Do you think this number still makes sense? Something to look into or not worth the effort?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks about right given my recent testing.

# small enough, send in one go
frames = [b"".join(frames)]
frames_nbytes = [frames_nbytes_total]

return frames, frames_nbytes, frames_nbytes_total


class TLS(TCP):
"""
A TLS-specific version of TCP.
Expand Down
95 changes: 94 additions & 1 deletion distributed/comm/tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
)
from distributed.comm.registry import backends, get_backend
from distributed.comm.tcp import get_stream_address
from distributed.compatibility import asyncio_run
from distributed.compatibility import WINDOWS, asyncio_run
from distributed.config import get_loop_factory
from distributed.metrics import time
from distributed.protocol import Serialized, deserialize, serialize, to_serialize
from distributed.protocol.utils_test import get_host_array
from distributed.utils import get_ip, get_ipv6, get_mp_context, wait_for
from distributed.utils_test import (
gen_test,
Expand Down Expand Up @@ -1379,3 +1380,95 @@ def test_register_backend_entrypoint(tmp_path):
with get_mp_context().Pool(1) as pool:
assert pool.apply(_get_backend_on_path, args=(tmp_path,)) == 1
pool.join()


class OpaqueList(list):
"""Don't let the serialization layer travese this object"""

pass


@pytest.mark.parametrize(
"list_cls",
[
list, # Use protocol.numpy.serialize_numpy_array / deserialize_numpy_array
OpaqueList, # Use generic pickle.dumps / pickle.loads
],
)
@gen_test()
async def test_do_not_share_buffers(tcp, list_cls):
"""Test that two objects with buffer interface in the same message do not share
their buffer upon deserialization

See Also
--------
test_share_buffer_with_header
test_ucx.py::test_do_not_share_buffers
"""
np = pytest.importorskip("numpy")

async def handle_comm(comm):
msg = await comm.read()
msg["data"] = to_serialize(list_cls([np.array([1, 2]), np.array([3, 4])]))
await comm.write(msg)
await comm.close()

listener = await tcp.TCPListener("127.0.0.1", handle_comm)
comm = await connect(listener.contact_address)

await comm.write({"op": "ping"})
msg = await comm.read()
await comm.close()

a, b = msg["data"]
assert get_host_array(a) is not get_host_array(b)


@pytest.mark.parametrize(
"nbytes_np,nbytes_other,expect_separate_buffer",
[
(1, 0, False), # <2 kiB (including prologue and msgpack header)
(1, 1800, False), # <2 kiB
(1, 2100, True), # >2 kiB
(200_000, 9500, False), # <5% of numpy array

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC the 9500 extra bytes here will be written to the same memory buffer the numpy array is using, i.e. those 9500 bytes will only be released once the numpy array is released. Similarly, the numpy array will only be released once the other thing has been released. However, that other thing is guaranteed to be a bytes object or some header information or some other garbage that is guaranteed to be released after the message is deserialized.

So, in other words, we're accepting a memory overhead of up to 5% for numpy arrays/arrow tables/etc. (and previously this could've been a multiple, depending on how large a single fetch was)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, in other words, we're accepting a memory overhead of up to 5% for numpy arrays/arrow tables/etc.

This is correct. This is really only material for pandas objects with substantial pure-python index / columns / other metadata; numpy objects tend to be <100 bytes worth of metadata.

(and previously this could've been a multiple, depending on how large a single fetch was)

It was worse than a multiple.
There were two nightmare scenarios:

  • pandas dataframe heavy with object string columns, with some numerical columns. The whole serialized data for the object columns remains alive for as long as the deserialized object is alive, because it's referenced by the numerical columns.
  • The key at the top of the WorkerStateMachine.fetch heap is a 49 MiB nump array. The second object in the heap from the same worker is a 1 MiB numpy array (or vice versa). The two are fetched together (distributed.worker.transfer.message-bytes-limit: 50 MiB). The 49 MiB array will survive its own free-keys command, as it is referenced by the 1 MiB array.

(200_000, 10500, True), # >5% of numpy array
(350_000, 0, False), # sharded buffer
],
)
@gen_test()
async def test_share_buffer_with_header(
tcp, nbytes_np, nbytes_other, expect_separate_buffer
):
"""Test that a numpy or parquet object shares the buffer with its serialized header
to improve performance, but only as long as the header is trivial in size.

See Also
--------
test_do_not_share_buffers
"""
np = pytest.importorskip("numpy")
if tcp is asyncio_tcp and WINDOWS:
pytest.xfail("asyncio_tcp is faulty on windows")

async def handle_comm(comm):
comm.max_shard_size = 250_000
msg = await comm.read()
msg["np"] = to_serialize(np.random.randint(0, 256, nbytes_np, dtype="u1"))
msg["other"] = np.random.bytes(nbytes_other)
await comm.write(msg)
await comm.close()

listener = await tcp.TCPListener("127.0.0.1", handle_comm)
comm = await connect(listener.contact_address)

await comm.write({"op": "ping"})
msg = await comm.read()
await comm.close()

a = msg["np"]
ha = get_host_array(a)
if tcp is asyncio_tcp:
# TODO unimplemented optimization. Buffers are always split.
assert ha.nbytes == a.nbytes
else:
assert (ha.nbytes == a.nbytes) == expect_separate_buffer
28 changes: 28 additions & 0 deletions distributed/comm/tests/test_ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
has_cuda_context,
)
from distributed.protocol import to_serialize
from distributed.protocol.utils_test import get_host_array
from distributed.utils_test import gen_test, inc

try:
Expand Down Expand Up @@ -432,3 +433,30 @@ async def test_embedded_cupy_array(
x = da.from_array(a, chunks=(10000,))
b = await client.compute(x)
cupy.testing.assert_array_equal(a, b)


@gen_test()
async def test_do_not_share_buffers(ucx_loop):
"""Test that two objects with buffer interface in the same message do not share
their buffer upon deserialization.

See Also
--------
test_comms.py::test_do_not_share_buffers
"""
np = pytest.importorskip("numpy")

com, serv_com = await get_comm_pair()
msg = {"data": to_serialize([np.array([1, 2]), np.array([3, 4])])}

await com.write(msg)
result = await serv_com.read()
await com.close()
await serv_com.close()

a, b = result["data"]
ha = get_host_array(a)
hb = get_host_array(b)
assert ha is not hb
assert ha.nbytes == a.nbytes
assert hb.nbytes == a.nbytes
42 changes: 40 additions & 2 deletions distributed/protocol/tests/test_protocol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,54 @@

import pytest

from distributed.protocol.utils import merge_memoryviews, pack_frames, unpack_frames
from distributed.protocol.utils import (
merge_memoryviews,
pack_frames,
pack_frames_prelude,
unpack_frames,
)


def test_pack_frames():
frames = [b"123", b"asdf"]
b = pack_frames(frames)
assert isinstance(b, bytes)
frames2 = unpack_frames(b)
assert frames2 == frames

assert frames == frames2

@pytest.mark.parametrize("extra", [b"456", b""])
def test_unpack_frames_remainder(extra):
frames = [b"123", b"asdf"]
b = pack_frames(frames)
assert isinstance(b, bytes)

frames2 = unpack_frames(b + extra)
assert frames2 == frames

frames2 = unpack_frames(b + extra, remainder=True)
assert isinstance(frames2[-1], memoryview)
assert frames2 == frames + [extra]


def test_unpack_frames_partial():
frames = [b"123", b"asdf"]
frames.insert(0, pack_frames_prelude(frames))

frames2, missing_lenghts = unpack_frames(b"".join(frames), partial=True)
assert frames2 == frames[1:]
assert missing_lenghts == []

frames2, missing_lenghts = unpack_frames(b"".join(frames[:-1]), partial=True)
assert frames2 == frames[1:-1]
assert missing_lenghts == [4]

frames2, missing_lenghts = unpack_frames(frames[0], partial=True)
assert frames2 == []
assert missing_lenghts == [3, 4]

with pytest.raises(AssertionError):
unpack_frames(b"".join(frames[:-1]))


class TestMergeMemroyviews:
Expand Down
26 changes: 26 additions & 0 deletions distributed/protocol/tests/test_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

import pytest

from distributed.protocol.utils import host_array
from distributed.protocol.utils_test import get_host_array


def test_get_host_array():
np = pytest.importorskip("numpy")

a = np.array([1, 2, 3])
assert get_host_array(a) is a
assert get_host_array(a[1:]) is a
assert get_host_array(a[1:][1:]) is a

buf = host_array(3)
a = np.frombuffer(buf, dtype="u1")
assert get_host_array(a) is buf.obj
assert get_host_array(a[1:]) is buf.obj
a = np.frombuffer(buf[1:], dtype="u1")
assert get_host_array(a) is buf.obj

a = np.frombuffer(bytearray(3), dtype="u1")
with pytest.raises(TypeError):
get_host_array(a)
Loading