-
-
Notifications
You must be signed in to change notification settings - Fork 762
Make P2P shuffle extensible #8096
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e0adc3a
Refactor and add dispatch
hendrikmakait 82b677d
Adjust error message
hendrikmakait 2cebf93
Minor
hendrikmakait 368f52e
Drop output_workers
hendrikmakait 7283a56
Refactor
hendrikmakait 8251678
Docs
hendrikmakait b76bd7a
Update docs
hendrikmakait File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -96,13 +96,15 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import os | ||
| import pickle | ||
| from collections import defaultdict | ||
| from collections.abc import Callable, Sequence | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from dataclasses import dataclass | ||
| from io import BytesIO | ||
| from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple | ||
| from itertools import product | ||
| from typing import TYPE_CHECKING, Any, NamedTuple | ||
|
|
||
| import dask | ||
| from dask.base import tokenize | ||
|
|
@@ -114,23 +116,21 @@ | |
| NDIndex, | ||
| ShuffleId, | ||
| ShuffleRun, | ||
| ShuffleState, | ||
| ShuffleType, | ||
| barrier_key, | ||
| ShuffleSpec, | ||
| get_worker_plugin, | ||
| ) | ||
| from distributed.shuffle._limiter import ResourceLimiter | ||
| from distributed.shuffle._shuffle import shuffle_barrier | ||
| from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin | ||
| from distributed.shuffle._shuffle import barrier_key, shuffle_barrier | ||
| from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin | ||
| from distributed.sizeof import sizeof | ||
|
|
||
| if TYPE_CHECKING: | ||
| import numpy as np | ||
| import pandas as pd | ||
| from typing_extensions import TypeAlias | ||
|
|
||
| import dask.array as da | ||
|
|
||
|
|
||
| ChunkedAxis: TypeAlias = tuple[float, ...] # chunks must either be an int or NaN | ||
| ChunkedAxes: TypeAlias = tuple[ChunkedAxis, ...] | ||
| NDSlice: TypeAlias = tuple[slice, ...] | ||
|
|
@@ -147,10 +147,7 @@ def rechunk_transfer( | |
| return get_worker_plugin().add_partition( | ||
| input, | ||
| partition_id=input_chunk, | ||
| shuffle_id=id, | ||
| type=ShuffleType.ARRAY_RECHUNK, | ||
| new=new, | ||
| old=old, | ||
| spec=ArrayRechunkSpec(id=id, new=new, old=old), | ||
| ) | ||
| except Exception as e: | ||
| raise RuntimeError(f"rechunk_transfer failed during shuffle {id}") from e | ||
|
|
@@ -296,18 +293,16 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): | |
| This object is responsible for splitting, sending, receiving and combining | ||
| data shards. | ||
|
|
||
| It is entirely agnostic to the distributed system and can perform a shuffle | ||
| with other `Shuffle` instances using `rpc` and `broadcast`. | ||
| It is entirely agnostic to the distributed system and can perform a rechunk | ||
| with other run instances using `rpc``. | ||
|
|
||
| The user of this needs to guarantee that only `Shuffle`s of the same unique | ||
| `ShuffleID` interact. | ||
| The user of this needs to guarantee that only `ArrayRechunkRun`s of the same unique | ||
| `ShuffleID` and `run_id` interact. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| worker_for: | ||
| A mapping partition_id -> worker_address. | ||
| output_workers: | ||
| A set of all participating worker (addresses). | ||
| old: | ||
| Existing chunking of the array per dimension. | ||
| new: | ||
|
|
@@ -322,8 +317,6 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): | |
| The scratch directory to buffer data in. | ||
| executor: | ||
| Thread pool to use for offloading compute. | ||
| loop: | ||
| The event loop. | ||
| rpc: | ||
| A callable returning a PooledRPCCall to contact other Shuffle instances. | ||
| Typically a ConnectionPool. | ||
|
|
@@ -338,7 +331,6 @@ class ArrayRechunkRun(ShuffleRun[NDIndex, "np.ndarray"]): | |
| def __init__( | ||
| self, | ||
| worker_for: dict[NDIndex, str], | ||
| output_workers: set, | ||
| old: ChunkedAxes, | ||
| new: ChunkedAxes, | ||
| id: ShuffleId, | ||
|
|
@@ -354,7 +346,6 @@ def __init__( | |
| super().__init__( | ||
| id=id, | ||
| run_id=run_id, | ||
| output_workers=output_workers, | ||
| local_address=local_address, | ||
| directory=directory, | ||
| executor=executor, | ||
|
|
@@ -403,7 +394,9 @@ def _repartition_shards(self, data: list[bytes]) -> dict[NDIndex, bytes]: | |
| repartitioned[id].append(shard) | ||
| return {k: pickle.dumps(v) for k, v in repartitioned.items()} | ||
|
|
||
| async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int: | ||
| async def add_partition( | ||
| self, data: np.ndarray, partition_id: NDIndex, **kwargs: Any | ||
| ) -> int: | ||
| self.raise_if_closed() | ||
| if self.transferred: | ||
| raise RuntimeError(f"Cannot add more partitions to {self}") | ||
|
|
@@ -441,47 +434,58 @@ def _() -> dict[str, tuple[NDIndex, bytes]]: | |
| return self.run_id | ||
|
|
||
| async def get_output_partition( | ||
| self, partition_id: NDIndex, key: str, meta: pd.DataFrame | None = None | ||
| self, partition_id: NDIndex, key: str, **kwargs: Any | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this switch to
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Future-proofing and making it completely transparent to the plugin. |
||
| ) -> np.ndarray: | ||
| self.raise_if_closed() | ||
| assert meta is None | ||
| assert self.transferred, "`get_output_partition` called before barrier task" | ||
| if not self.transferred: | ||
| raise RuntimeError("`get_output_partition` called before barrier task") | ||
|
|
||
| await self._ensure_output_worker(partition_id, key) | ||
|
|
||
| await self.flush_receive() | ||
|
|
||
| data = self._read_from_disk(partition_id) | ||
|
|
||
| def _() -> np.ndarray: | ||
| return convert_chunk(data) | ||
|
|
||
| return await self.offload(_) | ||
| return await self.offload(convert_chunk, data) | ||
|
|
||
| def _get_assigned_worker(self, id: NDIndex) -> str: | ||
| return self.worker_for[id] | ||
|
|
||
|
|
||
| @dataclass(eq=False) | ||
| class ArrayRechunkState(ShuffleState): | ||
| type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK | ||
| worker_for: dict[NDIndex, str] | ||
| old: ChunkedAxes | ||
| @dataclass(frozen=True) | ||
| class ArrayRechunkSpec(ShuffleSpec[NDIndex]): | ||
| new: ChunkedAxes | ||
| old: ChunkedAxes | ||
|
|
||
| def to_msg(self) -> dict[str, Any]: | ||
| return { | ||
| "status": "OK", | ||
| "type": ArrayRechunkState.type, | ||
| "run_id": self.run_id, | ||
| "worker_for": self.worker_for, | ||
| "old": self.old, | ||
| "new": self.new, | ||
| "output_workers": self.output_workers, | ||
| } | ||
| def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[NDIndex, str]: | ||
| parts_out = product(*(range(len(c)) for c in self.new)) | ||
| return plugin._pin_output_workers( | ||
| self.id, parts_out, _get_worker_for_hash_sharding | ||
| ) | ||
|
|
||
| def create_run_on_worker( | ||
| self, | ||
| run_id: int, | ||
| worker_for: dict[NDIndex, str], | ||
| plugin: ShuffleWorkerPlugin, | ||
| ) -> ShuffleRun: | ||
| return ArrayRechunkRun( | ||
| worker_for=worker_for, | ||
| old=self.old, | ||
| new=self.new, | ||
| id=self.id, | ||
| run_id=run_id, | ||
| directory=os.path.join( | ||
| plugin.worker.local_directory, | ||
| f"shuffle-{self.id}-{run_id}", | ||
| ), | ||
| executor=plugin._executor, | ||
| local_address=plugin.worker.address, | ||
| rpc=plugin.worker.rpc, | ||
| scheduler=plugin.worker.scheduler, | ||
| memory_limiter_disk=plugin.memory_limiter_disk, | ||
| memory_limiter_comms=plugin.memory_limiter_comms, | ||
| ) | ||
|
|
||
|
|
||
| def get_worker_for_hash_sharding( | ||
| def _get_worker_for_hash_sharding( | ||
| output_partition: NDIndex, workers: Sequence[str] | ||
| ) -> str: | ||
| """Get address of target worker for this output partition using hash sharding""" | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, this is this mypy bug