diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 5ecefcd6093..aa5a2824eeb 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -413,3 +413,22 @@ async def test_comm_closed_on_read_error(): await wait_for(reader.read(), 0.01) assert reader.closed() + + +@gen_test() +async def test_embedded_cupy_array( + ucx_loop, +): + cupy = pytest.importorskip("cupy") + da = pytest.importorskip("dask.array") + np = pytest.importorskip("numpy") + + async with LocalCluster( + protocol="ucx", n_workers=1, threads_per_worker=1, asynchronous=True + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + assert cluster.scheduler_address.startswith("ucx://") + a = cupy.arange(10000) + x = da.from_array(a, chunks=(10000,)) + b = await client.compute(x) + cupy.testing.assert_array_equal(a, b) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index d58ee011297..58740af62ce 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -146,6 +146,8 @@ def _decode_default(obj): sub_header = msgpack.loads(frames[offset]) offset += 1 sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] + if "compression" in sub_header: + sub_frames = decompress(sub_header, sub_frames) if allow_pickle: return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames) else: