From d3c1d62f62076b6a092318715c46a716f72c6559 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Tue, 9 Jun 2020 09:58:22 -0700 Subject: [PATCH 1/8] Test pickle serialization with NumPy object arrays --- distributed/protocol/tests/test_pickle.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index bd784117186..1bdf00fc163 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -74,6 +74,20 @@ def test_pickle_numpy(): assert (loads(dumps(x)) == x).all() assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all() + x = np.array([np.arange(3), np.arange(4, 6)], dtype=object) + x2 = loads(dumps(x)) + assert x.shape == x2.shape + assert x.dtype == x2.dtype + assert x.strides == x2.strides + for e_x, e_x2 in zip(x.flat, x2.flat): + np.testing.assert_equal(e_x, e_x2) + x3 = deserialize(*serialize(x, serializers=("pickle",))) + assert x.shape == x3.shape + assert x.dtype == x3.dtype + assert x.strides == x3.strides + for e_x, e_x3 in zip(x.flat, x3.flat): + np.testing.assert_equal(e_x, e_x3) + if HIGHEST_PROTOCOL >= 5: x = np.ones(5000) From 77427ef210d9e929b2eb6466aad982573c8d68eb Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Tue, 9 Jun 2020 09:58:23 -0700 Subject: [PATCH 2/8] Test equality after checking array strides --- distributed/protocol/tests/test_numpy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 830991fd56a..78636405a69 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -81,10 +81,11 @@ def test_dumps_serialize_numpy(x): assert isinstance(frame, (bytes, memoryview)) y = deserialize(header, frames) - np.testing.assert_equal(x, y) if x.flags.c_contiguous or x.flags.f_contiguous: assert x.strides == y.strides + np.testing.assert_equal(x, y) + @pytest.mark.parametrize( "x", From 6529237b88df92497ab0b94c197143725986e12e Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Tue, 9 Jun 2020 09:58:24 -0700 Subject: [PATCH 3/8] Check array shape and type as well --- distributed/protocol/tests/test_numpy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 78636405a69..76ae9da87fc 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -81,6 +81,8 @@ def test_dumps_serialize_numpy(x): assert isinstance(frame, (bytes, memoryview)) y = deserialize(header, frames) + assert x.shape == y.shape + assert x.dtype == y.dtype if x.flags.c_contiguous or x.flags.f_contiguous: assert x.strides == y.strides From 9e82ce1ee86057a513e60de94fdf6cd5f6200efb Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Tue, 9 Jun 2020 09:58:25 -0700 Subject: [PATCH 4/8] Handle object array tests specially --- distributed/protocol/tests/test_numpy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 76ae9da87fc..cf293ce18c5 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -86,7 +86,11 @@ def test_dumps_serialize_numpy(x): if x.flags.c_contiguous or x.flags.f_contiguous: assert x.strides == y.strides - np.testing.assert_equal(x, y) + if x.dtype.char == "O": + for e_x, e_y in zip(x.flat, y.flat): + np.testing.assert_equal(e_x, e_y) + else: + np.testing.assert_equal(x, y) @pytest.mark.parametrize( From 3759929ee72dc1505a1dc067b218723327f765d5 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Tue, 9 Jun 2020 09:58:25 -0700 Subject: [PATCH 5/8] Test pickling a ragged NumPy array --- distributed/protocol/tests/test_numpy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index cf293ce18c5..fd723019578 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -57,6 +57,7 @@ def test_serialize(): np.array(["abc"], dtype="S3"), np.array(["abc"], dtype="U3"), np.array(["abc"], dtype=object), + np.array([np.arange(3), np.arange(4, 6)], dtype=object), np.ones(shape=(5,), dtype=("f8", 32)), np.ones(shape=(5,), dtype=[("x", "f8", 32)]), np.ones(shape=(5,), dtype=np.dtype([("a", "i1"), ("b", "f8")], align=False)), From 1487612e6cbe6d7833f2167dc67322671f8fc937 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Tue, 9 Jun 2020 09:58:27 -0700 Subject: [PATCH 6/8] Use pickle protocol 5 with NumPy object arrays --- distributed/protocol/numpy.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index 2140c2f0c4e..caeca147169 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -4,7 +4,7 @@ from .serialize import dask_serialize, dask_deserialize from . import pickle -from ..utils import log_errors +from ..utils import log_errors, nbytes def itemsize(dt): @@ -22,7 +22,10 @@ def itemsize(dt): def serialize_numpy_ndarray(x): if x.dtype.hasobject: header = {"pickle": True} - frames = [pickle.dumps(x)] + frames = [None] + buffer_callback = lambda f: frames.append(memoryview(f)) + frames[0] = pickle.dumps(x, buffer_callback=buffer_callback) + header["lengths"] = tuple(map(nbytes, frames)) return header, frames # We cannot blindly pickle the dtype as some may fail pickling, @@ -96,10 +99,10 @@ def serialize_numpy_ndarray(x): @dask_deserialize.register(np.ndarray) def deserialize_numpy_ndarray(header, frames): with log_errors(): - (frame,) = frames - if header.get("pickle"): - return pickle.loads(frame) + return pickle.loads(frames[0], buffers=frames[1:]) + + (frame,) = frames is_custom, dt = header["dtype"] if is_custom: From 406dd3ba0915d36576934e7cbf0c3b66c4530e8c Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Tue, 9 Jun 2020 10:02:11 -0700 Subject: [PATCH 7/8] Assert ragged array components extracted as frames --- distributed/protocol/tests/test_pickle.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 1bdf00fc163..9ee496f5e9f 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -81,7 +81,12 @@ def test_pickle_numpy(): assert x.strides == x2.strides for e_x, e_x2 in zip(x.flat, x2.flat): np.testing.assert_equal(e_x, e_x2) - x3 = deserialize(*serialize(x, serializers=("pickle",))) + h, f = serialize(x, serializers=("pickle",)) + if HIGHEST_PROTOCOL >= 5: + assert len(f) == 3 + else: + assert len(f) == 1 + x3 = deserialize(h, f) assert x.shape == x3.shape assert x.dtype == x3.dtype assert x.strides == x3.strides From e07d26ba8136703a5896ebffd18167b0a980eaf1 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Tue, 9 Jun 2020 11:56:54 -0700 Subject: [PATCH 8/8] Assert additional frames when pickling --- distributed/protocol/tests/test_numpy.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index fd723019578..0e299632902 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -14,6 +14,7 @@ ) from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE from distributed.protocol.numpy import itemsize +from distributed.protocol.pickle import HIGHEST_PROTOCOL from distributed.protocol.compression import maybe_compress from distributed.system import MEMORY_LIMIT from distributed.utils import tmpfile, nbytes @@ -80,6 +81,11 @@ def test_dumps_serialize_numpy(x): frames = decompress(header, frames) for frame in frames: assert isinstance(frame, (bytes, memoryview)) + if x.dtype.char == "O" and any(isinstance(e, np.ndarray) for e in x.flat): + if HIGHEST_PROTOCOL >= 5: + assert len(frames) > 1 + else: + assert len(frames) == 1 y = deserialize(header, frames) assert x.shape == y.shape