From e0adc3ac3c6a43325019e798ab75f5adb6e64d2d Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 10 Aug 2023 23:05:06 +0200 Subject: [PATCH 1/7] Refactor and add dispatch Refactor and introduce dispatch more dispatches Generic ToPickle --- distributed/shuffle/_core.py | 48 +++++- distributed/shuffle/_rechunk.py | 127 +++++++++++----- distributed/shuffle/_scheduler_plugin.py | 101 +++---------- distributed/shuffle/_shuffle.py | 128 ++++++++++++---- distributed/shuffle/_worker_plugin.py | 141 ++++-------------- distributed/shuffle/tests/test_rechunk.py | 4 +- distributed/shuffle/tests/test_shuffle.py | 42 +++--- .../shuffle/tests/test_shuffle_plugins.py | 21 ++- 8 files changed, 321 insertions(+), 291 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index fe8a38f5f00..7659f691a4a 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -12,6 +12,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, ClassVar, Generic, NewType, TypeVar +from dask.utils import Dispatch + from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule from distributed.protocol import to_serialize @@ -21,19 +23,25 @@ from distributed.shuffle._limiter import ResourceLimiter if TYPE_CHECKING: - import pandas as pd + # TODO import from typing (requires Python >=3.10) from typing_extensions import TypeAlias - # avoid circular dependencies + # circular dependency from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin + +ShuffleId = NewType("ShuffleId", str) +NDIndex: TypeAlias = tuple[int, ...] + + _T_partition_id = TypeVar("_T_partition_id") _T_partition_type = TypeVar("_T_partition_type") _T = TypeVar("_T") -NDIndex: TypeAlias = tuple[int, ...] -ShuffleId = NewType("ShuffleId", str) +spec_to_scheduler_state = Dispatch("spec_to_scheduler_state") +scheduler_state_to_run_spec = Dispatch("scheduler_state_to_run_spec") +run_spec_to_worker_run = Dispatch("run_spec_to_worker_run") class ShuffleRun(Generic[_T_partition_id, _T_partition_type]): @@ -215,7 +223,7 @@ async def add_partition( @abc.abstractmethod async def get_output_partition( - self, partition_id: _T_partition_id, key: str, meta: pd.DataFrame | None = None + self, partition_id: _T_partition_id, key: str, **kwargs: Any ) -> _T_partition_type: """Get an output partition to the shuffle run""" @@ -275,3 +283,33 @@ def __str__(self) -> str: def __hash__(self) -> int: return hash(self.run_id) + + +@dataclass +class ShuffleSpec: + id: ShuffleId + + +@dataclass +class ShuffleRunSpec(Generic[_T_partition_id]): + id: ShuffleId + run_id: int + worker_for: dict[_T_partition_id, str] + output_workers: set[str] # TODO: Is this necessary? + + +@dataclass(eq=False) +class SchedulerShuffleState: + _run_id_iterator: ClassVar[itertools.count] = itertools.count(1) + + id: ShuffleId + run_id: int + output_workers: set[str] + participating_workers: set[str] + _archived_by: str | None = field(default=None, init=False) + + def __str__(self) -> str: + return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" + + def __hash__(self) -> int: + return hash(self.run_id) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 0c6116646e7..d70badcdc16 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -96,13 +96,15 @@ from __future__ import annotations +import os import pickle from collections import defaultdict from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple +from itertools import product +from typing import TYPE_CHECKING, Any, NamedTuple import dask from dask.base import tokenize @@ -112,24 +114,29 @@ from distributed.exceptions import Reschedule from distributed.shuffle._core import ( NDIndex, + SchedulerShuffleState, ShuffleId, ShuffleRun, - ShuffleState, - ShuffleType, - barrier_key, + ShuffleRunSpec, + ShuffleSpec, get_worker_plugin, + run_spec_to_worker_run, + scheduler_state_to_run_spec, + spec_to_scheduler_state, ) from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._shuffle import shuffle_barrier +from distributed.shuffle._shuffle import barrier_key, shuffle_barrier +from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof if TYPE_CHECKING: import numpy as np - import pandas as pd from typing_extensions import TypeAlias import dask.array as da + # circular dependency + from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin ChunkedAxis: TypeAlias = tuple[float, ...] # chunks must either be an int or NaN ChunkedAxes: TypeAlias = tuple[ChunkedAxis, ...] @@ -147,10 +154,7 @@ def rechunk_transfer( return get_worker_plugin().add_partition( input, partition_id=input_chunk, - shuffle_id=id, - type=ShuffleType.ARRAY_RECHUNK, - new=new, - old=old, + spec=ArrayRechunkSpec(id=id, new=new, old=old), ) except Exception as e: raise RuntimeError(f"rechunk_transfer failed during shuffle {id}") from e @@ -403,7 +407,9 @@ def _repartition_shards(self, data: list[bytes]) -> dict[NDIndex, bytes]: repartitioned[id].append(shard) return {k: pickle.dumps(v) for k, v in repartitioned.items()} - async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int: + async def add_partition( + self, data: np.ndarray, partition_id: NDIndex, **kwargs: Any + ) -> int: self.raise_if_closed() if self.transferred: raise RuntimeError(f"Cannot add more partitions to {self}") @@ -441,47 +447,98 @@ def _() -> dict[str, tuple[NDIndex, bytes]]: return self.run_id async def get_output_partition( - self, partition_id: NDIndex, key: str, meta: pd.DataFrame | None = None + self, partition_id: NDIndex, key: str, **kwargs: Any ) -> np.ndarray: self.raise_if_closed() - assert meta is None - assert self.transferred, "`get_output_partition` called before barrier task" + if not self.transferred: + raise RuntimeError("`get_output_partition` called before barrier task") await self._ensure_output_worker(partition_id, key) - await self.flush_receive() - data = self._read_from_disk(partition_id) - - def _() -> np.ndarray: - return convert_chunk(data) - - return await self.offload(_) + return await self.offload(convert_chunk, data) def _get_assigned_worker(self, id: NDIndex) -> str: return self.worker_for[id] +@dataclass +class ArrayRechunkSpec(ShuffleSpec): + new: ChunkedAxes + old: ChunkedAxes + + +@dataclass +class ArrayRechunkRunSpec(ShuffleRunSpec): + new: ChunkedAxes + old: ChunkedAxes + + @dataclass(eq=False) -class ArrayRechunkState(ShuffleState): - type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK +class ArrayRechunkState(SchedulerShuffleState): worker_for: dict[NDIndex, str] old: ChunkedAxes new: ChunkedAxes - def to_msg(self) -> dict[str, Any]: - return { - "status": "OK", - "type": ArrayRechunkState.type, - "run_id": self.run_id, - "worker_for": self.worker_for, - "old": self.old, - "new": self.new, - "output_workers": self.output_workers, - } - -def get_worker_for_hash_sharding( +@spec_to_scheduler_state.register(ArrayRechunkSpec) # type: ignore[misc] +def _array_rechunk_from_spec( + spec: ArrayRechunkSpec, plugin: ShuffleSchedulerPlugin +) -> ArrayRechunkState: + parts_out = product(*(range(len(c)) for c in spec.new)) + mapping = plugin._pin_output_workers( + spec.id, parts_out, _get_worker_for_hash_sharding + ) + output_workers = set(mapping.values()) + + return ArrayRechunkState( + id=spec.id, + run_id=next(SchedulerShuffleState._run_id_iterator), + worker_for=mapping, + output_workers=output_workers, + old=spec.old, + new=spec.new, + participating_workers=output_workers.copy(), + ) + + +@scheduler_state_to_run_spec.register(ArrayRechunkState) # type: ignore[misc] +def _array_state_to_run_spec(state: ArrayRechunkState) -> ArrayRechunkRunSpec: + return ArrayRechunkRunSpec( + id=state.id, + run_id=state.run_id, + worker_for=state.worker_for, + output_workers=state.output_workers, + new=state.new, + old=state.old, + ) + + +@run_spec_to_worker_run.register(ArrayRechunkRunSpec) # type: ignore[misc] +def _array_run_spec_to_run( + spec: ArrayRechunkRunSpec, plugin: ShuffleWorkerPlugin +) -> ShuffleRun: + return ArrayRechunkRun( + worker_for=spec.worker_for, + output_workers=spec.output_workers, + old=spec.old, + new=spec.new, + id=spec.id, + run_id=spec.run_id, + directory=os.path.join( + plugin.worker.local_directory, + f"shuffle-{spec.id}-{spec.run_id}", + ), + executor=plugin._executor, + local_address=plugin.worker.address, + rpc=plugin.worker.rpc, + scheduler=plugin.worker.scheduler, + memory_limiter_disk=plugin.memory_limiter_disk, + memory_limiter_comms=plugin.memory_limiter_comms, + ) + + +def _get_worker_for_hash_sharding( output_partition: NDIndex, workers: Sequence[str] ) -> str: """Get address of target worker for this output partition using hash sharding""" diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 9fdf8f10c2a..7ff79e2ee53 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -4,23 +4,20 @@ import logging from collections import defaultdict from collections.abc import Callable, Iterable, Sequence -from functools import partial -from itertools import product from typing import TYPE_CHECKING, Any from distributed.diagnostics.plugin import SchedulerPlugin from distributed.protocol.pickle import dumps +from distributed.protocol.serialize import ToPickle from distributed.shuffle._core import ( + SchedulerShuffleState, ShuffleId, - ShuffleState, - ShuffleType, + ShuffleRunSpec, + ShuffleSpec, barrier_key, id_from_key, -) -from distributed.shuffle._rechunk import ArrayRechunkState, get_worker_for_hash_sharding -from distributed.shuffle._shuffle import ( - DataFrameShuffleState, - get_worker_for_range_sharding, + scheduler_state_to_run_spec, + spec_to_scheduler_state, ) from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin @@ -47,10 +44,10 @@ class ShuffleSchedulerPlugin(SchedulerPlugin): """ scheduler: Scheduler - active_shuffles: dict[ShuffleId, ShuffleState] + active_shuffles: dict[ShuffleId, SchedulerShuffleState] heartbeats: defaultdict[ShuffleId, dict] - _shuffles: defaultdict[ShuffleId, set[ShuffleState]] - _archived_by_stimulus: defaultdict[str, set[ShuffleState]] + _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]] + _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]] def __init__(self, scheduler: Scheduler): self.scheduler = scheduler @@ -79,7 +76,8 @@ def shuffle_ids(self) -> set[ShuffleId]: async def barrier(self, id: ShuffleId, run_id: int) -> None: shuffle = self.active_shuffles[id] - assert shuffle.run_id == run_id, f"{run_id=} does not match {shuffle}" + if shuffle.run_id != run_id: + raise ValueError(f"{run_id=} does not match {shuffle}") msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id} await self.scheduler.broadcast( msg=msg, @@ -107,7 +105,7 @@ def heartbeat(self, ws: WorkerState, data: dict) -> None: if shuffle_id in self.shuffle_ids(): self.heartbeats[shuffle_id][ws.address].update(d) - def get(self, id: ShuffleId, worker: str) -> dict[str, Any]: + def get(self, id: ShuffleId, worker: str) -> ToPickle[ShuffleRunSpec]: if worker not in self.scheduler.workers: # This should never happen raise RuntimeError( @@ -115,36 +113,27 @@ def get(self, id: ShuffleId, worker: str) -> dict[str, Any]: ) # pragma: nocover state = self.active_shuffles[id] state.participating_workers.add(worker) - return state.to_msg() + return ToPickle(scheduler_state_to_run_spec(state)) def get_or_create( self, - id: ShuffleId, + spec: ShuffleSpec, key: str, - type: str, worker: str, - spec: dict[str, Any], - ) -> dict: + ) -> ToPickle[ShuffleRunSpec]: try: - return self.get(id, worker) + return self.get(spec.id, worker) except KeyError: # FIXME: The current implementation relies on the barrier task to be # known by its name. If the name has been mangled, we cannot guarantee # that the shuffle works as intended and should fail instead. - self._raise_if_barrier_unknown(id) + self._raise_if_barrier_unknown(spec.id) self._raise_if_task_not_processing(key) - - state: ShuffleState - if type == ShuffleType.DATAFRAME: - state = self._create_dataframe_shuffle_state(id, spec) - elif type == ShuffleType.ARRAY_RECHUNK: - state = self._create_array_rechunk_state(id, spec) - else: # pragma: no cover - raise TypeError(type) - self.active_shuffles[id] = state - self._shuffles[id].add(state) + state: SchedulerShuffleState = spec_to_scheduler_state(spec, self) + self.active_shuffles[spec.id] = state + self._shuffles[spec.id].add(state) state.participating_workers.add(worker) - return state.to_msg() + return ToPickle(scheduler_state_to_run_spec(state)) def _raise_if_barrier_unknown(self, id: ShuffleId) -> None: key = barrier_key(id) @@ -162,30 +151,6 @@ def _raise_if_task_not_processing(self, key: str) -> None: if task.state != "processing": raise RuntimeError(f"Expected {task} to be processing, is {task.state}.") - def _create_dataframe_shuffle_state( - self, id: ShuffleId, spec: dict[str, Any] - ) -> DataFrameShuffleState: - column = spec["column"] - npartitions = spec["npartitions"] - parts_out = spec["parts_out"] - assert column is not None - assert npartitions is not None - assert parts_out is not None - - pick_worker = partial(get_worker_for_range_sharding, npartitions) - - mapping = self._pin_output_workers(id, parts_out, pick_worker) - output_workers = set(mapping.values()) - - return DataFrameShuffleState( - id=id, - run_id=next(ShuffleState._run_id_iterator), - worker_for=mapping, - column=column, - output_workers=output_workers, - participating_workers=output_workers.copy(), - ) - def _pin_output_workers( self, id: ShuffleId, @@ -232,28 +197,6 @@ def _pin_output_workers( return mapping - def _create_array_rechunk_state( - self, id: ShuffleId, spec: dict[str, Any] - ) -> ArrayRechunkState: - old = spec["old"] - new = spec["new"] - assert old is not None - assert new is not None - - parts_out = product(*(range(len(c)) for c in new)) - mapping = self._pin_output_workers(id, parts_out, get_worker_for_hash_sharding) - output_workers = set(mapping.values()) - - return ArrayRechunkState( - id=id, - run_id=next(ShuffleState._run_id_iterator), - worker_for=mapping, - output_workers=output_workers, - old=old, - new=new, - participating_workers=output_workers.copy(), - ) - def _set_restriction(self, ts: TaskState, worker: str) -> None: if "shuffle_original_restrictions" in ts.annotations: # This may occur if multiple barriers share the same output task, @@ -361,7 +304,7 @@ def transition( if not archived: del self._archived_by_stimulus[shuffle._archived_by] - def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None: + def _fail_on_workers(self, shuffle: SchedulerShuffleState, message: str) -> None: worker_msgs = { worker: [ { diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 4a4f4601ea1..4be2ea0707c 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -1,11 +1,13 @@ from __future__ import annotations import logging +import os from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Union +from functools import partial +from typing import TYPE_CHECKING, Any, Union import toolz @@ -24,15 +26,20 @@ ) from distributed.shuffle._core import ( NDIndex, + SchedulerShuffleState, ShuffleId, ShuffleRun, - ShuffleState, - ShuffleType, + ShuffleRunSpec, + ShuffleSpec, barrier_key, get_worker_plugin, + run_spec_to_worker_run, + scheduler_state_to_run_spec, + spec_to_scheduler_state, ) from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof logger = logging.getLogger("distributed.shuffle") @@ -45,6 +52,9 @@ from dask.dataframe import DataFrame + # circular dependency + from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin + def shuffle_transfer( input: pd.DataFrame, @@ -57,12 +67,10 @@ def shuffle_transfer( try: return get_worker_plugin().add_partition( input, - shuffle_id=id, - type=ShuffleType.DATAFRAME, - partition_id=input_partition, - npartitions=npartitions, - column=column, - parts_out=parts_out, + input_partition, + spec=DataFrameShuffleSpec( + id=id, npartitions=npartitions, column=column, parts_out=parts_out + ), ) except ShuffleClosedError: raise Reschedule() @@ -443,7 +451,12 @@ def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, bytes]: del data return {(k,): serialize_table(v) for k, v in groups.items()} - async def add_partition(self, data: pd.DataFrame, partition_id: int) -> int: + async def add_partition( + self, + data: pd.DataFrame, + partition_id: int, + **kwargs: Any, + ) -> int: self.raise_if_closed() if self.transferred: raise RuntimeError(f"Cannot add more partitions to {self}") @@ -462,11 +475,17 @@ def _() -> dict[str, tuple[int, bytes]]: return self.run_id async def get_output_partition( - self, partition_id: int, key: str, meta: pd.DataFrame | None = None + self, + partition_id: int, + key: str, + meta: pd.DataFrame | None = None, + **kwargs: Any, ) -> pd.DataFrame: self.raise_if_closed() - assert meta is not None - assert self.transferred, "`get_output_partition` called before barrier task" + if meta is None: + raise ValueError("Excepted meta keyword argument") + if not self.transferred: + raise RuntimeError("`get_output_partition` called before barrier task") await self._ensure_output_worker(partition_id, key) @@ -474,10 +493,7 @@ async def get_output_partition( try: data = self._read_from_disk((partition_id,)) - def _() -> pd.DataFrame: - return convert_partition(data, meta) # type: ignore - - out = await self.offload(_) + out = await self.offload(convert_partition, data, meta) except KeyError: out = meta.copy() return out @@ -486,24 +502,78 @@ def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id] +@dataclass +class DataFrameShuffleSpec(ShuffleSpec): + npartitions: int + column: str + parts_out: set[int] + + +@dataclass +class DataFrameShuffleRunSpec(ShuffleRunSpec): + column: str + + @dataclass(eq=False) -class DataFrameShuffleState(ShuffleState): - type: ClassVar[ShuffleType] = ShuffleType.DATAFRAME +class DataFrameShuffleState(SchedulerShuffleState): worker_for: dict[int, str] column: str - def to_msg(self) -> dict[str, Any]: - return { - "status": "OK", - "type": DataFrameShuffleState.type, - "run_id": self.run_id, - "worker_for": self.worker_for, - "column": self.column, - "output_workers": self.output_workers, - } + +@spec_to_scheduler_state.register(DataFrameShuffleSpec) # type: ignore[misc] +def _dataframe_spec_to_state( + spec: DataFrameShuffleSpec, plugin: ShuffleSchedulerPlugin +) -> DataFrameShuffleState: + pick_worker = partial(_get_worker_for_range_sharding, spec.npartitions) + + mapping = plugin._pin_output_workers(spec.id, spec.parts_out, pick_worker) + output_workers = set(mapping.values()) + + return DataFrameShuffleState( + id=spec.id, + run_id=next(SchedulerShuffleState._run_id_iterator), + worker_for=mapping, + column=spec.column, + output_workers=output_workers, + participating_workers=output_workers.copy(), + ) + + +@scheduler_state_to_run_spec.register(DataFrameShuffleState) # type: ignore[misc] +def _dataframe_state_to_run_spec(state: DataFrameShuffleState) -> ShuffleRunSpec: + return DataFrameShuffleRunSpec( + id=state.id, + run_id=state.run_id, + worker_for=state.worker_for, + output_workers=state.output_workers, + column=state.column, + ) + + +@run_spec_to_worker_run.register(DataFrameShuffleRunSpec) # type: ignore[misc] +def _dataframe_run_spec_to_run( + spec: DataFrameShuffleRunSpec, plugin: ShuffleWorkerPlugin +) -> DataFrameShuffleRun: + return DataFrameShuffleRun( + column=spec.column, + worker_for=spec.worker_for, + output_workers=spec.output_workers, + id=spec.id, + run_id=spec.run_id, + directory=os.path.join( + plugin.worker.local_directory, + f"shuffle-{spec.id}-{spec.run_id}", + ), + executor=plugin._executor, + local_address=plugin.worker.address, + rpc=plugin.worker.rpc, + scheduler=plugin.worker.scheduler, + memory_limiter_disk=plugin.memory_limiter_disk, + memory_limiter_comms=plugin.memory_limiter_comms, + ) -def get_worker_for_range_sharding( +def _get_worker_for_range_sharding( npartitions: int, output_partition: int, workers: Sequence[str] ) -> str: """Get address of target worker for this output partition using range sharding""" diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 7340b8fe6e6..fcc79afa984 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -2,7 +2,6 @@ import asyncio import logging -import os from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any, overload @@ -10,11 +9,17 @@ from dask.utils import parse_bytes from distributed.diagnostics.plugin import WorkerPlugin -from distributed.shuffle._core import NDIndex, ShuffleId, ShuffleRun, ShuffleType +from distributed.protocol.serialize import ToPickle +from distributed.shuffle._core import ( + NDIndex, + ShuffleId, + ShuffleRun, + ShuffleRunSpec, + ShuffleSpec, + run_spec_to_worker_run, +) from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._rechunk import ArrayRechunkRun -from distributed.shuffle._shuffle import DataFrameShuffleRun from distributed.utils import log_errors, sync if TYPE_CHECKING: @@ -23,6 +28,7 @@ from distributed.worker import Worker + logger = logging.getLogger(__name__) @@ -127,17 +133,17 @@ def shuffle_fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None def add_partition( self, data: Any, - partition_id: int | tuple[int, ...], - shuffle_id: ShuffleId, - type: ShuffleType, + partition_id: int | NDIndex, + spec: ShuffleSpec, **kwargs: Any, ) -> int: - shuffle = self.get_or_create_shuffle(shuffle_id, type=type, **kwargs) + shuffle = self.get_or_create_shuffle(spec) return sync( self.worker.loop, shuffle.add_partition, data=data, partition_id=partition_id, + **kwargs, ) async def _barrier(self, shuffle_id: ShuffleId, run_ids: list[int]) -> int: @@ -149,7 +155,8 @@ async def _barrier(self, shuffle_id: ShuffleId, run_ids: list[int]) -> int: """ run_id = run_ids[0] # Assert that all input data has been shuffled using the same run_id - assert all(run_id == id for id in run_ids) + if any(run_id != id for id in run_ids): + raise RuntimeError(f"Expected all run IDs to match: {run_ids=}") # Tell all peers that we've reached the barrier # Note that this will call `shuffle_inputs_done` on our own worker as well shuffle = await self._get_shuffle_run(shuffle_id, run_id) @@ -194,10 +201,8 @@ async def _get_shuffle_run( async def _get_or_create_shuffle( self, - shuffle_id: ShuffleId, - type: ShuffleType, + spec: ShuffleSpec, key: str, - **kwargs: Any, ) -> ShuffleRun: """Get or create a shuffle matching the ID and data spec. @@ -210,13 +215,12 @@ async def _get_or_create_shuffle( key: Task key triggering the function """ - shuffle = self.shuffles.get(shuffle_id, None) + shuffle = self.shuffles.get(spec.id, None) if shuffle is None: shuffle = await self._refresh_shuffle( - shuffle_id=shuffle_id, - type=type, + shuffle_id=spec.id, + spec=ToPickle(spec), key=key, - kwargs=kwargs, ) if self.closed: @@ -236,58 +240,38 @@ async def _refresh_shuffle( async def _refresh_shuffle( self, shuffle_id: ShuffleId, - type: ShuffleType, + spec: ToPickle, key: str, - kwargs: dict, ) -> ShuffleRun: ... async def _refresh_shuffle( self, shuffle_id: ShuffleId, - type: ShuffleType | None = None, + spec: ToPickle | None = None, key: str | None = None, - kwargs: dict | None = None, ) -> ShuffleRun: - result: dict[str, Any] - if type is None: + result: ShuffleRunSpec + if spec is None: result = await self.worker.scheduler.shuffle_get( id=shuffle_id, worker=self.worker.address, ) - elif type == ShuffleType.DATAFRAME: - assert kwargs is not None + else: result = await self.worker.scheduler.shuffle_get_or_create( - id=shuffle_id, + spec=spec, key=key, - type=type, - spec={ - "npartitions": kwargs["npartitions"], - "column": kwargs["column"], - "parts_out": kwargs["parts_out"], - }, worker=self.worker.address, ) - elif type == ShuffleType.ARRAY_RECHUNK: - assert kwargs is not None - result = await self.worker.scheduler.shuffle_get_or_create( - id=shuffle_id, - key=key, - type=type, - spec=kwargs, - worker=self.worker.address, - ) - else: # pragma: no cover - raise TypeError(type) - if result["status"] == "error": - raise RuntimeError(result["message"]) - assert result["status"] == "OK" + # if result["status"] == "error": + # raise RuntimeError(result["message"]) + # assert result["status"] == "OK" if self.closed: raise ShuffleClosedError(f"{self} has already been closed") if shuffle_id in self.shuffles: existing = self.shuffles[shuffle_id] - if existing.run_id >= result["run_id"]: + if existing.run_id >= result.run_id: return existing else: self.shuffles.pop(shuffle_id) @@ -305,66 +289,11 @@ async def _( self.worker._ongoing_background_tasks.call_soon(_, self, existing) - shuffle = self._create_shuffle_run(shuffle_id, result) + shuffle: ShuffleRun = run_spec_to_worker_run(result, self) self.shuffles[shuffle_id] = shuffle self._runs.add(shuffle) return shuffle - def _create_shuffle_run( - self, shuffle_id: ShuffleId, result: dict[str, Any] - ) -> ShuffleRun: - shuffle: ShuffleRun - if result["type"] == ShuffleType.DATAFRAME: - shuffle = self._create_dataframe_shuffle_run(shuffle_id, result) - elif result["type"] == ShuffleType.ARRAY_RECHUNK: - shuffle = self._create_array_rechunk_run(shuffle_id, result) - else: # pragma: no cover - raise TypeError(result["type"]) - return shuffle - - def _create_dataframe_shuffle_run( - self, shuffle_id: ShuffleId, result: dict[str, Any] - ) -> DataFrameShuffleRun: - return DataFrameShuffleRun( - column=result["column"], - worker_for=result["worker_for"], - output_workers=result["output_workers"], - id=shuffle_id, - run_id=result["run_id"], - directory=os.path.join( - self.worker.local_directory, - f"shuffle-{shuffle_id}-{result['run_id']}", - ), - executor=self._executor, - local_address=self.worker.address, - rpc=self.worker.rpc, - scheduler=self.worker.scheduler, - memory_limiter_disk=self.memory_limiter_disk, - memory_limiter_comms=self.memory_limiter_comms, - ) - - def _create_array_rechunk_run( - self, shuffle_id: ShuffleId, result: dict[str, Any] - ) -> ArrayRechunkRun: - return ArrayRechunkRun( - worker_for=result["worker_for"], - output_workers=result["output_workers"], - old=result["old"], - new=result["new"], - id=shuffle_id, - run_id=result["run_id"], - directory=os.path.join( - self.worker.local_directory, - f"shuffle-{shuffle_id}-{result['run_id']}", - ), - executor=self._executor, - local_address=self.worker.address, - rpc=self.worker.rpc, - scheduler=self.worker.scheduler, - memory_limiter_disk=self.memory_limiter_disk, - memory_limiter_comms=self.memory_limiter_comms, - ) - async def teardown(self, worker: Worker) -> None: assert not self.closed @@ -406,18 +335,14 @@ def get_shuffle_run( def get_or_create_shuffle( self, - shuffle_id: ShuffleId, - type: ShuffleType, - **kwargs: Any, + spec: ShuffleSpec, ) -> ShuffleRun: key = thread_state.key return sync( self.worker.loop, self._get_or_create_shuffle, - shuffle_id, - type, + spec, key, - **kwargs, ) def get_output_partition( diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 62f3b1b9c93..72a3183ab6c 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -22,7 +22,7 @@ from distributed.shuffle._rechunk import ( ArrayRechunkRun, Split, - get_worker_for_hash_sharding, + _get_worker_for_hash_sharding, split_axes, ) from distributed.shuffle.tests.utils import AbstractShuffleTestPool @@ -97,7 +97,7 @@ async def test_lowlevel_rechunk( new_indices = list(product(*(range(len(dim)) for dim in new))) for i, idx in enumerate(new_indices): - worker_for_mapping[idx] = get_worker_for_hash_sharding(i, workers) + worker_for_mapping[idx] = _get_worker_for_hash_sharding(i, workers) assert len(set(worker_for_mapping.values())) == min(n_workers, len(new_indices)) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 8b6c00a2116..d25d67061cb 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -37,7 +37,7 @@ from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin from distributed.shuffle._shuffle import ( DataFrameShuffleRun, - get_worker_for_range_sharding, + _get_worker_for_range_sharding, split_by_partition, split_by_worker, ) @@ -467,7 +467,7 @@ def mock_get_worker_for_range_sharding( return a.address with mock.patch( - "distributed.shuffle._scheduler_plugin.get_worker_for_range_sharding", + "distributed.shuffle._shuffle._get_worker_for_range_sharding", mock_get_worker_for_range_sharding, ): df = dask.datasets.timeseries( @@ -500,7 +500,7 @@ def mock_mock_get_worker_for_range_sharding( return a.address with mock.patch( - "distributed.shuffle._scheduler_plugin.get_worker_for_range_sharding", + "distributed.shuffle._shuffle._get_worker_for_range_sharding", mock_mock_get_worker_for_range_sharding, ): async with Nanny(s.address, nthreads=1) as n: @@ -565,7 +565,7 @@ async def inputs_done(self) -> None: @mock.patch( - "distributed.shuffle._worker_plugin.DataFrameShuffleRun", + "distributed.shuffle._shuffle.DataFrameShuffleRun", BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) @@ -624,7 +624,7 @@ def shuffle_restarted(): @mock.patch( - "distributed.shuffle._worker_plugin.DataFrameShuffleRun", + "distributed.shuffle._shuffle.DataFrameShuffleRun", BlockedInputsDoneShuffle, ) @gen_cluster( @@ -672,7 +672,7 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): @mock.patch( - "distributed.shuffle._worker_plugin.DataFrameShuffleRun", + "distributed.shuffle._shuffle.DataFrameShuffleRun", BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) @@ -734,7 +734,7 @@ def shuffle_restarted(): @pytest.mark.slow @mock.patch( - "distributed.shuffle._worker_plugin.DataFrameShuffleRun", + "distributed.shuffle._shuffle.DataFrameShuffleRun", BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1501,7 +1501,7 @@ async def test_basic_lowlevel_shuffle( worker_for_mapping = {} for part in range(npartitions): - worker_for_mapping[part] = get_worker_for_range_sharding( + worker_for_mapping[part] = _get_worker_for_range_sharding( npartitions, part, workers ) assert len(set(worker_for_mapping.values())) == min(n_workers, npartitions) @@ -1575,7 +1575,7 @@ async def test_error_offload(tmp_path, loop_in_thread): partitions_for_worker = defaultdict(list) for part in range(npartitions): - worker_for_mapping[part] = w = get_worker_for_range_sharding( + worker_for_mapping[part] = w = _get_worker_for_range_sharding( npartitions, part, workers ) partitions_for_worker[w].append(part) @@ -1626,7 +1626,7 @@ async def test_error_send(tmp_path, loop_in_thread): partitions_for_worker = defaultdict(list) for part in range(npartitions): - worker_for_mapping[part] = w = get_worker_for_range_sharding( + worker_for_mapping[part] = w = _get_worker_for_range_sharding( npartitions, part, workers ) partitions_for_worker[w].append(part) @@ -1676,7 +1676,7 @@ async def test_error_receive(tmp_path, loop_in_thread): partitions_for_worker = defaultdict(list) for part in range(npartitions): - worker_for_mapping[part] = w = get_worker_for_range_sharding( + worker_for_mapping[part] = w = _get_worker_for_range_sharding( npartitions, part, workers ) partitions_for_worker[w].append(part) @@ -1843,15 +1843,15 @@ async def test_shuffle_run_consistency(c, s, a): out = out.persist() shuffle_id = await wait_until_new_shuffle_is_initialized(s) - shuffle_dict = scheduler_ext.get(shuffle_id, a.worker_address) + spec = scheduler_ext.get(shuffle_id, a.worker_address).data # Worker plugin can fetch the current run - assert await worker_plugin._get_shuffle_run(shuffle_id, shuffle_dict["run_id"]) + assert await worker_plugin._get_shuffle_run(shuffle_id, spec.run_id) # This should never occur, but fetching an ID larger than the ID available on # the scheduler should result in an error. with pytest.raises(RuntimeError, match="invalid"): - await worker_plugin._get_shuffle_run(shuffle_id, shuffle_dict["run_id"] + 1) + await worker_plugin._get_shuffle_run(shuffle_id, spec.run_id + 1) # Finish first execution worker_plugin.block_barrier.set() @@ -1868,17 +1868,17 @@ async def test_shuffle_run_consistency(c, s, a): new_shuffle_id = await wait_until_new_shuffle_is_initialized(s) assert shuffle_id == new_shuffle_id - new_shuffle_dict = scheduler_ext.get(shuffle_id, a.worker_address) + new_spec = scheduler_ext.get(shuffle_id, a.worker_address).data # Check invariant that the new run ID is larger than the previous - assert shuffle_dict["run_id"] < new_shuffle_dict["run_id"] + assert spec.run_id < new_spec.run_id # Worker plugin can fetch the new shuffle run - assert await worker_plugin._get_shuffle_run(shuffle_id, new_shuffle_dict["run_id"]) + assert await worker_plugin._get_shuffle_run(shuffle_id, new_spec.run_id) # Fetching a stale run from a worker aware of the new run raises an error with pytest.raises(RuntimeError, match="stale"): - await worker_plugin._get_shuffle_run(shuffle_id, shuffle_dict["run_id"]) + await worker_plugin._get_shuffle_run(shuffle_id, spec.run_id) worker_plugin.block_barrier.set() await out @@ -1892,13 +1892,11 @@ async def test_shuffle_run_consistency(c, s, a): independent_shuffle_id = await wait_until_new_shuffle_is_initialized(s) assert shuffle_id != independent_shuffle_id - independent_shuffle_dict = scheduler_ext.get( - independent_shuffle_id, a.worker_address - ) + independent_spec = scheduler_ext.get(independent_shuffle_id, a.worker_address).data # Check invariant that the new run ID is larger than the previous # for independent shuffles - assert new_shuffle_dict["run_id"] < independent_shuffle_dict["run_id"] + assert new_spec.run_id < independent_spec.run_id worker_plugin.block_barrier.set() await out diff --git a/distributed/shuffle/tests/test_shuffle_plugins.py b/distributed/shuffle/tests/test_shuffle_plugins.py index d8fec10834c..0ddc9132dc8 100644 --- a/distributed/shuffle/tests/test_shuffle_plugins.py +++ b/distributed/shuffle/tests/test_shuffle_plugins.py @@ -4,19 +4,18 @@ import pytest +from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin from distributed.shuffle._shuffle import ( - get_worker_for_range_sharding, + _get_worker_for_range_sharding, split_by_partition, split_by_worker, ) +from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin +from distributed.utils_test import gen_cluster pd = pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") -from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin -from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin -from distributed.utils_test import gen_cluster - @gen_cluster([("", 1)]) async def test_installation_on_worker(s, a): @@ -52,7 +51,7 @@ def test_split_by_worker(): worker_for_mapping = {} npartitions = 3 for part in range(npartitions): - worker_for_mapping[part] = get_worker_for_range_sharding( + worker_for_mapping[part] = _get_worker_for_range_sharding( npartitions, part, workers ) worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category") @@ -90,15 +89,15 @@ def test_split_by_worker_many_workers(): npartitions = 10 worker_for_mapping = {} for part in range(npartitions): - worker_for_mapping[part] = get_worker_for_range_sharding( + worker_for_mapping[part] = _get_worker_for_range_sharding( npartitions, part, workers ) worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category") out = split_by_worker(df, "_partition", worker_for) - assert get_worker_for_range_sharding(npartitions, 5, workers) in out - assert get_worker_for_range_sharding(npartitions, 0, workers) in out - assert get_worker_for_range_sharding(npartitions, 7, workers) in out - assert get_worker_for_range_sharding(npartitions, 1, workers) in out + assert _get_worker_for_range_sharding(npartitions, 5, workers) in out + assert _get_worker_for_range_sharding(npartitions, 0, workers) in out + assert _get_worker_for_range_sharding(npartitions, 7, workers) in out + assert _get_worker_for_range_sharding(npartitions, 1, workers) in out assert sum(map(len, out.values())) == len(df) From 82b677dbbf0bd92bbbc9d2a770f792045d80fc93 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 15 Aug 2023 18:14:44 +0200 Subject: [PATCH 2/7] Adjust error message --- distributed/shuffle/_core.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 7659f691a4a..d3540e9d37f 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -238,12 +238,12 @@ def get_worker_plugin() -> ShuffleWorkerPlugin: "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; " "please confirm that you've created a distributed Client and are submitting this computation through it." ) from e - plugin: ShuffleWorkerPlugin | None = worker.plugins.get("shuffle") # type: ignore - if plugin is None: + try: + plugin: ShuffleWorkerPlugin = worker.plugins["shuffle"] # type: ignore + except KeyError as e: raise RuntimeError( - f"The worker {worker.address} does not have a ShuffleExtension. " - "Is pandas installed on the worker?" - ) + f"The worker {worker.address} does not have a P2P shuffle plugin." + ) from e return plugin From 2cebf93d5862fbb6e7c7fb37a9b4854cd668f314 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 15 Aug 2023 18:15:08 +0200 Subject: [PATCH 3/7] Minor --- distributed/shuffle/_core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index d3540e9d37f..486752291b2 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -239,12 +239,11 @@ def get_worker_plugin() -> ShuffleWorkerPlugin: "please confirm that you've created a distributed Client and are submitting this computation through it." ) from e try: - plugin: ShuffleWorkerPlugin = worker.plugins["shuffle"] # type: ignore + return worker.plugins["shuffle"] # type: ignore except KeyError as e: raise RuntimeError( f"The worker {worker.address} does not have a P2P shuffle plugin." ) from e - return plugin _BARRIER_PREFIX = "shuffle-barrier-" From 368f52ecd48691143b05fbcbeb0d6f17d9cd84fc Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 16 Aug 2023 13:41:13 +0200 Subject: [PATCH 4/7] Drop output_workers --- distributed/shuffle/_core.py | 5 ----- distributed/shuffle/_rechunk.py | 10 ++-------- distributed/shuffle/_shuffle.py | 10 ++-------- distributed/shuffle/tests/test_rechunk.py | 2 -- distributed/shuffle/tests/test_shuffle.py | 2 -- 5 files changed, 4 insertions(+), 25 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 486752291b2..a98e1196017 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -49,7 +49,6 @@ def __init__( self, id: ShuffleId, run_id: int, - output_workers: set[str], local_address: str, directory: str, executor: ThreadPoolExecutor, @@ -60,7 +59,6 @@ def __init__( ): self.id = id self.run_id = run_id - self.output_workers = output_workers self.local_address = local_address self.executor = executor self.rpc = rpc @@ -269,7 +267,6 @@ class ShuffleState(abc.ABC): id: ShuffleId run_id: int - output_workers: set[str] participating_workers: set[str] _archived_by: str | None = field(default=None, init=False) @@ -294,7 +291,6 @@ class ShuffleRunSpec(Generic[_T_partition_id]): id: ShuffleId run_id: int worker_for: dict[_T_partition_id, str] - output_workers: set[str] # TODO: Is this necessary? @dataclass(eq=False) @@ -303,7 +299,6 @@ class SchedulerShuffleState: id: ShuffleId run_id: int - output_workers: set[str] participating_workers: set[str] _archived_by: str | None = field(default=None, init=False) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index d70badcdc16..8922c515c9a 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -308,10 +308,9 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): Parameters ---------- + # FIXME worker_for: A mapping partition_id -> worker_address. - output_workers: - A set of all participating worker (addresses). old: Existing chunking of the array per dimension. new: @@ -342,7 +341,6 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): def __init__( self, worker_for: dict[NDIndex, str], - output_workers: set, old: ChunkedAxes, new: ChunkedAxes, id: ShuffleId, @@ -358,7 +356,6 @@ def __init__( super().__init__( id=id, run_id=run_id, - output_workers=output_workers, local_address=local_address, directory=directory, executor=executor, @@ -495,10 +492,9 @@ def _array_rechunk_from_spec( id=spec.id, run_id=next(SchedulerShuffleState._run_id_iterator), worker_for=mapping, - output_workers=output_workers, old=spec.old, new=spec.new, - participating_workers=output_workers.copy(), + participating_workers=output_workers, ) @@ -508,7 +504,6 @@ def _array_state_to_run_spec(state: ArrayRechunkState) -> ArrayRechunkRunSpec: id=state.id, run_id=state.run_id, worker_for=state.worker_for, - output_workers=state.output_workers, new=state.new, old=state.old, ) @@ -520,7 +515,6 @@ def _array_run_spec_to_run( ) -> ShuffleRun: return ArrayRechunkRun( worker_for=spec.worker_for, - output_workers=spec.output_workers, old=spec.old, new=spec.new, id=spec.id, diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 4be2ea0707c..ce1e506a88a 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -356,10 +356,9 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): Parameters ---------- + # FIXME worker_for: A mapping partition_id -> worker_address. - output_workers: - A set of all participating worker (addresses). column: The data column we split the input partition by. id: @@ -388,7 +387,6 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): def __init__( self, worker_for: dict[int, str], - output_workers: set, column: str, id: ShuffleId, run_id: int, @@ -405,7 +403,6 @@ def __init__( super().__init__( id=id, run_id=run_id, - output_workers=output_workers, local_address=local_address, directory=directory, executor=executor, @@ -534,8 +531,7 @@ def _dataframe_spec_to_state( run_id=next(SchedulerShuffleState._run_id_iterator), worker_for=mapping, column=spec.column, - output_workers=output_workers, - participating_workers=output_workers.copy(), + participating_workers=output_workers, ) @@ -545,7 +541,6 @@ def _dataframe_state_to_run_spec(state: DataFrameShuffleState) -> ShuffleRunSpec id=state.id, run_id=state.run_id, worker_for=state.worker_for, - output_workers=state.output_workers, column=state.column, ) @@ -557,7 +552,6 @@ def _dataframe_run_spec_to_run( return DataFrameShuffleRun( column=spec.column, worker_for=spec.worker_for, - output_workers=spec.output_workers, id=spec.id, run_id=spec.run_id, directory=os.path.join( diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 72a3183ab6c..fcdee28f857 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -55,8 +55,6 @@ def new_shuffle( ): s = Shuffle( worker_for=worker_for_mapping, - # FIXME: Is output_workers redundant with worker_for? - output_workers=set(worker_for_mapping.values()), old=old, new=new, directory=directory / name, diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index d25d67061cb..2209164cb45 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1456,8 +1456,6 @@ def new_shuffle( s = Shuffle( column="_partition", worker_for=worker_for_mapping, - # FIXME: Is output_workers redundant with worker_for? - output_workers=set(worker_for_mapping.values()), directory=directory / name, id=ShuffleId(name), run_id=next(AbstractShuffleTestPool._shuffle_run_id_iterator), From 7283a5626988e11158e64b724fc05a279540024d Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 16 Aug 2023 15:52:26 +0200 Subject: [PATCH 5/7] Refactor --- distributed/shuffle/_core.py | 87 ++++++++++--------- distributed/shuffle/_rechunk.py | 106 +++++++---------------- distributed/shuffle/_scheduler_plugin.py | 8 +- distributed/shuffle/_shuffle.py | 96 ++++++-------------- distributed/shuffle/_worker_plugin.py | 5 +- 5 files changed, 111 insertions(+), 191 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index a98e1196017..3f0e5614e13 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -10,9 +10,8 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Generic, NewType, TypeVar - -from dask.utils import Dispatch +from functools import partial +from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule @@ -26,9 +25,10 @@ # TODO import from typing (requires Python >=3.10) from typing_extensions import TypeAlias - # circular dependency - from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin + from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin + # circular dependencies + from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin ShuffleId = NewType("ShuffleId", str) NDIndex: TypeAlias = tuple[int, ...] @@ -39,11 +39,6 @@ _T = TypeVar("_T") -spec_to_scheduler_state = Dispatch("spec_to_scheduler_state") -scheduler_state_to_run_spec = Dispatch("scheduler_state_to_run_spec") -run_spec_to_worker_run = Dispatch("run_spec_to_worker_run") - - class ShuffleRun(Generic[_T_partition_id, _T_partition_type]): def __init__( self, @@ -261,47 +256,61 @@ class ShuffleType(Enum): ARRAY_RECHUNK = "ArrayRechunk" -@dataclass(eq=False) -class ShuffleState(abc.ABC): - _run_id_iterator: ClassVar[itertools.count] = itertools.count(1) - - id: ShuffleId - run_id: int - participating_workers: set[str] - _archived_by: str | None = field(default=None, init=False) +@dataclass(frozen=True) +class ShuffleRunSpec(Generic[_T_partition_id]): + run_id: int = field(init=False, default_factory=partial(next, itertools.count(1))) # type: ignore + spec: ShuffleSpec + worker_for: dict[_T_partition_id, str] - @abc.abstractmethod - def to_msg(self) -> dict[str, Any]: - """Transform the shuffle state into a JSON-serializable message""" + @property + def id(self) -> ShuffleId: + return self.spec.id - def __str__(self) -> str: - return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" - def __hash__(self) -> int: - return hash(self.run_id) - - -@dataclass -class ShuffleSpec: +@dataclass(frozen=True) +class ShuffleSpec(abc.ABC, Generic[_T_partition_id]): id: ShuffleId + def create_new_run( + self, + plugin: ShuffleSchedulerPlugin, + ) -> SchedulerShuffleState: + worker_for = self._pin_output_workers(plugin) + return SchedulerShuffleState( + run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for), + participating_workers=set(worker_for.values()), + ) -@dataclass -class ShuffleRunSpec(Generic[_T_partition_id]): - id: ShuffleId - run_id: int - worker_for: dict[_T_partition_id, str] + @abc.abstractmethod + def _pin_output_workers( + self, plugin: ShuffleSchedulerPlugin + ) -> dict[_T_partition_id, str]: + """TODO""" + @abc.abstractmethod + def initialize_run_on_worker( + self, + run_id: int, + worker_for: dict[_T_partition_id, str], + plugin: ShuffleWorkerPlugin, + ) -> ShuffleRun: + """TODO""" -@dataclass(eq=False) -class SchedulerShuffleState: - _run_id_iterator: ClassVar[itertools.count] = itertools.count(1) - id: ShuffleId - run_id: int +@dataclass(eq=False) +class SchedulerShuffleState(Generic[_T_partition_id]): + run_spec: ShuffleRunSpec participating_workers: set[str] _archived_by: str | None = field(default=None, init=False) + @property + def id(self) -> ShuffleId: + return self.run_spec.id + + @property + def run_id(self) -> int: + return self.run_spec.run_id + def __str__(self) -> str: return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 8922c515c9a..ec5f3f0093c 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -114,17 +114,13 @@ from distributed.exceptions import Reschedule from distributed.shuffle._core import ( NDIndex, - SchedulerShuffleState, ShuffleId, ShuffleRun, - ShuffleRunSpec, ShuffleSpec, get_worker_plugin, - run_spec_to_worker_run, - scheduler_state_to_run_spec, - spec_to_scheduler_state, ) from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin from distributed.shuffle._shuffle import barrier_key, shuffle_barrier from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof @@ -135,9 +131,6 @@ import dask.array as da - # circular dependency - from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin - ChunkedAxis: TypeAlias = tuple[float, ...] # chunks must either be an int or NaN ChunkedAxes: TypeAlias = tuple[ChunkedAxis, ...] NDSlice: TypeAlias = tuple[slice, ...] @@ -459,77 +452,40 @@ def _get_assigned_worker(self, id: NDIndex) -> str: return self.worker_for[id] -@dataclass -class ArrayRechunkSpec(ShuffleSpec): - new: ChunkedAxes - old: ChunkedAxes - - -@dataclass -class ArrayRechunkRunSpec(ShuffleRunSpec): +@dataclass(frozen=True) +class ArrayRechunkSpec(ShuffleSpec[NDIndex]): new: ChunkedAxes old: ChunkedAxes + def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[NDIndex, str]: + parts_out = product(*(range(len(c)) for c in self.new)) + return plugin._pin_output_workers( + self.id, parts_out, _get_worker_for_hash_sharding + ) -@dataclass(eq=False) -class ArrayRechunkState(SchedulerShuffleState): - worker_for: dict[NDIndex, str] - old: ChunkedAxes - new: ChunkedAxes - - -@spec_to_scheduler_state.register(ArrayRechunkSpec) # type: ignore[misc] -def _array_rechunk_from_spec( - spec: ArrayRechunkSpec, plugin: ShuffleSchedulerPlugin -) -> ArrayRechunkState: - parts_out = product(*(range(len(c)) for c in spec.new)) - mapping = plugin._pin_output_workers( - spec.id, parts_out, _get_worker_for_hash_sharding - ) - output_workers = set(mapping.values()) - - return ArrayRechunkState( - id=spec.id, - run_id=next(SchedulerShuffleState._run_id_iterator), - worker_for=mapping, - old=spec.old, - new=spec.new, - participating_workers=output_workers, - ) - - -@scheduler_state_to_run_spec.register(ArrayRechunkState) # type: ignore[misc] -def _array_state_to_run_spec(state: ArrayRechunkState) -> ArrayRechunkRunSpec: - return ArrayRechunkRunSpec( - id=state.id, - run_id=state.run_id, - worker_for=state.worker_for, - new=state.new, - old=state.old, - ) - - -@run_spec_to_worker_run.register(ArrayRechunkRunSpec) # type: ignore[misc] -def _array_run_spec_to_run( - spec: ArrayRechunkRunSpec, plugin: ShuffleWorkerPlugin -) -> ShuffleRun: - return ArrayRechunkRun( - worker_for=spec.worker_for, - old=spec.old, - new=spec.new, - id=spec.id, - run_id=spec.run_id, - directory=os.path.join( - plugin.worker.local_directory, - f"shuffle-{spec.id}-{spec.run_id}", - ), - executor=plugin._executor, - local_address=plugin.worker.address, - rpc=plugin.worker.rpc, - scheduler=plugin.worker.scheduler, - memory_limiter_disk=plugin.memory_limiter_disk, - memory_limiter_comms=plugin.memory_limiter_comms, - ) + def initialize_run_on_worker( + self, + run_id: int, + worker_for: dict[NDIndex, str], + plugin: ShuffleWorkerPlugin, + ) -> ShuffleRun: + return ArrayRechunkRun( + worker_for=worker_for, + old=self.old, + new=self.new, + id=self.id, + run_id=run_id, + directory=os.path.join( + plugin.worker.local_directory, + f"shuffle-{self.id}-{run_id}", + ), + executor=plugin._executor, + local_address=plugin.worker.address, + rpc=plugin.worker.rpc, + scheduler=plugin.worker.scheduler, + memory_limiter_disk=plugin.memory_limiter_disk, + memory_limiter_comms=plugin.memory_limiter_comms, + ) def _get_worker_for_hash_sharding( diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 7ff79e2ee53..0a511145bc8 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -16,8 +16,6 @@ ShuffleSpec, barrier_key, id_from_key, - scheduler_state_to_run_spec, - spec_to_scheduler_state, ) from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin @@ -113,7 +111,7 @@ def get(self, id: ShuffleId, worker: str) -> ToPickle[ShuffleRunSpec]: ) # pragma: nocover state = self.active_shuffles[id] state.participating_workers.add(worker) - return ToPickle(scheduler_state_to_run_spec(state)) + return ToPickle(state.run_spec) def get_or_create( self, @@ -129,11 +127,11 @@ def get_or_create( # that the shuffle works as intended and should fail instead. self._raise_if_barrier_unknown(spec.id) self._raise_if_task_not_processing(key) - state: SchedulerShuffleState = spec_to_scheduler_state(spec, self) + state = spec.create_new_run(self) self.active_shuffles[spec.id] = state self._shuffles[spec.id].add(state) state.participating_workers.add(worker) - return ToPickle(scheduler_state_to_run_spec(state)) + return ToPickle(state.run_spec) def _raise_if_barrier_unknown(self, id: ShuffleId) -> None: key = barrier_key(id) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index ce1e506a88a..8f57c8e8138 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -26,19 +26,15 @@ ) from distributed.shuffle._core import ( NDIndex, - SchedulerShuffleState, ShuffleId, ShuffleRun, - ShuffleRunSpec, ShuffleSpec, barrier_key, get_worker_plugin, - run_spec_to_worker_run, - scheduler_state_to_run_spec, - spec_to_scheduler_state, ) from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof @@ -52,9 +48,6 @@ from dask.dataframe import DataFrame - # circular dependency - from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin - def shuffle_transfer( input: pd.DataFrame, @@ -499,72 +492,35 @@ def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id] -@dataclass -class DataFrameShuffleSpec(ShuffleSpec): +@dataclass(frozen=True) +class DataFrameShuffleSpec(ShuffleSpec[int]): npartitions: int column: str parts_out: set[int] - -@dataclass -class DataFrameShuffleRunSpec(ShuffleRunSpec): - column: str - - -@dataclass(eq=False) -class DataFrameShuffleState(SchedulerShuffleState): - worker_for: dict[int, str] - column: str - - -@spec_to_scheduler_state.register(DataFrameShuffleSpec) # type: ignore[misc] -def _dataframe_spec_to_state( - spec: DataFrameShuffleSpec, plugin: ShuffleSchedulerPlugin -) -> DataFrameShuffleState: - pick_worker = partial(_get_worker_for_range_sharding, spec.npartitions) - - mapping = plugin._pin_output_workers(spec.id, spec.parts_out, pick_worker) - output_workers = set(mapping.values()) - - return DataFrameShuffleState( - id=spec.id, - run_id=next(SchedulerShuffleState._run_id_iterator), - worker_for=mapping, - column=spec.column, - participating_workers=output_workers, - ) - - -@scheduler_state_to_run_spec.register(DataFrameShuffleState) # type: ignore[misc] -def _dataframe_state_to_run_spec(state: DataFrameShuffleState) -> ShuffleRunSpec: - return DataFrameShuffleRunSpec( - id=state.id, - run_id=state.run_id, - worker_for=state.worker_for, - column=state.column, - ) - - -@run_spec_to_worker_run.register(DataFrameShuffleRunSpec) # type: ignore[misc] -def _dataframe_run_spec_to_run( - spec: DataFrameShuffleRunSpec, plugin: ShuffleWorkerPlugin -) -> DataFrameShuffleRun: - return DataFrameShuffleRun( - column=spec.column, - worker_for=spec.worker_for, - id=spec.id, - run_id=spec.run_id, - directory=os.path.join( - plugin.worker.local_directory, - f"shuffle-{spec.id}-{spec.run_id}", - ), - executor=plugin._executor, - local_address=plugin.worker.address, - rpc=plugin.worker.rpc, - scheduler=plugin.worker.scheduler, - memory_limiter_disk=plugin.memory_limiter_disk, - memory_limiter_comms=plugin.memory_limiter_comms, - ) + def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[int, str]: + pick_worker = partial(_get_worker_for_range_sharding, self.npartitions) + return plugin._pin_output_workers(self.id, self.parts_out, pick_worker) + + def initialize_run_on_worker( + self, run_id: int, worker_for: dict[int, str], plugin: ShuffleWorkerPlugin + ) -> ShuffleRun: + return DataFrameShuffleRun( + column=self.column, + worker_for=worker_for, + id=self.id, + run_id=run_id, + directory=os.path.join( + plugin.worker.local_directory, + f"shuffle-{self.id}-{run_id}", + ), + executor=plugin._executor, + local_address=plugin.worker.address, + rpc=plugin.worker.rpc, + scheduler=plugin.worker.scheduler, + memory_limiter_disk=plugin.memory_limiter_disk, + memory_limiter_comms=plugin.memory_limiter_comms, + ) def _get_worker_for_range_sharding( diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index fcc79afa984..bcbc9e29bbe 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -16,7 +16,6 @@ ShuffleRun, ShuffleRunSpec, ShuffleSpec, - run_spec_to_worker_run, ) from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter @@ -289,7 +288,9 @@ async def _( self.worker._ongoing_background_tasks.call_soon(_, self, existing) - shuffle: ShuffleRun = run_spec_to_worker_run(result, self) + shuffle: ShuffleRun = result.spec.initialize_run_on_worker( + result.run_id, result.worker_for, self + ) self.shuffles[shuffle_id] = shuffle self._runs.add(shuffle) return shuffle From 82516788bbcf758c8bdf8b1e0bdd467cf666a77e Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 17 Aug 2023 12:38:09 +0200 Subject: [PATCH 6/7] Docs --- distributed/shuffle/_core.py | 6 +++--- distributed/shuffle/_rechunk.py | 2 +- distributed/shuffle/_shuffle.py | 2 +- distributed/shuffle/_worker_plugin.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 3f0e5614e13..e7dd72c73cf 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -285,16 +285,16 @@ def create_new_run( def _pin_output_workers( self, plugin: ShuffleSchedulerPlugin ) -> dict[_T_partition_id, str]: - """TODO""" + """Pin output tasks to workers and return the mapping of partition ID to worker.""" @abc.abstractmethod - def initialize_run_on_worker( + def create_run_on_worker( self, run_id: int, worker_for: dict[_T_partition_id, str], plugin: ShuffleWorkerPlugin, ) -> ShuffleRun: - """TODO""" + """Create the new shuffle run on the worker.""" @dataclass(eq=False) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index ec5f3f0093c..0a28da5a4b2 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -463,7 +463,7 @@ def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[NDIndex, s self.id, parts_out, _get_worker_for_hash_sharding ) - def initialize_run_on_worker( + def create_run_on_worker( self, run_id: int, worker_for: dict[NDIndex, str], diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 8f57c8e8138..78cec1a740f 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -502,7 +502,7 @@ def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[int, str]: pick_worker = partial(_get_worker_for_range_sharding, self.npartitions) return plugin._pin_output_workers(self.id, self.parts_out, pick_worker) - def initialize_run_on_worker( + def create_run_on_worker( self, run_id: int, worker_for: dict[int, str], plugin: ShuffleWorkerPlugin ) -> ShuffleRun: return DataFrameShuffleRun( diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index bcbc9e29bbe..bc8b1f3d4d1 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -288,7 +288,7 @@ async def _( self.worker._ongoing_background_tasks.call_soon(_, self, existing) - shuffle: ShuffleRun = result.spec.initialize_run_on_worker( + shuffle: ShuffleRun = result.spec.create_run_on_worker( result.run_id, result.worker_for, self ) self.shuffles[shuffle_id] = shuffle From b76bd7a7f6b8a3b2c97bfedc61f06f070ca56e2a Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 17 Aug 2023 12:42:17 +0200 Subject: [PATCH 7/7] Update docs --- distributed/shuffle/_rechunk.py | 11 ++++------- distributed/shuffle/_shuffle.py | 9 +++------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 0a28da5a4b2..d9f3e179753 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -293,15 +293,14 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): This object is responsible for splitting, sending, receiving and combining data shards. - It is entirely agnostic to the distributed system and can perform a shuffle - with other `Shuffle` instances using `rpc` and `broadcast`. + It is entirely agnostic to the distributed system and can perform a rechunk + with other run instances using `rpc``. - The user of this needs to guarantee that only `Shuffle`s of the same unique - `ShuffleID` interact. + The user of this needs to guarantee that only `ArrayRechunkRun`s of the same unique + `ShuffleID` and `run_id` interact. Parameters ---------- - # FIXME worker_for: A mapping partition_id -> worker_address. old: @@ -318,8 +317,6 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): The scratch directory to buffer data in. executor: Thread pool to use for offloading compute. - loop: - The event loop. rpc: A callable returning a PooledRPCCall to contact other Shuffle instances. Typically a ConnectionPool. diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 78cec1a740f..e16340c6399 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -342,14 +342,13 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): data shards. It is entirely agnostic to the distributed system and can perform a shuffle - with other `Shuffle` instances using `rpc` and `broadcast`. + with other run instances using `rpc`. - The user of this needs to guarantee that only `Shuffle`s of the same unique - `ShuffleID` interact. + The user of this needs to guarantee that only `DataFrameShuffleRun`s of the + same unique `ShuffleID` and `run_id` interact. Parameters ---------- - # FIXME worker_for: A mapping partition_id -> worker_address. column: @@ -364,8 +363,6 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): The scratch directory to buffer data in. executor: Thread pool to use for offloading compute. - loop: - The event loop. rpc: A callable returning a PooledRPCCall to contact other Shuffle instances. Typically a ConnectionPool.