diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index 5995d1d9d51..5cbc21445c7 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -6,9 +6,9 @@ from __future__ import annotations import logging -import random from collections.abc import Callable from contextlib import suppress +from random import randint from typing import Literal from packaging.version import parse as parse_version @@ -16,7 +16,7 @@ import dask -from distributed.utils import ensure_bytes +from distributed.utils import ensure_memoryview, nbytes compressions: dict[ str | None | Literal[False], @@ -120,24 +120,37 @@ def byte_sample(b, size, n): ---------- b : bytes or memoryview size : int - size of each sample to collect + target size of each sample to collect + (may be smaller if samples collide) n : int number of samples to collect """ - starts = [random.randint(0, len(b) - size) for j in range(n)] - ends = [] - for i, start in enumerate(starts[:-1]): - ends.append(min(start + size, starts[i + 1])) - ends.append(starts[-1] + size) - - parts = [b[start:end] for start, end in zip(starts, ends)] - return b"".join(map(ensure_bytes, parts)) + assert size >= 0 and n >= 0 + if size == 0 or n == 0: + return memoryview(b"") + + b = ensure_memoryview(b) + + parts = n * [None] + max_start = b.nbytes - size + start = randint(0, max_start) + for i in range(n - 1): + next_start = randint(0, max_start) + end = min(start + size, next_start) + parts[i] = b[start:end] + start = next_start + parts[-1] = b[start : start + size] + + if n == 1: + return parts[0] + else: + return memoryview(b"".join(parts)) def maybe_compress( payload, - min_size=1e4, - sample_size=1e4, + min_size=10_000, + sample_size=10_000, nsamples=5, compression=dask.config.get("distributed.comm.compression"), ): @@ -151,37 +164,30 @@ def maybe_compress( return the original 4. We return the compressed result """ - if compression == "auto": - compression = default_compression - if not compression: return None, payload - if len(payload) < min_size: - return None, payload - if len(payload) > 2**31: # Too large, compression libraries often fail + if not (min_size <= nbytes(payload) <= 2**31): + # Either too small to bother + # or too large (compression libraries often fail) return None, payload - min_size = int(min_size) - sample_size = int(sample_size) - + # Normalize function arguments + if compression == "auto": + compression = default_compression compress = compressions[compression]["compress"] - # Compress a sample, return original if not very compressed - sample = byte_sample(payload, sample_size, nsamples) - if len(compress(sample)) > 0.9 * len(sample): # sample not very compressible - return None, payload - - if type(payload) is memoryview: - nbytes = payload.itemsize * len(payload) - else: - nbytes = len(payload) - - compressed = compress(ensure_bytes(payload)) - - if len(compressed) > 0.9 * nbytes: # full data not very compressible - return None, payload - else: - return compression, compressed + # Take a view of payload for efficient usage + mv = ensure_memoryview(payload) + + # Try compressing a sample to see if it compresses well + sample = byte_sample(mv, sample_size, nsamples) + if len(compress(sample)) <= 0.9 * sample.nbytes: + # Try compressing the real thing and check how compressed it is + compressed = compress(mv) + if len(compressed) <= 0.9 * mv.nbytes: + return compression, compressed + # Skip compression as the sample or the data didn't compress well + return None, payload def decompress(header, frames): diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 0e0ae003b5f..db063922794 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -17,19 +17,11 @@ serialize_and_split, ) from distributed.protocol.utils import msgpack_opts +from distributed.utils import ensure_memoryview logger = logging.getLogger(__name__) -def ensure_memoryview(obj): - """Ensure `obj` is a memoryview of datatype bytes""" - ret = memoryview(obj) - if ret.nbytes: - return ret.cast("B") - else: - return ret - - def dumps( msg, serializers=None, on_error="message", context=None, frame_split_size=None ) -> list: diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index b2d0abd64fd..8c281003a76 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -48,21 +48,32 @@ def test_compression_1(): pytest.importorskip("lz4") np = pytest.importorskip("numpy") x = np.ones(1000000) - frames = dumps({"x": Serialize(x.tobytes())}) - assert sum(map(nbytes, frames)) < x.nbytes + b = x.tobytes() + frames = dumps({"x": Serialize(b)}) + assert sum(map(nbytes, frames)) < nbytes(b) y = loads(frames) - assert {"x": x.tobytes()} == y + assert {"x": b} == y def test_compression_2(): pytest.importorskip("lz4") np = pytest.importorskip("numpy") x = np.random.random(10000) - msg = dumps(to_serialize(x.tobytes())) + msg = dumps(to_serialize(x.data)) compression = msgpack.loads(msg[1]).get("compression") assert all(c is None for c in compression) +def test_compression_3(): + pytest.importorskip("lz4") + np = pytest.importorskip("numpy") + x = np.ones(1000000) + frames = dumps({"x": Serialize(x.data)}) + assert sum(map(nbytes, frames)) < x.nbytes + y = loads(frames) + assert {"x": x.data} == y + + def test_compression_without_deserialization(): pytest.importorskip("lz4") np = pytest.importorskip("numpy") diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 3358dbc1907..f50bd2c0081 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -28,6 +28,7 @@ _maybe_complex, ensure_bytes, ensure_ip, + ensure_memoryview, format_dashboard_link, get_ip_interface, get_traceback, @@ -269,6 +270,36 @@ def test_ensure_bytes_pyarrow_buffer(): assert isinstance(result, bytes) +def test_ensure_memoryview_empty(): + result = ensure_memoryview(b"") + assert isinstance(result, memoryview) + assert result == memoryview(b"") + + +def test_ensure_memoryview(): + data = [b"1", memoryview(b"1"), bytearray(b"1"), array.array("b", [49])] + for d in data: + result = ensure_memoryview(d) + assert isinstance(result, memoryview) + assert result == memoryview(b"1") + + +def test_ensure_memoryview_ndarray(): + np = pytest.importorskip("numpy") + result = ensure_memoryview(np.arange(12).reshape(3, 4)[:, ::2].T) + assert isinstance(result, memoryview) + assert result.ndim == 1 + assert result.format == "B" + assert result.contiguous + + +def test_ensure_memoryview_pyarrow_buffer(): + pa = pytest.importorskip("pyarrow") + buf = pa.py_buffer(b"123") + result = ensure_memoryview(buf) + assert isinstance(result, memoryview) + + def test_nbytes(): np = pytest.importorskip("numpy") diff --git a/distributed/utils.py b/distributed/utils.py index 40116f6b01b..475ea836aa5 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1013,6 +1013,25 @@ def ensure_bytes(s): ) from e +def ensure_memoryview(obj): + """Ensure `obj` is a 1-D contiguous `uint8` `memoryview`""" + mv: memoryview + if type(obj) is memoryview: + mv = obj + else: + mv = memoryview(obj) + + if not mv.nbytes: + # Drop `obj` reference to permit freeing underlying data + return memoryview(b"") + elif mv.contiguous: + # Perform zero-copy reshape & cast + return mv.cast("B") + else: + # Copy to contiguous form of expected shape & type + return memoryview(mv.tobytes()) + + def open_port(host=""): """Return a probably-open port