From 671ab10420ce286e1e2a4a68da05821dc0eb0d19 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Sat, 2 Sep 2023 10:54:13 +0200 Subject: [PATCH 01/22] Early copy --- distributed/shuffle/_arrow.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 08ef2b3f616..2483ed0813c 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -55,7 +55,10 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: shards = [] while file.tell() < end: sr = pa.RecordBatchStreamReader(file) - shards.append(sr.read_all()) + shard = sr.read_all() + arrs = [pa.concat_arrays(column.chunks) for column in shard.columns] + shard = pa.table(data=arrs, schema=shard.schema) + shards.append(shard) table = pa.concat_tables(shards, promote=True) df = from_pyarrow_table_dispatch(meta, table, self_destruct=True) From 9f08ff38b75d582b0689fb6230da203647db4afd Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Sat, 2 Sep 2023 11:11:19 +0200 Subject: [PATCH 02/22] read fn --- distributed/shuffle/_arrow.py | 10 +++++++++- distributed/shuffle/_core.py | 6 ++++++ distributed/shuffle/_disk.py | 11 +++++------ distributed/shuffle/_shuffle.py | 5 +++++ 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 2483ed0813c..18d30565aa6 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -1,7 +1,8 @@ from __future__ import annotations from io import BytesIO -from typing import TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Any from packaging.version import parse @@ -88,3 +89,10 @@ def deserialize_table(buffer: bytes) -> pa.Table: with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader: return reader.read_all() + + +def read_from_disk(path: Path) -> tuple[Any, int]: + with open(path, mode="rb", buffering=100_000_000) as f: + data = f.read() + size = f.tell() + return data, size diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index e7dd72c73cf..5abb1e35df1 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -11,6 +11,7 @@ from dataclasses import dataclass, field from enum import Enum from functools import partial +from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar from distributed.core import PooledRPCCall @@ -62,6 +63,7 @@ def __init__( self._disk_buffer = DiskShardsBuffer( directory=directory, + read=self.read, memory_limiter=memory_limiter_disk, ) @@ -220,6 +222,10 @@ async def get_output_partition( ) -> _T_partition_type: """Get an output partition to the shuffle run""" + @abc.abstractmethod + def read(self, path: Path) -> tuple[Any, int]: + raise NotImplementedError() + def get_worker_plugin() -> ShuffleWorkerPlugin: from distributed import get_worker diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 2b3dc37beed..3052ee5faa2 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -3,6 +3,7 @@ import contextlib import pathlib import shutil +from typing import Any, Callable from distributed.shuffle._buffer import ShardsBuffer from distributed.shuffle._limiter import ResourceLimiter @@ -41,6 +42,7 @@ class DiskShardsBuffer(ShardsBuffer): def __init__( self, directory: str | pathlib.Path, + read: Callable[[pathlib.Path], tuple[Any, int]], memory_limiter: ResourceLimiter | None = None, ): super().__init__( @@ -50,6 +52,7 @@ def __init__( ) self.directory = pathlib.Path(directory) self.directory.mkdir(exist_ok=True) + self._read = read async def _process(self, id: str, shards: list[bytes]) -> None: """Write one buffer to file @@ -74,7 +77,7 @@ async def _process(self, id: str, shards: list[bytes]) -> None: for shard in shards: f.write(shard) - def read(self, id: int | str) -> bytes: + def read(self, id: int | str) -> Any: """Read a complete file back into memory""" self.raise_on_exception() if not self._inputs_done: @@ -82,11 +85,7 @@ def read(self, id: int | str) -> bytes: try: with self.time("read"): - with open( - self.directory / str(id), mode="rb", buffering=100_000_000 - ) as f: - data = f.read() - size = f.tell() + data, size = self._read(self.directory / str(id)) except FileNotFoundError: raise KeyError(id) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 81e83b66585..d8f9ca14fd7 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from functools import partial +from pathlib import Path from typing import TYPE_CHECKING, Any, Union import toolz @@ -22,6 +23,7 @@ check_minimal_arrow_version, convert_partition, list_of_buffers_to_table, + read_from_disk, serialize_table, ) from distributed.shuffle._core import ( @@ -501,6 +503,9 @@ async def get_output_partition( def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id] + def read(self, path: Path) -> tuple[Any, int]: + return read_from_disk(path) + @dataclass(frozen=True) class DataFrameShuffleSpec(ShuffleSpec[int]): From bf4c7379136d3b22644895a0dc48348e2d40b0d2 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Sat, 2 Sep 2023 11:21:40 +0200 Subject: [PATCH 03/22] OSFile --- distributed/shuffle/_arrow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 18d30565aa6..0ca5769175f 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -92,7 +92,9 @@ def deserialize_table(buffer: bytes) -> pa.Table: def read_from_disk(path: Path) -> tuple[Any, int]: - with open(path, mode="rb", buffering=100_000_000) as f: + import pyarrow as pa + + with pa.OSFile(str(path), mode="rb") as f: data = f.read() size = f.tell() return data, size From c80bf63e3f1710eab2c045f2a1a9b50d950f9be1 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Sat, 2 Sep 2023 11:28:57 +0200 Subject: [PATCH 04/22] Move deser logic --- distributed/shuffle/_arrow.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 0ca5769175f..08a7d7b809f 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -51,16 +51,7 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: from dask.dataframe.dispatch import from_pyarrow_table_dispatch - file = BytesIO(data) - end = len(data) - shards = [] - while file.tell() < end: - sr = pa.RecordBatchStreamReader(file) - shard = sr.read_all() - arrs = [pa.concat_arrays(column.chunks) for column in shard.columns] - shard = pa.table(data=arrs, schema=shard.schema) - shards.append(shard) - table = pa.concat_tables(shards, promote=True) + table = pa.concat_tables(data, promote=True) df = from_pyarrow_table_dispatch(meta, table, self_destruct=True) return df.astype(meta.dtypes, copy=False) @@ -94,7 +85,16 @@ def deserialize_table(buffer: bytes) -> pa.Table: def read_from_disk(path: Path) -> tuple[Any, int]: import pyarrow as pa + shards = [] with pa.OSFile(str(path), mode="rb") as f: - data = f.read() + pos = f.tell() + end = f.seek(0, whence=2) + f.seek(pos) + while f.tell() < end: + sr = pa.RecordBatchStreamReader(f) + shard = sr.read_all() + arrs = [pa.concat_arrays(column.chunks) for column in shard.columns] + shard = pa.table(data=arrs, schema=shard.schema) + shards.append(shard) size = f.tell() - return data, size + return shards, size From 2417a5a9481c1bd4ba6e91b8ac1317adbe01ee13 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Sat, 2 Sep 2023 11:46:30 +0200 Subject: [PATCH 05/22] meta --- distributed/shuffle/_arrow.py | 5 +++-- distributed/shuffle/_shuffle.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 08a7d7b809f..169fafe976f 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -82,10 +82,11 @@ def deserialize_table(buffer: bytes) -> pa.Table: return reader.read_all() -def read_from_disk(path: Path) -> tuple[Any, int]: +def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]: import pyarrow as pa shards = [] + schema = pa.Schema.from_pandas(meta, preserve_index=True) with pa.OSFile(str(path), mode="rb") as f: pos = f.tell() end = f.seek(0, whence=2) @@ -94,7 +95,7 @@ def read_from_disk(path: Path) -> tuple[Any, int]: sr = pa.RecordBatchStreamReader(f) shard = sr.read_all() arrs = [pa.concat_arrays(column.chunks) for column in shard.columns] - shard = pa.table(data=arrs, schema=shard.schema) + shard = pa.table(data=arrs, schema=schema) shards.append(shard) size = f.tell() return shards, size diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index d8f9ca14fd7..a6518813c01 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -504,7 +504,7 @@ def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id] def read(self, path: Path) -> tuple[Any, int]: - return read_from_disk(path) + return read_from_disk(path, self.meta) @dataclass(frozen=True) From 871c043f895926c6f2635d6ba4c73965c0309a38 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Sat, 2 Sep 2023 12:07:31 +0200 Subject: [PATCH 06/22] Remove delayed enforcement --- distributed/shuffle/_arrow.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 169fafe976f..838cd0fa47b 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -51,10 +51,9 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: from dask.dataframe.dispatch import from_pyarrow_table_dispatch - table = pa.concat_tables(data, promote=True) + table = pa.concat_tables(data) - df = from_pyarrow_table_dispatch(meta, table, self_destruct=True) - return df.astype(meta.dtypes, copy=False) + return from_pyarrow_table_dispatch(meta, table, self_destruct=True) def list_of_buffers_to_table(data: list[bytes]) -> pa.Table: @@ -62,7 +61,7 @@ def list_of_buffers_to_table(data: list[bytes]) -> pa.Table: import pyarrow as pa return pa.concat_tables( - (deserialize_table(buffer) for buffer in data), promote=True + (deserialize_table(buffer) for buffer in data) ) From 20634e9b74dadce1a95a6ea88c69d528e42f5dab Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Sat, 2 Sep 2023 12:11:58 +0200 Subject: [PATCH 07/22] Remove delayed enforcement --- distributed/shuffle/_arrow.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 838cd0fa47b..f406dee1c23 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -60,9 +60,7 @@ def list_of_buffers_to_table(data: list[bytes]) -> pa.Table: """Convert a list of arrow buffers and a schema to an Arrow Table""" import pyarrow as pa - return pa.concat_tables( - (deserialize_table(buffer) for buffer in data) - ) + return pa.concat_tables(deserialize_table(buffer) for buffer in data) def serialize_table(table: pa.Table) -> bytes: From f76c6227d091631de5d632c2a0e3250cd98839b7 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 4 Sep 2023 16:44:57 +0200 Subject: [PATCH 08/22] Rechunk --- distributed/shuffle/_arrow.py | 9 +++------ distributed/shuffle/_core.py | 7 +++---- distributed/shuffle/_rechunk.py | 31 ++++++++++++++++++------------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index f406dee1c23..18184c7b3e1 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -1,6 +1,5 @@ from __future__ import annotations -from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any @@ -85,14 +84,12 @@ def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]: shards = [] schema = pa.Schema.from_pandas(meta, preserve_index=True) with pa.OSFile(str(path), mode="rb") as f: - pos = f.tell() - end = f.seek(0, whence=2) - f.seek(pos) - while f.tell() < end: + size = f.seek(0, whence=2) + f.seek(0) + while f.tell() < size: sr = pa.RecordBatchStreamReader(f) shard = sr.read_all() arrs = [pa.concat_arrays(column.chunks) for column in shard.columns] shard = pa.table(data=arrs, schema=schema) shards.append(shard) - size = f.tell() return shards, size diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 5abb1e35df1..cda78df6902 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -182,10 +182,9 @@ def fail(self, exception: Exception) -> None: if not self.closed: self._exception = exception - def _read_from_disk(self, id: NDIndex) -> bytes: + def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing self.raise_if_closed() - data: bytes = self._disk_buffer.read("_".join(str(i) for i in id)) - return data + return self._disk_buffer.read("_".join(str(i) for i in id)) async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None: await self._receive(data) @@ -224,7 +223,7 @@ async def get_output_partition( @abc.abstractmethod def read(self, path: Path) -> tuple[Any, int]: - raise NotImplementedError() + """Read shards from disk""" def get_worker_plugin() -> ShuffleWorkerPlugin: diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index d9f3e179753..949f21eda53 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -102,8 +102,8 @@ from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from io import BytesIO from itertools import product +from pathlib import Path from typing import TYPE_CHECKING, Any, NamedTuple import dask @@ -263,26 +263,22 @@ def split_axes(old: ChunkedAxes, new: ChunkedAxes) -> SplitAxes: return axes -def convert_chunk(data: bytes) -> np.ndarray: +def convert_chunk(shards: list[tuple[NDIndex, np.ndarray]]) -> np.ndarray: import numpy as np from dask.array.core import concatenate3 - file = BytesIO(data) - shards: dict[NDIndex, np.ndarray] = {} + indexed: dict[NDIndex, np.ndarray] = {} + for index, shard in shards: + indexed[index] = shard + del shards - while file.tell() < len(data): - for index, shard in pickle.load(file): - shards[index] = shard - - subshape = [max(dim) + 1 for dim in zip(*shards.keys())] - assert len(shards) == np.prod(subshape) + subshape = [max(dim) + 1 for dim in zip(*indexed.keys())] + assert len(indexed) == np.prod(subshape) rec_cat_arg = np.empty(subshape, dtype="O") - for index, shard in shards.items(): + for index, shard in indexed.items(): rec_cat_arg[tuple(index)] = shard - del data - del file arrs = rec_cat_arg.tolist() return concatenate3(arrs) @@ -445,6 +441,15 @@ async def get_output_partition( data = self._read_from_disk(partition_id) return await self.offload(convert_chunk, data) + def read(self, path: Path) -> tuple[Any, int]: + shards: list[tuple[NDIndex, np.ndarray]] = [] + with path.open(mode="rb") as f: + size = f.seek(0, os.SEEK_END) + f.seek(0) + while f.tell() < size: + shards.extend(pickle.load(f)) + return shards, size + def _get_assigned_worker(self, id: NDIndex) -> str: return self.worker_for[id] From 5a3b33e520ef89de56a26228af95812b0341a7cd Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 4 Sep 2023 17:30:03 +0200 Subject: [PATCH 09/22] Fix tests --- distributed/shuffle/_arrow.py | 7 ++++--- distributed/shuffle/_shuffle.py | 4 ++-- distributed/shuffle/tests/test_disk_buffer.py | 17 +++++++++++++---- distributed/shuffle/tests/test_shuffle.py | 18 ++++++++---------- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 18184c7b3e1..a1af0575643 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -45,14 +45,15 @@ def check_minimal_arrow_version() -> None: ) -def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: +def convert_shards(shards: list[pa.Table], meta: pd.DataFrame) -> pd.DataFrame: import pyarrow as pa from dask.dataframe.dispatch import from_pyarrow_table_dispatch - table = pa.concat_tables(data) + table = pa.concat_tables(shards) - return from_pyarrow_table_dispatch(meta, table, self_destruct=True) + df = from_pyarrow_table_dispatch(meta, table, self_destruct=True) + return df.astype(meta.dtypes, copy=False) def list_of_buffers_to_table(data: list[bytes]) -> pa.Table: diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index a6518813c01..788d33a08dd 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -21,7 +21,7 @@ from distributed.shuffle._arrow import ( check_dtype_support, check_minimal_arrow_version, - convert_partition, + convert_shards, list_of_buffers_to_table, read_from_disk, serialize_table, @@ -495,7 +495,7 @@ async def get_output_partition( try: data = self._read_from_disk((partition_id,)) - out = await self.offload(convert_partition, data, self.meta) + out = await self.offload(convert_shards, data, self.meta) except KeyError: out = self.meta.copy() return out diff --git a/distributed/shuffle/tests/test_disk_buffer.py b/distributed/shuffle/tests/test_disk_buffer.py index 76a4a1e70c2..19192de5822 100644 --- a/distributed/shuffle/tests/test_disk_buffer.py +++ b/distributed/shuffle/tests/test_disk_buffer.py @@ -2,6 +2,7 @@ import asyncio import os +from pathlib import Path from typing import Any import pytest @@ -10,9 +11,16 @@ from distributed.utils_test import gen_test +def read_bytes(path: Path) -> tuple[bytes, int]: + with path.open("rb") as f: + data = f.read() + size = f.tell() + return data, size + + @gen_test() async def test_basic(tmp_path): - async with DiskShardsBuffer(directory=tmp_path) as mf: + async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf: await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) @@ -33,7 +41,7 @@ async def test_basic(tmp_path): @gen_test() async def test_read_before_flush(tmp_path): payload = {"1": b"foo"} - async with DiskShardsBuffer(directory=tmp_path) as mf: + async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf: with pytest.raises(RuntimeError): mf.read(1) @@ -51,7 +59,7 @@ async def test_read_before_flush(tmp_path): @pytest.mark.parametrize("count", [2, 100, 1000]) @gen_test() async def test_many(tmp_path, count): - async with DiskShardsBuffer(directory=tmp_path) as mf: + async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf: d = {i: str(i).encode() * 100 for i in range(count)} for _ in range(10): @@ -76,7 +84,7 @@ async def _process(self, *args: Any, **kwargs: Any) -> None: @gen_test() async def test_exceptions(tmp_path): - async with BrokenDiskShardsBuffer(directory=tmp_path) as mf: + async with BrokenDiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf: await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) while not mf._exception: @@ -107,6 +115,7 @@ async def test_high_pressure_flush_with_exception(tmp_path): async with EventuallyBrokenDiskShardsBuffer( directory=tmp_path, + read=read_bytes, ) as mf: tasks = [] for _ in range(10): diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index fc78ef1fda1..b2924f2c99a 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import io import itertools import os import random @@ -31,8 +30,9 @@ from distributed.scheduler import KilledWorker, Scheduler from distributed.scheduler import TaskState as SchedulerTaskState from distributed.shuffle._arrow import ( - convert_partition, + convert_shards, list_of_buffers_to_table, + read_from_disk, serialize_table, ) from distributed.shuffle._limiter import ResourceLimiter @@ -936,7 +936,7 @@ async def test_heartbeat(c, s, a, b): await check_scheduler_cleanup(s) -def test_processing_chain(): +def test_processing_chain(tmp_path): """ This is a serial version of the entire compute chain @@ -1085,18 +1085,16 @@ def __init__(self, value: int) -> None: if w1 is not w2 ) - # Our simple file system - filesystem = defaultdict(io.BytesIO) - for partitions in splits_by_worker.values(): for partition, tables in partitions.items(): for table in tables: - filesystem[partition].write(serialize_table(table)) + with (tmp_path / str(partition)).open("ab") as f: + f.write(serialize_table(table)) out = {} - for k, bio in filesystem.items(): - bio.seek(0) - out[k] = convert_partition(bio.read(), meta) + for k in range(npartitions): + shards, _ = read_from_disk(tmp_path / str(k), meta) + out[k] = convert_shards(shards, meta) shuffled_df = pd.concat(df for df in out.values()) pd.testing.assert_frame_equal( From a60cc07befba1cc6d98f04f233e1d47123d7afed Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 4 Sep 2023 19:10:14 +0200 Subject: [PATCH 10/22] Resolve path --- distributed/shuffle/_disk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 3052ee5faa2..cbdb6b2d410 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -85,7 +85,7 @@ def read(self, id: int | str) -> Any: try: with self.time("read"): - data, size = self._read(self.directory / str(id)) + data, size = self._read((self.directory / str(id)).resolve()) except FileNotFoundError: raise KeyError(id) From d5f81f46727f68b8222b85c4840f3bd923d23fc7 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 5 Sep 2023 20:46:49 +0200 Subject: [PATCH 11/22] Increase minimum pyarrow version --- continuous_integration/environment-3.9.yaml | 2 +- distributed/shuffle/_arrow.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/continuous_integration/environment-3.9.yaml b/continuous_integration/environment-3.9.yaml index ad4d340871e..c275371668d 100644 --- a/continuous_integration/environment-3.9.yaml +++ b/continuous_integration/environment-3.9.yaml @@ -28,7 +28,7 @@ dependencies: - pre-commit - prometheus_client - psutil - - pyarrow=7 + - pyarrow=12 - pynvml # Only tested here - pytest - pytest-cov diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index a1af0575643..3312e9145a9 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -32,8 +32,8 @@ def check_minimal_arrow_version() -> None: Raises a RuntimeError in case pyarrow is not installed or installed version is not recent enough. """ - # First version to introduce Table.sort_by - minversion = "7.0.0" + # First version that supports concatenating extension arrays (apache/arrow#14463) + minversion = "12.0.0" try: import pyarrow as pa except ImportError: From 491b9e67b28f0110fae6cd8a7143146e5a124ca8 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 6 Sep 2023 10:25:59 +0200 Subject: [PATCH 12/22] Add assertion --- distributed/shuffle/_shuffle.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 788d33a08dd..53263819118 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -105,9 +105,15 @@ def rearrange_by_column_p2p( column: str, npartitions: int | None = None, ) -> DataFrame: + import pandas as pd + from dask.dataframe.core import new_dd_object meta = df._meta + if not pd.api.types.is_integer_dtype(meta[column]): + raise TypeError( + f"Expected meta {column=} to be an integer column, is {meta[column].dtype}." + ) check_dtype_support(meta) npartitions = npartitions or df.npartitions token = tokenize(df, column, npartitions) @@ -327,7 +333,7 @@ def split_by_worker( return out -def split_by_partition(t: pa.Table, column: str) -> dict[Any, pa.Table]: +def split_by_partition(t: pa.Table, column: str) -> dict[int, pa.Table]: """ Split data into many arrow batches, partitioned by final partition """ @@ -389,6 +395,11 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): buffer. """ + column: str + meta: pd.DataFrame + partitions_of: dict[str, list[int]] + worker_for: pd.Series + def __init__( self, worker_for: dict[int, str], From f14aba61bf55a65e51b0eac7ba4561b3808c7891 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 7 Sep 2023 11:01:40 +0200 Subject: [PATCH 13/22] check_minimal_arrow_version --- distributed/shuffle/_arrow.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 3312e9145a9..ce2eda517c5 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -29,18 +29,17 @@ def check_minimal_arrow_version() -> None: """Verify that the the correct version of pyarrow is installed to support the P2P extension. - Raises a RuntimeError in case pyarrow is not installed or installed version - is not recent enough. + Raises a ModuleNotFoundError if pyarrow is not installed or an + ImportError if the installed version is not recent enough. """ # First version that supports concatenating extension arrays (apache/arrow#14463) minversion = "12.0.0" try: import pyarrow as pa - except ImportError: - raise RuntimeError(f"P2P shuffling requires pyarrow>={minversion}") - + except ModuleNotFoundError: + raise ModuleNotFoundError(f"P2P shuffling requires pyarrow>={minversion}") if parse(pa.__version__) < parse(minversion): - raise RuntimeError( + raise ImportError( f"P2P shuffling requires pyarrow>={minversion} but only found {pa.__version__}" ) From 67282f74fe4f50f589f04fcfdfea4ce2dd3e59d5 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 7 Sep 2023 18:56:03 +0200 Subject: [PATCH 14/22] Offload entire read conversion --- distributed/shuffle/_shuffle.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 63b8264dbdb..31a0452185a 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -504,9 +504,12 @@ async def get_output_partition( await self.flush_receive() try: - data = self._read_from_disk((partition_id,)) - out = await self.offload(convert_shards, data, self.meta) + def _(partition_id: int, meta: pd.DataFrame) -> pd.DataFrame: + data = self._read_from_disk((partition_id,)) + return convert_shards(data, meta) + + out = await self.offload(_, partition_id, self.meta) except KeyError: out = self.meta.copy() return out From 429a7ac44fed176cef146651a1869f81d2dc59ce Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 7 Sep 2023 19:03:43 +0200 Subject: [PATCH 15/22] offload --- distributed/shuffle/_rechunk.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 949f21eda53..c0913313f7c 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -438,8 +438,12 @@ async def get_output_partition( await self._ensure_output_worker(partition_id, key) await self.flush_receive() - data = self._read_from_disk(partition_id) - return await self.offload(convert_chunk, data) + + def _(partition_id: NDIndex) -> np.ndarray: + data = self._read_from_disk(partition_id) + return convert_chunk(data) + + return await self.offload(_, partition_id) def read(self, path: Path) -> tuple[Any, int]: shards: list[tuple[NDIndex, np.ndarray]] = [] From 2f92de062763dcfd87f994f330055ff0ab7d8e48 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 8 Sep 2023 10:38:57 +0200 Subject: [PATCH 16/22] batching --- distributed/shuffle/_arrow.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index ce2eda517c5..a38f68f8b7e 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -5,6 +5,8 @@ from packaging.version import parse +from dask.utils import parse_bytes + if TYPE_CHECKING: import pandas as pd import pyarrow as pa @@ -81,15 +83,35 @@ def deserialize_table(buffer: bytes) -> pa.Table: def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]: import pyarrow as pa + batch_size = parse_bytes("10 MiB") + batch = [] shards = [] schema = pa.Schema.from_pandas(meta, preserve_index=True) + with pa.OSFile(str(path), mode="rb") as f: size = f.seek(0, whence=2) f.seek(0) - while f.tell() < size: + prev = 0 + offset = f.tell() + while offset() < size: sr = pa.RecordBatchStreamReader(f) shard = sr.read_all() - arrs = [pa.concat_arrays(column.chunks) for column in shard.columns] - shard = pa.table(data=arrs, schema=schema) - shards.append(shard) + offset = f.tell() + batch.append(shard) + + if offset - prev >= batch_size: + table = pa.concat_tables(batch) + shards.append(_copy_table(table, schema)) + batch = [] + prev = offset + if batch: + table = pa.concat_tables(batch) + shards.append(_copy_table(table, schema)) return shards, size + + +def _copy_table(table: pa.Table, schema: pa.Schema) -> pa.Table: + import pyarrow as pa + + arrs = [pa.concat_arrays(column.chunks) for column in table.columns] + return pa.table(data=arrs, schema=schema) From e2368b55171cc737016bb7f8e441946211be42a5 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 8 Sep 2023 10:51:07 +0200 Subject: [PATCH 17/22] minor --- distributed/shuffle/_arrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index a38f68f8b7e..b34851b91fd 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -93,7 +93,7 @@ def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]: f.seek(0) prev = 0 offset = f.tell() - while offset() < size: + while offset < size: sr = pa.RecordBatchStreamReader(f) shard = sr.read_all() offset = f.tell() From 629124a5f71eb75f19abe1bb172cb7cee80e47df Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 8 Sep 2023 12:18:55 +0200 Subject: [PATCH 18/22] smaller batches --- distributed/shuffle/_arrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index b34851b91fd..3222adb88c7 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -83,7 +83,7 @@ def deserialize_table(buffer: bytes) -> pa.Table: def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]: import pyarrow as pa - batch_size = parse_bytes("10 MiB") + batch_size = parse_bytes("1 MiB") batch = [] shards = [] schema = pa.Schema.from_pandas(meta, preserve_index=True) From ecfe534208eaff1c40fdf4a896033f9996a41c72 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 8 Sep 2023 16:04:19 +0200 Subject: [PATCH 19/22] Dispatch --- distributed/shuffle/_arrow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 3222adb88c7..a5fa4a35d09 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -83,10 +83,12 @@ def deserialize_table(buffer: bytes) -> pa.Table: def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]: import pyarrow as pa + from dask.dataframe.dispatch import pyarrow_schema_dispatch + batch_size = parse_bytes("1 MiB") batch = [] shards = [] - schema = pa.Schema.from_pandas(meta, preserve_index=True) + schema = pyarrow_schema_dispatch(meta, preserve_index=True) with pa.OSFile(str(path), mode="rb") as f: size = f.seek(0, whence=2) From ecca1d8d0e6c52b0c8a89a55c2fad13dc61f7255 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 8 Sep 2023 17:29:31 +0200 Subject: [PATCH 20/22] [skip-caching] From 3bc2c5a78b42e30f792f50b0b5047f90c76d034d Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 8 Sep 2023 18:33:43 +0200 Subject: [PATCH 21/22] Fix test --- distributed/tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 8822d39c42e..c855a6a9484 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3359,7 +3359,7 @@ def test_default_get(loop_in_thread): try: check_minimal_arrow_version() has_pyarrow = True - except RuntimeError: + except ImportError: pass loop = loop_in_thread with cluster() as (s, [a, b]): From f23c1aad3791e90603cfcf707cad8c802c736dd7 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 8 Sep 2023 19:49:37 +0200 Subject: [PATCH 22/22] [skip-caching]