diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index fe8a38f5f00..e7dd72c73cf 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -10,7 +10,8 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Generic, NewType, TypeVar +from functools import partial +from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule @@ -21,27 +22,28 @@ from distributed.shuffle._limiter import ResourceLimiter if TYPE_CHECKING: - import pandas as pd + # TODO import from typing (requires Python >=3.10) from typing_extensions import TypeAlias - # avoid circular dependencies + from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin + + # circular dependencies from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin +ShuffleId = NewType("ShuffleId", str) +NDIndex: TypeAlias = tuple[int, ...] + + _T_partition_id = TypeVar("_T_partition_id") _T_partition_type = TypeVar("_T_partition_type") _T = TypeVar("_T") -NDIndex: TypeAlias = tuple[int, ...] - -ShuffleId = NewType("ShuffleId", str) - class ShuffleRun(Generic[_T_partition_id, _T_partition_type]): def __init__( self, id: ShuffleId, run_id: int, - output_workers: set[str], local_address: str, directory: str, executor: ThreadPoolExecutor, @@ -52,7 +54,6 @@ def __init__( ): self.id = id self.run_id = run_id - self.output_workers = output_workers self.local_address = local_address self.executor = executor self.rpc = rpc @@ -215,7 +216,7 @@ async def add_partition( @abc.abstractmethod async def get_output_partition( - self, partition_id: _T_partition_id, key: str, meta: pd.DataFrame | None = None + self, partition_id: _T_partition_id, key: str, **kwargs: Any ) -> _T_partition_type: """Get an output partition to the shuffle run""" @@ -230,13 +231,12 @@ def get_worker_plugin() -> ShuffleWorkerPlugin: "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; " "please confirm that you've created a distributed Client and are submitting this computation through it." ) from e - plugin: ShuffleWorkerPlugin | None = worker.plugins.get("shuffle") # type: ignore - if plugin is None: + try: + return worker.plugins["shuffle"] # type: ignore + except KeyError as e: raise RuntimeError( - f"The worker {worker.address} does not have a ShuffleExtension. " - "Is pandas installed on the worker?" - ) - return plugin + f"The worker {worker.address} does not have a P2P shuffle plugin." + ) from e _BARRIER_PREFIX = "shuffle-barrier-" @@ -256,19 +256,60 @@ class ShuffleType(Enum): ARRAY_RECHUNK = "ArrayRechunk" -@dataclass(eq=False) -class ShuffleState(abc.ABC): - _run_id_iterator: ClassVar[itertools.count] = itertools.count(1) +@dataclass(frozen=True) +class ShuffleRunSpec(Generic[_T_partition_id]): + run_id: int = field(init=False, default_factory=partial(next, itertools.count(1))) # type: ignore + spec: ShuffleSpec + worker_for: dict[_T_partition_id, str] + + @property + def id(self) -> ShuffleId: + return self.spec.id + +@dataclass(frozen=True) +class ShuffleSpec(abc.ABC, Generic[_T_partition_id]): id: ShuffleId - run_id: int - output_workers: set[str] + + def create_new_run( + self, + plugin: ShuffleSchedulerPlugin, + ) -> SchedulerShuffleState: + worker_for = self._pin_output_workers(plugin) + return SchedulerShuffleState( + run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for), + participating_workers=set(worker_for.values()), + ) + + @abc.abstractmethod + def _pin_output_workers( + self, plugin: ShuffleSchedulerPlugin + ) -> dict[_T_partition_id, str]: + """Pin output tasks to workers and return the mapping of partition ID to worker.""" + + @abc.abstractmethod + def create_run_on_worker( + self, + run_id: int, + worker_for: dict[_T_partition_id, str], + plugin: ShuffleWorkerPlugin, + ) -> ShuffleRun: + """Create the new shuffle run on the worker.""" + + +@dataclass(eq=False) +class SchedulerShuffleState(Generic[_T_partition_id]): + run_spec: ShuffleRunSpec participating_workers: set[str] _archived_by: str | None = field(default=None, init=False) - @abc.abstractmethod - def to_msg(self) -> dict[str, Any]: - """Transform the shuffle state into a JSON-serializable message""" + @property + def id(self) -> ShuffleId: + return self.run_spec.id + + @property + def run_id(self) -> int: + return self.run_spec.run_id def __str__(self) -> str: return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 0c6116646e7..d9f3e179753 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -96,13 +96,15 @@ from __future__ import annotations +import os import pickle from collections import defaultdict from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple +from itertools import product +from typing import TYPE_CHECKING, Any, NamedTuple import dask from dask.base import tokenize @@ -114,23 +116,21 @@ NDIndex, ShuffleId, ShuffleRun, - ShuffleState, - ShuffleType, - barrier_key, + ShuffleSpec, get_worker_plugin, ) from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._shuffle import shuffle_barrier +from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin +from distributed.shuffle._shuffle import barrier_key, shuffle_barrier +from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof if TYPE_CHECKING: import numpy as np - import pandas as pd from typing_extensions import TypeAlias import dask.array as da - ChunkedAxis: TypeAlias = tuple[float, ...] # chunks must either be an int or NaN ChunkedAxes: TypeAlias = tuple[ChunkedAxis, ...] NDSlice: TypeAlias = tuple[slice, ...] @@ -147,10 +147,7 @@ def rechunk_transfer( return get_worker_plugin().add_partition( input, partition_id=input_chunk, - shuffle_id=id, - type=ShuffleType.ARRAY_RECHUNK, - new=new, - old=old, + spec=ArrayRechunkSpec(id=id, new=new, old=old), ) except Exception as e: raise RuntimeError(f"rechunk_transfer failed during shuffle {id}") from e @@ -296,18 +293,16 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): This object is responsible for splitting, sending, receiving and combining data shards. - It is entirely agnostic to the distributed system and can perform a shuffle - with other `Shuffle` instances using `rpc` and `broadcast`. + It is entirely agnostic to the distributed system and can perform a rechunk + with other run instances using `rpc``. - The user of this needs to guarantee that only `Shuffle`s of the same unique - `ShuffleID` interact. + The user of this needs to guarantee that only `ArrayRechunkRun`s of the same unique + `ShuffleID` and `run_id` interact. Parameters ---------- worker_for: A mapping partition_id -> worker_address. - output_workers: - A set of all participating worker (addresses). old: Existing chunking of the array per dimension. new: @@ -322,8 +317,6 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): The scratch directory to buffer data in. executor: Thread pool to use for offloading compute. - loop: - The event loop. rpc: A callable returning a PooledRPCCall to contact other Shuffle instances. Typically a ConnectionPool. @@ -338,7 +331,6 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): def __init__( self, worker_for: dict[NDIndex, str], - output_workers: set, old: ChunkedAxes, new: ChunkedAxes, id: ShuffleId, @@ -354,7 +346,6 @@ def __init__( super().__init__( id=id, run_id=run_id, - output_workers=output_workers, local_address=local_address, directory=directory, executor=executor, @@ -403,7 +394,9 @@ def _repartition_shards(self, data: list[bytes]) -> dict[NDIndex, bytes]: repartitioned[id].append(shard) return {k: pickle.dumps(v) for k, v in repartitioned.items()} - async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int: + async def add_partition( + self, data: np.ndarray, partition_id: NDIndex, **kwargs: Any + ) -> int: self.raise_if_closed() if self.transferred: raise RuntimeError(f"Cannot add more partitions to {self}") @@ -441,47 +434,58 @@ def _() -> dict[str, tuple[NDIndex, bytes]]: return self.run_id async def get_output_partition( - self, partition_id: NDIndex, key: str, meta: pd.DataFrame | None = None + self, partition_id: NDIndex, key: str, **kwargs: Any ) -> np.ndarray: self.raise_if_closed() - assert meta is None - assert self.transferred, "`get_output_partition` called before barrier task" + if not self.transferred: + raise RuntimeError("`get_output_partition` called before barrier task") await self._ensure_output_worker(partition_id, key) - await self.flush_receive() - data = self._read_from_disk(partition_id) - - def _() -> np.ndarray: - return convert_chunk(data) - - return await self.offload(_) + return await self.offload(convert_chunk, data) def _get_assigned_worker(self, id: NDIndex) -> str: return self.worker_for[id] -@dataclass(eq=False) -class ArrayRechunkState(ShuffleState): - type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK - worker_for: dict[NDIndex, str] - old: ChunkedAxes +@dataclass(frozen=True) +class ArrayRechunkSpec(ShuffleSpec[NDIndex]): new: ChunkedAxes + old: ChunkedAxes - def to_msg(self) -> dict[str, Any]: - return { - "status": "OK", - "type": ArrayRechunkState.type, - "run_id": self.run_id, - "worker_for": self.worker_for, - "old": self.old, - "new": self.new, - "output_workers": self.output_workers, - } + def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[NDIndex, str]: + parts_out = product(*(range(len(c)) for c in self.new)) + return plugin._pin_output_workers( + self.id, parts_out, _get_worker_for_hash_sharding + ) + + def create_run_on_worker( + self, + run_id: int, + worker_for: dict[NDIndex, str], + plugin: ShuffleWorkerPlugin, + ) -> ShuffleRun: + return ArrayRechunkRun( + worker_for=worker_for, + old=self.old, + new=self.new, + id=self.id, + run_id=run_id, + directory=os.path.join( + plugin.worker.local_directory, + f"shuffle-{self.id}-{run_id}", + ), + executor=plugin._executor, + local_address=plugin.worker.address, + rpc=plugin.worker.rpc, + scheduler=plugin.worker.scheduler, + memory_limiter_disk=plugin.memory_limiter_disk, + memory_limiter_comms=plugin.memory_limiter_comms, + ) -def get_worker_for_hash_sharding( +def _get_worker_for_hash_sharding( output_partition: NDIndex, workers: Sequence[str] ) -> str: """Get address of target worker for this output partition using hash sharding""" diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 9fdf8f10c2a..0a511145bc8 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -4,24 +4,19 @@ import logging from collections import defaultdict from collections.abc import Callable, Iterable, Sequence -from functools import partial -from itertools import product from typing import TYPE_CHECKING, Any from distributed.diagnostics.plugin import SchedulerPlugin from distributed.protocol.pickle import dumps +from distributed.protocol.serialize import ToPickle from distributed.shuffle._core import ( + SchedulerShuffleState, ShuffleId, - ShuffleState, - ShuffleType, + ShuffleRunSpec, + ShuffleSpec, barrier_key, id_from_key, ) -from distributed.shuffle._rechunk import ArrayRechunkState, get_worker_for_hash_sharding -from distributed.shuffle._shuffle import ( - DataFrameShuffleState, - get_worker_for_range_sharding, -) from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin if TYPE_CHECKING: @@ -47,10 +42,10 @@ class ShuffleSchedulerPlugin(SchedulerPlugin): """ scheduler: Scheduler - active_shuffles: dict[ShuffleId, ShuffleState] + active_shuffles: dict[ShuffleId, SchedulerShuffleState] heartbeats: defaultdict[ShuffleId, dict] - _shuffles: defaultdict[ShuffleId, set[ShuffleState]] - _archived_by_stimulus: defaultdict[str, set[ShuffleState]] + _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]] + _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]] def __init__(self, scheduler: Scheduler): self.scheduler = scheduler @@ -79,7 +74,8 @@ def shuffle_ids(self) -> set[ShuffleId]: async def barrier(self, id: ShuffleId, run_id: int) -> None: shuffle = self.active_shuffles[id] - assert shuffle.run_id == run_id, f"{run_id=} does not match {shuffle}" + if shuffle.run_id != run_id: + raise ValueError(f"{run_id=} does not match {shuffle}") msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id} await self.scheduler.broadcast( msg=msg, @@ -107,7 +103,7 @@ def heartbeat(self, ws: WorkerState, data: dict) -> None: if shuffle_id in self.shuffle_ids(): self.heartbeats[shuffle_id][ws.address].update(d) - def get(self, id: ShuffleId, worker: str) -> dict[str, Any]: + def get(self, id: ShuffleId, worker: str) -> ToPickle[ShuffleRunSpec]: if worker not in self.scheduler.workers: # This should never happen raise RuntimeError( @@ -115,36 +111,27 @@ def get(self, id: ShuffleId, worker: str) -> dict[str, Any]: ) # pragma: nocover state = self.active_shuffles[id] state.participating_workers.add(worker) - return state.to_msg() + return ToPickle(state.run_spec) def get_or_create( self, - id: ShuffleId, + spec: ShuffleSpec, key: str, - type: str, worker: str, - spec: dict[str, Any], - ) -> dict: + ) -> ToPickle[ShuffleRunSpec]: try: - return self.get(id, worker) + return self.get(spec.id, worker) except KeyError: # FIXME: The current implementation relies on the barrier task to be # known by its name. If the name has been mangled, we cannot guarantee # that the shuffle works as intended and should fail instead. - self._raise_if_barrier_unknown(id) + self._raise_if_barrier_unknown(spec.id) self._raise_if_task_not_processing(key) - - state: ShuffleState - if type == ShuffleType.DATAFRAME: - state = self._create_dataframe_shuffle_state(id, spec) - elif type == ShuffleType.ARRAY_RECHUNK: - state = self._create_array_rechunk_state(id, spec) - else: # pragma: no cover - raise TypeError(type) - self.active_shuffles[id] = state - self._shuffles[id].add(state) + state = spec.create_new_run(self) + self.active_shuffles[spec.id] = state + self._shuffles[spec.id].add(state) state.participating_workers.add(worker) - return state.to_msg() + return ToPickle(state.run_spec) def _raise_if_barrier_unknown(self, id: ShuffleId) -> None: key = barrier_key(id) @@ -162,30 +149,6 @@ def _raise_if_task_not_processing(self, key: str) -> None: if task.state != "processing": raise RuntimeError(f"Expected {task} to be processing, is {task.state}.") - def _create_dataframe_shuffle_state( - self, id: ShuffleId, spec: dict[str, Any] - ) -> DataFrameShuffleState: - column = spec["column"] - npartitions = spec["npartitions"] - parts_out = spec["parts_out"] - assert column is not None - assert npartitions is not None - assert parts_out is not None - - pick_worker = partial(get_worker_for_range_sharding, npartitions) - - mapping = self._pin_output_workers(id, parts_out, pick_worker) - output_workers = set(mapping.values()) - - return DataFrameShuffleState( - id=id, - run_id=next(ShuffleState._run_id_iterator), - worker_for=mapping, - column=column, - output_workers=output_workers, - participating_workers=output_workers.copy(), - ) - def _pin_output_workers( self, id: ShuffleId, @@ -232,28 +195,6 @@ def _pin_output_workers( return mapping - def _create_array_rechunk_state( - self, id: ShuffleId, spec: dict[str, Any] - ) -> ArrayRechunkState: - old = spec["old"] - new = spec["new"] - assert old is not None - assert new is not None - - parts_out = product(*(range(len(c)) for c in new)) - mapping = self._pin_output_workers(id, parts_out, get_worker_for_hash_sharding) - output_workers = set(mapping.values()) - - return ArrayRechunkState( - id=id, - run_id=next(ShuffleState._run_id_iterator), - worker_for=mapping, - output_workers=output_workers, - old=old, - new=new, - participating_workers=output_workers.copy(), - ) - def _set_restriction(self, ts: TaskState, worker: str) -> None: if "shuffle_original_restrictions" in ts.annotations: # This may occur if multiple barriers share the same output task, @@ -361,7 +302,7 @@ def transition( if not archived: del self._archived_by_stimulus[shuffle._archived_by] - def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None: + def _fail_on_workers(self, shuffle: SchedulerShuffleState, message: str) -> None: worker_msgs = { worker: [ { diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 4a4f4601ea1..e16340c6399 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -1,11 +1,13 @@ from __future__ import annotations import logging +import os from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Union +from functools import partial +from typing import TYPE_CHECKING, Any, Union import toolz @@ -26,13 +28,14 @@ NDIndex, ShuffleId, ShuffleRun, - ShuffleState, - ShuffleType, + ShuffleSpec, barrier_key, get_worker_plugin, ) from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin +from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof logger = logging.getLogger("distributed.shuffle") @@ -57,12 +60,10 @@ def shuffle_transfer( try: return get_worker_plugin().add_partition( input, - shuffle_id=id, - type=ShuffleType.DATAFRAME, - partition_id=input_partition, - npartitions=npartitions, - column=column, - parts_out=parts_out, + input_partition, + spec=DataFrameShuffleSpec( + id=id, npartitions=npartitions, column=column, parts_out=parts_out + ), ) except ShuffleClosedError: raise Reschedule() @@ -341,17 +342,15 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): data shards. It is entirely agnostic to the distributed system and can perform a shuffle - with other `Shuffle` instances using `rpc` and `broadcast`. + with other run instances using `rpc`. - The user of this needs to guarantee that only `Shuffle`s of the same unique - `ShuffleID` interact. + The user of this needs to guarantee that only `DataFrameShuffleRun`s of the + same unique `ShuffleID` and `run_id` interact. Parameters ---------- worker_for: A mapping partition_id -> worker_address. - output_workers: - A set of all participating worker (addresses). column: The data column we split the input partition by. id: @@ -364,8 +363,6 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): The scratch directory to buffer data in. executor: Thread pool to use for offloading compute. - loop: - The event loop. rpc: A callable returning a PooledRPCCall to contact other Shuffle instances. Typically a ConnectionPool. @@ -380,7 +377,6 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): def __init__( self, worker_for: dict[int, str], - output_workers: set, column: str, id: ShuffleId, run_id: int, @@ -397,7 +393,6 @@ def __init__( super().__init__( id=id, run_id=run_id, - output_workers=output_workers, local_address=local_address, directory=directory, executor=executor, @@ -443,7 +438,12 @@ def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, bytes]: del data return {(k,): serialize_table(v) for k, v in groups.items()} - async def add_partition(self, data: pd.DataFrame, partition_id: int) -> int: + async def add_partition( + self, + data: pd.DataFrame, + partition_id: int, + **kwargs: Any, + ) -> int: self.raise_if_closed() if self.transferred: raise RuntimeError(f"Cannot add more partitions to {self}") @@ -462,11 +462,17 @@ def _() -> dict[str, tuple[int, bytes]]: return self.run_id async def get_output_partition( - self, partition_id: int, key: str, meta: pd.DataFrame | None = None + self, + partition_id: int, + key: str, + meta: pd.DataFrame | None = None, + **kwargs: Any, ) -> pd.DataFrame: self.raise_if_closed() - assert meta is not None - assert self.transferred, "`get_output_partition` called before barrier task" + if meta is None: + raise ValueError("Excepted meta keyword argument") + if not self.transferred: + raise RuntimeError("`get_output_partition` called before barrier task") await self._ensure_output_worker(partition_id, key) @@ -474,10 +480,7 @@ async def get_output_partition( try: data = self._read_from_disk((partition_id,)) - def _() -> pd.DataFrame: - return convert_partition(data, meta) # type: ignore - - out = await self.offload(_) + out = await self.offload(convert_partition, data, meta) except KeyError: out = meta.copy() return out @@ -486,24 +489,38 @@ def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id] -@dataclass(eq=False) -class DataFrameShuffleState(ShuffleState): - type: ClassVar[ShuffleType] = ShuffleType.DATAFRAME - worker_for: dict[int, str] +@dataclass(frozen=True) +class DataFrameShuffleSpec(ShuffleSpec[int]): + npartitions: int column: str - - def to_msg(self) -> dict[str, Any]: - return { - "status": "OK", - "type": DataFrameShuffleState.type, - "run_id": self.run_id, - "worker_for": self.worker_for, - "column": self.column, - "output_workers": self.output_workers, - } + parts_out: set[int] + + def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[int, str]: + pick_worker = partial(_get_worker_for_range_sharding, self.npartitions) + return plugin._pin_output_workers(self.id, self.parts_out, pick_worker) + + def create_run_on_worker( + self, run_id: int, worker_for: dict[int, str], plugin: ShuffleWorkerPlugin + ) -> ShuffleRun: + return DataFrameShuffleRun( + column=self.column, + worker_for=worker_for, + id=self.id, + run_id=run_id, + directory=os.path.join( + plugin.worker.local_directory, + f"shuffle-{self.id}-{run_id}", + ), + executor=plugin._executor, + local_address=plugin.worker.address, + rpc=plugin.worker.rpc, + scheduler=plugin.worker.scheduler, + memory_limiter_disk=plugin.memory_limiter_disk, + memory_limiter_comms=plugin.memory_limiter_comms, + ) -def get_worker_for_range_sharding( +def _get_worker_for_range_sharding( npartitions: int, output_partition: int, workers: Sequence[str] ) -> str: """Get address of target worker for this output partition using range sharding""" diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 7340b8fe6e6..bc8b1f3d4d1 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -2,7 +2,6 @@ import asyncio import logging -import os from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any, overload @@ -10,11 +9,16 @@ from dask.utils import parse_bytes from distributed.diagnostics.plugin import WorkerPlugin -from distributed.shuffle._core import NDIndex, ShuffleId, ShuffleRun, ShuffleType +from distributed.protocol.serialize import ToPickle +from distributed.shuffle._core import ( + NDIndex, + ShuffleId, + ShuffleRun, + ShuffleRunSpec, + ShuffleSpec, +) from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._rechunk import ArrayRechunkRun -from distributed.shuffle._shuffle import DataFrameShuffleRun from distributed.utils import log_errors, sync if TYPE_CHECKING: @@ -23,6 +27,7 @@ from distributed.worker import Worker + logger = logging.getLogger(__name__) @@ -127,17 +132,17 @@ def shuffle_fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None def add_partition( self, data: Any, - partition_id: int | tuple[int, ...], - shuffle_id: ShuffleId, - type: ShuffleType, + partition_id: int | NDIndex, + spec: ShuffleSpec, **kwargs: Any, ) -> int: - shuffle = self.get_or_create_shuffle(shuffle_id, type=type, **kwargs) + shuffle = self.get_or_create_shuffle(spec) return sync( self.worker.loop, shuffle.add_partition, data=data, partition_id=partition_id, + **kwargs, ) async def _barrier(self, shuffle_id: ShuffleId, run_ids: list[int]) -> int: @@ -149,7 +154,8 @@ async def _barrier(self, shuffle_id: ShuffleId, run_ids: list[int]) -> int: """ run_id = run_ids[0] # Assert that all input data has been shuffled using the same run_id - assert all(run_id == id for id in run_ids) + if any(run_id != id for id in run_ids): + raise RuntimeError(f"Expected all run IDs to match: {run_ids=}") # Tell all peers that we've reached the barrier # Note that this will call `shuffle_inputs_done` on our own worker as well shuffle = await self._get_shuffle_run(shuffle_id, run_id) @@ -194,10 +200,8 @@ async def _get_shuffle_run( async def _get_or_create_shuffle( self, - shuffle_id: ShuffleId, - type: ShuffleType, + spec: ShuffleSpec, key: str, - **kwargs: Any, ) -> ShuffleRun: """Get or create a shuffle matching the ID and data spec. @@ -210,13 +214,12 @@ async def _get_or_create_shuffle( key: Task key triggering the function """ - shuffle = self.shuffles.get(shuffle_id, None) + shuffle = self.shuffles.get(spec.id, None) if shuffle is None: shuffle = await self._refresh_shuffle( - shuffle_id=shuffle_id, - type=type, + shuffle_id=spec.id, + spec=ToPickle(spec), key=key, - kwargs=kwargs, ) if self.closed: @@ -236,58 +239,38 @@ async def _refresh_shuffle( async def _refresh_shuffle( self, shuffle_id: ShuffleId, - type: ShuffleType, + spec: ToPickle, key: str, - kwargs: dict, ) -> ShuffleRun: ... async def _refresh_shuffle( self, shuffle_id: ShuffleId, - type: ShuffleType | None = None, + spec: ToPickle | None = None, key: str | None = None, - kwargs: dict | None = None, ) -> ShuffleRun: - result: dict[str, Any] - if type is None: + result: ShuffleRunSpec + if spec is None: result = await self.worker.scheduler.shuffle_get( id=shuffle_id, worker=self.worker.address, ) - elif type == ShuffleType.DATAFRAME: - assert kwargs is not None - result = await self.worker.scheduler.shuffle_get_or_create( - id=shuffle_id, - key=key, - type=type, - spec={ - "npartitions": kwargs["npartitions"], - "column": kwargs["column"], - "parts_out": kwargs["parts_out"], - }, - worker=self.worker.address, - ) - elif type == ShuffleType.ARRAY_RECHUNK: - assert kwargs is not None + else: result = await self.worker.scheduler.shuffle_get_or_create( - id=shuffle_id, + spec=spec, key=key, - type=type, - spec=kwargs, worker=self.worker.address, ) - else: # pragma: no cover - raise TypeError(type) - if result["status"] == "error": - raise RuntimeError(result["message"]) - assert result["status"] == "OK" + # if result["status"] == "error": + # raise RuntimeError(result["message"]) + # assert result["status"] == "OK" if self.closed: raise ShuffleClosedError(f"{self} has already been closed") if shuffle_id in self.shuffles: existing = self.shuffles[shuffle_id] - if existing.run_id >= result["run_id"]: + if existing.run_id >= result.run_id: return existing else: self.shuffles.pop(shuffle_id) @@ -305,66 +288,13 @@ async def _( self.worker._ongoing_background_tasks.call_soon(_, self, existing) - shuffle = self._create_shuffle_run(shuffle_id, result) + shuffle: ShuffleRun = result.spec.create_run_on_worker( + result.run_id, result.worker_for, self + ) self.shuffles[shuffle_id] = shuffle self._runs.add(shuffle) return shuffle - def _create_shuffle_run( - self, shuffle_id: ShuffleId, result: dict[str, Any] - ) -> ShuffleRun: - shuffle: ShuffleRun - if result["type"] == ShuffleType.DATAFRAME: - shuffle = self._create_dataframe_shuffle_run(shuffle_id, result) - elif result["type"] == ShuffleType.ARRAY_RECHUNK: - shuffle = self._create_array_rechunk_run(shuffle_id, result) - else: # pragma: no cover - raise TypeError(result["type"]) - return shuffle - - def _create_dataframe_shuffle_run( - self, shuffle_id: ShuffleId, result: dict[str, Any] - ) -> DataFrameShuffleRun: - return DataFrameShuffleRun( - column=result["column"], - worker_for=result["worker_for"], - output_workers=result["output_workers"], - id=shuffle_id, - run_id=result["run_id"], - directory=os.path.join( - self.worker.local_directory, - f"shuffle-{shuffle_id}-{result['run_id']}", - ), - executor=self._executor, - local_address=self.worker.address, - rpc=self.worker.rpc, - scheduler=self.worker.scheduler, - memory_limiter_disk=self.memory_limiter_disk, - memory_limiter_comms=self.memory_limiter_comms, - ) - - def _create_array_rechunk_run( - self, shuffle_id: ShuffleId, result: dict[str, Any] - ) -> ArrayRechunkRun: - return ArrayRechunkRun( - worker_for=result["worker_for"], - output_workers=result["output_workers"], - old=result["old"], - new=result["new"], - id=shuffle_id, - run_id=result["run_id"], - directory=os.path.join( - self.worker.local_directory, - f"shuffle-{shuffle_id}-{result['run_id']}", - ), - executor=self._executor, - local_address=self.worker.address, - rpc=self.worker.rpc, - scheduler=self.worker.scheduler, - memory_limiter_disk=self.memory_limiter_disk, - memory_limiter_comms=self.memory_limiter_comms, - ) - async def teardown(self, worker: Worker) -> None: assert not self.closed @@ -406,18 +336,14 @@ def get_shuffle_run( def get_or_create_shuffle( self, - shuffle_id: ShuffleId, - type: ShuffleType, - **kwargs: Any, + spec: ShuffleSpec, ) -> ShuffleRun: key = thread_state.key return sync( self.worker.loop, self._get_or_create_shuffle, - shuffle_id, - type, + spec, key, - **kwargs, ) def get_output_partition( diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 62f3b1b9c93..fcdee28f857 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -22,7 +22,7 @@ from distributed.shuffle._rechunk import ( ArrayRechunkRun, Split, - get_worker_for_hash_sharding, + _get_worker_for_hash_sharding, split_axes, ) from distributed.shuffle.tests.utils import AbstractShuffleTestPool @@ -55,8 +55,6 @@ def new_shuffle( ): s = Shuffle( worker_for=worker_for_mapping, - # FIXME: Is output_workers redundant with worker_for? - output_workers=set(worker_for_mapping.values()), old=old, new=new, directory=directory / name, @@ -97,7 +95,7 @@ async def test_lowlevel_rechunk( new_indices = list(product(*(range(len(dim)) for dim in new))) for i, idx in enumerate(new_indices): - worker_for_mapping[idx] = get_worker_for_hash_sharding(i, workers) + worker_for_mapping[idx] = _get_worker_for_hash_sharding(i, workers) assert len(set(worker_for_mapping.values())) == min(n_workers, len(new_indices)) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 8b6c00a2116..2209164cb45 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -37,7 +37,7 @@ from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin from distributed.shuffle._shuffle import ( DataFrameShuffleRun, - get_worker_for_range_sharding, + _get_worker_for_range_sharding, split_by_partition, split_by_worker, ) @@ -467,7 +467,7 @@ def mock_get_worker_for_range_sharding( return a.address with mock.patch( - "distributed.shuffle._scheduler_plugin.get_worker_for_range_sharding", + "distributed.shuffle._shuffle._get_worker_for_range_sharding", mock_get_worker_for_range_sharding, ): df = dask.datasets.timeseries( @@ -500,7 +500,7 @@ def mock_mock_get_worker_for_range_sharding( return a.address with mock.patch( - "distributed.shuffle._scheduler_plugin.get_worker_for_range_sharding", + "distributed.shuffle._shuffle._get_worker_for_range_sharding", mock_mock_get_worker_for_range_sharding, ): async with Nanny(s.address, nthreads=1) as n: @@ -565,7 +565,7 @@ async def inputs_done(self) -> None: @mock.patch( - "distributed.shuffle._worker_plugin.DataFrameShuffleRun", + "distributed.shuffle._shuffle.DataFrameShuffleRun", BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) @@ -624,7 +624,7 @@ def shuffle_restarted(): @mock.patch( - "distributed.shuffle._worker_plugin.DataFrameShuffleRun", + "distributed.shuffle._shuffle.DataFrameShuffleRun", BlockedInputsDoneShuffle, ) @gen_cluster( @@ -672,7 +672,7 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): @mock.patch( - "distributed.shuffle._worker_plugin.DataFrameShuffleRun", + "distributed.shuffle._shuffle.DataFrameShuffleRun", BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) @@ -734,7 +734,7 @@ def shuffle_restarted(): @pytest.mark.slow @mock.patch( - "distributed.shuffle._worker_plugin.DataFrameShuffleRun", + "distributed.shuffle._shuffle.DataFrameShuffleRun", BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1456,8 +1456,6 @@ def new_shuffle( s = Shuffle( column="_partition", worker_for=worker_for_mapping, - # FIXME: Is output_workers redundant with worker_for? - output_workers=set(worker_for_mapping.values()), directory=directory / name, id=ShuffleId(name), run_id=next(AbstractShuffleTestPool._shuffle_run_id_iterator), @@ -1501,7 +1499,7 @@ async def test_basic_lowlevel_shuffle( worker_for_mapping = {} for part in range(npartitions): - worker_for_mapping[part] = get_worker_for_range_sharding( + worker_for_mapping[part] = _get_worker_for_range_sharding( npartitions, part, workers ) assert len(set(worker_for_mapping.values())) == min(n_workers, npartitions) @@ -1575,7 +1573,7 @@ async def test_error_offload(tmp_path, loop_in_thread): partitions_for_worker = defaultdict(list) for part in range(npartitions): - worker_for_mapping[part] = w = get_worker_for_range_sharding( + worker_for_mapping[part] = w = _get_worker_for_range_sharding( npartitions, part, workers ) partitions_for_worker[w].append(part) @@ -1626,7 +1624,7 @@ async def test_error_send(tmp_path, loop_in_thread): partitions_for_worker = defaultdict(list) for part in range(npartitions): - worker_for_mapping[part] = w = get_worker_for_range_sharding( + worker_for_mapping[part] = w = _get_worker_for_range_sharding( npartitions, part, workers ) partitions_for_worker[w].append(part) @@ -1676,7 +1674,7 @@ async def test_error_receive(tmp_path, loop_in_thread): partitions_for_worker = defaultdict(list) for part in range(npartitions): - worker_for_mapping[part] = w = get_worker_for_range_sharding( + worker_for_mapping[part] = w = _get_worker_for_range_sharding( npartitions, part, workers ) partitions_for_worker[w].append(part) @@ -1843,15 +1841,15 @@ async def test_shuffle_run_consistency(c, s, a): out = out.persist() shuffle_id = await wait_until_new_shuffle_is_initialized(s) - shuffle_dict = scheduler_ext.get(shuffle_id, a.worker_address) + spec = scheduler_ext.get(shuffle_id, a.worker_address).data # Worker plugin can fetch the current run - assert await worker_plugin._get_shuffle_run(shuffle_id, shuffle_dict["run_id"]) + assert await worker_plugin._get_shuffle_run(shuffle_id, spec.run_id) # This should never occur, but fetching an ID larger than the ID available on # the scheduler should result in an error. with pytest.raises(RuntimeError, match="invalid"): - await worker_plugin._get_shuffle_run(shuffle_id, shuffle_dict["run_id"] + 1) + await worker_plugin._get_shuffle_run(shuffle_id, spec.run_id + 1) # Finish first execution worker_plugin.block_barrier.set() @@ -1868,17 +1866,17 @@ async def test_shuffle_run_consistency(c, s, a): new_shuffle_id = await wait_until_new_shuffle_is_initialized(s) assert shuffle_id == new_shuffle_id - new_shuffle_dict = scheduler_ext.get(shuffle_id, a.worker_address) + new_spec = scheduler_ext.get(shuffle_id, a.worker_address).data # Check invariant that the new run ID is larger than the previous - assert shuffle_dict["run_id"] < new_shuffle_dict["run_id"] + assert spec.run_id < new_spec.run_id # Worker plugin can fetch the new shuffle run - assert await worker_plugin._get_shuffle_run(shuffle_id, new_shuffle_dict["run_id"]) + assert await worker_plugin._get_shuffle_run(shuffle_id, new_spec.run_id) # Fetching a stale run from a worker aware of the new run raises an error with pytest.raises(RuntimeError, match="stale"): - await worker_plugin._get_shuffle_run(shuffle_id, shuffle_dict["run_id"]) + await worker_plugin._get_shuffle_run(shuffle_id, spec.run_id) worker_plugin.block_barrier.set() await out @@ -1892,13 +1890,11 @@ async def test_shuffle_run_consistency(c, s, a): independent_shuffle_id = await wait_until_new_shuffle_is_initialized(s) assert shuffle_id != independent_shuffle_id - independent_shuffle_dict = scheduler_ext.get( - independent_shuffle_id, a.worker_address - ) + independent_spec = scheduler_ext.get(independent_shuffle_id, a.worker_address).data # Check invariant that the new run ID is larger than the previous # for independent shuffles - assert new_shuffle_dict["run_id"] < independent_shuffle_dict["run_id"] + assert new_spec.run_id < independent_spec.run_id worker_plugin.block_barrier.set() await out diff --git a/distributed/shuffle/tests/test_shuffle_plugins.py b/distributed/shuffle/tests/test_shuffle_plugins.py index d8fec10834c..0ddc9132dc8 100644 --- a/distributed/shuffle/tests/test_shuffle_plugins.py +++ b/distributed/shuffle/tests/test_shuffle_plugins.py @@ -4,19 +4,18 @@ import pytest +from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin from distributed.shuffle._shuffle import ( - get_worker_for_range_sharding, + _get_worker_for_range_sharding, split_by_partition, split_by_worker, ) +from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin +from distributed.utils_test import gen_cluster pd = pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") -from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin -from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin -from distributed.utils_test import gen_cluster - @gen_cluster([("", 1)]) async def test_installation_on_worker(s, a): @@ -52,7 +51,7 @@ def test_split_by_worker(): worker_for_mapping = {} npartitions = 3 for part in range(npartitions): - worker_for_mapping[part] = get_worker_for_range_sharding( + worker_for_mapping[part] = _get_worker_for_range_sharding( npartitions, part, workers ) worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category") @@ -90,15 +89,15 @@ def test_split_by_worker_many_workers(): npartitions = 10 worker_for_mapping = {} for part in range(npartitions): - worker_for_mapping[part] = get_worker_for_range_sharding( + worker_for_mapping[part] = _get_worker_for_range_sharding( npartitions, part, workers ) worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category") out = split_by_worker(df, "_partition", worker_for) - assert get_worker_for_range_sharding(npartitions, 5, workers) in out - assert get_worker_for_range_sharding(npartitions, 0, workers) in out - assert get_worker_for_range_sharding(npartitions, 7, workers) in out - assert get_worker_for_range_sharding(npartitions, 1, workers) in out + assert _get_worker_for_range_sharding(npartitions, 5, workers) in out + assert _get_worker_for_range_sharding(npartitions, 0, workers) in out + assert _get_worker_for_range_sharding(npartitions, 7, workers) in out + assert _get_worker_for_range_sharding(npartitions, 1, workers) in out assert sum(map(len, out.values())) == len(df)