Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
671ab10
Early copy
hendrikmakait Sep 2, 2023
9f08ff3
read fn
hendrikmakait Sep 2, 2023
bf4c737
OSFile
hendrikmakait Sep 2, 2023
c80bf63
Move deser logic
hendrikmakait Sep 2, 2023
2417a5a
meta
hendrikmakait Sep 2, 2023
871c043
Remove delayed enforcement
hendrikmakait Sep 2, 2023
20634e9
Remove delayed enforcement
hendrikmakait Sep 2, 2023
f76c622
Rechunk
hendrikmakait Sep 4, 2023
5a3b33e
Fix tests
hendrikmakait Sep 4, 2023
0149bbd
Merge branch 'main' into reduced-p2p-footprint
hendrikmakait Sep 4, 2023
a60cc07
Resolve path
hendrikmakait Sep 4, 2023
d5f81f4
Increase minimum pyarrow version
hendrikmakait Sep 5, 2023
491b9e6
Add assertion
hendrikmakait Sep 6, 2023
bbbe303
Merge branch 'main' into reduced-p2p-footprint
hendrikmakait Sep 6, 2023
6c80354
Merge branch 'main' into reduced-p2p-footprint
hendrikmakait Sep 6, 2023
bdc0d8e
Merge branch 'main' into reduced-p2p-footprint
hendrikmakait Sep 7, 2023
f14aba6
check_minimal_arrow_version
hendrikmakait Sep 7, 2023
67282f7
Offload entire read conversion
hendrikmakait Sep 7, 2023
429a7ac
offload
hendrikmakait Sep 7, 2023
2f92de0
batching
hendrikmakait Sep 8, 2023
e2368b5
minor
hendrikmakait Sep 8, 2023
629124a
smaller batches
hendrikmakait Sep 8, 2023
ecfe534
Dispatch
hendrikmakait Sep 8, 2023
396c719
Merge branch 'main' into reduced-p2p-footprint
hendrikmakait Sep 8, 2023
ecca1d8
[skip-caching]
hendrikmakait Sep 8, 2023
3bc2c5a
Fix test
hendrikmakait Sep 8, 2023
d0286a4
Merge branch 'main' into reduced-p2p-footprint
hendrikmakait Sep 8, 2023
f23c1aa
[skip-caching]
hendrikmakait Sep 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- pre-commit
- prometheus_client
- psutil
- pyarrow=7
- pyarrow=12
- pynvml # Only tested here
- pytest
- pytest-cov
Expand Down
74 changes: 53 additions & 21 deletions distributed/shuffle/_arrow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from io import BytesIO
from typing import TYPE_CHECKING
from pathlib import Path
from typing import TYPE_CHECKING, Any

from packaging.version import parse

from dask.utils import parse_bytes

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa
Expand All @@ -29,34 +31,27 @@ def check_minimal_arrow_version() -> None:
"""Verify that the the correct version of pyarrow is installed to support
the P2P extension.

Raises a RuntimeError in case pyarrow is not installed or installed version
is not recent enough.
Raises a ModuleNotFoundError if pyarrow is not installed or an
ImportError if the installed version is not recent enough.
"""
# First version to introduce Table.sort_by
minversion = "7.0.0"
# First version that supports concatenating extension arrays (apache/arrow#14463)
minversion = "12.0.0"
try:
import pyarrow as pa
except ImportError:
raise RuntimeError(f"P2P shuffling requires pyarrow>={minversion}")

except ModuleNotFoundError:
raise ModuleNotFoundError(f"P2P shuffling requires pyarrow>={minversion}")
if parse(pa.__version__) < parse(minversion):
raise RuntimeError(
raise ImportError(
Comment on lines +34 to +44

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fjetter: Together with dask/dask#10496, get_default_shuffle_method should raise if pyarrow is outdated and choose tasks if it's not installed.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(testing it manually)

f"P2P shuffling requires pyarrow>={minversion} but only found {pa.__version__}"
)


def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame:
def convert_shards(shards: list[pa.Table], meta: pd.DataFrame) -> pd.DataFrame:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(disclaimer: still in early review) I once tried to move tables around instead of bytes but that messed up the event loop. We should check this before merging

import pyarrow as pa

from dask.dataframe.dispatch import from_pyarrow_table_dispatch

file = BytesIO(data)
end = len(data)
shards = []
while file.tell() < end:
sr = pa.RecordBatchStreamReader(file)
shards.append(sr.read_all())
table = pa.concat_tables(shards, promote=True)
table = pa.concat_tables(shards)

df = from_pyarrow_table_dispatch(meta, table, self_destruct=True)
return df.astype(meta.dtypes, copy=False)
Expand All @@ -66,9 +61,7 @@ def list_of_buffers_to_table(data: list[bytes]) -> pa.Table:
"""Convert a list of arrow buffers and a schema to an Arrow Table"""
import pyarrow as pa

return pa.concat_tables(
(deserialize_table(buffer) for buffer in data), promote=True
)
return pa.concat_tables(deserialize_table(buffer) for buffer in data)


def serialize_table(table: pa.Table) -> bytes:
Expand All @@ -85,3 +78,42 @@ def deserialize_table(buffer: bytes) -> pa.Table:

with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader:
return reader.read_all()


def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]:
import pyarrow as pa

from dask.dataframe.dispatch import pyarrow_schema_dispatch

batch_size = parse_bytes("1 MiB")

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fragile and I don't really like it, but for now it seems to do the job. We will have to spend more time on performance optimization and understanding memory (de)allocation here to make this more robust.

batch = []
shards = []
schema = pyarrow_schema_dispatch(meta, preserve_index=True)

with pa.OSFile(str(path), mode="rb") as f:
size = f.seek(0, whence=2)
f.seek(0)
prev = 0
offset = f.tell()
while offset < size:
sr = pa.RecordBatchStreamReader(f)
shard = sr.read_all()
offset = f.tell()
batch.append(shard)

if offset - prev >= batch_size:
table = pa.concat_tables(batch)
shards.append(_copy_table(table, schema))
batch = []
prev = offset
if batch:
table = pa.concat_tables(batch)
shards.append(_copy_table(table, schema))
return shards, size


def _copy_table(table: pa.Table, schema: pa.Schema) -> pa.Table:
import pyarrow as pa

arrs = [pa.concat_arrays(column.chunks) for column in table.columns]
return pa.table(data=arrs, schema=schema)
11 changes: 8 additions & 3 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar

from distributed.core import PooledRPCCall
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(

self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)

Expand Down Expand Up @@ -180,10 +182,9 @@ def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception

def _read_from_disk(self, id: NDIndex) -> bytes:
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
data: bytes = self._disk_buffer.read("_".join(str(i) for i in id))
return data
return self._disk_buffer.read("_".join(str(i) for i in id))

async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
await self._receive(data)
Expand Down Expand Up @@ -238,6 +239,10 @@ async def _get_output_partition(
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""

@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""


def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
Expand Down
11 changes: 5 additions & 6 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import pathlib
import shutil
from typing import Any, Callable

from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
Expand Down Expand Up @@ -41,6 +42,7 @@ class DiskShardsBuffer(ShardsBuffer):
def __init__(
self,
directory: str | pathlib.Path,
read: Callable[[pathlib.Path], tuple[Any, int]],
memory_limiter: ResourceLimiter | None = None,
):
super().__init__(
Expand All @@ -50,6 +52,7 @@ def __init__(
)
self.directory = pathlib.Path(directory)
self.directory.mkdir(exist_ok=True)
self._read = read

async def _process(self, id: str, shards: list[bytes]) -> None:
"""Write one buffer to file
Expand All @@ -74,19 +77,15 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
for shard in shards:
f.write(shard)

def read(self, id: int | str) -> bytes:
def read(self, id: int | str) -> Any:
"""Read a complete file back into memory"""
self.raise_on_exception()
if not self._inputs_done:
raise RuntimeError("Tried to read from file before done.")

try:
with self.time("read"):
with open(
self.directory / str(id), mode="rb", buffering=100_000_000
) as f:
data = f.read()
size = f.tell()
data, size = self._read((self.directory / str(id)).resolve())
except FileNotFoundError:
raise KeyError(id)

Expand Down
38 changes: 23 additions & 15 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from io import BytesIO
from itertools import product
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple

import dask
Expand Down Expand Up @@ -258,26 +258,22 @@ def split_axes(old: ChunkedAxes, new: ChunkedAxes) -> SplitAxes:
return axes


def convert_chunk(data: bytes) -> np.ndarray:
def convert_chunk(shards: list[tuple[NDIndex, np.ndarray]]) -> np.ndarray:
import numpy as np

from dask.array.core import concatenate3

file = BytesIO(data)
shards: dict[NDIndex, np.ndarray] = {}
indexed: dict[NDIndex, np.ndarray] = {}
for index, shard in shards:
indexed[index] = shard
del shards

while file.tell() < len(data):
for index, shard in pickle.load(file):
shards[index] = shard

subshape = [max(dim) + 1 for dim in zip(*shards.keys())]
assert len(shards) == np.prod(subshape)
subshape = [max(dim) + 1 for dim in zip(*indexed.keys())]
assert len(indexed) == np.prod(subshape)

rec_cat_arg = np.empty(subshape, dtype="O")
for index, shard in shards.items():
for index, shard in indexed.items():
rec_cat_arg[tuple(index)] = shard
del data
del file
arrs = rec_cat_arg.tolist()
return concatenate3(arrs)

Expand Down Expand Up @@ -427,8 +423,20 @@ def _() -> dict[str, tuple[NDIndex, bytes]]:
async def _get_output_partition(
self, partition_id: NDIndex, key: str, **kwargs: Any
) -> np.ndarray:
data = self._read_from_disk(partition_id)
return await self.offload(convert_chunk, data)
def _(partition_id: NDIndex) -> np.ndarray:
data = self._read_from_disk(partition_id)
return convert_chunk(data)

return await self.offload(_, partition_id)

def read(self, path: Path) -> tuple[Any, int]:
shards: list[tuple[NDIndex, np.ndarray]] = []
with path.open(mode="rb") as f:
size = f.seek(0, os.SEEK_END)
f.seek(0)
while f.tell() < size:
shards.extend(pickle.load(f))
return shards, size

def _get_assigned_worker(self, id: NDIndex) -> str:
return self.worker_for[id]
Expand Down
21 changes: 17 additions & 4 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any

import toolz
Expand All @@ -20,8 +21,9 @@
from distributed.shuffle._arrow import (
check_dtype_support,
check_minimal_arrow_version,
convert_partition,
convert_shards,
list_of_buffers_to_table,
read_from_disk,
serialize_table,
)
from distributed.shuffle._core import (
Expand Down Expand Up @@ -321,7 +323,7 @@ def split_by_worker(
return out


def split_by_partition(t: pa.Table, column: str) -> dict[Any, pa.Table]:
def split_by_partition(t: pa.Table, column: str) -> dict[int, pa.Table]:
"""
Split data into many arrow batches, partitioned by final partition
"""
Expand Down Expand Up @@ -383,6 +385,11 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]):
buffer.
"""

column: str
meta: pd.DataFrame
partitions_of: dict[str, list[int]]
worker_for: pd.Series

def __init__(
self,
worker_for: dict[int, str],
Expand Down Expand Up @@ -476,16 +483,22 @@ async def _get_output_partition(
**kwargs: Any,
) -> pd.DataFrame:
try:
data = self._read_from_disk((partition_id,))

out = await self.offload(convert_partition, data, self.meta)
def _(partition_id: int, meta: pd.DataFrame) -> pd.DataFrame:
data = self._read_from_disk((partition_id,))
return convert_shards(data, meta)

out = await self.offload(_, partition_id, self.meta)
except KeyError:
out = self.meta.copy()
return out

def _get_assigned_worker(self, id: int) -> str:
return self.worker_for[id]

def read(self, path: Path) -> tuple[Any, int]:
return read_from_disk(path, self.meta)


@dataclass(frozen=True)
class DataFrameShuffleSpec(ShuffleSpec[int]):
Expand Down
Loading