Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion distributed/comm/tests/test_ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from distributed.comm import ucx, parse_address
from distributed.protocol import to_serialize
from distributed.deploy.local import LocalCluster
from distributed.utils_test import gen_test, loop, inc # noqa: 401
from distributed.utils_test import gen_cluster, gen_test, loop, inc # noqa: 401

from .test_comms import check_deserialize

Expand Down Expand Up @@ -294,3 +294,20 @@ def test_tcp_localcluster(loop):
# assert any(w.data == {x.key: 2} for w in c.workers)
# assert e.loop is c.loop
# print(c.scheduler.workers)


@pytest.mark.asyncio
async def test_cudf_join():
from dask.distributed import Scheduler, Worker
import dask

cudf = pytest.importorskip("cudf")
async with Scheduler(protocol="ucx", port=0, interface="ib0") as s:
async with Worker(s.address, port=0) as a, Worker(s.address, port=0) as b:
async with Client(s.address, asynchronous=True) as c:
df = dask.datasets.timeseries(
dtypes={"x": int, "y": float}, freq="1s"
).partitions[:2]
df = df.map_partitions(cudf.from_pandas)
await c.compute(df.x.sum())
await c.compute(df[["x"]].sum())
1 change: 0 additions & 1 deletion distributed/protocol/cudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def serialize_cudf_dataframe(x):
arrays.extend(null_masks)

header = {
"is_cuda": len(arrays),
"subheaders": sub_headers,
# TODO: the header must be msgpack (de)serializable.
# See if we can avoid names, and just use integer positions.
Expand Down
1 change: 0 additions & 1 deletion distributed/protocol/cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def serialize_cupy_ndarray(x):
# used in the ucx comms for gpu/cpu message passing
# 'lengths' set by dask
header = x.__cuda_array_interface__.copy()
header["is_cuda"] = 1
header["dtype"] = dtype
return header, [data]

Expand Down
1 change: 0 additions & 1 deletion distributed/protocol/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def serialize_numba_ndarray(x):
# used in the ucx comms for gpu/cpu message passing
# 'lengths' set by dask
header = x.__cuda_array_interface__.copy()
header["is_cuda"] = 1
header["dtype"] = dtype
return header, [data]

Expand Down
29 changes: 29 additions & 0 deletions distributed/protocol/tests/test_cudf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

cudf = pytest.importorskip("cudf")

from distributed.protocol import serialize, deserialize
import dask.dataframe as dd


@pytest.mark.parametrize(
"df",
[
cudf.Series([1, 2, 3]),
cudf.Series([1, 2, None]),
cudf.DataFrame({"x": [1, 2, 3], "y": [1.0, 2.0, 3.0]}),
cudf.DataFrame({"x": [1, 2, 3], "s": ["a", "bb", "ccc"]}),
cudf.DataFrame(
{"x": [1, 2, None], "y": [1.0, 2.0, None], "s": ["a", "bb", None]}
),
],
)
def test_basic(df):
header, frames = serialize(
df, serializers=("cuda", "dask", "pickle"), on_error="raise"
)
assert header["serializer"] == "cuda"
assert not any(isinstance(frame, (bytes, memoryview)) for frame in frames)

df2 = deserialize(header, frames, deserializers=("cuda", "dask", "pickle"))
dd.assert_eq(df, df2)