diff --git a/continuous_integration/environment-3.9.yaml b/continuous_integration/environment-3.9.yaml index ad4d340871e..c275371668d 100644 --- a/continuous_integration/environment-3.9.yaml +++ b/continuous_integration/environment-3.9.yaml @@ -28,7 +28,7 @@ dependencies: - pre-commit - prometheus_client - psutil - - pyarrow=7 + - pyarrow=12 - pynvml # Only tested here - pytest - pytest-cov diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 08ef2b3f616..a5fa4a35d09 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -1,10 +1,12 @@ from __future__ import annotations -from io import BytesIO -from typing import TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Any from packaging.version import parse +from dask.utils import parse_bytes + if TYPE_CHECKING: import pandas as pd import pyarrow as pa @@ -29,34 +31,27 @@ def check_minimal_arrow_version() -> None: """Verify that the the correct version of pyarrow is installed to support the P2P extension. - Raises a RuntimeError in case pyarrow is not installed or installed version - is not recent enough. + Raises a ModuleNotFoundError if pyarrow is not installed or an + ImportError if the installed version is not recent enough. """ - # First version to introduce Table.sort_by - minversion = "7.0.0" + # First version that supports concatenating extension arrays (apache/arrow#14463) + minversion = "12.0.0" try: import pyarrow as pa - except ImportError: - raise RuntimeError(f"P2P shuffling requires pyarrow>={minversion}") - + except ModuleNotFoundError: + raise ModuleNotFoundError(f"P2P shuffling requires pyarrow>={minversion}") if parse(pa.__version__) < parse(minversion): - raise RuntimeError( + raise ImportError( f"P2P shuffling requires pyarrow>={minversion} but only found {pa.__version__}" ) -def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: +def convert_shards(shards: list[pa.Table], meta: pd.DataFrame) -> pd.DataFrame: import pyarrow as pa from dask.dataframe.dispatch import from_pyarrow_table_dispatch - file = BytesIO(data) - end = len(data) - shards = [] - while file.tell() < end: - sr = pa.RecordBatchStreamReader(file) - shards.append(sr.read_all()) - table = pa.concat_tables(shards, promote=True) + table = pa.concat_tables(shards) df = from_pyarrow_table_dispatch(meta, table, self_destruct=True) return df.astype(meta.dtypes, copy=False) @@ -66,9 +61,7 @@ def list_of_buffers_to_table(data: list[bytes]) -> pa.Table: """Convert a list of arrow buffers and a schema to an Arrow Table""" import pyarrow as pa - return pa.concat_tables( - (deserialize_table(buffer) for buffer in data), promote=True - ) + return pa.concat_tables(deserialize_table(buffer) for buffer in data) def serialize_table(table: pa.Table) -> bytes: @@ -85,3 +78,42 @@ def deserialize_table(buffer: bytes) -> pa.Table: with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader: return reader.read_all() + + +def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]: + import pyarrow as pa + + from dask.dataframe.dispatch import pyarrow_schema_dispatch + + batch_size = parse_bytes("1 MiB") + batch = [] + shards = [] + schema = pyarrow_schema_dispatch(meta, preserve_index=True) + + with pa.OSFile(str(path), mode="rb") as f: + size = f.seek(0, whence=2) + f.seek(0) + prev = 0 + offset = f.tell() + while offset < size: + sr = pa.RecordBatchStreamReader(f) + shard = sr.read_all() + offset = f.tell() + batch.append(shard) + + if offset - prev >= batch_size: + table = pa.concat_tables(batch) + shards.append(_copy_table(table, schema)) + batch = [] + prev = offset + if batch: + table = pa.concat_tables(batch) + shards.append(_copy_table(table, schema)) + return shards, size + + +def _copy_table(table: pa.Table, schema: pa.Schema) -> pa.Table: + import pyarrow as pa + + arrs = [pa.concat_arrays(column.chunks) for column in table.columns] + return pa.table(data=arrs, schema=schema) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 9f9763ab091..be5d75ea527 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -11,6 +11,7 @@ from dataclasses import dataclass, field from enum import Enum from functools import partial +from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar from distributed.core import PooledRPCCall @@ -62,6 +63,7 @@ def __init__( self._disk_buffer = DiskShardsBuffer( directory=directory, + read=self.read, memory_limiter=memory_limiter_disk, ) @@ -180,10 +182,9 @@ def fail(self, exception: Exception) -> None: if not self.closed: self._exception = exception - def _read_from_disk(self, id: NDIndex) -> bytes: + def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing self.raise_if_closed() - data: bytes = self._disk_buffer.read("_".join(str(i) for i in id)) - return data + return self._disk_buffer.read("_".join(str(i) for i in id)) async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None: await self._receive(data) @@ -238,6 +239,10 @@ async def _get_output_partition( ) -> _T_partition_type: """Get an output partition to the shuffle run""" + @abc.abstractmethod + def read(self, path: Path) -> tuple[Any, int]: + """Read shards from disk""" + def get_worker_plugin() -> ShuffleWorkerPlugin: from distributed import get_worker diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 2b3dc37beed..cbdb6b2d410 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -3,6 +3,7 @@ import contextlib import pathlib import shutil +from typing import Any, Callable from distributed.shuffle._buffer import ShardsBuffer from distributed.shuffle._limiter import ResourceLimiter @@ -41,6 +42,7 @@ class DiskShardsBuffer(ShardsBuffer): def __init__( self, directory: str | pathlib.Path, + read: Callable[[pathlib.Path], tuple[Any, int]], memory_limiter: ResourceLimiter | None = None, ): super().__init__( @@ -50,6 +52,7 @@ def __init__( ) self.directory = pathlib.Path(directory) self.directory.mkdir(exist_ok=True) + self._read = read async def _process(self, id: str, shards: list[bytes]) -> None: """Write one buffer to file @@ -74,7 +77,7 @@ async def _process(self, id: str, shards: list[bytes]) -> None: for shard in shards: f.write(shard) - def read(self, id: int | str) -> bytes: + def read(self, id: int | str) -> Any: """Read a complete file back into memory""" self.raise_on_exception() if not self._inputs_done: @@ -82,11 +85,7 @@ def read(self, id: int | str) -> bytes: try: with self.time("read"): - with open( - self.directory / str(id), mode="rb", buffering=100_000_000 - ) as f: - data = f.read() - size = f.tell() + data, size = self._read((self.directory / str(id)).resolve()) except FileNotFoundError: raise KeyError(id) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index dc8bd0371a2..f0efd4f66f3 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -102,8 +102,8 @@ from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from io import BytesIO from itertools import product +from pathlib import Path from typing import TYPE_CHECKING, Any, NamedTuple import dask @@ -258,26 +258,22 @@ def split_axes(old: ChunkedAxes, new: ChunkedAxes) -> SplitAxes: return axes -def convert_chunk(data: bytes) -> np.ndarray: +def convert_chunk(shards: list[tuple[NDIndex, np.ndarray]]) -> np.ndarray: import numpy as np from dask.array.core import concatenate3 - file = BytesIO(data) - shards: dict[NDIndex, np.ndarray] = {} + indexed: dict[NDIndex, np.ndarray] = {} + for index, shard in shards: + indexed[index] = shard + del shards - while file.tell() < len(data): - for index, shard in pickle.load(file): - shards[index] = shard - - subshape = [max(dim) + 1 for dim in zip(*shards.keys())] - assert len(shards) == np.prod(subshape) + subshape = [max(dim) + 1 for dim in zip(*indexed.keys())] + assert len(indexed) == np.prod(subshape) rec_cat_arg = np.empty(subshape, dtype="O") - for index, shard in shards.items(): + for index, shard in indexed.items(): rec_cat_arg[tuple(index)] = shard - del data - del file arrs = rec_cat_arg.tolist() return concatenate3(arrs) @@ -427,8 +423,20 @@ def _() -> dict[str, tuple[NDIndex, bytes]]: async def _get_output_partition( self, partition_id: NDIndex, key: str, **kwargs: Any ) -> np.ndarray: - data = self._read_from_disk(partition_id) - return await self.offload(convert_chunk, data) + def _(partition_id: NDIndex) -> np.ndarray: + data = self._read_from_disk(partition_id) + return convert_chunk(data) + + return await self.offload(_, partition_id) + + def read(self, path: Path) -> tuple[Any, int]: + shards: list[tuple[NDIndex, np.ndarray]] = [] + with path.open(mode="rb") as f: + size = f.seek(0, os.SEEK_END) + f.seek(0) + while f.tell() < size: + shards.extend(pickle.load(f)) + return shards, size def _get_assigned_worker(self, id: NDIndex) -> str: return self.worker_for[id] diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 7d0b0a6a8f2..43fe87b4c5d 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from functools import partial +from pathlib import Path from typing import TYPE_CHECKING, Any import toolz @@ -20,8 +21,9 @@ from distributed.shuffle._arrow import ( check_dtype_support, check_minimal_arrow_version, - convert_partition, + convert_shards, list_of_buffers_to_table, + read_from_disk, serialize_table, ) from distributed.shuffle._core import ( @@ -321,7 +323,7 @@ def split_by_worker( return out -def split_by_partition(t: pa.Table, column: str) -> dict[Any, pa.Table]: +def split_by_partition(t: pa.Table, column: str) -> dict[int, pa.Table]: """ Split data into many arrow batches, partitioned by final partition """ @@ -383,6 +385,11 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): buffer. """ + column: str + meta: pd.DataFrame + partitions_of: dict[str, list[int]] + worker_for: pd.Series + def __init__( self, worker_for: dict[int, str], @@ -476,9 +483,12 @@ async def _get_output_partition( **kwargs: Any, ) -> pd.DataFrame: try: - data = self._read_from_disk((partition_id,)) - out = await self.offload(convert_partition, data, self.meta) + def _(partition_id: int, meta: pd.DataFrame) -> pd.DataFrame: + data = self._read_from_disk((partition_id,)) + return convert_shards(data, meta) + + out = await self.offload(_, partition_id, self.meta) except KeyError: out = self.meta.copy() return out @@ -486,6 +496,9 @@ async def _get_output_partition( def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id] + def read(self, path: Path) -> tuple[Any, int]: + return read_from_disk(path, self.meta) + @dataclass(frozen=True) class DataFrameShuffleSpec(ShuffleSpec[int]): diff --git a/distributed/shuffle/tests/test_disk_buffer.py b/distributed/shuffle/tests/test_disk_buffer.py index 76a4a1e70c2..19192de5822 100644 --- a/distributed/shuffle/tests/test_disk_buffer.py +++ b/distributed/shuffle/tests/test_disk_buffer.py @@ -2,6 +2,7 @@ import asyncio import os +from pathlib import Path from typing import Any import pytest @@ -10,9 +11,16 @@ from distributed.utils_test import gen_test +def read_bytes(path: Path) -> tuple[bytes, int]: + with path.open("rb") as f: + data = f.read() + size = f.tell() + return data, size + + @gen_test() async def test_basic(tmp_path): - async with DiskShardsBuffer(directory=tmp_path) as mf: + async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf: await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) @@ -33,7 +41,7 @@ async def test_basic(tmp_path): @gen_test() async def test_read_before_flush(tmp_path): payload = {"1": b"foo"} - async with DiskShardsBuffer(directory=tmp_path) as mf: + async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf: with pytest.raises(RuntimeError): mf.read(1) @@ -51,7 +59,7 @@ async def test_read_before_flush(tmp_path): @pytest.mark.parametrize("count", [2, 100, 1000]) @gen_test() async def test_many(tmp_path, count): - async with DiskShardsBuffer(directory=tmp_path) as mf: + async with DiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf: d = {i: str(i).encode() * 100 for i in range(count)} for _ in range(10): @@ -76,7 +84,7 @@ async def _process(self, *args: Any, **kwargs: Any) -> None: @gen_test() async def test_exceptions(tmp_path): - async with BrokenDiskShardsBuffer(directory=tmp_path) as mf: + async with BrokenDiskShardsBuffer(directory=tmp_path, read=read_bytes) as mf: await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) while not mf._exception: @@ -107,6 +115,7 @@ async def test_high_pressure_flush_with_exception(tmp_path): async with EventuallyBrokenDiskShardsBuffer( directory=tmp_path, + read=read_bytes, ) as mf: tasks = [] for _ in range(10): diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 68f97a2558e..e3e63ded0e4 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import io import itertools import os import random @@ -31,8 +30,9 @@ from distributed.scheduler import KilledWorker, Scheduler from distributed.scheduler import TaskState as SchedulerTaskState from distributed.shuffle._arrow import ( - convert_partition, + convert_shards, list_of_buffers_to_table, + read_from_disk, serialize_table, ) from distributed.shuffle._limiter import ResourceLimiter @@ -936,7 +936,7 @@ async def test_heartbeat(c, s, a, b): await check_scheduler_cleanup(s) -def test_processing_chain(): +def test_processing_chain(tmp_path): """ This is a serial version of the entire compute chain @@ -1085,18 +1085,16 @@ def __init__(self, value: int) -> None: if w1 is not w2 ) - # Our simple file system - filesystem = defaultdict(io.BytesIO) - for partitions in splits_by_worker.values(): for partition, tables in partitions.items(): for table in tables: - filesystem[partition].write(serialize_table(table)) + with (tmp_path / str(partition)).open("ab") as f: + f.write(serialize_table(table)) out = {} - for k, bio in filesystem.items(): - bio.seek(0) - out[k] = convert_partition(bio.read(), meta) + for k in range(npartitions): + shards, _ = read_from_disk(tmp_path / str(k), meta) + out[k] = convert_shards(shards, meta) shuffled_df = pd.concat(df for df in out.values()) pd.testing.assert_frame_equal( diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 8822d39c42e..c855a6a9484 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3359,7 +3359,7 @@ def test_default_get(loop_in_thread): try: check_minimal_arrow_version() has_pyarrow = True - except RuntimeError: + except ImportError: pass loop = loop_in_thread with cluster() as (s, [a, b]):