Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 65 additions & 24 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, NewType, TypeVar
from functools import partial
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
Expand All @@ -21,27 +22,28 @@
from distributed.shuffle._limiter import ResourceLimiter

if TYPE_CHECKING:
import pandas as pd
# TODO import from typing (requires Python >=3.10)
from typing_extensions import TypeAlias

# avoid circular dependencies
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin

# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin

ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]


_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")

NDIndex: TypeAlias = tuple[int, ...]

ShuffleId = NewType("ShuffleId", str)


class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
def __init__(
self,
id: ShuffleId,
run_id: int,
output_workers: set[str],
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
Expand All @@ -52,7 +54,6 @@ def __init__(
):
self.id = id
self.run_id = run_id
self.output_workers = output_workers
self.local_address = local_address
self.executor = executor
self.rpc = rpc
Expand Down Expand Up @@ -215,7 +216,7 @@ async def add_partition(

@abc.abstractmethod
async def get_output_partition(
self, partition_id: _T_partition_id, key: str, meta: pd.DataFrame | None = None
self, partition_id: _T_partition_id, key: str, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""

Expand All @@ -230,13 +231,12 @@ def get_worker_plugin() -> ShuffleWorkerPlugin:
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
plugin: ShuffleWorkerPlugin | None = worker.plugins.get("shuffle") # type: ignore
if plugin is None:
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a ShuffleExtension. "
"Is pandas installed on the worker?"
)
return plugin
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e


_BARRIER_PREFIX = "shuffle-barrier-"
Expand All @@ -256,19 +256,60 @@ class ShuffleType(Enum):
ARRAY_RECHUNK = "ArrayRechunk"


@dataclass(eq=False)
class ShuffleState(abc.ABC):
_run_id_iterator: ClassVar[itertools.count] = itertools.count(1)
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1))) # type: ignore

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, this is this mypy bug

spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]

@property
def id(self) -> ShuffleId:
return self.spec.id


@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
Comment thread
hendrikmakait marked this conversation as resolved.
run_id: int
output_workers: set[str]

def create_new_run(
self,
plugin: ShuffleSchedulerPlugin,
) -> SchedulerShuffleState:
worker_for = self._pin_output_workers(plugin)
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for),
participating_workers=set(worker_for.values()),
)

@abc.abstractmethod
def _pin_output_workers(
self, plugin: ShuffleSchedulerPlugin
) -> dict[_T_partition_id, str]:
"""Pin output tasks to workers and return the mapping of partition ID to worker."""

@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""


@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)

@abc.abstractmethod
def to_msg(self) -> dict[str, Any]:
"""Transform the shuffle state into a JSON-serializable message"""
@property
def id(self) -> ShuffleId:
return self.run_spec.id

@property
def run_id(self) -> int:
return self.run_spec.run_id

def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
Expand Down
100 changes: 52 additions & 48 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,15 @@

from __future__ import annotations

import os
import pickle
from collections import defaultdict
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple

import dask
from dask.base import tokenize
Expand All @@ -114,23 +116,21 @@
NDIndex,
ShuffleId,
ShuffleRun,
ShuffleState,
ShuffleType,
barrier_key,
ShuffleSpec,
get_worker_plugin,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._shuffle import shuffle_barrier
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
from distributed.shuffle._shuffle import barrier_key, shuffle_barrier
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof

if TYPE_CHECKING:
import numpy as np
import pandas as pd
from typing_extensions import TypeAlias

import dask.array as da


ChunkedAxis: TypeAlias = tuple[float, ...] # chunks must either be an int or NaN
ChunkedAxes: TypeAlias = tuple[ChunkedAxis, ...]
NDSlice: TypeAlias = tuple[slice, ...]
Expand All @@ -147,10 +147,7 @@ def rechunk_transfer(
return get_worker_plugin().add_partition(
input,
partition_id=input_chunk,
shuffle_id=id,
type=ShuffleType.ARRAY_RECHUNK,
new=new,
old=old,
spec=ArrayRechunkSpec(id=id, new=new, old=old),
)
except Exception as e:
raise RuntimeError(f"rechunk_transfer failed during shuffle {id}") from e
Expand Down Expand Up @@ -296,18 +293,16 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]):
This object is responsible for splitting, sending, receiving and combining
data shards.

It is entirely agnostic to the distributed system and can perform a shuffle
with other `Shuffle` instances using `rpc` and `broadcast`.
It is entirely agnostic to the distributed system and can perform a rechunk
with other run instances using `rpc``.

The user of this needs to guarantee that only `Shuffle`s of the same unique
`ShuffleID` interact.
The user of this needs to guarantee that only `ArrayRechunkRun`s of the same unique
`ShuffleID` and `run_id` interact.

Parameters
----------
worker_for:
A mapping partition_id -> worker_address.
output_workers:
A set of all participating worker (addresses).
old:
Existing chunking of the array per dimension.
new:
Expand All @@ -322,8 +317,6 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]):
The scratch directory to buffer data in.
executor:
Thread pool to use for offloading compute.
loop:
The event loop.
rpc:
A callable returning a PooledRPCCall to contact other Shuffle instances.
Typically a ConnectionPool.
Expand All @@ -338,7 +331,6 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]):
def __init__(
self,
worker_for: dict[NDIndex, str],
output_workers: set,
old: ChunkedAxes,
new: ChunkedAxes,
id: ShuffleId,
Expand All @@ -354,7 +346,6 @@ def __init__(
super().__init__(
id=id,
run_id=run_id,
output_workers=output_workers,
local_address=local_address,
directory=directory,
executor=executor,
Expand Down Expand Up @@ -403,7 +394,9 @@ def _repartition_shards(self, data: list[bytes]) -> dict[NDIndex, bytes]:
repartitioned[id].append(shard)
return {k: pickle.dumps(v) for k, v in repartitioned.items()}

async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int:
async def add_partition(
self, data: np.ndarray, partition_id: NDIndex, **kwargs: Any
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
Expand Down Expand Up @@ -441,47 +434,58 @@ def _() -> dict[str, tuple[NDIndex, bytes]]:
return self.run_id

async def get_output_partition(
self, partition_id: NDIndex, key: str, meta: pd.DataFrame | None = None
self, partition_id: NDIndex, key: str, **kwargs: Any

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this switch to kwargs future-proofing?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future-proofing and making it completely transparent to the plugin.

) -> np.ndarray:
self.raise_if_closed()
assert meta is None
assert self.transferred, "`get_output_partition` called before barrier task"
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")

await self._ensure_output_worker(partition_id, key)

await self.flush_receive()

data = self._read_from_disk(partition_id)

def _() -> np.ndarray:
return convert_chunk(data)

return await self.offload(_)
return await self.offload(convert_chunk, data)

def _get_assigned_worker(self, id: NDIndex) -> str:
return self.worker_for[id]


@dataclass(eq=False)
class ArrayRechunkState(ShuffleState):
type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK
worker_for: dict[NDIndex, str]
old: ChunkedAxes
@dataclass(frozen=True)
class ArrayRechunkSpec(ShuffleSpec[NDIndex]):
new: ChunkedAxes
old: ChunkedAxes

def to_msg(self) -> dict[str, Any]:
return {
"status": "OK",
"type": ArrayRechunkState.type,
"run_id": self.run_id,
"worker_for": self.worker_for,
"old": self.old,
"new": self.new,
"output_workers": self.output_workers,
}
def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[NDIndex, str]:
parts_out = product(*(range(len(c)) for c in self.new))
return plugin._pin_output_workers(
self.id, parts_out, _get_worker_for_hash_sharding
)

def create_run_on_worker(
self,
run_id: int,
worker_for: dict[NDIndex, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
return ArrayRechunkRun(
worker_for=worker_for,
old=self.old,
new=self.new,
id=self.id,
run_id=run_id,
directory=os.path.join(
plugin.worker.local_directory,
f"shuffle-{self.id}-{run_id}",
),
executor=plugin._executor,
local_address=plugin.worker.address,
rpc=plugin.worker.rpc,
scheduler=plugin.worker.scheduler,
memory_limiter_disk=plugin.memory_limiter_disk,
memory_limiter_comms=plugin.memory_limiter_comms,
)


def get_worker_for_hash_sharding(
def _get_worker_for_hash_sharding(
output_partition: NDIndex, workers: Sequence[str]
) -> str:
"""Get address of target worker for this output partition using hash sharding"""
Expand Down
Loading