diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 2c3b40158b6..eec104573ba 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -49,6 +49,8 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: import pandas as pd import pyarrow as pa + from dask.dataframe.dispatch import from_pyarrow_table_dispatch + file = BytesIO(data) end = len(data) shards = [] @@ -67,7 +69,9 @@ def default_types_mapper(pyarrow_dtype: pa.DataType) -> object: return pd.StringDtype("pyarrow") return None - df = table.to_pandas(self_destruct=True, types_mapper=default_types_mapper) + df = from_pyarrow_table_dispatch( + meta, table, self_destruct=True, types_mapper=default_types_mapper + ) return df.astype(meta.dtypes, copy=False) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 4a4f4601ea1..1e0ad120f46 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -97,7 +97,7 @@ def rearrange_by_column_p2p( column: str, npartitions: int | None = None, ) -> DataFrame: - from dask.dataframe import DataFrame + from dask.dataframe.core import new_dd_object meta = df._meta check_dtype_support(meta) @@ -119,7 +119,7 @@ def rearrange_by_column_p2p( name_input=df._name, meta_input=meta, ) - return DataFrame( + return new_dd_object( HighLevelGraph.from_collections(name, layer, [df]), name, meta, @@ -273,8 +273,13 @@ def split_by_worker( Split data into many arrow batches, partitioned by destination worker """ import numpy as np - import pyarrow as pa + from dask.dataframe.dispatch import to_pyarrow_table_dispatch + + # (cudf support) Avoid pd.Series + constructor = df._constructor_sliced + assert isinstance(constructor, type) + worker_for = constructor(worker_for) df = df.merge( right=worker_for.cat.codes.rename("_worker"), left_on=column, @@ -287,7 +292,7 @@ def split_by_worker( # assert len(df) == nrows # Not true if some outputs aren't wanted # FIXME: If we do not preserve the index something is corrupting the # bytestream such that it cannot be deserialized anymore - t = pa.Table.from_pandas(df, preserve_index=True) + t = to_pyarrow_table_dispatch(df, preserve_index=True) t = t.sort_by("_worker") codes = np.asarray(t["_worker"]) t = t.drop(["_worker"]) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 8b6c00a2116..f70b3b84d9c 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -121,6 +121,40 @@ async def test_minimal_version(c, s, a, b): await c.compute(dd.shuffle.shuffle(df, "x", shuffle="p2p")) +@pytest.mark.gpu +@pytest.mark.filterwarnings( + "ignore:Ignoring the following arguments to `from_pyarrow_table_dispatch`." +) +@gen_cluster(client=True) +async def test_basic_cudf_support(c, s, a, b): + cudf = pytest.importorskip("cudf") + pytest.importorskip("dask_cudf") + + try: + from dask.dataframe.dispatch import to_pyarrow_table_dispatch + + to_pyarrow_table_dispatch(cudf.DataFrame()) + except TypeError: + pytest.skip(reason="Newer version of dask_cudf is required.") + + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ).to_backend("cudf") + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + assert out.npartitions == df.npartitions + x, y = c.compute([df.x.size, out.x.size]) + x = await x + y = await y + assert x == y + + await check_worker_cleanup(a) + await check_worker_cleanup(b) + await check_scheduler_cleanup(s) + + def get_shuffle_run_from_worker(shuffle_id: ShuffleId, worker: Worker) -> ShuffleRun: plugin = worker.plugins["shuffle"] assert isinstance(plugin, ShuffleWorkerPlugin)