diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index f9bc91cf423..082b274a483 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -23,7 +23,7 @@ pack_frames_prelude, unpack_frames, ) -from distributed.utils import has_keyword +from distributed.utils import ensure_memoryview, has_keyword dask_serialize = dask.utils.Dispatch("dask_serialize") dask_deserialize = dask.utils.Dispatch("dask_deserialize") @@ -765,12 +765,11 @@ def _serialize_array(obj): @dask_deserialize.register(array) def _deserialize_array(header, frames): a = array(header["typecode"]) - for f in map(memoryview, frames): - try: - f = f.cast("B") - except TypeError: - f = f.tobytes() - a.frombytes(f) + nframes = len(frames) + if nframes == 1: + a.frombytes(ensure_memoryview(frames[0])) + elif nframes > 1: + a.frombytes(b"".join(map(ensure_memoryview, frames))) return a diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 8c4eb51c02b..27368393b7d 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -32,7 +32,7 @@ to_serialize, ) from distributed.protocol.serialize import check_dask_serializable -from distributed.utils import nbytes +from distributed.utils import ensure_memoryview, nbytes from distributed.utils_test import gen_test, inc @@ -88,12 +88,28 @@ def test_serialize_bytestrings(): assert bb == b +def test_serialize_empty_array(): + a = array("I") + + # serialize array + header, frames = serialize(a) + assert frames[0] == memoryview(a) + # drop empty frame + del frames[:] + # deserialize with no frames + a2 = deserialize(header, frames) + assert type(a2) == type(a) + assert a2.typecode == a.typecode + assert a2 == a + + @pytest.mark.parametrize( "typecode", ["b", "B", "h", "H", "i", "I", "l", "L", "q", "Q", "f", "d"] ) def test_serialize_arrays(typecode): - a = array(typecode) - a.extend(range(5)) + a = array(typecode, range(5)) + + # handle normal round trip through serialization header, frames = serialize(a) assert frames[0] == memoryview(a) a2 = deserialize(header, frames) @@ -101,6 +117,16 @@ def test_serialize_arrays(typecode): assert a2.typecode == a.typecode assert a2 == a + # split up frames to test joining them back together + header, frames = serialize(a) + (f,) = frames + f = ensure_memoryview(f) + frames = [f[:1], f[1:2], f[2:-1], f[-1:]] + a3 = deserialize(header, frames) + assert type(a3) == type(a) + assert a3.typecode == a.typecode + assert a3 == a + def test_Serialize(): s = Serialize(123)