From 7210c38ef59be4fada43504d7c6232e87e0bdadd Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 3 Apr 2023 15:54:55 -0700 Subject: [PATCH 01/20] basic support for cudf-backed collection in p2p shuffle --- distributed/shuffle/_worker_extension.py | 39 +++++++++++++++++++----- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 751d69e1f57..5df0f1b5081 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -10,12 +10,13 @@ from collections import defaultdict from collections.abc import Callable, Iterator from concurrent.futures import ThreadPoolExecutor +from importlib import import_module from io import BytesIO from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload import toolz -from dask.utils import parse_bytes +from dask.utils import parse_bytes, typename from distributed.core import PooledRPCCall from distributed.protocol import to_serialize @@ -419,6 +420,8 @@ class DataFrameShuffleRun(ShuffleRun[int, int, "pd.DataFrame"]): memory_limiter_comm: A ``ResourceLimiter`` limiting the total amount of memory used in either buffer. + dataframe_backend: + Backend dataframe library name. Default is "pandas". """ def __init__( @@ -436,8 +439,9 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, + dataframe_backend: str = "pandas", ): - import pandas as pd + lib = import_module(dataframe_backend) super().__init__( id=id, @@ -457,7 +461,8 @@ def __init__( for part, addr in worker_for.items(): partitions_of[addr].append(part) self.partitions_of = dict(partitions_of) - self.worker_for = pd.Series(worker_for, name="_workers").astype("category") + self.worker_for = lib.Series(worker_for, name="_workers").astype("category") + self._dataframe_backend = dataframe_backend async def receive(self, data: list[tuple[int, bytes]]) -> None: await self._receive(data) @@ -507,6 +512,13 @@ def _() -> dict[str, list[tuple[int, bytes]]]: await self._write_to_comm(out) return self.run_id + def _arrow_to_df(self, table) -> pd.DataFrame: + if self._dataframe_backend == "cudf": + import cudf + + return cudf.DataFrame.from_arrow(table) + return table.to_pandas() + async def get_output_partition(self, i: int) -> pd.DataFrame: self.raise_if_closed() assert self.transferred, "`get_output_partition` called before barrier task" @@ -524,11 +536,11 @@ async def get_output_partition(self, i: int) -> pd.DataFrame: def _() -> pd.DataFrame: df = convert_partition(data) - return df.to_pandas() + return self._arrow_to_df(df) out = await self.offload(_) except KeyError: - out = self.schema.empty_table().to_pandas() + out = self._arrow_to_df(self.schema.empty_table()) return out @@ -550,6 +562,7 @@ class ShuffleWorkerExtension: memory_limiter_comms: ResourceLimiter memory_limiter_disk: ResourceLimiter closed: bool + _backends: dict def __init__(self, worker: Worker) -> None: # Attach to worker @@ -566,6 +579,7 @@ def __init__(self, worker: Worker) -> None: self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB")) self.closed = False self._executor = ThreadPoolExecutor(self.worker.state.nthreads) + self._backends = {} # Handlers ########## @@ -628,6 +642,8 @@ def add_partition( ) -> int: if type == ShuffleType.DATAFRAME: kwargs["empty"] = data + if shuffle_id not in self._backends: + self._backends[shuffle_id] = typename(data).partition(".")[0] shuffle = self.get_or_create_shuffle(shuffle_id, type=type, **kwargs) return sync( self.worker.loop, @@ -758,7 +774,11 @@ async def _refresh_shuffle( id=shuffle_id, type=type, spec={ - "schema": pa.Schema.from_pandas(kwargs["empty"]) + "schema": pa.Schema.from_pandas( + kwargs["empty"].to_pandas() + if hasattr(kwargs["empty"], "to_pandas") + else kwargs["empty"] + ) .serialize() .to_pybytes(), "npartitions": kwargs["npartitions"], @@ -818,6 +838,7 @@ async def _( scheduler=self.worker.scheduler, memory_limiter_disk=self.memory_limiter_disk, memory_limiter_comms=self.memory_limiter_comms, + dataframe_backend=self._backends.get(shuffle_id, "pandas"), ) elif result["type"] == ShuffleType.ARRAY_RECHUNK: shuffle = ArrayRechunkRun( @@ -926,7 +947,11 @@ def split_by_worker( # assert len(df) == nrows # Not true if some outputs aren't wanted # FIXME: If we do not preserve the index something is corrupting the # bytestream such that it cannot be deserialized anymore - t = pa.Table.from_pandas(df, preserve_index=True) + t = ( + df.to_arrow(preserve_index=True) + if hasattr(df, "to_arrow") and callable(df.to_arrow) + else pa.Table.from_pandas(df, preserve_index=True) + ) t = t.sort_by("_worker") codes = np.asarray(t.select(["_worker"]))[0] t = t.drop(["_worker"]) From b03c132346561a150e863c4ec0c531f9a02a760a Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 3 Apr 2023 19:24:01 -0700 Subject: [PATCH 02/20] use schema metadata --- distributed/shuffle/_worker_extension.py | 28 ++++++++++-------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 5df0f1b5081..df295557ead 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -420,8 +420,6 @@ class DataFrameShuffleRun(ShuffleRun[int, int, "pd.DataFrame"]): memory_limiter_comm: A ``ResourceLimiter`` limiting the total amount of memory used in either buffer. - dataframe_backend: - Backend dataframe library name. Default is "pandas". """ def __init__( @@ -439,10 +437,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, - dataframe_backend: str = "pandas", ): - lib = import_module(dataframe_backend) - super().__init__( id=id, run_id=run_id, @@ -461,8 +456,7 @@ def __init__( for part, addr in worker_for.items(): partitions_of[addr].append(part) self.partitions_of = dict(partitions_of) - self.worker_for = lib.Series(worker_for, name="_workers").astype("category") - self._dataframe_backend = dataframe_backend + self.worker_for = worker_for async def receive(self, data: list[tuple[int, bytes]]) -> None: await self._receive(data) @@ -512,8 +506,8 @@ def _() -> dict[str, list[tuple[int, bytes]]]: await self._write_to_comm(out) return self.run_id - def _arrow_to_df(self, table) -> pd.DataFrame: - if self._dataframe_backend == "cudf": + def _arrow_to_df(self, table: pa.Table) -> pd.DataFrame: + if table.schema.metadata.get(b"dataframe", None) == b"cudf": import cudf return cudf.DataFrame.from_arrow(table) @@ -562,7 +556,6 @@ class ShuffleWorkerExtension: memory_limiter_comms: ResourceLimiter memory_limiter_disk: ResourceLimiter closed: bool - _backends: dict def __init__(self, worker: Worker) -> None: # Attach to worker @@ -579,7 +572,6 @@ def __init__(self, worker: Worker) -> None: self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB")) self.closed = False self._executor = ThreadPoolExecutor(self.worker.state.nthreads) - self._backends = {} # Handlers ########## @@ -642,8 +634,6 @@ def add_partition( ) -> int: if type == ShuffleType.DATAFRAME: kwargs["empty"] = data - if shuffle_id not in self._backends: - self._backends[shuffle_id] = typename(data).partition(".")[0] shuffle = self.get_or_create_shuffle(shuffle_id, type=type, **kwargs) return sync( self.worker.loop, @@ -838,7 +828,6 @@ async def _( scheduler=self.worker.scheduler, memory_limiter_disk=self.memory_limiter_disk, memory_limiter_comms=self.memory_limiter_comms, - dataframe_backend=self._backends.get(shuffle_id, "pandas"), ) elif result["type"] == ShuffleType.ARRAY_RECHUNK: shuffle = ArrayRechunkRun( @@ -927,7 +916,7 @@ def get_output_partition( def split_by_worker( df: pd.DataFrame, column: str, - worker_for: pd.Series, + worker_for: dict[int, str], ) -> dict[Any, pa.Table]: """ Split data into many arrow batches, partitioned by destination worker @@ -935,8 +924,12 @@ def split_by_worker( import numpy as np import pyarrow as pa + lib = typename(df).partition(".")[0] + worker_for_ser = ( + import_module(lib).Series(worker_for, name="_workers").astype("category") + ) df = df.merge( - right=worker_for.cat.codes.rename("_worker"), + right=worker_for_ser.cat.codes.rename("_worker"), left_on=column, right_index=True, how="inner", @@ -952,6 +945,7 @@ def split_by_worker( if hasattr(df, "to_arrow") and callable(df.to_arrow) else pa.Table.from_pandas(df, preserve_index=True) ) + t = t.replace_schema_metadata(t.schema.metadata | {"dataframe": lib}) t = t.sort_by("_worker") codes = np.asarray(t.select(["_worker"]))[0] t = t.drop(["_worker"]) @@ -968,7 +962,7 @@ def split_by_worker( unique_codes = codes[splits] out = { # FIXME https://github.com/pandas-dev/pandas-stubs/issues/43 - worker_for.cat.categories[code]: shard + worker_for_ser.cat.categories[code]: shard for code, shard in zip(unique_codes, shards) } assert sum(map(len, out.values())) == nrows From 0e867853246cd1ebcbc7ed4fe1e2403d8ebfcea2 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 3 Apr 2023 19:41:55 -0700 Subject: [PATCH 03/20] avoid leaving worker_for as dict --- distributed/shuffle/_worker_extension.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index df295557ead..2a601ae5fbc 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -10,7 +10,6 @@ from collections import defaultdict from collections.abc import Callable, Iterator from concurrent.futures import ThreadPoolExecutor -from importlib import import_module from io import BytesIO from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload @@ -438,6 +437,8 @@ def __init__( memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, ): + import pandas as pd + super().__init__( id=id, run_id=run_id, @@ -456,7 +457,7 @@ def __init__( for part, addr in worker_for.items(): partitions_of[addr].append(part) self.partitions_of = dict(partitions_of) - self.worker_for = worker_for + self.worker_for = pd.Series(worker_for, name="_workers").astype("category") async def receive(self, data: list[tuple[int, bytes]]) -> None: await self._receive(data) @@ -916,7 +917,7 @@ def get_output_partition( def split_by_worker( df: pd.DataFrame, column: str, - worker_for: dict[int, str], + worker_for: pd.Series, ) -> dict[Any, pa.Table]: """ Split data into many arrow batches, partitioned by destination worker @@ -925,11 +926,12 @@ def split_by_worker( import pyarrow as pa lib = typename(df).partition(".")[0] - worker_for_ser = ( - import_module(lib).Series(worker_for, name="_workers").astype("category") - ) + if lib == "cudf": + import cudf + + worker_for = cudf.from_pandas(worker_for) df = df.merge( - right=worker_for_ser.cat.codes.rename("_worker"), + right=worker_for.cat.codes.rename("_worker"), left_on=column, right_index=True, how="inner", @@ -962,7 +964,7 @@ def split_by_worker( unique_codes = codes[splits] out = { # FIXME https://github.com/pandas-dev/pandas-stubs/issues/43 - worker_for_ser.cat.categories[code]: shard + worker_for.cat.categories[code]: shard for code, shard in zip(unique_codes, shards) } assert sum(map(len, out.values())) == nrows From ce23c474c3e7456278d1b571d6b76d2c32f9afee Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 3 Apr 2023 19:48:43 -0700 Subject: [PATCH 04/20] avoid updating metadata for pandas --- distributed/shuffle/_worker_extension.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 2a601ae5fbc..045d73c519a 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -942,12 +942,11 @@ def split_by_worker( # assert len(df) == nrows # Not true if some outputs aren't wanted # FIXME: If we do not preserve the index something is corrupting the # bytestream such that it cannot be deserialized anymore - t = ( - df.to_arrow(preserve_index=True) - if hasattr(df, "to_arrow") and callable(df.to_arrow) - else pa.Table.from_pandas(df, preserve_index=True) - ) - t = t.replace_schema_metadata(t.schema.metadata | {"dataframe": lib}) + if hasattr(df, "to_arrow") and callable(df.to_arrow): + t = df.to_arrow(preserve_index=True) + t = t.replace_schema_metadata(t.schema.metadata | {"dataframe": lib}) + else: + t = pa.Table.from_pandas(df, preserve_index=True) t = t.sort_by("_worker") codes = np.asarray(t.select(["_worker"]))[0] t = t.drop(["_worker"]) From ad57395a54c9f3b3f5e27d9ec5210de0011482ae Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 24 May 2023 06:47:49 -0700 Subject: [PATCH 05/20] use new get_meta_library utility --- distributed/shuffle/_worker_extension.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index facea56b98b..7f20e3400cc 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -16,7 +16,7 @@ import toolz from dask.context import thread_state -from dask.utils import parse_bytes, typename +from dask.utils import get_meta_library, parse_bytes from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule @@ -945,11 +945,11 @@ def split_by_worker( import numpy as np import pyarrow as pa - lib = typename(df).partition(".")[0] - if lib == "cudf": - import cudf + # Allow cudf-based data + lib = get_meta_library(df) + if hasattr(lib, "from_pandas"): + worker_for = lib.from_pandas(worker_for) - worker_for = cudf.from_pandas(worker_for) df = df.merge( right=worker_for.cat.codes.rename("_worker"), left_on=column, From 964b57a711d75ac6279668726801c21f7e4ff712 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 24 May 2023 06:49:40 -0700 Subject: [PATCH 06/20] linting --- distributed/shuffle/_worker_extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 7f20e3400cc..0fe145de55f 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -945,9 +945,9 @@ def split_by_worker( import numpy as np import pyarrow as pa - # Allow cudf-based data lib = get_meta_library(df) if hasattr(lib, "from_pandas"): + # Allow cudf-based data worker_for = lib.from_pandas(worker_for) df = df.merge( From b088b25744914aa0831d703e580f6289c0b1a880 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 31 May 2023 10:31:04 -0700 Subject: [PATCH 07/20] save state --- distributed/shuffle/_worker_extension.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 0fe145de55f..6d722a34738 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -16,7 +16,7 @@ import toolz from dask.context import thread_state -from dask.utils import get_meta_library, parse_bytes +from dask.utils import funcname, get_meta_library, parse_bytes from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule @@ -945,11 +945,9 @@ def split_by_worker( import numpy as np import pyarrow as pa + # (cudf support) Align dataframe backends lib = get_meta_library(df) - if hasattr(lib, "from_pandas"): - # Allow cudf-based data - worker_for = lib.from_pandas(worker_for) - + worker_for = lib.Series(worker_for) df = df.merge( right=worker_for.cat.codes.rename("_worker"), left_on=column, @@ -964,7 +962,7 @@ def split_by_worker( # bytestream such that it cannot be deserialized anymore if hasattr(df, "to_arrow") and callable(df.to_arrow): t = df.to_arrow(preserve_index=True) - t = t.replace_schema_metadata(t.schema.metadata | {"dataframe": lib}) + t = t.replace_schema_metadata(t.schema.metadata | {"dataframe": funcname(lib)}) else: t = pa.Table.from_pandas(df, preserve_index=True) t = t.sort_by("_worker") From 20616bb71a4462d25bfc0ac9ec3484d82fb0ea33 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 8 Jun 2023 11:19:57 -0700 Subject: [PATCH 08/20] add test --- distributed/shuffle/tests/test_shuffle.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 28f94aeabc2..4a92d5ca46f 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -111,6 +111,28 @@ async def test_minimal_version(c, s, a, b): await c.compute(dd.shuffle.shuffle(df, "x", shuffle="p2p")) +@pytest.mark.gpu +@gen_cluster(client=True) +async def test_basic_cudf_support(c, s, a, b): + pytest.importorskip("dask_cudf") + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ).to_backend("cudf") + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + assert out.npartitions == df.npartitions + x, y = c.compute([df.x.size, out.x.size]) + x = await x + y = await y + assert x == y + + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) + + @pytest.mark.parametrize("npartitions", [None, 1, 20]) @gen_cluster(client=True) async def test_basic_integration(c, s, a, b, lose_annotations, npartitions): From d39a774a2bf1b27f6bdcaa8fbaca5f3e4b723a23 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Fri, 9 Jun 2023 06:17:32 -0700 Subject: [PATCH 09/20] leverage meta instead of custom pyarrow metadata --- distributed/shuffle/_arrow.py | 14 +++++++------- distributed/shuffle/_worker_extension.py | 5 +++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 58ae301193b..386a8279234 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -1,7 +1,7 @@ from __future__ import annotations from io import BytesIO -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from packaging.version import parse @@ -45,11 +45,11 @@ def check_minimal_arrow_version() -> None: ) -def _arrow_to_df(table: pa.Table) -> pd.DataFrame: - if table.schema.metadata.get(b"dataframe", None) == b"cudf": - import cudf - - return cudf.DataFrame.from_arrow(table) +def _arrow_to_df(table: pa.Table, like: Any) -> pd.DataFrame: + if hasattr(like, "from_arrow"): + # TODO: Dispatch on `meta` + # (see: https://github.com/dask/dask/pull/10312) + return like.from_arrow(table) return table.to_pandas(self_destruct=True) @@ -63,7 +63,7 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: sr = pa.RecordBatchStreamReader(file) shards.append(sr.read_all()) table = pa.concat_tables(shards) - df = _arrow_to_df(table) + df = _arrow_to_df(table, meta) return df.astype(meta.dtypes) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index ff9dac6dbbc..1a4de2a09e7 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -16,7 +16,7 @@ import toolz from dask.context import thread_state -from dask.utils import funcname, get_meta_library, parse_bytes +from dask.utils import get_meta_library, parse_bytes from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule @@ -953,8 +953,9 @@ def split_by_worker( # FIXME: If we do not preserve the index something is corrupting the # bytestream such that it cannot be deserialized anymore if hasattr(df, "to_arrow") and callable(df.to_arrow): + # TODO: Dispatch on `df` + # (see: https://github.com/dask/dask/pull/10312) t = df.to_arrow(preserve_index=True) - t = t.replace_schema_metadata(t.schema.metadata | {"dataframe": funcname(lib)}) else: t = pa.Table.from_pandas(df, preserve_index=True) t = t.sort_by("_worker") From f605ac9df75fad3987e806d7d5e60363adb6d7bb Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 14 Jun 2023 06:10:50 -0700 Subject: [PATCH 10/20] use dispatch functions --- distributed/shuffle/_arrow.py | 14 ++++---------- distributed/shuffle/_worker_extension.py | 9 ++------- distributed/shuffle/tests/test_shuffle.py | 9 +++++++++ 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 386a8279234..22fc7e8a42f 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -1,10 +1,12 @@ from __future__ import annotations from io import BytesIO -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from packaging.version import parse +from dask.dataframe.dispatch import from_pyarrow_table_dispatch + if TYPE_CHECKING: import pandas as pd import pyarrow as pa @@ -45,14 +47,6 @@ def check_minimal_arrow_version() -> None: ) -def _arrow_to_df(table: pa.Table, like: Any) -> pd.DataFrame: - if hasattr(like, "from_arrow"): - # TODO: Dispatch on `meta` - # (see: https://github.com/dask/dask/pull/10312) - return like.from_arrow(table) - return table.to_pandas(self_destruct=True) - - def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: import pyarrow as pa @@ -63,7 +57,7 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: sr = pa.RecordBatchStreamReader(file) shards.append(sr.read_all()) table = pa.concat_tables(shards) - df = _arrow_to_df(table, meta) + df = from_pyarrow_table_dispatch(meta, table) return df.astype(meta.dtypes) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index f75a6bd69e1..9a57868dff4 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -16,6 +16,7 @@ import toolz from dask.context import thread_state +from dask.dataframe.dispatch import to_pyarrow_table_dispatch from dask.utils import get_meta_library, parse_bytes from distributed.core import PooledRPCCall @@ -927,7 +928,6 @@ def split_by_worker( Split data into many arrow batches, partitioned by destination worker """ import numpy as np - import pyarrow as pa # (cudf support) Align dataframe backends lib = get_meta_library(df) @@ -944,12 +944,7 @@ def split_by_worker( # assert len(df) == nrows # Not true if some outputs aren't wanted # FIXME: If we do not preserve the index something is corrupting the # bytestream such that it cannot be deserialized anymore - if hasattr(df, "to_arrow") and callable(df.to_arrow): - # TODO: Dispatch on `df` - # (see: https://github.com/dask/dask/pull/10312) - t = df.to_arrow(preserve_index=True) - else: - t = pa.Table.from_pandas(df, preserve_index=True) + t = to_pyarrow_table_dispatch(df, preserve_index=True) t = t.sort_by("_worker") codes = np.asarray(t.select(["_worker"]))[0] t = t.drop(["_worker"]) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 0afd184bf9b..3d11ec55b0f 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -114,7 +114,16 @@ async def test_minimal_version(c, s, a, b): @pytest.mark.gpu @gen_cluster(client=True) async def test_basic_cudf_support(c, s, a, b): + cudf = pytest.importorskip("cudf") pytest.importorskip("dask_cudf") + + try: + from dask.dataframe.dispatch import to_pyarrow_table_dispatch + + to_pyarrow_table_dispatch(cudf.DataFrame()) + except (ImportError, TypeError): + pytest.skip(reason="Newer version of cudf is required.") + df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", From ef37b5b3dbebae0455b77d6b18f0908e8e933cfc Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 14 Jun 2023 06:12:36 -0700 Subject: [PATCH 11/20] clarify error message --- distributed/shuffle/tests/test_shuffle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 3d11ec55b0f..bdae197d455 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -121,8 +121,8 @@ async def test_basic_cudf_support(c, s, a, b): from dask.dataframe.dispatch import to_pyarrow_table_dispatch to_pyarrow_table_dispatch(cudf.DataFrame()) - except (ImportError, TypeError): - pytest.skip(reason="Newer version of cudf is required.") + except TypeError: + pytest.skip(reason="Newer version of dask_cudf is required.") df = dask.datasets.timeseries( start="2000-01-01", From d87db602696616295405182d43b488b73af39fde Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 14 Jun 2023 07:35:19 -0700 Subject: [PATCH 12/20] check attr --- distributed/shuffle/_arrow.py | 10 ++++++++-- distributed/shuffle/_worker_extension.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 22fc7e8a42f..d29b0e36638 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -5,7 +5,7 @@ from packaging.version import parse -from dask.dataframe.dispatch import from_pyarrow_table_dispatch +import dask.dataframe.dispatch as dataframe_dispatch if TYPE_CHECKING: import pandas as pd @@ -57,7 +57,13 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: sr = pa.RecordBatchStreamReader(file) shards.append(sr.read_all()) table = pa.concat_tables(shards) - df = from_pyarrow_table_dispatch(meta, table) + if hasattr(dataframe_dispatch, "from_pyarrow_table_dispatch"): + df = dataframe_dispatch.from_pyarrow_table_dispatch( + meta, table, self_destruct=True + ) + else: + # Backward compat + df = table.to_pandas(self_destruct=True) return df.astype(meta.dtypes) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 9a57868dff4..ece35a92b2b 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -15,8 +15,8 @@ import toolz +import dask.dataframe.dispatch as dataframe_dispatch from dask.context import thread_state -from dask.dataframe.dispatch import to_pyarrow_table_dispatch from dask.utils import get_meta_library, parse_bytes from distributed.core import PooledRPCCall @@ -944,7 +944,13 @@ def split_by_worker( # assert len(df) == nrows # Not true if some outputs aren't wanted # FIXME: If we do not preserve the index something is corrupting the # bytestream such that it cannot be deserialized anymore - t = to_pyarrow_table_dispatch(df, preserve_index=True) + if hasattr(dataframe_dispatch, "to_pyarrow_table_dispatch"): + t = dataframe_dispatch.to_pyarrow_table_dispatch(df, preserve_index=True) + else: + import pyarrow as pa + + # Backward compat + t = pa.Table.from_pandas(df, preserve_index=True) t = t.sort_by("_worker") codes = np.asarray(t.select(["_worker"]))[0] t = t.drop(["_worker"]) From d171dc21378324d7c06cca1ff74ae580aa7a3bcd Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 14 Jun 2023 07:42:22 -0700 Subject: [PATCH 13/20] catch importerror in test --- distributed/shuffle/tests/test_shuffle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index bdae197d455..1f009b0f843 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -121,8 +121,8 @@ async def test_basic_cudf_support(c, s, a, b): from dask.dataframe.dispatch import to_pyarrow_table_dispatch to_pyarrow_table_dispatch(cudf.DataFrame()) - except TypeError: - pytest.skip(reason="Newer version of dask_cudf is required.") + except (ImportError, TypeError): + pytest.skip(reason="Newer version of dask and/or dask_cudf is required.") df = dask.datasets.timeseries( start="2000-01-01", From cbd30264a620b1ca9c5572b532997a52b3b7f4c0 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 14 Jun 2023 12:51:32 -0700 Subject: [PATCH 14/20] assume latest version of dask --- distributed/shuffle/_arrow.py | 8 +------- distributed/shuffle/_worker_extension.py | 8 +------- distributed/shuffle/tests/test_shuffle.py | 4 ++-- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index d29b0e36638..c0f8b575083 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -57,13 +57,7 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: sr = pa.RecordBatchStreamReader(file) shards.append(sr.read_all()) table = pa.concat_tables(shards) - if hasattr(dataframe_dispatch, "from_pyarrow_table_dispatch"): - df = dataframe_dispatch.from_pyarrow_table_dispatch( - meta, table, self_destruct=True - ) - else: - # Backward compat - df = table.to_pandas(self_destruct=True) + df = dataframe_dispatch.from_pyarrow_table_dispatch(meta, table, self_destruct=True) return df.astype(meta.dtypes) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index ece35a92b2b..bc6aeac7281 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -944,13 +944,7 @@ def split_by_worker( # assert len(df) == nrows # Not true if some outputs aren't wanted # FIXME: If we do not preserve the index something is corrupting the # bytestream such that it cannot be deserialized anymore - if hasattr(dataframe_dispatch, "to_pyarrow_table_dispatch"): - t = dataframe_dispatch.to_pyarrow_table_dispatch(df, preserve_index=True) - else: - import pyarrow as pa - - # Backward compat - t = pa.Table.from_pandas(df, preserve_index=True) + t = dataframe_dispatch.to_pyarrow_table_dispatch(df, preserve_index=True) t = t.sort_by("_worker") codes = np.asarray(t.select(["_worker"]))[0] t = t.drop(["_worker"]) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 1f009b0f843..bdae197d455 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -121,8 +121,8 @@ async def test_basic_cudf_support(c, s, a, b): from dask.dataframe.dispatch import to_pyarrow_table_dispatch to_pyarrow_table_dispatch(cudf.DataFrame()) - except (ImportError, TypeError): - pytest.skip(reason="Newer version of dask and/or dask_cudf is required.") + except TypeError: + pytest.skip(reason="Newer version of dask_cudf is required.") df = dask.datasets.timeseries( start="2000-01-01", From dac7291487eed9441df52e395ac0c4ca90119ef0 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 14 Jun 2023 12:59:25 -0700 Subject: [PATCH 15/20] back to original imports --- distributed/shuffle/_arrow.py | 4 ++-- distributed/shuffle/_worker_extension.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index c0f8b575083..312974213b2 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -5,7 +5,7 @@ from packaging.version import parse -import dask.dataframe.dispatch as dataframe_dispatch +from dask.dataframe.dispatch import from_pyarrow_table_dispatch if TYPE_CHECKING: import pandas as pd @@ -57,7 +57,7 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: sr = pa.RecordBatchStreamReader(file) shards.append(sr.read_all()) table = pa.concat_tables(shards) - df = dataframe_dispatch.from_pyarrow_table_dispatch(meta, table, self_destruct=True) + df = from_pyarrow_table_dispatch(meta, table, self_destruct=True) return df.astype(meta.dtypes) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index bc6aeac7281..9a57868dff4 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -15,8 +15,8 @@ import toolz -import dask.dataframe.dispatch as dataframe_dispatch from dask.context import thread_state +from dask.dataframe.dispatch import to_pyarrow_table_dispatch from dask.utils import get_meta_library, parse_bytes from distributed.core import PooledRPCCall @@ -944,7 +944,7 @@ def split_by_worker( # assert len(df) == nrows # Not true if some outputs aren't wanted # FIXME: If we do not preserve the index something is corrupting the # bytestream such that it cannot be deserialized anymore - t = dataframe_dispatch.to_pyarrow_table_dispatch(df, preserve_index=True) + t = to_pyarrow_table_dispatch(df, preserve_index=True) t = t.sort_by("_worker") codes = np.asarray(t.select(["_worker"]))[0] t = t.drop(["_worker"]) From aae562ac41686754a14bbb73d672420aca8005d5 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 20 Jun 2023 16:06:13 -0700 Subject: [PATCH 16/20] ignore dispatch warnings --- distributed/shuffle/tests/test_shuffle.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index bdae197d455..f980dd44c46 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -112,6 +112,9 @@ async def test_minimal_version(c, s, a, b): @pytest.mark.gpu +@pytest.mark.filterwarnings( + "ignore:Ignoring the following arguments to `from_pyarrow_table_dispatch`." +) @gen_cluster(client=True) async def test_basic_cudf_support(c, s, a, b): cudf = pytest.importorskip("cudf") From 338ecff444c15c0eed41b039815c2d24931dac88 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 21 Jun 2023 15:08:17 -0700 Subject: [PATCH 17/20] move imports --- distributed/shuffle/_arrow.py | 4 ++-- distributed/shuffle/_worker_extension.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index ea131f846a2..bee28245cb0 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -5,8 +5,6 @@ from packaging.version import parse -from dask.dataframe.dispatch import from_pyarrow_table_dispatch - if TYPE_CHECKING: import pandas as pd import pyarrow as pa @@ -51,6 +49,8 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: import pandas as pd import pyarrow as pa + from dask.dataframe.dispatch import from_pyarrow_table_dispatch + file = BytesIO(data) end = len(data) shards = [] diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 8e189d034aa..829219cf513 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -16,7 +16,6 @@ import toolz from dask.context import thread_state -from dask.dataframe.dispatch import to_pyarrow_table_dispatch from dask.utils import get_meta_library, parse_bytes from distributed.core import PooledRPCCall @@ -933,6 +932,8 @@ def split_by_worker( """ import numpy as np + from dask.dataframe.dispatch import to_pyarrow_table_dispatch + # (cudf support) Align dataframe backends lib = get_meta_library(df) worker_for = lib.Series(worker_for) From 3da3a05a716a451c561cad6b6e5c1e35f928c240 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 22 Jun 2023 11:35:16 -0700 Subject: [PATCH 18/20] use _constructor_sliced --- distributed/shuffle/_worker_extension.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 829219cf513..df22e2b367d 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -16,7 +16,7 @@ import toolz from dask.context import thread_state -from dask.utils import get_meta_library, parse_bytes +from dask.utils import parse_bytes from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule @@ -934,9 +934,8 @@ def split_by_worker( from dask.dataframe.dispatch import to_pyarrow_table_dispatch - # (cudf support) Align dataframe backends - lib = get_meta_library(df) - worker_for = lib.Series(worker_for) + # (cudf support) Avoid pd.Series + worker_for = df._constructor_sliced(worker_for) df = df.merge( right=worker_for.cat.codes.rename("_worker"), left_on=column, From b221af55a569b152bcd6881aa1a0a2adb948b710 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 24 Jul 2023 14:18:31 -0700 Subject: [PATCH 19/20] DataFrame constructor bugfix --- distributed/shuffle/_shuffle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 62bb13c90b7..32ffc5ede15 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -103,7 +103,7 @@ def rearrange_by_column_p2p( column: str, npartitions: int | None = None, ) -> DataFrame: - from dask.dataframe import DataFrame + from dask.dataframe.core import new_dd_object meta = df._meta check_dtype_support(meta) @@ -125,7 +125,7 @@ def rearrange_by_column_p2p( name_input=df._name, meta_input=meta, ) - return DataFrame( + return new_dd_object( HighLevelGraph.from_collections(name, layer, [df]), name, meta, From 99232e74a48bcb98784a79c8dbfe33b00fc433c9 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 24 Jul 2023 14:48:39 -0700 Subject: [PATCH 20/20] mypy workaround --- distributed/shuffle/_worker_plugin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index e58cf41357f..46c3325266a 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -989,7 +989,9 @@ def split_by_worker( from dask.dataframe.dispatch import to_pyarrow_table_dispatch # (cudf support) Avoid pd.Series - worker_for = df._constructor_sliced(worker_for) + constructor = df._constructor_sliced + assert isinstance(constructor, type) + worker_for = constructor(worker_for) df = df.merge( right=worker_for.cat.codes.rename("_worker"), left_on=column,