Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
8e3d41f
Move `ensure_memoryview` to `distributed.utils`
jakirkham May 5, 2022
20b3f9f
Replace `ret` with `mv` for clarity
jakirkham May 5, 2022
6fd7179
Coerce `obj` to `memoryview` only if needed
jakirkham May 5, 2022
007c8d1
Require `contiguous` data for `.cast("B")`
jakirkham May 5, 2022
ec4220c
Copy to `bytes` first in non-contiguous case
jakirkham May 5, 2022
0a63b8d
Shortcut trivial `memoryview` case
jakirkham May 5, 2022
5fa88cf
Fill out docstring & add comments
jakirkham May 5, 2022
e932c42
Drop unused conversion of `min_size`
jakirkham May 5, 2022
c3297f7
Join `if` cases together
jakirkham May 5, 2022
288bdb4
Use `nbytes` to get `payload` size
jakirkham May 5, 2022
7871181
Consolidate payload size checks
jakirkham May 5, 2022
e8672d3
Use `memoryview` of `payload`
jakirkham May 5, 2022
d5b9c69
Add tests of `ensure_memoryview`
jakirkham May 5, 2022
8a8125d
Test compression with `memoryview`
jakirkham May 5, 2022
bbed121
Drop blank line & add a comment
jakirkham May 5, 2022
2658a4e
Unwrap comment
jakirkham May 5, 2022
c4299c1
Test empty `bytes` with `memoryview`
jakirkham May 5, 2022
6ef14e7
Coerce `b` to `memoryview` to avoid copies
jakirkham May 5, 2022
e74d3c8
Add blank line
jakirkham May 5, 2022
f99b523
Special case fewer `parts` in `byte_sample`
jakirkham May 5, 2022
e4500b9
Shortcut `n == 0` in `byte_sample` sooner
jakirkham May 5, 2022
7c801e5
Always return a `memoryview` from `byte_sample`
jakirkham May 5, 2022
2a26502
Nest compressibility checks
jakirkham May 5, 2022
d3064f2
Lighten up on comments around compressibility
jakirkham May 5, 2022
2f71f88
Use `;` instead of `,` in comment
jakirkham May 5, 2022
17677b8
Use `.nbytes` with `memoryview`s
jakirkham May 5, 2022
28d45a3
Clarify `ensure_memoryview` cases in comments
jakirkham May 5, 2022
4fd06a1
Also fast path `size == 0`
jakirkham May 5, 2022
afb1a99
`assert` both `size` & `n` are well behaved
jakirkham May 5, 2022
fc35a13
Use `islice` with `starts`
jakirkham May 5, 2022
689c8a5
From `random` just `import` `randint`
jakirkham May 5, 2022
a25fb6c
Fast path `not compression` case
jakirkham May 5, 2022
72dbad0
Normalize args after size check
jakirkham May 5, 2022
3fd096a
Use `mv.nbytes` in compression check
jakirkham May 5, 2022
2dab947
Consolidate size check code
jakirkham May 5, 2022
d02a4e9
Consolidate `size` & `n` handling
jakirkham May 5, 2022
dcb8c60
Compute largest `start` once
jakirkham May 5, 2022
63d94e0
Consolidate fast paths
jakirkham May 5, 2022
5d6aa1c
Simplify final comment
jakirkham May 5, 2022
cf760cc
Fuse loops in `byte_sample` to make `parts`
jakirkham May 6, 2022
8612caa
Set `start` to `next_start` at end
jakirkham May 6, 2022
883f43f
Tidy comments
jakirkham May 6, 2022
8de6296
Tweak wording
jakirkham May 6, 2022
e4e86bc
Also note `shape` change in comment
jakirkham May 6, 2022
57df6fd
Clarify `size` given sample selection behavior
jakirkham May 6, 2022
3aded1e
Tweak comment
jakirkham May 6, 2022
8aaf04d
Fix comparisons
jakirkham May 6, 2022
dc04e4c
Shorten docstring in `ensure_memoryview`
jakirkham May 6, 2022
75213a4
Preallocate `parts` to match intended size
jakirkham May 6, 2022
a2d891a
Just use `int`s for `min_size` & `sample_size`
jakirkham May 6, 2022
8661564
Call `x.tobytes()` once and assign it
jakirkham May 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 44 additions & 38 deletions distributed/protocol/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
from __future__ import annotations

import logging
import random
from collections.abc import Callable
from contextlib import suppress
from random import randint
from typing import Literal

from packaging.version import parse as parse_version
from tlz import identity

import dask

from distributed.utils import ensure_bytes
from distributed.utils import ensure_memoryview, nbytes

compressions: dict[
str | None | Literal[False],
Expand Down Expand Up @@ -120,24 +120,37 @@ def byte_sample(b, size, n):
----------
b : bytes or memoryview
size : int
size of each sample to collect
target size of each sample to collect
(may be smaller if samples collide)
n : int
number of samples to collect
"""
starts = [random.randint(0, len(b) - size) for j in range(n)]
ends = []
for i, start in enumerate(starts[:-1]):
ends.append(min(start + size, starts[i + 1]))
ends.append(starts[-1] + size)

parts = [b[start:end] for start, end in zip(starts, ends)]
return b"".join(map(ensure_bytes, parts))
assert size >= 0 and n >= 0
if size == 0 or n == 0:
return memoryview(b"")

b = ensure_memoryview(b)

parts = n * [None]
max_start = b.nbytes - size
start = randint(0, max_start)
for i in range(n - 1):
next_start = randint(0, max_start)
end = min(start + size, next_start)
parts[i] = b[start:end]
start = next_start
parts[-1] = b[start : start + size]

if n == 1:
return parts[0]
Comment thread
jakirkham marked this conversation as resolved.
else:
return memoryview(b"".join(parts))


def maybe_compress(
payload,
min_size=1e4,
sample_size=1e4,
min_size=10_000,
sample_size=10_000,
nsamples=5,
compression=dask.config.get("distributed.comm.compression"),
):
Expand All @@ -151,37 +164,30 @@ def maybe_compress(
return the original
4. We return the compressed result
"""
if compression == "auto":
compression = default_compression

if not compression:
return None, payload
if len(payload) < min_size:
return None, payload
if len(payload) > 2**31: # Too large, compression libraries often fail
if not (min_size <= nbytes(payload) <= 2**31):
# Either too small to bother
# or too large (compression libraries often fail)
return None, payload

min_size = int(min_size)
sample_size = int(sample_size)

# Normalize function arguments
if compression == "auto":
compression = default_compression
compress = compressions[compression]["compress"]

# Compress a sample, return original if not very compressed
sample = byte_sample(payload, sample_size, nsamples)
if len(compress(sample)) > 0.9 * len(sample): # sample not very compressible
return None, payload

if type(payload) is memoryview:
nbytes = payload.itemsize * len(payload)
else:
nbytes = len(payload)

compressed = compress(ensure_bytes(payload))

if len(compressed) > 0.9 * nbytes: # full data not very compressible
return None, payload
else:
return compression, compressed
# Take a view of payload for efficient usage
mv = ensure_memoryview(payload)

# Try compressing a sample to see if it compresses well
sample = byte_sample(mv, sample_size, nsamples)
if len(compress(sample)) <= 0.9 * sample.nbytes:
# Try compressing the real thing and check how compressed it is
compressed = compress(mv)
if len(compressed) <= 0.9 * mv.nbytes:
return compression, compressed
# Skip compression as the sample or the data didn't compress well
return None, payload


def decompress(header, frames):
Expand Down
10 changes: 1 addition & 9 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,11 @@
serialize_and_split,
)
from distributed.protocol.utils import msgpack_opts
from distributed.utils import ensure_memoryview

logger = logging.getLogger(__name__)


def ensure_memoryview(obj):
"""Ensure `obj` is a memoryview of datatype bytes"""
ret = memoryview(obj)
if ret.nbytes:
return ret.cast("B")
else:
return ret


def dumps(
msg, serializers=None, on_error="message", context=None, frame_split_size=None
) -> list:
Expand Down
19 changes: 15 additions & 4 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,32 @@ def test_compression_1():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.ones(1000000)
frames = dumps({"x": Serialize(x.tobytes())})
assert sum(map(nbytes, frames)) < x.nbytes
b = x.tobytes()
frames = dumps({"x": Serialize(b)})
assert sum(map(nbytes, frames)) < nbytes(b)
y = loads(frames)
assert {"x": x.tobytes()} == y
assert {"x": b} == y


def test_compression_2():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.random.random(10000)
msg = dumps(to_serialize(x.tobytes()))
msg = dumps(to_serialize(x.data))
compression = msgpack.loads(msg[1]).get("compression")
assert all(c is None for c in compression)


def test_compression_3():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.ones(1000000)
frames = dumps({"x": Serialize(x.data)})
assert sum(map(nbytes, frames)) < x.nbytes
y = loads(frames)
assert {"x": x.data} == y


def test_compression_without_deserialization():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
Expand Down
31 changes: 31 additions & 0 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_maybe_complex,
ensure_bytes,
ensure_ip,
ensure_memoryview,
format_dashboard_link,
get_ip_interface,
get_traceback,
Expand Down Expand Up @@ -269,6 +270,36 @@ def test_ensure_bytes_pyarrow_buffer():
assert isinstance(result, bytes)


def test_ensure_memoryview_empty():
result = ensure_memoryview(b"")
assert isinstance(result, memoryview)
assert result == memoryview(b"")


def test_ensure_memoryview():
data = [b"1", memoryview(b"1"), bytearray(b"1"), array.array("b", [49])]
for d in data:
result = ensure_memoryview(d)
assert isinstance(result, memoryview)
assert result == memoryview(b"1")


def test_ensure_memoryview_ndarray():
np = pytest.importorskip("numpy")
result = ensure_memoryview(np.arange(12).reshape(3, 4)[:, ::2].T)
assert isinstance(result, memoryview)
assert result.ndim == 1
assert result.format == "B"
assert result.contiguous


def test_ensure_memoryview_pyarrow_buffer():
pa = pytest.importorskip("pyarrow")
buf = pa.py_buffer(b"123")
result = ensure_memoryview(buf)
assert isinstance(result, memoryview)


def test_nbytes():
np = pytest.importorskip("numpy")

Expand Down
19 changes: 19 additions & 0 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,25 @@ def ensure_bytes(s):
) from e


def ensure_memoryview(obj):
"""Ensure `obj` is a 1-D contiguous `uint8` `memoryview`"""
mv: memoryview
if type(obj) is memoryview:
mv = obj
else:
mv = memoryview(obj)

if not mv.nbytes:
# Drop `obj` reference to permit freeing underlying data
return memoryview(b"")
Comment thread
martindurant marked this conversation as resolved.
elif mv.contiguous:
# Perform zero-copy reshape & cast
return mv.cast("B")
else:
# Copy to contiguous form of expected shape & type
return memoryview(mv.tobytes())


def open_port(host=""):
"""Return a probably-open port

Expand Down