diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 497288ff4a7..2db267f111c 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -19,12 +19,8 @@ from distributed.comm.addressing import parse_host_port, unparse_host_port from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector from distributed.comm.registry import Backend -from distributed.comm.utils import ( - ensure_concrete_host, - from_frames, - host_array, - to_frames, -) +from distributed.comm.utils import ensure_concrete_host, from_frames, to_frames +from distributed.protocol.utils import host_array from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6 logger = logging.getLogger(__name__) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index d2da61aa287..e3e8029e7a1 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -35,10 +35,9 @@ ensure_concrete_host, from_frames, get_tcp_server_address, - host_array, to_frames, ) -from distributed.protocol.utils import pack_frames_prelude, unpack_frames +from distributed.protocol.utils import host_array, pack_frames_prelude, unpack_frames from distributed.system import MEMORY_LIMIT from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6, nbytes diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 674d2f7e49d..ecb6eb081ff 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -22,17 +22,13 @@ from distributed.comm.addressing import parse_host_port, unparse_host_port from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector from distributed.comm.registry import Backend, backends -from distributed.comm.utils import ( - ensure_concrete_host, - from_frames, - host_array, - to_frames, -) +from distributed.comm.utils import ensure_concrete_host, from_frames, to_frames from distributed.diagnostics.nvml import ( CudaDeviceInfo, get_device_index_and_uuid, has_cuda_context, ) +from distributed.protocol.utils import host_array from distributed.utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes logger = logging.getLogger(__name__) diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index d1b2d2791ee..e91ddc2c839 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -19,26 +19,6 @@ OFFLOAD_THRESHOLD = parse_bytes(OFFLOAD_THRESHOLD) -# Find the function, `host_array()`, to use when allocating new host arrays -try: - # Use NumPy, when available, to avoid memory initialization cost. - # A `bytearray` is zero-initialized using `calloc`, which we don't need. - # `np.empty` both skips the zero-initialization, and - # uses hugepages when available ( https://github.com/numpy/numpy/pull/14216 ). - import numpy - - def numpy_host_array(n: int) -> memoryview: - return numpy.empty((n,), dtype="u1").data - - host_array = numpy_host_array -except ImportError: - - def builtin_host_array(n: int) -> memoryview: - return memoryview(bytearray(n)) - - host_array = builtin_host_array - - async def to_frames( msg, allow_offload=True, diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index f25a2ce3c77..98426332fed 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -21,6 +21,7 @@ from distributed.protocol.compression import decompress, maybe_compress from distributed.protocol.utils import ( frame_split_size, + host_array_from_buffers, merge_memoryviews, msgpack_opts, pack_frames_prelude, @@ -504,7 +505,7 @@ def merge_and_deserialize(header, frames, deserializers=None): try: merged = merge_memoryviews(subframes) except (ValueError, TypeError): - merged = bytearray().join(subframes) + merged = host_array_from_buffers(subframes) merged_frames.append(merged) diff --git a/distributed/protocol/tests/__init__.py b/distributed/protocol/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/distributed/protocol/tests/test_utils.py b/distributed/protocol/tests/test_utils.py new file mode 100644 index 00000000000..b67b650f97b --- /dev/null +++ b/distributed/protocol/tests/test_utils.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import pytest + +from distributed.protocol.utils import host_array, host_array_from_buffers + + +def test_host_array(): + a = host_array(5) + a[:3] = b"abc" + a[3:] = b"de" + assert bytes(a) == b"abcde" + + +def test_host_array_from_buffers(): + a = host_array_from_buffers([b"abc", b"de"]) + a[:1] = b"f" + assert bytes(a) == b"fbcde" + + +def test_host_array_from_buffers_numpy(): + """Test for word sizes larger than 1 byte""" + np = pytest.importorskip("numpy") + a = host_array_from_buffers( + [np.array([1, 2], dtype="u1"), np.array([3, 4], dtype="u8")] + ) + assert a.nbytes == 18 diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 201a7e3da13..1f42fe387e6 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -2,7 +2,7 @@ import ctypes import struct -from collections.abc import Collection, Sequence +from collections.abc import Collection, Iterable, Sequence import dask @@ -18,6 +18,35 @@ msgpack_opts["raw"] = False +# Find the function, `host_array()`, to use when allocating new host arrays +try: + # Use NumPy, when available, to avoid memory initialization cost. + # A `bytearray` is zero-initialized using `calloc`, which we don't need. + # `np.empty` both skips the zero-initialization, and + # uses hugepages when available ( https://github.com/numpy/numpy/pull/14216 ). + import numpy + + def host_array(n: int) -> memoryview: + return numpy.empty((n,), dtype="u1").data + +except ImportError: + + def host_array(n: int) -> memoryview: + return memoryview(bytearray(n)) + + +def host_array_from_buffers( + buffers: Iterable[bytes | bytearray | memoryview], +) -> memoryview: + mvs = [memoryview(buf) for buf in buffers] + out = host_array(sum(mv.nbytes for mv in mvs)) + offset = 0 + for mv in mvs: + out[offset : offset + mv.nbytes] = mv.cast("B") + offset += mv.nbytes + return out + + def frame_split_size( frame: bytes | memoryview, n: int = BIG_BYTES_SHARD_SIZE ) -> list[memoryview]: