diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index b634e270a5d..26953c67b97 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -75,6 +75,10 @@ jobs: if: ${{ matrix.os == 'windows-latest' && matrix.python-version == '3.9' }} run: mamba uninstall ipython + - name: Install dask branch + shell: bash -l {0} + run: python -m pip install --no-deps git+https://github.com/mrocklin/dask@p2p-shuffle + - name: Install shell: bash -l {0} run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cb9e74ed986..20e54aaa3a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -49,3 +49,4 @@ repos: - dask - tornado - zict + - pyarrow diff --git a/continuous_integration/environment-3.10.yaml b/continuous_integration/environment-3.10.yaml index f11fe529aec..a7b7678224f 100644 --- a/continuous_integration/environment-3.10.yaml +++ b/continuous_integration/environment-3.10.yaml @@ -23,6 +23,7 @@ dependencies: - pre-commit - prometheus_client - psutil + - pyarrow=7 - pytest - pytest-cov - pytest-faulthandler diff --git a/continuous_integration/environment-3.8.yaml b/continuous_integration/environment-3.8.yaml index e6cdf0bf0fe..619361a413f 100644 --- a/continuous_integration/environment-3.8.yaml +++ b/continuous_integration/environment-3.8.yaml @@ -24,6 +24,7 @@ dependencies: - pre-commit - prometheus_client - psutil + - pyarrow=7 - pynvml # Only tested here - pytest - pytest-cov diff --git a/continuous_integration/environment-3.9.yaml b/continuous_integration/environment-3.9.yaml index 17a21f1e5fd..95ba8df90d0 100644 --- a/continuous_integration/environment-3.9.yaml +++ b/continuous_integration/environment-3.9.yaml @@ -25,6 +25,8 @@ dependencies: - pre-commit - prometheus_client - psutil + - pyarrow=7 + - pynvml # Only tested here - pytest - pytest-cov - pytest-faulthandler diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index a19e2838f77..18760138bdd 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1043,7 +1043,7 @@ class SystemTimeseries(DashboardComponent): from ws.metrics["val"] for ws in scheduler.workers.values() divided by nuber of workers. """ - def __init__(self, scheduler, **kwargs): + def __init__(self, scheduler, follow_interval=20000, **kwargs): with log_errors(): self.scheduler = scheduler self.source = ColumnDataSource( @@ -1060,7 +1060,9 @@ def __init__(self, scheduler, **kwargs): update(self.source, self.get_data()) - x_range = DataRange1d(follow="end", follow_interval=20000, range_padding=0) + x_range = DataRange1d( + follow="end", follow_interval=follow_interval, range_padding=0 + ) tools = "reset, xpan, xwheel_zoom" self.bandwidth = figure( @@ -3487,6 +3489,261 @@ def update(self): self.source.data.update(data) +class Shuffling(DashboardComponent): + """Occupancy (in time) per worker""" + + def __init__(self, scheduler, **kwargs): + with log_errors(): + self.scheduler = scheduler + self.source = ColumnDataSource( + { + "worker": [], + "y": [], + "comm_memory": [], + "comm_memory_limit": [], + "comm_buckets": [], + "comm_active": [], + "comm_avg_duration": [], + "comm_avg_size": [], + "comm_read": [], + "comm_written": [], + "comm_color": [], + "disk_memory": [], + "disk_memory_limit": [], + "disk_buckets": [], + "disk_active": [], + "disk_avg_duration": [], + "disk_avg_size": [], + "disk_read": [], + "disk_written": [], + "disk_color": [], + } + ) + self.totals_source = ColumnDataSource( + { + "x": ["Network Send", "Network Receive", "Disk Write", "Disk Read"], + "values": [0, 0, 0, 0], + } + ) + + self.comm_memory = figure( + title="Comms Buffer", + tools="", + toolbar_location="above", + x_range=Range1d(0, 100_000_000), + **kwargs, + ) + self.comm_memory.hbar( + source=self.source, + right="comm_memory", + y="y", + height=0.9, + color="comm_color", + ) + hover = HoverTool( + tooltips=[ + ("Memory Used", "@comm_memory{0.00 b}"), + ("Average Write", "@comm_avg_size{0.00 b}"), + ("# Buckets", "@comm_buckets"), + ("Average Duration", "@comm_avg_duration"), + ], + formatters={"@comm_avg_duration": "datetime"}, + mode="hline", + ) + self.comm_memory.add_tools(hover) + self.comm_memory.x_range.start = 0 + self.comm_memory.x_range.end = 1 + self.comm_memory.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + + self.disk_memory = figure( + title="Disk Buffer", + tools="", + toolbar_location="above", + x_range=Range1d(0, 100_000_000), + **kwargs, + ) + self.disk_memory.yaxis.visible = False + + self.disk_memory.hbar( + source=self.source, + right="disk_memory", + y="y", + height=0.9, + color="disk_color", + ) + + hover = HoverTool( + tooltips=[ + ("Memory Used", "@disk_memory{0.00 b}"), + ("Average Write", "@disk_avg_size{0.00 b}"), + ("# Buckets", "@disk_buckets"), + ("Average Duration", "@disk_avg_duration"), + ], + formatters={"@disk_avg_duration": "datetime"}, + mode="hline", + ) + self.disk_memory.add_tools(hover) + self.disk_memory.xaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + + self.totals = figure( + title="Total movement", + tools="", + toolbar_location="above", + **kwargs, + ) + titles = ["Network Send", "Network Receive", "Disk Write", "Disk Read"] + self.totals = figure( + x_range=titles, + title="Totals", + toolbar_location=None, + tools="", + **kwargs, + ) + + self.totals.vbar( + x="x", + top="values", + width=0.9, + source=self.totals_source, + ) + + self.totals.xgrid.grid_line_color = None + self.totals.y_range.start = 0 + self.totals.yaxis[0].formatter = NumeralTickFormatter(format="0.0 b") + + hover = HoverTool( + tooltips=[("Total", "@values{0.00b}")], + mode="vline", + ) + self.totals.add_tools(hover) + + self.root = row(self.comm_memory, self.disk_memory) + + @without_property_validation + def update(self): + with log_errors(): + input = self.scheduler.extensions["shuffle"].shuffles + if not input: + return + + input = list(input.values())[-1] # TODO: multiple concurrent shuffles + + data = { + "worker": [], + "y": [], + "comm_memory": [], + "comm_memory_limit": [], + "comm_buckets": [], + "comm_active": [], + "comm_avg_duration": [], + "comm_avg_size": [], + "comm_read": [], + "comm_written": [], + "comm_color": [], + "disk_memory": [], + "disk_memory_limit": [], + "disk_buckets": [], + "disk_active": [], + "disk_avg_duration": [], + "disk_avg_size": [], + "disk_read": [], + "disk_written": [], + "disk_color": [], + } + now = time() + + for i, (worker, d) in enumerate(input.items()): + data["y"].append(i) + data["worker"].append(worker) + data["comm_memory"].append(d["comms"]["memory"]) + data["comm_memory_limit"].append(d["comms"]["memory_limit"]) + data["comm_buckets"].append(d["comms"]["buckets"]) + data["comm_active"].append(d["comms"]["active"]) + data["comm_avg_duration"].append( + d["comms"]["diagnostics"].get("avg_duration", 0) + ) + data["comm_avg_size"].append( + d["comms"]["diagnostics"].get("avg_size", 0) + ) + data["comm_read"].append(d["comms"]["read"]) + data["comm_written"].append(d["comms"]["written"]) + try: + if self.scheduler.workers[worker].last_seen < now - 5: + data["comm_color"].append("gray") + elif d["comms"]["active"]: + data["comm_color"].append("green") + elif d["comms"]["memory"] > d["comms"]["memory_limit"]: + data["comm_color"].append("red") + else: + data["comm_color"].append("blue") + except KeyError: + data["comm_color"].append("black") + + data["disk_memory"].append(d["disk"]["memory"]) + data["disk_memory_limit"].append(d["disk"]["memory_limit"]) + data["disk_buckets"].append(d["disk"]["buckets"]) + data["disk_active"].append(d["disk"]["active"]) + data["disk_avg_duration"].append( + d["disk"]["diagnostics"].get("avg_duration", 0) + ) + data["disk_avg_size"].append( + d["disk"]["diagnostics"].get("avg_size", 0) + ) + data["disk_read"].append(d["disk"]["read"]) + data["disk_written"].append(d["disk"]["written"]) + try: + if self.scheduler.workers[worker].last_seen < now - 5: + data["disk_color"].append("gray") + elif d["disk"]["active"]: + data["disk_color"].append("green") + elif d["disk"]["memory"] > d["disk"]["memory_limit"]: + data["disk_color"].append("red") + else: + data["disk_color"].append("blue") + except KeyError: + data["disk_color"].append("black") + + """ + singletons = { + "comm_avg_duration": [ + sum(data["comm_avg_duration"]) / len(data["comm_avg_duration"]) + ], + "comm_avg_size": [ + sum(data["comm_avg_size"]) / len(data["comm_avg_size"]) + ], + "disk_avg_duration": [ + sum(data["disk_avg_duration"]) / len(data["disk_avg_duration"]) + ], + "disk_avg_size": [ + sum(data["disk_avg_size"]) / len(data["disk_avg_size"]) + ], + } + singletons["comm_avg_bandwidth"] = [ + singletons["comm_avg_size"][0] / singletons["comm_avg_duration"][0] + ] + singletons["disk_avg_bandwidth"] = [ + singletons["disk_avg_size"][0] / singletons["disk_avg_duration"][0] + ] + singletons["y"] = [data["y"][-1] / 2] + """ + + totals = { + "x": ["Network Send", "Network Receive", "Disk Write", "Disk Read"], + "values": [ + sum(data["comm_written"]), + sum(data["comm_read"]), + sum(data["disk_written"]), + sum(data["disk_read"]), + ], + } + update(self.totals_source, totals) + + update(self.source, data) + limit = max(data["comm_memory_limit"] + data["disk_memory_limit"]) * 1.2 + self.comm_memory.x_range.end = limit + self.disk_memory.x_range.end = limit + + class SchedulerLogs: def __init__(self, scheduler, start=None): logs = scheduler.get_logs(start=start, timestamps=True) @@ -3531,6 +3788,41 @@ def systemmonitor_doc(scheduler, extra, doc): doc.theme = BOKEH_THEME +def shuffling_doc(scheduler, extra, doc): + with log_errors(): + doc.title = "Dask: Shuffling" + + shuffling = Shuffling(scheduler, width=400, height=400) + workers_memory = WorkersMemory(scheduler, width=400, height=400) + timeseries = SystemTimeseries( + scheduler, width=1600, height=200, follow_interval=3000 + ) + event_loop = EventLoop(scheduler, width=200, height=400) + + add_periodic_callback(doc, shuffling, 200) + add_periodic_callback(doc, workers_memory, 200) + add_periodic_callback(doc, timeseries, 500) + add_periodic_callback(doc, event_loop, 500) + + timeseries.bandwidth.y_range = timeseries.disk.y_range + + doc.add_root( + column( + row( + workers_memory.root, + shuffling.comm_memory, + shuffling.disk_memory, + shuffling.totals, + event_loop.root, + ), + row(column(timeseries.bandwidth, timeseries.disk)), + ) + ) + doc.template = env.get_template("simple.html") + doc.template_variables.update(extra) + doc.theme = BOKEH_THEME + + def stealing_doc(scheduler, extra, doc): with log_errors(): occupancy = Occupancy(scheduler) diff --git a/distributed/dashboard/scheduler.py b/distributed/dashboard/scheduler.py index a7dc0a05331..00894d55092 100644 --- a/distributed/dashboard/scheduler.py +++ b/distributed/dashboard/scheduler.py @@ -37,6 +37,7 @@ individual_profile_server_doc, profile_doc, profile_server_doc, + shuffling_doc, status_doc, stealing_doc, systemmonitor_doc, @@ -49,6 +50,7 @@ applications = { "/system": systemmonitor_doc, + "/shuffle": shuffling_doc, "/stealing": stealing_doc, "/workers": workers_doc, "/events": events_doc, diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 6288ec26377..0d1649de0d5 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -30,6 +30,7 @@ Occupancy, ProcessingHistogram, ProfileServer, + Shuffling, StealingEvents, StealingTimeSeries, SystemMonitor, @@ -1007,6 +1008,21 @@ async def test_prefix_bokeh(s, a, b): assert bokeh_app.prefix == f"/{prefix}" +@gen_cluster(client=True, worker_kwargs={"dashboard": True}) +async def test_shuffling(c, s, a, b): + dd = pytest.importorskip("dask.dataframe") + ss = Shuffling(s) + + df = dask.datasets.timeseries() + df2 = dd.shuffle.shuffle(df, "x", shuffle="p2p").persist() + + start = time() + while not ss.source.data["disk_read"]: + ss.update() + await asyncio.sleep(0.1) + assert time() < start + 5 + + @gen_cluster(client=True, nthreads=[], scheduler_kwargs={"dashboard": True}) async def test_hardware(c, s): plot = Hardware(s) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ddf1e1ad2b9..cc6edecb7ce 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -83,6 +83,7 @@ from distributed.recreate_tasks import ReplayTaskScheduler from distributed.security import Security from distributed.semaphore import SemaphoreExtension +from distributed.shuffle import ShuffleSchedulerExtension from distributed.stealing import WorkStealing from distributed.stories import scheduler_story from distributed.utils import ( @@ -187,6 +188,7 @@ def nogil(func): "events": EventExtension, "amm": ActiveMemoryManagerExtension, "memory_sampler": MemorySamplerExtension, + "shuffle": ShuffleSchedulerExtension, "stealing": WorkStealing, } diff --git a/distributed/shuffle/__init__.py b/distributed/shuffle/__init__.py index 29d5610d373..7a47cba5040 100644 --- a/distributed/shuffle/__init__.py +++ b/distributed/shuffle/__init__.py @@ -2,12 +2,6 @@ from distributed.shuffle.shuffle_extension import ( ShuffleId, ShuffleMetadata, + ShuffleSchedulerExtension, ShuffleWorkerExtension, ) - -__all__ = [ - "rearrange_by_column_p2p", - "ShuffleId", - "ShuffleMetadata", - "ShuffleWorkerExtension", -] diff --git a/distributed/shuffle/arrow.py b/distributed/shuffle/arrow.py new file mode 100644 index 00000000000..f2e757728b7 --- /dev/null +++ b/distributed/shuffle/arrow.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + try: + import pyarrow as pa + except ImportError: + raise ImportError("PyArrow is needed for fast shuffling") + + +def dump_batch(batch, file, schema=None) -> None: + """ + Dump a batch to file, if we're the first, also write the schema + + See Also + -------- + load_arrow + """ + if file.tell() == 0: + file.write(schema.serialize()) + file.write(batch) + + +def load_arrow(file) -> pa.Table: + """Load batched data written to file back out into a table again + + Example + ------- + >>> t = pa.Table.from_pandas(df) # doctest: +SKIP + >>> with open("myfile", mode="wb") as f: # doctest: +SKIP + ... for batch in t.to_batches(): # doctest: +SKIP + ... dump_batch(batch, f, schema=t.schema) # doctest: +SKIP + + >>> with open("myfile", mode="rb") as f: # doctest: +SKIP + ... t = load_arrow(f) # doctest: +SKIP + + See Also + -------- + dump_batch + """ + import pyarrow as pa + + try: + sr = pa.RecordBatchStreamReader(file) + return sr.read_all() + except Exception: + raise EOFError + + +def list_of_buffers_to_table(data: list[pa.Buffer], schema: pa.Schema) -> pa.Table: + """Convert a list of arrow buffers and a schema to an Arrow Table""" + import io + + import pyarrow as pa + + bio = io.BytesIO() + bio.write(schema.serialize()) + for batch in data: + bio.write(batch) + bio.seek(0) + sr = pa.RecordBatchStreamReader(bio) + data = sr.read_all() + bio.close() + return data diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py new file mode 100644 index 00000000000..4307cb275b4 --- /dev/null +++ b/distributed/shuffle/multi_comm.py @@ -0,0 +1,215 @@ +import asyncio +import contextlib +import threading +import time +import weakref +from collections import defaultdict + +from dask.utils import parse_bytes + +from distributed.utils import log_errors + + +class MultiComm: + """Accept, buffer, and send many small messages to many workers + + This takes in lots of small messages destined for remote workers, buffers + those messages in memory, and then sends out batches of them when possible + to different workers. This tries to send larger messages when possible, + while also respecting a memory bound + + **State** + + - shards: dict[str, list[bytes]] + + This is our in-memory buffer of data waiting to be sent to other workers. + + - sizes: dict[str, int] + + The size of each list of shards. We find the largest and send data from that buffer + + State + ----- + + memory_limit: str + A maximum amount of memory to use, like "1 GiB" + max_connections: int + The maximum number of connections to have out at once + max_message_size: str + The maximum size of a single message that we want to send + + Parameters + ---------- + send: callable + How to send a list of shards to a worker + """ + + max_message_size = parse_bytes("2 MiB") + memory_limit = parse_bytes("100 MiB") + max_connections = 10 + _queues: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + total_size = 0 + lock = threading.Lock() + + def __init__( + self, + send=None, + loop=None, + ): + self.send = send + self.shards = defaultdict(list) + self.sizes = defaultdict(int) + self.total_size = 0 + self.total_moved = 0 + self.thread_condition = threading.Condition() + self._futures = set() + self._done = False + self.diagnostics = defaultdict(float) + self._loop = loop or asyncio.get_event_loop() + + self._communicate_future = asyncio.create_task(self.communicate()) + self._exception = None + + @property + def queue(self): + try: + return MultiComm._queues[self._loop] + except KeyError: + queue = asyncio.Queue() + for _ in range(MultiComm.max_connections): + queue.put_nowait(None) + MultiComm._queues[self._loop] = queue + return queue + + def put(self, data: dict): + """ + Put a dict of shards into our buffers + + This is intended to be run from a worker thread, hence the synchronous + nature and the lock. + + If we're out of space then we block in order to enforce backpressure. + """ + if self._exception: + raise self._exception + with self.lock: + for address, shards in data.items(): + size = sum(map(len, shards)) + self.shards[address].extend(shards) + self.sizes[address] += size + self.total_size += size + MultiComm.total_size += size + self.total_moved += size + + del data + + while MultiComm.total_size > self.memory_limit: + with self.time("waiting-on-memory"): + with self.thread_condition: + self.thread_condition.wait(1) # Block until memory calms down + + async def communicate(self): + """ + Continuously find the largest batch and send from there + + We keep ``max_connections`` comms running while we still have any data + as an old comm finishes, we find the next largest buffer, pull off + ``max_message_size`` data from it, and ship it to the target worker. + + We do this until we're done. This coroutine runs in the background. + + See Also + -------- + process: does the actual writing + """ + + while not self._done: + with self.time("idle"): + if not self.shards: + await asyncio.sleep(0.1) + continue + + await self.queue.get() + + with self.lock: + address = max(self.sizes, key=self.sizes.get) + + size = 0 + shards = [] + while size < self.max_message_size: + try: + shard = self.shards[address].pop() + shards.append(shard) + s = len(shard) + size += s + self.sizes[address] -= s + except IndexError: + break + finally: + if not self.shards[address]: + del self.shards[address] + assert not self.sizes[address] + del self.sizes[address] + + assert set(self.sizes) == set(self.shards) + assert shards + future = asyncio.create_task(self.process(address, shards, size)) + del shards + self._futures.add(future) + + async def process(self, address: str, shards: list, size: int): + """Send one message off to a neighboring worker""" + with log_errors(): + + # Consider boosting total_size a bit here to account for duplication + + try: + # while (time.time() // 5 % 4) == 0: + # await asyncio.sleep(0.1) + start = time.time() + try: + with self.time("send"): + await self.send(address, [b"".join(shards)]) + except Exception as e: + self._exception = e + self._done = True + stop = time.time() + self.diagnostics["avg_size"] = ( + 0.95 * self.diagnostics["avg_size"] + 0.05 * size + ) + self.diagnostics["avg_duration"] = 0.98 * self.diagnostics[ + "avg_duration" + ] + 0.02 * (stop - start) + finally: + self.total_size -= size + MultiComm.total_size -= size + with self.thread_condition: + self.thread_condition.notify() + await self.queue.put(None) + + async def flush(self): + """ + We don't expect any more data, wait until everything is flushed through + """ + if self._exception: + await self._communicate_future + await asyncio.gather(*self._futures) + raise self._exception + + while self.shards: + await asyncio.sleep(0.05) + + await asyncio.gather(*self._futures) + self._futures.clear() + + assert not self.total_size + + self._done = True + await self._communicate_future + + @contextlib.contextmanager + def time(self, name: str): + start = time.time() + yield + stop = time.time() + self.diagnostics[name] += stop - start diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/multi_file.py new file mode 100644 index 00000000000..cf323d174ab --- /dev/null +++ b/distributed/shuffle/multi_file.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import asyncio +import contextlib +import os +import pathlib +import pickle +import shutil +import time +import weakref +from collections import defaultdict + +from dask.sizeof import sizeof +from dask.utils import parse_bytes + +from distributed.utils import log_errors + + +class MultiFile: + """Accept, buffer, and write many small objects to many files + + This takes in lots of small objects, writes them to a local directory, and + then reads them back when all writes are complete. It buffers these + objects in memory so that it can optimize disk access for larger writes. + + **State** + + - shards: dict[str, list[bytes]] + + This is our in-memory buffer of data waiting to be sent to other workers. + + - sizes: dict[str, int] + + The size of each list of shards. We find the largest and send data from that buffer + + State + ----- + + memory_limit: str + A maximum amount of memory to use, like "1 GiB" + max_connections: int + The maximum number of connections to have out at once + max_message_size: str + The maximum size of a single message that we want to send + + Parameters + ---------- + directory: pathlib.Path + Where to write and read data. Ideally points to fast disk. + dump: callable + Writes an object to a file, like pickle.dump + load: callable + Reads an object from that file, like pickle.load + send: callable + How to send a list of shards to a worker + sizeof: callable + Measures the size of an object in memory + """ + + memory_limit = parse_bytes("1 GiB") + _queues: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + concurrent_files = 2 + total_size = 0 + + def __init__( + self, + directory, + dump=pickle.dump, + load=pickle.load, + sizeof=sizeof, + loop=None, + ): + self.directory = pathlib.Path(directory) + if not os.path.exists(self.directory): + os.mkdir(self.directory) + self.dump = dump + self.load = load + self.sizeof = sizeof + + self.shards = defaultdict(list) + self.sizes = defaultdict(int) + self.total_size = 0 + self.total_received = 0 + + self.condition = asyncio.Condition() + + self.bytes_written = 0 + self.bytes_read = 0 + + self._done = False + self._futures = set() + self.active = set() + self.diagnostics = defaultdict(float) + + self._communicate_future = asyncio.create_task(self.communicate()) + self._loop = loop or asyncio.get_event_loop() + self._exception = None + + @property + def queue(self): + try: + return MultiFile._queues[self._loop] + except KeyError: + queue = asyncio.Queue() + for _ in range(MultiFile.concurrent_files): + queue.put_nowait(None) + MultiFile._queues[self._loop] = queue + return queue + + async def put(self, data: dict[str, list[object]]): + """ + Writes many objects into the local buffers, blocks until ready for more + + Parameters + ---------- + data: dict + A dictionary mapping destinations to lists of objects that should + be written to that destination + """ + if self._exception: + raise self._exception + + this_size = 0 + for id, shard in data.items(): + size = self.sizeof(shard) + self.shards[id].extend(shard) + self.sizes[id] += size + self.total_size += size + MultiFile.total_size += size + self.total_received += size + this_size += size + + del data, shard + + while MultiFile.total_size > self.memory_limit: + with self.time("waiting-on-memory"): + async with self.condition: + + try: + await asyncio.wait_for( + self.condition.wait(), 1 + ) # Block until memory calms down + except asyncio.TimeoutError: + continue + + async def communicate(self): + """ + Continuously find the largest batch and trigger writes + + We keep ``concurrent_files`` files open, writing while we still have any data + as an old write finishes, we find the next largest buffer, and write + its contents to file. + + We do this until we're done. This coroutine runs in the background. + + See Also + -------- + process: does the actual writing + """ + with log_errors(): + + while not self._done: + with self.time("idle"): + if not self.shards: + await asyncio.sleep(0.1) + continue + + await self.queue.get() + + id = max(self.sizes, key=self.sizes.get) + shards = self.shards.pop(id) + size = self.sizes.pop(id) + + future = asyncio.create_task(self.process(id, shards, size)) + del shards + self._futures.add(future) + async with self.condition: + self.condition.notify() + + async def process(self, id: str, shards: list, size: int): + """Write one buffer to file + + This function was built to offload the disk IO, but since then we've + decided to keep this within the event loop (disk bandwidth should be + prioritized, and writes are typically small enough to not be a big + deal). + + Most of the logic here is about possibly going back to a separate + thread, or about diagnostics. If things don't change much in the + future then we should consider simplifying this considerably and + dropping the write into communicate above. + """ + + with log_errors(): + # Consider boosting total_size a bit here to account for duplication + while id in self.active: + await asyncio.sleep(0.01) + + self.active.add(id) + + start = time.time() + try: + with self.time("write"): + with open( + self.directory / str(id), mode="ab", buffering=100_000_000 + ) as f: + for shard in shards: + self.dump(shard, f) + # os.fsync(f) # TODO: maybe? + except Exception as e: + self._exception = e + self._done = True + + stop = time.time() + + self.diagnostics["avg_size"] = ( + 0.98 * self.diagnostics["avg_size"] + 0.02 * size + ) + self.diagnostics["avg_duration"] = 0.98 * self.diagnostics[ + "avg_duration" + ] + 0.02 * (stop - start) + + self.active.remove(id) + self.bytes_written += size + self.total_size -= size + MultiFile.total_size -= size + async with self.condition: + self.condition.notify() + await self.queue.put(None) + + def read(self, id): + """Read a complete file back into memory""" + if self._exception: + raise self._exception + parts = [] + + try: + with self.time("read"): + with open( + self.directory / str(id), mode="rb", buffering=100_000_000 + ) as f: + while True: + try: + parts.append(self.load(f)) + except EOFError: + break + size = f.tell() + except FileNotFoundError: + raise KeyError(id) + + # TODO: We could consider deleting the file at this point + if parts: + self.bytes_read += size + assert len(parts) == 1 + return parts[0] + else: + raise KeyError(id) + + async def flush(self): + """Wait until all writes are finished""" + if self._exception: + await self._communicate_future + await asyncio.gather(*self._futures) + raise self._exception + while self.shards: + await asyncio.sleep(0.05) + + await asyncio.gather(*self._futures) + if all(future.done() for future in self._futures): + self._futures.clear() + + assert not self.total_size + + self._done = True + + await self._communicate_future + + def close(self): + self._done = True + shutil.rmtree(self.directory) + + def __enter__(self): + return self + + def __exit__(self, exc, typ, traceback): + self.close() + + @contextlib.contextmanager + def time(self, name: str): + start = time.time() + yield + stop = time.time() + self.diagnostics[name] += stop - start diff --git a/distributed/shuffle/shuffle.py b/distributed/shuffle/shuffle.py index c832dd1c366..e054cbc6e64 100644 --- a/distributed/shuffle/shuffle.py +++ b/distributed/shuffle/shuffle.py @@ -63,10 +63,18 @@ def rearrange_by_column_p2p( npartitions = npartitions or df.npartitions token = tokenize(df, column, npartitions) + empty = df._meta.copy() + for c, dt in empty.dtypes.items(): + if dt == object: + empty[c] = empty[c].astype( + "string" + ) # TODO: we fail at non-string object dtypes + empty[column] = empty[column].astype("int64") # TODO: this shouldn't be necesssary + setup = delayed(shuffle_setup, pure=True)( NewShuffleMetadata( ShuffleId(token), - df._meta, + empty, column, npartitions, ) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index 8f13480b91d..08c26eea8c5 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -1,17 +1,32 @@ from __future__ import annotations import asyncio +import contextlib +import functools import math +import os +import time from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import TYPE_CHECKING, NewType +import toolz + from distributed.protocol import to_serialize +from distributed.shuffle.arrow import dump_batch, list_of_buffers_to_table, load_arrow +from distributed.shuffle.multi_comm import MultiComm +from distributed.shuffle.multi_file import MultiFile from distributed.utils import sync if TYPE_CHECKING: import pandas as pd + try: + import pyarrow as pa + except ImportError: + raise ImportError("PyArrow is needed for fast shuffling") + from distributed.worker import Worker ShuffleId = NewType("ShuffleId", str) @@ -49,8 +64,7 @@ def worker_for(self, output_partition: int) -> str: raise IndexError( f"Output partition {output_partition} does not exist in a shuffle producing {self.npartitions} partitions" ) - i = len(self.workers) * output_partition // self.npartitions - return self.workers[i] + return worker_for(output_partition, self.workers, self.npartitions) def _partition_range(self, worker: str) -> tuple[int, int]: "Get the output partition numbers (inclusive) that a worker will hold" @@ -68,43 +82,137 @@ def npartitions_for(self, worker: str) -> int: class Shuffle: "State for a single active shuffle" - def __init__(self, metadata: ShuffleMetadata, worker: Worker) -> None: + def __init__( + self, + metadata: ShuffleMetadata, + worker: Worker, + executor: ThreadPoolExecutor, + ) -> None: self.metadata = metadata self.worker = worker - self.output_partitions: defaultdict[int, list[pd.DataFrame]] = defaultdict(list) + self.executor = executor + + import pyarrow as pa + + self.multi_file = MultiFile( + dump=functools.partial( + dump_batch, schema=pa.Schema.from_pandas(self.metadata.empty) + ), + load=load_arrow, + directory=os.path.join(self.worker.local_directory, str(self.metadata.id)), + sizeof=lambda L: sum(map(len, L)), + loop=worker.io_loop, + ) + + async def send(address, shards): + return await self.worker.rpc(address).shuffle_receive( + data=to_serialize(shards), + shuffle_id=self.metadata.id, + ) + + self.multi_comm = MultiComm( + send=send, + loop=worker.io_loop, + ) + MultiComm.max_connections = min(len(self.metadata.workers), 10) + + self.diagnostics: dict[str, float] = defaultdict(float) self.output_partitions_left = metadata.npartitions_for(worker.address) self.transferred = False + self.total_recvd = 0 + self.start_time = time.time() + self._exception: Exception | None = None + + @contextlib.contextmanager + def time(self, name: str): + start = time.time() + yield + stop = time.time() + self.diagnostics[name] += stop - start + + async def offload(self, func, *args): + # return func(*args) + return await asyncio.get_event_loop().run_in_executor( + self.executor, + func, + *args, + ) + + def heartbeat(self): + return { + "disk": { + "memory": self.multi_file.total_size, + "buckets": len(self.multi_file.shards), + "written": self.multi_file.bytes_written, + "read": self.multi_file.bytes_read, + "active": len(self.multi_file.active), + "diagnostics": self.multi_file.diagnostics, + "memory_limit": self.multi_file.memory_limit, + }, + "comms": { + "memory": self.multi_comm.total_size, + "buckets": len(self.multi_comm.shards), + "written": self.multi_comm.total_moved, + "read": self.total_recvd, + "active": self.multi_comm.queue.qsize(), # TODO: maybe not built yet + "diagnostics": self.multi_comm.diagnostics, + "memory_limit": self.multi_comm.memory_limit, + }, + "diagnostics": self.diagnostics, + "start": self.start_time, + } - def receive(self, output_partition: int, data: pd.DataFrame) -> None: - assert not self.transferred, "`receive` called after barrier task" - self.output_partitions[output_partition].append(data) - - async def add_partition(self, data: pd.DataFrame) -> None: - assert not self.transferred, "`add_partition` called after barrier task" - tasks = [] - # NOTE: `groupby` blocks the event loop, but it also holds the GIL, - # so we don't bother offloading to a thread. See bpo-7946. - for output_partition, data in data.groupby(self.metadata.column): - # NOTE: `column` must refer to an integer column, which is the output partition number for the row. - # This is always `_partitions`, added by `dask/dataframe/shuffle.py::shuffle`. - addr = self.metadata.worker_for(int(output_partition)) - task = asyncio.create_task( - self.worker.rpc(addr).shuffle_receive( - shuffle_id=self.metadata.id, - output_partition=output_partition, - data=to_serialize(data), - ) + async def receive(self, data: list[pa.Buffer]) -> None: + # This is actually ok. Our local barrier might have finished, + # but barriers on other workers might still be running and sending us + # data + # assert not self.transferred, "`receive` called after barrier task" + if self._exception: + raise self._exception + import pyarrow as pa + + self.total_recvd += sum(map(len, data)) + # An ugly way of turning these batches back into an arrow table + with self.time("cpu"): + data = await self.offload( + list_of_buffers_to_table, + data, + pa.Schema.from_pandas(self.metadata.empty), ) - tasks.append(task) - # TODO Once RerunGroup logic exists (https://github.com/dask/distributed/issues/5403), - # handle errors and cancellation here in a way that lets other workers cancel & clean up their shuffles. - # Without it, letting errors kill the task is all we can do. - await asyncio.gather(*tasks) + groups = await self.offload(split_by_partition, data, self.metadata.column) - def get_output_partition(self, i: int) -> pd.DataFrame: - import pandas as pd + assert len(data) == sum(map(len, groups.values())) + del data + + with self.time("cpu"): + groups = await self.offload( + lambda: { + k: [batch.serialize() for batch in v.to_batches()] + for k, v in groups.items() + } + ) + try: + await self.multi_file.put(groups) + except Exception as e: + self._exception = e + + def add_partition(self, data: pd.DataFrame) -> None: + with self.time("cpu"): + out = split_by_worker( + data, + self.metadata.column, + self.metadata.npartitions, + self.metadata.workers, + ) + assert len(data) == sum(map(len, out.values())) + out = { + k: [b.serialize().to_pybytes() for b in t.to_batches()] + for k, t in out.items() + } + self.multi_comm.put(out) + def get_output_partition(self, i: int) -> pd.DataFrame: assert self.transferred, "`get_output_partition` called before barrier task" assert self.metadata.worker_for(i) == self.worker.address, ( @@ -119,13 +227,13 @@ def get_output_partition(self, i: int) -> pd.DataFrame: ), f"No outputs remaining, but requested output partition {i} on {self.worker.address}." self.output_partitions_left -= 1 + sync(self.worker.loop, self.multi_file.flush) # type: ignore try: - parts = self.output_partitions.pop(i) + df = self.multi_file.read(i) + with self.time("cpu"): + return df.to_pandas() except KeyError: - return self.metadata.empty - - assert parts, f"Empty entry for output partition {i}" - return pd.concat(parts, copy=False) + return self.metadata.empty.head(0) def inputs_done(self) -> None: assert not self.transferred, "`inputs_done` called multiple times" @@ -148,6 +256,7 @@ def __init__(self, worker: Worker) -> None: # Initialize self.worker: Worker = worker self.shuffles: dict[ShuffleId, Shuffle] = {} + self.executor = ThreadPoolExecutor(worker.nthreads) # Handlers ########## @@ -162,27 +271,40 @@ def shuffle_init(self, comm: object, metadata: ShuffleMetadata) -> None: raise ValueError( f"Shuffle {metadata.id!r} is already registered on worker {self.worker.address}" ) - self.shuffles[metadata.id] = Shuffle(metadata, self.worker) + self.shuffles[metadata.id] = Shuffle( + metadata, + self.worker, + self.executor, + ) - def shuffle_receive( + def heartbeat(self): + return {id: shuffle.heartbeat() for id, shuffle in self.shuffles.items()} + + async def shuffle_receive( self, comm: object, shuffle_id: ShuffleId, - output_partition: int, - data: pd.DataFrame, + data: list[bytes], ) -> None: """ Hander: Receive an incoming shard of data from a peer worker. Using an unknown ``shuffle_id`` is an error. """ - self._get_shuffle(shuffle_id).receive(output_partition, data) - - def shuffle_inputs_done(self, comm: object, shuffle_id: ShuffleId) -> None: + shuffle = self._get_shuffle(shuffle_id) + future = asyncio.ensure_future(shuffle.receive(data)) + if ( + shuffle.multi_file.total_size + sum(map(len, data)) + > shuffle.multi_file.memory_limit + ): + await future # backpressure + + async def shuffle_inputs_done(self, comm: object, shuffle_id: ShuffleId) -> None: """ Hander: Inform the extension that all input partitions have been handed off to extensions. Using an unknown ``shuffle_id`` is an error. """ shuffle = self._get_shuffle(shuffle_id) + await shuffle.multi_comm.flush() shuffle.inputs_done() if shuffle.done(): # If the shuffle has no output partitions, remove it now; @@ -241,17 +363,7 @@ async def _create_shuffle( return metadata # NOTE: unused in tasks, just handy for tests def add_partition(self, data: pd.DataFrame, shuffle_id: ShuffleId) -> None: - sync(self.worker.loop, self._add_partition, data, shuffle_id) - - async def _add_partition(self, data: pd.DataFrame, shuffle_id: ShuffleId) -> None: - """ - Task: Hand off an input partition to the ShuffleExtension. - - This will block until the extension is ready to receive another input partition. - - Using an unknown ``shuffle_id`` is an error. - """ - await self._get_shuffle(shuffle_id).add_partition(data) + self._get_shuffle(shuffle_id).add_partition(data=data) def barrier(self, shuffle_id: ShuffleId) -> None: sync(self.worker.loop, self._barrier, shuffle_id) @@ -317,8 +429,8 @@ def get_output_partition( shuffle = self._get_shuffle(shuffle_id) output = shuffle.get_output_partition(output_partition) if shuffle.done(): - # key missing if another thread got to it first self.shuffles.pop(shuffle_id, None) + # key missing if another thread got to it first return output def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: @@ -329,3 +441,92 @@ def _get_shuffle(self, shuffle_id: ShuffleId) -> Shuffle: raise ValueError( f"Shuffle {shuffle_id!r} is not registered on worker {self.worker.address}" ) from None + + def close(self): + self.executor.shutdown() + + +class ShuffleSchedulerExtension: + """ + Shuffle extension for the scheduler + + Today this mostly just collects heartbeat messages for the dashboard, + but in the future it may be responsible for more + + See Also + -------- + ShuffleWorkerExtension + """ + + def __init__(self, scheduler): + self.scheduler = scheduler + self.shuffles = defaultdict(lambda: defaultdict(dict)) + + def heartbeat(self, ws, data): + for shuffle_id, d in data.items(): + self.shuffles[shuffle_id][ws.address].update(d) + + +def worker_for(output_partition: int, workers: list[str], npartitions: int) -> str: + "Get the address of the worker which should hold this output partition number" + i = len(workers) * output_partition // npartitions + return workers[i] + + +def split_by_worker( + df: pd.DataFrame, column: str, npartitions: int, workers: list[str] +) -> dict: + """ + Split data into many arrow batches, partitioned by destination worker + """ + import numpy as np + import pandas as pd + import pyarrow as pa + + grouper = (len(workers) * df[column] // npartitions).astype(df[column].dtype).values + + t = pa.Table.from_pandas(df) + del df + t = t.add_column(len(t.columns), "_worker", [grouper]) + t = t.sort_by("_worker") + + worker = np.asarray(t.select(["_worker"]))[0] + t = t.drop(["_worker"]) + splits = np.where(worker[1:] != worker[:-1])[0] + 1 + splits = np.concatenate([[0], splits]) + + shards = [ + t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) + ] + shards.append(t.slice(offset=splits[-1], length=None)) + + w_unique = pd.Series(grouper).unique() + w_unique.sort() + + return {workers[w]: shard for w, shard in zip(w_unique, shards)} + + +def split_by_partition( + t: pa.Table, + column: str, +) -> dict: + """ + Split data into many arrow batches, partitioned by final partition + """ + import numpy as np + + partitions = t.select([column]).to_pandas()[column].unique() + partitions.sort() + t = t.sort_by(column) + + partition = np.asarray(t.select([column]))[0] + splits = np.where(partition[1:] != partition[:-1])[0] + 1 + splits = np.concatenate([[0], splits]) + + shards = [ + t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) + ] + shards.append(t.slice(offset=splits[-1], length=None)) + assert len(t) == sum(map(len, shards)) + assert len(partitions) == len(shards) + return dict(zip(partitions, shards)) diff --git a/distributed/shuffle/tests/test_graph.py b/distributed/shuffle/tests/test_graph.py index 3844ff3db49..d261ee67097 100644 --- a/distributed/shuffle/tests/test_graph.py +++ b/distributed/shuffle/tests/test_graph.py @@ -46,6 +46,7 @@ def test_shuffle_helper(client: Client): def test_basic(client: Client): df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + df["name"] = df["name"].astype("string[python]") shuffled = shuffle(df, "id") (opt,) = dask.optimize(shuffled) @@ -79,6 +80,7 @@ async def test_basic_state(c: Client, s: Scheduler, *workers: Worker): def test_multiple_linear(client: Client): df = dd.demo.make_timeseries(freq="15D", partition_freq="30D") + df["name"] = df["name"].astype("string[python]") s1 = shuffle(df, "id") s1["x"] = s1["x"] + 1 s2 = shuffle(s1, "x") diff --git a/distributed/shuffle/tests/test_multi_comm.py b/distributed/shuffle/tests/test_multi_comm.py new file mode 100644 index 00000000000..d048b1faad8 --- /dev/null +++ b/distributed/shuffle/tests/test_multi_comm.py @@ -0,0 +1,43 @@ +import asyncio +from collections import defaultdict + +import pytest + +from distributed.shuffle.multi_comm import MultiComm + + +@pytest.mark.asyncio +async def test_basic(tmp_path): + d = defaultdict(list) + + async def send(address, shards): + d[address].extend(shards) + + mc = MultiComm(send=send) + mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + await mc.flush() + + assert b"".join(d["x"]) == b"0" * 2000 + assert b"".join(d["y"]) == b"1" * 1000 + + +@pytest.mark.asyncio +async def test_exceptions(tmp_path): + d = defaultdict(list) + + async def send(address, shards): + raise Exception(123) + + mc = MultiComm(send=send) + mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + while not mc._exception: + await asyncio.sleep(0.1) + + with pytest.raises(Exception, match="123"): + mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + with pytest.raises(Exception, match="123"): + await mc.flush() diff --git a/distributed/shuffle/tests/test_multi_file.py b/distributed/shuffle/tests/test_multi_file.py new file mode 100644 index 00000000000..0e042c0cc2e --- /dev/null +++ b/distributed/shuffle/tests/test_multi_file.py @@ -0,0 +1,70 @@ +import asyncio +import os + +import pytest + +from distributed.shuffle.multi_file import MultiFile + + +def dump(data, f): + f.write(data) + + +def load(f): + out = f.read() + if not out: + raise EOFError() + return out + + +@pytest.mark.asyncio +async def test_basic(tmp_path): + with MultiFile(directory=tmp_path, dump=dump, load=load) as mf: + await mf.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + await mf.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + await mf.flush() + + x = mf.read("x") + y = mf.read("y") + + assert x == b"0" * 2000 + assert y == b"1" * 1000 + + assert not os.path.exists(tmp_path) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("count", [2, 100, 1000]) +async def test_many(tmp_path, count): + with MultiFile(directory=tmp_path, dump=dump, load=load) as mf: + d = {i: [str(i).encode() * 100] for i in range(count)} + + for i in range(10): + await mf.put(d) + + await mf.flush() + + for i in d: + out = mf.read(i) + assert out == str(i).encode() * 100 * 10 + + assert not os.path.exists(tmp_path) + + +@pytest.mark.asyncio +async def test_exceptions(tmp_path): + def dump(data, f): + raise Exception(123) + + with MultiFile(directory=tmp_path, dump=dump, load=load) as mf: + await mf.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + while not mf._exception: + await asyncio.sleep(0.1) + + with pytest.raises(Exception, match="123"): + await mf.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + with pytest.raises(Exception, match="123"): + await mf.flush() diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py new file mode 100644 index 00000000000..027d3d24bdc --- /dev/null +++ b/distributed/shuffle/tests/test_shuffle.py @@ -0,0 +1,197 @@ +import asyncio +import io +import shutil +from collections import defaultdict + +import pandas as pd +import pytest + +pa = pytest.importorskip("pyarrow") + +import dask +import dask.dataframe as dd + +from distributed.shuffle.shuffle_extension import ( + dump_batch, + list_of_buffers_to_table, + load_arrow, + split_by_partition, + split_by_worker, +) +from distributed.utils_test import gen_cluster + + +@gen_cluster(client=True, timeout=1000000) +async def test_basic(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + x, y = c.compute([df.x.size, out.x.size]) + x = await x + y = await y + assert x == y + + +@gen_cluster(client=True, timeout=1000000) +async def test_concurrent(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + x = dd.shuffle.shuffle(df, "x", shuffle="p2p") + y = dd.shuffle.shuffle(df, "y", shuffle="p2p") + x, y = c.compute([x.x.size, y.y.size]) + x = await x + y = await y + assert x == y + + +@gen_cluster(client=True) +async def test_bad_disk(c, s, a, b): + + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + while not a.extensions["shuffle"].shuffles: + await asyncio.sleep(0.01) + shutil.rmtree(a.local_directory) + + while not b.extensions["shuffle"].shuffles: + await asyncio.sleep(0.01) + shutil.rmtree(b.local_directory) + with pytest.raises(FileNotFoundError) as e: + out = await c.compute(out) + + assert a.local_directory in str(e.value) or b.local_directory in str(e.value) + + +@pytest.mark.slow +@gen_cluster(client=True) +async def test_crashed_worker(c, s, a, b): + + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + + while ( + len( + [ + ts + for ts in s.tasks.values() + if "shuffle_transfer" in ts.key and ts.state == "memory" + ] + ) + < 3 + ): + await asyncio.sleep(0.01) + await b.close() + + with pytest.raises(Exception) as e: + out = await c.compute(out) + + assert a.address in str(e.value) or b.address in str(e.value) + + +@gen_cluster(client=True) +async def test_heartbeat(c, s, a, b): + await a.heartbeat() + assert not s.extensions["shuffle"].shuffles + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + + while not s.extensions["shuffle"].shuffles: + await asyncio.sleep(0.001) + await a.heartbeat() + + [s] = s.extensions["shuffle"].shuffles.values() + await out + + +def test_processing_chain(): + """ + This is a serial version of the entire compute chain + + In practice this takes place on many different workers. + Here we verify its accuracy in a single threaded situation. + """ + workers = ["a", "b", "c"] + npartitions = 5 + df = pd.DataFrame({"x": range(100), "y": range(100)}) + df["_partitions"] = df.x % npartitions + schema = pa.Schema.from_pandas(df) + + data = split_by_worker(df, "_partitions", npartitions, workers) + assert set(data) == set(workers) + + batches = { + worker: [b.serialize().to_pybytes() for b in t.to_batches()] + for worker, t in data.items() + } + + # Typically we communicate to different workers at this stage + # We then receive them back and reconstute them + + by_worker = { + worker: list_of_buffers_to_table(list_of_batches, schema) + for worker, list_of_batches in batches.items() + } + + # We split them again, and then dump them down to disk + + splits_by_worker = { + worker: split_by_partition(t, "_partitions") for worker, t in by_worker.items() + } + + splits_by_worker = { + worker: { + partition: [batch.serialize() for batch in t.to_batches()] + for partition, t in d.items() + } + for worker, d in splits_by_worker.items() + } + + # No two workers share data from any partition + assert not any( + set(a) & set(b) + for w1, a in splits_by_worker.items() + for w2, b in splits_by_worker.items() + if w1 is not w2 + ) + + # Our simple file system + + filesystem = defaultdict(io.BytesIO) + + for worker, partitions in splits_by_worker.items(): + for partition, batches in partitions.items(): + for batch in batches: + dump_batch(batch, filesystem[partition], schema) + + out = {} + for k, bio in filesystem.items(): + bio.seek(0) + out[k] = load_arrow(bio) + + assert sum(map(len, out.values())) == len(df) diff --git a/distributed/shuffle/tests/test_shuffle_extension.py b/distributed/shuffle/tests/test_shuffle_extension.py index ec50e596561..80c3475e904 100644 --- a/distributed/shuffle/tests/test_shuffle_extension.py +++ b/distributed/shuffle/tests/test_shuffle_extension.py @@ -15,6 +15,9 @@ ShuffleId, ShuffleMetadata, ShuffleWorkerExtension, + split_by_partition, + split_by_worker, + worker_for, ) from distributed.utils_test import gen_cluster @@ -151,137 +154,54 @@ async def test_create(s: Scheduler, *workers: Worker): await exts[0]._create_shuffle(new_metadata) -@gen_cluster([("", 1)] * 4) -async def test_add_partition(s: Scheduler, *workers: Worker): - exts: dict[str, ShuffleWorkerExtension] = { - w.address: w.extensions["shuffle"] for w in workers - } - - new_metadata = NewShuffleMetadata( - ShuffleId("foo"), - pd.DataFrame({"A": [], "partition": []}), - "partition", - 8, - ) - - ext = next(iter(exts.values())) - metadata = await ext._create_shuffle(new_metadata) - partition = pd.DataFrame( +def test_split_by_worker(): + df = pd.DataFrame( { - "A": ["a", "b", "c", "d", "e", "f", "g", "h"], - "partition": [0, 1, 2, 3, 4, 5, 6, 7], + "x": [1, 2, 3, 4, 5], + "_partition": [0, 1, 2, 0, 1], } ) - await ext._add_partition(partition, new_metadata.id) + workers = ["alice", "bob"] + npartitions = 3 - with pytest.raises(ValueError, match="not registered"): - await ext._add_partition(partition, ShuffleId("bar")) + out = split_by_worker(df, "_partition", npartitions, workers) + assert set(out) == {"alice", "bob"} + assert out["alice"].column_names == list(df.columns) - for i, data in partition.groupby(new_metadata.column): - addr = metadata.worker_for(int(i)) - ext = exts[addr] - received = ext.shuffles[metadata.id].output_partitions[int(i)] - assert len(received) == 1 - dd.utils.assert_eq(data, received[0]) + assert sum(map(len, out.values())) == len(df) - # TODO (resilience stage) test failed sends - -@gen_cluster([("", 1)] * 4, client=True) -async def test_barrier(c: Client, s: Scheduler, *workers: Worker): - exts: dict[str, ShuffleWorkerExtension] = { - w.address: w.extensions["shuffle"] for w in workers - } - - new_metadata = NewShuffleMetadata( - ShuffleId("foo"), - pd.DataFrame({"A": [], "partition": []}), - "partition", - 4, - ) - fs = await add_dummy_unpack_keys(new_metadata, c) - - ext = next(iter(exts.values())) - metadata = await ext._create_shuffle(new_metadata) - partition = pd.DataFrame( +def test_split_by_worker_many_workers(): + df = pd.DataFrame( { - "A": ["a", "b", "c"], - "partition": [0, 1, 2], + "x": [1, 2, 3, 4, 5], + "_partition": [5, 7, 5, 0, 1], } ) - await ext._add_partition(partition, metadata.id) - - await ext._barrier(metadata.id) - - # Check scheduler restrictions were set for unpack tasks - for i, key in enumerate(fs): - assert s.tasks[key].worker_restrictions == {metadata.worker_for(i)} + workers = ["a", "b", "c", "d", "e", "f", "g", "h"] + npartitions = 10 - # Check all workers have been informed of the barrier - for addr, ext in exts.items(): - if metadata.npartitions_for(addr): - shuffle = ext.shuffles[metadata.id] - assert shuffle.transferred - assert not shuffle.done() - else: - # No output partitions on this worker; shuffle already cleaned up - assert not ext.shuffles + out = split_by_worker(df, "_partition", npartitions, workers) + assert worker_for(5, workers, npartitions) in out + assert worker_for(0, workers, npartitions) in out + assert worker_for(7, workers, npartitions) in out + assert worker_for(1, workers, npartitions) in out + assert sum(map(len, out.values())) == len(df) -@gen_cluster([("", 1)] * 4, client=True) -async def test_get_partition(c: Client, s: Scheduler, *workers: Worker): - exts: dict[str, ShuffleWorkerExtension] = { - w.address: w.extensions["shuffle"] for w in workers - } - new_metadata = NewShuffleMetadata( - ShuffleId("foo"), - pd.DataFrame({"A": [], "partition": []}), - "partition", - 8, - ) - _ = await add_dummy_unpack_keys(new_metadata, c) +def test_split_by_partition(): + import pyarrow as pa - ext = next(iter(exts.values())) - metadata = await ext._create_shuffle(new_metadata) - p1 = pd.DataFrame( - { - "A": ["a", "b", "c", "d", "e", "f", "g", "h"], - "partition": [0, 1, 2, 3, 4, 5, 6, 6], - } - ) - p2 = pd.DataFrame( + df = pd.DataFrame( { - "A": ["a", "b", "c", "d", "e", "f", "g", "h"], - "partition": [0, 1, 2, 3, 0, 0, 2, 3], + "x": [1, 2, 3, 4, 5], + "_partition": [3, 1, 2, 3, 1], } ) - await asyncio.gather( - ext._add_partition(p1, metadata.id), ext._add_partition(p2, metadata.id) - ) - await ext._barrier(metadata.id) - - for addr, ext in exts.items(): - if metadata.worker_for(0) != addr: - with pytest.raises(AssertionError, match="belongs on"): - ext.get_output_partition(metadata.id, 0) - - full = pd.concat([p1, p2]) - expected_groups = full.groupby("partition") - for output_i in range(metadata.npartitions): - addr = metadata.worker_for(output_i) - ext = exts[addr] - result = ext.get_output_partition(metadata.id, output_i) - try: - expected = expected_groups.get_group(output_i) - except KeyError: - expected = metadata.empty - dd.utils.assert_eq(expected, result) - # ^ NOTE: use `assert_eq` instead of `pd.testing.assert_frame_equal` directly - # to ignore order of the rows (`assert_eq` pre-sorts its inputs). - - # Once all partitions are retrieved, shuffles are cleaned up - for ext in exts.values(): - assert not ext.shuffles - with pytest.raises(ValueError, match="not registered"): - ext.get_output_partition(metadata.id, 0) + t = pa.Table.from_pandas(df) + + out = split_by_partition(t, "_partition") + assert set(out) == {1, 2, 3} + assert out[1].column_names == list(df.columns) + assert sum(map(len, out.values())) == len(df) diff --git a/distributed/worker.py b/distributed/worker.py index 7c5bc61ca15..8de4df24895 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1411,7 +1411,9 @@ async def close( for extension in self.extensions.values(): if hasattr(extension, "close"): - await extension.close() + result = extension.close() + if isawaitable(result): + result = await result if nanny and self.nanny: with self.rpc(self.nanny) as r: